transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -23,6 +23,7 @@ import torch
23
23
  from torch import nn
24
24
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
25
 
26
+ from ... import initialization as init
26
27
  from ...activations import ACT2FN
27
28
  from ...modeling_layers import GradientCheckpointingLayer
28
29
  from ...modeling_outputs import (
@@ -719,6 +720,11 @@ class CaninePreTrainedModel(PreTrainedModel):
719
720
  base_model_prefix = "canine"
720
721
  supports_gradient_checkpointing = True
721
722
 
723
+ def _init_weights(self, module):
724
+ super()._init_weights(module)
725
+ if isinstance(module, CanineEmbeddings):
726
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
727
+
722
728
 
723
729
  @auto_docstring
724
730
  class CanineModel(CaninePreTrainedModel):
@@ -67,6 +67,8 @@ class CanineTokenizer(PreTrainedTokenizer):
67
67
  The maximum sentence length the model accepts.
68
68
  """
69
69
 
70
+ model_input_names = ["input_ids", "attention_mask", "token_type_ids"]
71
+
70
72
  def __init__(
71
73
  self,
72
74
  bos_token=chr(CLS),
@@ -84,7 +84,7 @@ class ChameleonRotaryEmbedding(nn.Module):
84
84
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
85
85
 
86
86
  self.register_buffer("inv_freq", inv_freq, persistent=False)
87
- self.original_inv_freq = inv_freq
87
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
88
88
 
89
89
  @staticmethod
90
90
  def compute_default_rope_parameters(
@@ -809,6 +809,7 @@ class ChameleonVQVAE(ChameleonPreTrainedModel):
809
809
  self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
810
810
  self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
811
811
  self.eval() # Chameleon's VQ model is frozen
812
+ self.post_init()
812
813
 
813
814
  def encode(self, pixel_values: torch.LongTensor):
814
815
  hidden_states = self.encoder(pixel_values)
@@ -1122,6 +1123,7 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
1122
1123
  cache_position=None,
1123
1124
  position_ids=None,
1124
1125
  use_cache=True,
1126
+ is_first_iteration=False,
1125
1127
  **kwargs,
1126
1128
  ):
1127
1129
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -1135,12 +1137,15 @@ class ChameleonForConditionalGeneration(ChameleonPreTrainedModel, GenerationMixi
1135
1137
  cache_position=cache_position,
1136
1138
  position_ids=position_ids,
1137
1139
  use_cache=use_cache,
1140
+ is_first_iteration=is_first_iteration,
1138
1141
  **kwargs,
1139
1142
  )
1140
1143
 
1141
- if cache_position[0] != 0:
1142
- # If we're in cached decoding stage, pixel values should be `None` because input ids do not contain special image token anymore
1143
- # Otherwise we need pixel values to be passed to model
1144
+ if not is_first_iteration and use_cache:
1145
+ # Pixel values are used only in the first iteration if available
1146
+ # In subsquent iterations, they are already merged with text and cached
1147
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
1148
+ # iteration with a question and cached system prompt (continue generate from cache)
1144
1149
  model_inputs["pixel_values"] = None
1145
1150
 
1146
1151
  return model_inputs
@@ -572,10 +572,13 @@ class ChineseCLIPPreTrainedModel(PreTrainedModel):
572
572
  init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
573
573
  init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
574
574
  init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
575
+ init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
575
576
  elif isinstance(module, ChineseCLIPTextEmbeddings):
576
577
  init.normal_(module.word_embeddings.weight, mean=0.0, std=self.config.initializer_range)
577
578
  init.normal_(module.position_embeddings.weight, mean=0.0, std=self.config.initializer_range)
578
579
  init.normal_(module.token_type_embeddings.weight, mean=0.0, std=self.config.initializer_range)
580
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
581
+ init.zeros_(module.token_type_ids)
579
582
  for embedding in [module.word_embeddings, module.position_embeddings, module.token_type_embeddings]:
580
583
  if embedding.padding_idx is not None:
581
584
  init.zeros_(embedding.weight[embedding.padding_idx])
@@ -638,9 +641,9 @@ class ChineseCLIPTextEncoder(nn.Module):
638
641
  all_hidden_states = all_hidden_states + (hidden_states,)
639
642
 
640
643
  layer_outputs = layer_module(
641
- hidden_states=hidden_states,
642
- attention_mask=attention_mask,
643
- output_attentions=output_attentions,
644
+ hidden_states,
645
+ attention_mask,
646
+ output_attentions,
644
647
  **kwargs,
645
648
  )
646
649
 
@@ -71,7 +71,7 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
71
71
  Truncation pattern for long audio inputs. Two patterns are available:
72
72
  - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and a
73
73
  downsampled version of the entire mel spectrogram.
74
- If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a copy
74
+ If `config.fusion` is set to True, shorter audios also need to return 4 mels, which will just be a copy
75
75
  of the original mel obtained from the padded audio.
76
76
  - `rand_trunc` will select a random crop of the mel spectrogram.
77
77
  padding (`str`, *optional*, defaults to `"repeatpad"`):
@@ -279,7 +279,7 @@ class ClapFeatureExtractor(SequenceFeatureExtractor):
279
279
  Truncation pattern for long audio inputs. Two patterns are available:
280
280
  - `fusion` will use `_random_mel_fusion`, which stacks 3 random crops from the mel spectrogram and
281
281
  a downsampled version of the entire mel spectrogram.
282
- If `config.fusion` is set to True, shorter audios also need to to return 4 mels, which will just be a
282
+ If `config.fusion` is set to True, shorter audios also need to return 4 mels, which will just be a
283
283
  copy of the original mel obtained from the padded audio.
284
284
  - `rand_trunc` will select a random crop of the mel spectrogram.
285
285
  padding (`str`, *optional*):
@@ -365,18 +365,7 @@ class ClapAudioSelfAttention(nn.Module):
365
365
  torch.zeros((2 * self.window_size[0] - 1) * (2 * self.window_size[1] - 1), num_heads)
366
366
  )
367
367
 
368
- # get pair-wise relative position index for each token inside the window
369
- coords_h = torch.arange(self.window_size[0])
370
- coords_w = torch.arange(self.window_size[1])
371
- coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
372
- coords_flatten = torch.flatten(coords, 1)
373
- relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
374
- relative_coords = relative_coords.permute(1, 2, 0).contiguous()
375
- relative_coords[:, :, 0] += self.window_size[0] - 1
376
- relative_coords[:, :, 1] += self.window_size[1] - 1
377
- relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
378
- relative_position_index = relative_coords.sum(-1)
379
- self.register_buffer("relative_position_index", relative_position_index)
368
+ self.register_buffer("relative_position_index", self.create_relative_position_index())
380
369
 
381
370
  self.query = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
382
371
  self.key = nn.Linear(self.all_head_size, self.all_head_size, bias=config.qkv_bias)
@@ -435,6 +424,20 @@ class ClapAudioSelfAttention(nn.Module):
435
424
 
436
425
  return outputs
437
426
 
427
+ def create_relative_position_index(self):
428
+ # get pair-wise relative position index for each token inside the window
429
+ coords_h = torch.arange(self.window_size[0])
430
+ coords_w = torch.arange(self.window_size[1])
431
+ coords = torch.stack(meshgrid([coords_h, coords_w], indexing="ij"))
432
+ coords_flatten = torch.flatten(coords, 1)
433
+ relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]
434
+ relative_coords = relative_coords.permute(1, 2, 0).contiguous()
435
+ relative_coords[:, :, 0] += self.window_size[0] - 1
436
+ relative_coords[:, :, 1] += self.window_size[1] - 1
437
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
438
+ relative_position_index = relative_coords.sum(-1)
439
+ return relative_position_index
440
+
438
441
 
439
442
  # Copied from transformers.models.swin.modeling_swin.SwinSelfOutput with Swin->ClapAudio
440
443
  class ClapAudioSelfOutput(nn.Module):
@@ -1266,9 +1269,9 @@ class ClapTextEncoder(nn.Module):
1266
1269
  all_hidden_states = all_hidden_states + (hidden_states,)
1267
1270
 
1268
1271
  layer_outputs = layer_module(
1269
- hidden_states=hidden_states,
1270
- attention_mask=attention_mask,
1271
- output_attentions=output_attentions,
1272
+ hidden_states,
1273
+ attention_mask,
1274
+ output_attentions,
1272
1275
  **kwargs,
1273
1276
  )
1274
1277
 
@@ -1317,6 +1320,8 @@ class ClapPreTrainedModel(PreTrainedModel):
1317
1320
  if isinstance(module, ClapTextEmbeddings):
1318
1321
  init.normal_(module.position_embeddings.weight, mean=0.0, std=factor * 0.02)
1319
1322
  init.normal_(module.token_type_embeddings.weight, mean=0.0, std=factor * 0.02)
1323
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
1324
+ init.zeros_(module.token_type_ids)
1320
1325
  elif isinstance(module, ClapModel):
1321
1326
  init.constant_(module.logit_scale_a, math.log(self.config.logit_scale_init_value))
1322
1327
  init.constant_(module.logit_scale_t, math.log(self.config.logit_scale_init_value))
@@ -1325,6 +1330,10 @@ class ClapPreTrainedModel(PreTrainedModel):
1325
1330
  elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d)):
1326
1331
  init.zeros_(module.bias)
1327
1332
  init.ones_(module.weight)
1333
+ if getattr(module, "running_mean", None) is not None:
1334
+ init.zeros_(module.running_mean)
1335
+ init.ones_(module.running_var)
1336
+ init.zeros_(module.num_batches_tracked)
1328
1337
  elif isinstance(module, (nn.Conv2d, nn.Linear)):
1329
1338
  in_proj_std = (self.config.hidden_size**-0.5) * ((2 * self.config.num_hidden_layers) ** -0.5) * factor
1330
1339
  init.normal_(module.weight, std=in_proj_std)
@@ -1332,6 +1341,7 @@ class ClapPreTrainedModel(PreTrainedModel):
1332
1341
  init.zeros_(module.bias)
1333
1342
  elif isinstance(module, ClapAudioSelfAttention):
1334
1343
  init.zeros_(module.relative_position_bias_table)
1344
+ init.copy_(module.relative_position_index, module.create_relative_position_index())
1335
1345
 
1336
1346
 
1337
1347
  class ClapAudioModel(ClapPreTrainedModel):
@@ -416,11 +416,13 @@ class CLIPPreTrainedModel(PreTrainedModel):
416
416
  if isinstance(module, CLIPTextEmbeddings):
417
417
  init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
418
418
  init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
419
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
419
420
  elif isinstance(module, CLIPVisionEmbeddings):
420
421
  factor = self.config.initializer_factor
421
422
  init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
422
423
  init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
423
424
  init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
425
+ init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
424
426
  elif isinstance(module, CLIPAttention):
425
427
  factor = self.config.initializer_factor
426
428
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -435,11 +435,13 @@ class CLIPSegPreTrainedModel(PreTrainedModel):
435
435
  if isinstance(module, CLIPSegTextEmbeddings):
436
436
  init.normal_(module.token_embedding.weight, mean=0.0, std=factor * 0.02)
437
437
  init.normal_(module.position_embedding.weight, mean=0.0, std=factor * 0.02)
438
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
438
439
  elif isinstance(module, CLIPSegVisionEmbeddings):
439
440
  factor = self.config.initializer_factor
440
441
  init.normal_(module.class_embedding, mean=0.0, std=module.embed_dim**-0.5 * factor)
441
442
  init.normal_(module.patch_embedding.weight, std=module.config.initializer_range * factor)
442
443
  init.normal_(module.position_embedding.weight, std=module.config.initializer_range * factor)
444
+ init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
443
445
  elif isinstance(module, CLIPSegAttention):
444
446
  factor = self.config.initializer_factor
445
447
  in_proj_std = (module.embed_dim**-0.5) * ((2 * module.config.num_hidden_layers) ** -0.5) * factor
@@ -1121,6 +1123,8 @@ class CLIPSegDecoder(CLIPSegPreTrainedModel):
1121
1123
  decoder_config.hidden_act = "relu"
1122
1124
  self.layers = nn.ModuleList([CLIPSegDecoderLayer(decoder_config) for _ in range(len(config.extract_layers))])
1123
1125
 
1126
+ self.post_init()
1127
+
1124
1128
  def forward(
1125
1129
  self,
1126
1130
  hidden_states: tuple[torch.Tensor],
@@ -238,7 +238,7 @@ class ClvpRMSNorm(nn.Module):
238
238
  class ClvpRotaryPositionalEmbedding(nn.Module):
239
239
  """
240
240
  Rotary Position Embedding Class for CLVP. It was proposed in the paper 'ROFORMER: ENHANCED TRANSFORMER WITH ROTARY
241
- POSITION EMBEDDING', Please see https://huggingface.co/papers/2104.09864v1.pdf .
241
+ POSITION EMBEDDING', Please see https://huggingface.co/papers/2104.09864.
242
242
  """
243
243
 
244
244
  def __init__(self, config):
@@ -814,7 +814,16 @@ class ClvpPreTrainedModel(PreTrainedModel):
814
814
  )
815
815
  elif isinstance(module, ClvpModelForConditionalGeneration):
816
816
  init.constant_(module.logit_scale, self.config.logit_scale_init_value)
817
-
817
+ elif isinstance(module, ClvpSelfAttention):
818
+ if hasattr(module.config, "max_position_embeddings"):
819
+ max_positions = module.config.max_position_embeddings
820
+ bias = torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool))
821
+ bias = bias.view(1, 1, max_positions, max_positions)
822
+ init.copy_(module.bias, bias)
823
+ elif isinstance(module, ClvpRotaryPositionalEmbedding):
824
+ dim = max(self.config.projection_dim // (self.config.num_attention_heads * 2), 32)
825
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
826
+ init.copy_(module.inv_freq, inv_freq)
818
827
  if isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
819
828
  init.zeros_(module.bias)
820
829
  init.ones_(module.weight)
@@ -1309,6 +1318,7 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
1309
1318
  inputs_embeds=None,
1310
1319
  conditioning_embeds=None,
1311
1320
  cache_position=None,
1321
+ is_first_iteration=False,
1312
1322
  **kwargs,
1313
1323
  ):
1314
1324
  # Overwritten: has `conditioning_embeds`-related logic
@@ -1320,9 +1330,10 @@ class ClvpForCausalLM(ClvpPreTrainedModel, GenerationMixin):
1320
1330
  past_key_values=past_key_values,
1321
1331
  inputs_embeds=inputs_embeds,
1322
1332
  cache_position=cache_position,
1333
+ is_first_iteration=is_first_iteration,
1323
1334
  **kwargs,
1324
1335
  )
1325
- if conditioning_embeds is not None and cache_position[0] != 0:
1336
+ if conditioning_embeds is not None and not is_first_iteration:
1326
1337
  model_inputs["position_ids"] = torch.tensor([input_ids_length], dtype=torch.long, device=input_ids.device)
1327
1338
 
1328
1339
  return model_inputs
@@ -158,7 +158,7 @@ class CodeLlamaTokenizer(TokenizersBackend):
158
158
  unk_token=str(unk_token),
159
159
  )
160
160
  )
161
- prepend_scheme = "first" if self.add_prefix_space else "none"
161
+ prepend_scheme = "first" if self.add_prefix_space else "never"
162
162
  self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(
163
163
  replacement="▁", prepend_scheme=prepend_scheme, split=False
164
164
  )
@@ -14,11 +14,13 @@
14
14
  # limitations under the License.
15
15
  """PyTorch CodeGen model."""
16
16
 
17
+ import math
17
18
  from typing import Optional, Union
18
19
 
19
20
  import torch
20
21
  from torch import nn
21
22
 
23
+ from ... import initialization as init
22
24
  from ...activations import ACT2FN
23
25
  from ...cache_utils import Cache, DynamicCache
24
26
  from ...generation import GenerationMixin
@@ -69,7 +71,7 @@ class CodeGenAttention(nn.Module):
69
71
  def __init__(self, config, layer_idx=None):
70
72
  super().__init__()
71
73
 
72
- max_positions = config.max_position_embeddings
74
+ self.max_positions = config.max_position_embeddings
73
75
  self.attn_dropout = nn.Dropout(config.attn_pdrop)
74
76
  self.resid_dropout = nn.Dropout(config.resid_pdrop)
75
77
  self.layer_idx = layer_idx
@@ -88,13 +90,15 @@ class CodeGenAttention(nn.Module):
88
90
  f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and"
89
91
  f" `num_attention_heads`: {self.num_attention_heads})."
90
92
  )
91
- self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32)).to(torch.get_default_dtype())
93
+ self.scale_attn = math.sqrt(self.head_dim)
92
94
  self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)
93
95
 
94
96
  self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
95
97
  self.rotary_dim = config.rotary_dim
96
- pos_embd_dim = self.rotary_dim or self.embed_dim
97
- self.embed_positions = create_sinusoidal_positions(max_positions, pos_embd_dim)
98
+ self.pos_embd_dim = self.rotary_dim or self.embed_dim
99
+ self.register_buffer(
100
+ "embed_positions", create_sinusoidal_positions(self.max_positions, self.pos_embd_dim), persistent=False
101
+ )
98
102
 
99
103
  def _split_heads(self, x, n_head, dim_head, mp_num):
100
104
  reshaped = x.reshape(x.shape[:-1] + (n_head // mp_num, dim_head))
@@ -279,6 +283,11 @@ class CodeGenPreTrainedModel(PreTrainedModel):
279
283
  _skip_keys_device_placement = "past_key_values"
280
284
  _can_compile_fullgraph = True
281
285
 
286
+ def _init_weights(self, module):
287
+ super()._init_weights(module)
288
+ if isinstance(module, CodeGenAttention):
289
+ init.copy_(module.embed_positions, create_sinusoidal_positions(module.max_positions, module.pos_embd_dim))
290
+
282
291
 
283
292
  @auto_docstring
284
293
  class CodeGenModel(CodeGenPreTrainedModel):
@@ -83,7 +83,7 @@ class CohereRotaryEmbedding(nn.Module):
83
83
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
84
84
 
85
85
  self.register_buffer("inv_freq", inv_freq, persistent=False)
86
- self.original_inv_freq = inv_freq
86
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
87
87
 
88
88
  @staticmethod
89
89
  def compute_default_rope_parameters(
@@ -57,7 +57,7 @@ class Cohere2RotaryEmbedding(nn.Module):
57
57
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
58
58
 
59
59
  self.register_buffer("inv_freq", inv_freq, persistent=False)
60
- self.original_inv_freq = inv_freq
60
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
61
61
 
62
62
  @staticmethod
63
63
  def compute_default_rope_parameters(
@@ -263,7 +263,6 @@ class Cohere2VisionImageProcessorFast(BaseImageProcessorFast):
263
263
  processed_images_grouped[shape] = stacked_images
264
264
 
265
265
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
266
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
267
266
 
268
267
  return BatchFeature(
269
268
  data={"pixel_values": processed_images, "num_patches": num_patches}, tensor_type=return_tensors
@@ -376,6 +376,7 @@ class Cohere2VisionForConditionalGeneration(Cohere2VisionPreTrainedModel, Genera
376
376
  attention_mask=None,
377
377
  cache_position=None,
378
378
  logits_to_keep=None,
379
+ is_first_iteration=False,
379
380
  **kwargs,
380
381
  ):
381
382
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -387,12 +388,15 @@ class Cohere2VisionForConditionalGeneration(Cohere2VisionPreTrainedModel, Genera
387
388
  attention_mask=attention_mask,
388
389
  cache_position=cache_position,
389
390
  logits_to_keep=logits_to_keep,
391
+ is_first_iteration=is_first_iteration,
390
392
  **kwargs,
391
393
  )
392
394
 
393
- if cache_position[0] == 0:
394
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
395
- # Otherwise we need pixel values to be passed to model
395
+ if is_first_iteration or not kwargs.get("use_cache", True):
396
+ # Pixel values are used only in the first iteration if available
397
+ # In subsquent iterations, they are already merged with text and cached
398
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
399
+ # iteration with a question and cached system prompt (continue generate from cache)
396
400
  model_inputs["pixel_values"] = pixel_values
397
401
 
398
402
  return model_inputs
@@ -37,7 +37,7 @@ class ConditionalDetrConfig(PreTrainedConfig):
37
37
  use_timm_backbone (`bool`, *optional*, defaults to `True`):
38
38
  Whether or not to use the `timm` library for the backbone. If set to `False`, will use the [`AutoBackbone`]
39
39
  API.
40
- backbone_config (`PreTrainedConfig` or `dict`, *optional*):
40
+ backbone_config (`Union[dict, "PreTrainedConfig"]`, *optional*, defaults to `ResNetConfig()`):
41
41
  The configuration of the backbone model. Only used in case `use_timm_backbone` is set to `False` in which
42
42
  case it will default to `ResNetConfig()`.
43
43
  num_channels (`int`, *optional*, defaults to 3):
@@ -984,7 +984,7 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
984
984
  elif isinstance(module, ConditionalDetrLearnedPositionEmbedding):
985
985
  init.uniform_(module.row_embeddings.weight)
986
986
  init.uniform_(module.column_embeddings.weight)
987
- if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
987
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
988
988
  init.normal_(module.weight, mean=0.0, std=std)
989
989
  if module.bias is not None:
990
990
  init.zeros_(module.bias)
@@ -993,6 +993,9 @@ class ConditionalDetrPreTrainedModel(PreTrainedModel):
993
993
  # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
994
994
  if module.padding_idx is not None and not getattr(module.weight, "_is_hf_initialized", False):
995
995
  init.zeros_(module.weight[module.padding_idx])
996
+ elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
997
+ init.ones_(module.weight)
998
+ init.zeros_(module.bias)
996
999
 
997
1000
 
998
1001
  # Copied from transformers.models.detr.modeling_detr.DetrEncoder with Detr->ConditionalDetr,DETR->ConditionalDETR
@@ -118,6 +118,9 @@ class ConvBertPreTrainedModel(PreTrainedModel):
118
118
  elif isinstance(module, GroupedLinearLayer):
119
119
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
120
120
  init.zeros_(module.bias)
121
+ elif isinstance(module, ConvBertEmbeddings):
122
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
123
+ init.zeros_(module.token_type_ids)
121
124
 
122
125
 
123
126
  class SeparableConv1D(nn.Module):
@@ -78,7 +78,7 @@ class ConvNextImageProcessor(BaseImageProcessor):
78
78
  crop_pct (`float` *optional*, defaults to 224 / 256):
79
79
  Percentage of the image to crop. Only has an effect if `do_resize` is `True` and size < 384. Can be
80
80
  overridden by `crop_pct` in the `preprocess` method.
81
- resample (`PILImageResampling`, *optional*, defaults to `Resampling.BILINEAR`):
81
+ resample (`PILImageResampling`, *optional*, defaults to `Resampling.BICUBIC`):
82
82
  Resampling filter to use if resizing the image. Can be overridden by `resample` in the `preprocess` method.
83
83
  do_rescale (`bool`, *optional*, defaults to `True`):
84
84
  Whether to rescale the image by the specified scale `rescale_factor`. Can be overridden by `do_rescale` in
@@ -105,7 +105,7 @@ class ConvNextImageProcessor(BaseImageProcessor):
105
105
  do_resize: bool = True,
106
106
  size: Optional[dict[str, int]] = None,
107
107
  crop_pct: Optional[float] = None,
108
- resample: PILImageResampling = PILImageResampling.BILINEAR,
108
+ resample: PILImageResampling = PILImageResampling.BICUBIC,
109
109
  do_rescale: bool = True,
110
110
  rescale_factor: Union[int, float] = 1 / 255,
111
111
  do_normalize: bool = True,
@@ -20,11 +20,7 @@ import torch
20
20
  from torchvision.transforms.v2 import functional as F
21
21
 
22
22
  from ...image_processing_utils import BatchFeature
23
- from ...image_processing_utils_fast import (
24
- BaseImageProcessorFast,
25
- group_images_by_shape,
26
- reorder_images,
27
- )
23
+ from ...image_processing_utils_fast import BaseImageProcessorFast, group_images_by_shape, reorder_images
28
24
  from ...image_transforms import get_resize_output_image_size
29
25
  from ...image_utils import (
30
26
  IMAGENET_STANDARD_MEAN,
@@ -32,6 +28,7 @@ from ...image_utils import (
32
28
  ChannelDimension,
33
29
  ImageInput,
34
30
  PILImageResampling,
31
+ SizeDict,
35
32
  )
36
33
  from ...processing_utils import Unpack
37
34
  from ...utils import (
@@ -43,7 +40,7 @@ from .image_processing_convnext import ConvNextImageProcessorKwargs
43
40
 
44
41
  @auto_docstring
45
42
  class ConvNextImageProcessorFast(BaseImageProcessorFast):
46
- resample = PILImageResampling.BILINEAR
43
+ resample = PILImageResampling.BICUBIC
47
44
  image_mean = IMAGENET_STANDARD_MEAN
48
45
  image_std = IMAGENET_STANDARD_STD
49
46
  size = {"shortest_edge": 384}
@@ -98,23 +95,23 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
98
95
  resize_size = get_resize_output_image_size(
99
96
  image, size=resize_shortest_edge, default_to_square=False, input_data_format=ChannelDimension.FIRST
100
97
  )
101
- image = F.resize(
98
+ image = super().resize(
102
99
  image,
103
- resize_size,
100
+ SizeDict(height=resize_size[0], width=resize_size[1]),
104
101
  interpolation=interpolation,
105
102
  **kwargs,
106
103
  )
107
104
  # then crop to (shortest_edge, shortest_edge)
108
- return F.center_crop(
105
+ return self.center_crop(
109
106
  image,
110
- (shortest_edge, shortest_edge),
107
+ SizeDict(height=shortest_edge, width=shortest_edge),
111
108
  **kwargs,
112
109
  )
113
110
  else:
114
111
  # warping (no cropping) when evaluated at 384 or larger
115
- return F.resize(
112
+ return super().resize(
116
113
  image,
117
- (shortest_edge, shortest_edge),
114
+ SizeDict(height=shortest_edge, width=shortest_edge),
118
115
  interpolation=interpolation,
119
116
  **kwargs,
120
117
  )
@@ -162,7 +159,6 @@ class ConvNextImageProcessorFast(BaseImageProcessorFast):
162
159
  processed_images_grouped[shape] = stacked_images
163
160
 
164
161
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
165
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
166
162
 
167
163
  return BatchFeature(data={"pixel_values": processed_images}, tensor_type=return_tensors)
168
164
 
@@ -89,7 +89,7 @@ class CsmGenerationMixin(GenerationMixin):
89
89
  return kept_criteria
90
90
 
91
91
  def _prepare_generation_config(
92
- self, generation_config: Optional[GenerationConfig], use_model_defaults: Optional[bool] = None, **kwargs: Any
92
+ self, generation_config: Optional[GenerationConfig], **kwargs: Any
93
93
  ) -> tuple[GenerationConfig, dict]:
94
94
  """
95
95
  This method overrides [~generation.utils.GenerationMixin._prepare_generation_config].
@@ -104,9 +104,7 @@ class CsmGenerationMixin(GenerationMixin):
104
104
  kwargs = {k: v for k, v in kwargs.items() if not k.startswith("depth_decoder_")}
105
105
 
106
106
  # initialize the generation config
107
- generation_config, model_kwargs = super()._prepare_generation_config(
108
- generation_config, use_model_defaults, **kwargs
109
- )
107
+ generation_config, model_kwargs = super()._prepare_generation_config(generation_config, **kwargs)
110
108
  self.depth_decoder.generation_config.update(**depth_decoder_kwargs)
111
109
 
112
110
  # ensure the depth decoder generation config is valid
@@ -209,26 +207,25 @@ class CsmGenerationMixin(GenerationMixin):
209
207
  else self.__call__
210
208
  )
211
209
 
212
- is_prefill = True
213
- while self._has_unfinished_sequences(
214
- this_peer_finished,
215
- synced_gpus,
216
- device=input_ids.device,
217
- ):
218
- # prepare model inputs
219
- model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
220
-
221
- # prepare variable output controls (note: some models won't accept all output controls)
222
- model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
223
- # *************** Csm specific ***************
224
- model_inputs.update({"output_hidden_states": True})
225
- # ============================================
210
+ # *************** Csm specific ***************
211
+ model_kwargs.update({"output_hidden_states": True})
226
212
 
227
- if is_prefill:
228
- outputs = self(**model_inputs, return_dict=True)
229
- is_prefill = False
230
- else:
213
+ # Assisted generation completes the prefill stage in candidate generator so that
214
+ # we don't have several `prefill` calls in one generation loop. Skip `_prefill` for assistants
215
+ if not generation_config.is_assistant:
216
+ outputs = self._prefill(input_ids, generation_config, model_kwargs)
217
+ prefill_consumed = False
218
+ else:
219
+ model_kwargs = self._get_initial_cache_position(input_ids.shape[1], input_ids.device, model_kwargs)
220
+ prefill_consumed = True
221
+
222
+ while self._has_unfinished_sequences(this_peer_finished, synced_gpus, device=input_ids.device):
223
+ if prefill_consumed:
224
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
225
+ # prepare variable output controls (note: some models won't accept all output controls)
226
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
231
227
  outputs = model_forward(**model_inputs, return_dict=True)
228
+ prefill_consumed = True
232
229
 
233
230
  # synced_gpus: don't waste resources running the code we don't need; kwargs must be updated before skipping
234
231
  model_kwargs = self._update_model_kwargs_for_generation(
@@ -136,7 +136,7 @@ class CsmRotaryEmbedding(nn.Module):
136
136
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
137
137
 
138
138
  self.register_buffer("inv_freq", inv_freq, persistent=False)
139
- self.original_inv_freq = inv_freq
139
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
140
140
 
141
141
  @staticmethod
142
142
  def compute_default_rope_parameters(
@@ -421,6 +421,8 @@ class CsmPreTrainedModel(PreTrainedModel):
421
421
  num_codebooks = module.num_codebooks
422
422
  for i in range(num_codebooks - 1):
423
423
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
424
+ elif isinstance(module, CsmBackboneModelEmbeddings):
425
+ init.copy_(module.audio_tokens_offsets, torch.arange(self.config.num_codebooks) * self.config.vocab_size)
424
426
 
425
427
 
426
428
  @auto_docstring
@@ -149,6 +149,8 @@ class CsmPreTrainedModel(PreTrainedModel):
149
149
  num_codebooks = module.num_codebooks
150
150
  for i in range(num_codebooks - 1):
151
151
  init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
152
+ elif isinstance(module, CsmBackboneModelEmbeddings):
153
+ init.copy_(module.audio_tokens_offsets, torch.arange(self.config.num_codebooks) * self.config.vocab_size)
152
154
 
153
155
 
154
156
  @auto_docstring