transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,820 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/pe_audio/modular_pe_audio.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_pe_audio.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
8
+ # Copyright 2025 the HuggingFace Inc. team. All rights reserved.
9
+ #
10
+ # Licensed under the Apache License, Version 2.0 (the "License");
11
+ # you may not use this file except in compliance with the License.
12
+ # You may obtain a copy of the License at
13
+ #
14
+ # http://www.apache.org/licenses/LICENSE-2.0
15
+ #
16
+ # Unless required by applicable law or agreed to in writing, software
17
+ # distributed under the License is distributed on an "AS IS" BASIS,
18
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
19
+ # See the License for the specific language governing permissions and
20
+ # limitations under the License.
21
+ import math
22
+ from collections.abc import Callable
23
+ from dataclasses import dataclass
24
+ from typing import Any, Optional
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+
30
+ from ... import initialization as init
31
+ from ...activations import ACT2FN
32
+ from ...cache_utils import Cache
33
+ from ...configuration_utils import PreTrainedConfig
34
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
35
+ from ...modeling_attn_mask_utils import _prepare_4d_attention_mask
36
+ from ...modeling_layers import GradientCheckpointingLayer
37
+ from ...modeling_outputs import BaseModelOutputWithPooling, MaskedLMOutput
38
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
39
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
40
+ from ...processing_utils import Unpack
41
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
42
+ from ...utils.generic import check_model_inputs, maybe_autocast
43
+ from ..auto import AutoModel
44
+ from .configuration_pe_audio import PeAudioConfig, PeAudioEncoderConfig
45
+
46
+
47
+ class Snake1d(nn.Module):
48
+ """
49
+ A 1-dimensional Snake activation function module.
50
+ """
51
+
52
+ def __init__(self, hidden_dim):
53
+ super().__init__()
54
+ self.alpha = nn.Parameter(torch.ones(1, hidden_dim, 1))
55
+
56
+ def forward(self, hidden_states):
57
+ shape = hidden_states.shape
58
+ hidden_states = hidden_states.reshape(shape[0], shape[1], -1)
59
+ hidden_states = hidden_states + (self.alpha + 1e-9).reciprocal() * torch.sin(self.alpha * hidden_states).pow(2)
60
+ hidden_states = hidden_states.reshape(shape)
61
+ return hidden_states
62
+
63
+
64
+ class PeAudioDacResidualUnit(nn.Module):
65
+ """
66
+ A residual unit composed of Snake1d and weight-normalized Conv1d layers with dilations.
67
+ """
68
+
69
+ def __init__(self, dimension: int = 16, dilation: int = 1):
70
+ super().__init__()
71
+ pad = ((7 - 1) * dilation) // 2
72
+
73
+ self.snake1 = Snake1d(dimension)
74
+ self.conv1 = nn.Conv1d(dimension, dimension, kernel_size=7, dilation=dilation, padding=pad)
75
+ self.snake2 = Snake1d(dimension)
76
+ self.conv2 = nn.Conv1d(dimension, dimension, kernel_size=1)
77
+
78
+ def forward(self, hidden_state):
79
+ """
80
+ Forward pass through the residual unit.
81
+
82
+ Args:
83
+ hidden_state (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
84
+ Input tensor .
85
+
86
+ Returns:
87
+ output_tensor (`torch.Tensor` of shape `(batch_size, channels, time_steps)`):
88
+ Input tensor after passing through the residual unit.
89
+ """
90
+ output_tensor = hidden_state
91
+ output_tensor = self.conv1(self.snake1(output_tensor))
92
+ output_tensor = self.conv2(self.snake2(output_tensor))
93
+
94
+ padding = (hidden_state.shape[-1] - output_tensor.shape[-1]) // 2
95
+ if padding > 0:
96
+ hidden_state = hidden_state[..., padding:-padding]
97
+ output_tensor = hidden_state + output_tensor
98
+ return output_tensor
99
+
100
+
101
+ class PeAudioDacEncoderBlock(nn.Module):
102
+ """Encoder block used in PE_AUDIO_DAC encoder."""
103
+
104
+ def __init__(self, config: PreTrainedConfig, stride: int = 1, stride_index: int = 1):
105
+ super().__init__()
106
+
107
+ dimension = config.encoder_hidden_size * 2**stride_index
108
+ self.res_unit1 = PeAudioDacResidualUnit(dimension // 2, dilation=1)
109
+ self.res_unit2 = PeAudioDacResidualUnit(dimension // 2, dilation=3)
110
+ self.res_unit3 = PeAudioDacResidualUnit(dimension // 2, dilation=9)
111
+ self.snake1 = Snake1d(dimension // 2)
112
+ self.conv1 = nn.Conv1d(
113
+ dimension // 2, dimension, kernel_size=2 * stride, stride=stride, padding=math.ceil(stride / 2)
114
+ )
115
+
116
+ def forward(self, hidden_state):
117
+ hidden_state = self.res_unit1(hidden_state)
118
+ hidden_state = self.res_unit2(hidden_state)
119
+ hidden_state = self.snake1(self.res_unit3(hidden_state))
120
+ hidden_state = self.conv1(hidden_state)
121
+
122
+ return hidden_state
123
+
124
+
125
+ class PeAudioDacEncoder(nn.Module):
126
+ """PE_AUDIO_DAC Encoder"""
127
+
128
+ def __init__(self, config: PreTrainedConfig):
129
+ super().__init__()
130
+
131
+ strides = config.downsampling_ratios
132
+ # Create first convolution
133
+ self.conv1 = nn.Conv1d(1, config.encoder_hidden_size, kernel_size=7, padding=3)
134
+
135
+ self.block = []
136
+ # Create EncoderBlocks that double channels as they downsample by `stride`
137
+ for stride_index, stride in enumerate(strides):
138
+ stride_index = stride_index + 1
139
+ self.block += [PeAudioDacEncoderBlock(config, stride=stride, stride_index=stride_index)]
140
+
141
+ self.block = nn.ModuleList(self.block)
142
+ d_model = config.encoder_hidden_size * 2**stride_index
143
+ self.snake1 = Snake1d(d_model)
144
+ self.conv2 = nn.Conv1d(d_model, config.hidden_size, kernel_size=3, padding=1)
145
+
146
+ def forward(self, hidden_state):
147
+ hidden_state = self.conv1(hidden_state)
148
+
149
+ for module in self.block:
150
+ hidden_state = module(hidden_state)
151
+
152
+ hidden_state = self.snake1(hidden_state)
153
+ hidden_state = self.conv2(hidden_state)
154
+
155
+ return hidden_state
156
+
157
+
158
+ class PeAudioEncoderEmbedder(nn.Module):
159
+ def __init__(self, config: PeAudioEncoderConfig):
160
+ super().__init__()
161
+ self.dac_encoder = PeAudioDacEncoder(config.dac_config)
162
+ self.bottleneck = nn.Conv1d(config.dac_config.hidden_size, config.dac_config.codebook_dim, 1)
163
+ self.data_proj = nn.Linear(config.dac_config.codebook_dim, config.hidden_size)
164
+ self.config = config
165
+
166
+ def forward(
167
+ self,
168
+ input_values: torch.Tensor,
169
+ padding_mask: Optional[torch.Tensor] = None,
170
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
171
+ with torch.no_grad(), torch.backends.cudnn.flags(enabled=False):
172
+ hidden_states = self.dac_encoder(input_values)
173
+ hidden_states = self.bottleneck(hidden_states)
174
+
175
+ codec_features = hidden_states.transpose(1, 2)
176
+ inputs_embeds = self.data_proj(codec_features)
177
+
178
+ if padding_mask is not None:
179
+ padding_mask = padding_mask[:, :: self.config.dac_config.hop_length]
180
+
181
+ return inputs_embeds, padding_mask
182
+
183
+
184
+ class PeAudioContrastiveHead(nn.Module):
185
+ def __init__(
186
+ self,
187
+ in_dim: int,
188
+ out_dim: int,
189
+ ) -> None:
190
+ super().__init__()
191
+ self.layer_norm = nn.LayerNorm(normalized_shape=in_dim, eps=1e-6)
192
+ self.proj = nn.Linear(in_dim, out_dim, bias=False)
193
+
194
+ def forward(self, x: torch.Tensor) -> torch.FloatTensor:
195
+ return self.proj(self.layer_norm(x))
196
+
197
+
198
+ class PeAudioMaskedGroupNorm(nn.GroupNorm):
199
+ def forward(self, x, padding_mask=None):
200
+ if padding_mask is None:
201
+ return super().forward(x)
202
+
203
+ batch_size, hidden_size, seq_len = x.shape
204
+ group_size = hidden_size // self.num_groups
205
+ grouped_shape = (batch_size, -1, group_size, seq_len)
206
+
207
+ x_grouped = x.view(grouped_shape)
208
+ padding_mask_grouped = padding_mask.reshape(grouped_shape).bool()
209
+
210
+ mean = torch.masked.mean(x_grouped, mask=padding_mask_grouped, dim=(2, 3), keepdim=True)
211
+ var = torch.masked.var(x_grouped, mask=padding_mask_grouped, dim=(2, 3), keepdim=True, unbiased=False)
212
+
213
+ x_norm = (x_grouped - mean) / torch.sqrt(var + self.eps)
214
+ x_norm = x_norm.view(x.shape)
215
+
216
+ if self.affine:
217
+ x_norm = x_norm * self.weight.view(1, -1, 1) + self.bias.view(1, -1, 1)
218
+
219
+ return x_norm * padding_mask
220
+
221
+
222
+ class PeAudioConvBlock1d(nn.Module):
223
+ def __init__(self, config):
224
+ super().__init__()
225
+ self.groupnorm = PeAudioMaskedGroupNorm(num_groups=1, num_channels=config.hidden_size)
226
+ self.activation = nn.SiLU()
227
+ self.project = nn.Conv1d(
228
+ in_channels=config.hidden_size,
229
+ out_channels=config.hidden_size,
230
+ kernel_size=3,
231
+ padding="same",
232
+ )
233
+
234
+ def forward(self, x, padding_mask=None):
235
+ x = self.groupnorm(x, padding_mask=padding_mask)
236
+ x = self.activation(x)
237
+ return self.project(x)
238
+
239
+
240
+ class PeAudioResnetBlock1d(nn.Module):
241
+ def __init__(self, config):
242
+ super().__init__()
243
+ self.block1 = PeAudioConvBlock1d(config)
244
+ self.block2 = PeAudioConvBlock1d(config)
245
+
246
+ def forward(self, hidden_states, padding_mask=None):
247
+ """
248
+ Args:
249
+ hidden_states: (batch_size, seq_len, hidden_size)
250
+ padding_mask: (batch_size, seq_len)
251
+ Returns:
252
+ hidden_states: (batch_size, seq_len, hidden_size)
253
+ """
254
+ # transpose for convolutions
255
+ # (batch_size, seq_len, hidden_size) -> (batch_size, hidden_size, seq_len)
256
+ hidden_states = hidden_states.transpose(1, 2)
257
+
258
+ if padding_mask is not None:
259
+ padding_mask = padding_mask.unsqueeze(1).expand_as(hidden_states)
260
+
261
+ residual = hidden_states
262
+ hidden_states = self.block1(hidden_states, padding_mask=padding_mask)
263
+ hidden_states = self.block2(hidden_states, padding_mask=padding_mask)
264
+ hidden_states = residual + hidden_states
265
+
266
+ return hidden_states.transpose(1, 2)
267
+
268
+
269
+ class PeAudioEncoderPatchEmbedder(nn.Module):
270
+ def __init__(self, config):
271
+ super().__init__()
272
+ self.resnet_block = PeAudioResnetBlock1d(config)
273
+ self.class_embedding = nn.Parameter(torch.randn(1, 1, config.hidden_size))
274
+
275
+ def forward(self, inputs_embeds, padding_mask=None):
276
+ # Embedding step: prepend class token and run the ResNet block.
277
+ hidden_states = torch.cat(
278
+ [self.class_embedding.expand(inputs_embeds.size(0), -1, -1), inputs_embeds],
279
+ dim=1,
280
+ )
281
+
282
+ if padding_mask is not None:
283
+ # TODO: any reason why we take padding_mask[0] and not just 1?
284
+ padding_mask = torch.cat([padding_mask[:, [0]], padding_mask], dim=1)
285
+
286
+ hidden_states = self.resnet_block(hidden_states, padding_mask=padding_mask)
287
+ return hidden_states, padding_mask
288
+
289
+
290
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
291
+ """
292
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
293
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
294
+ """
295
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
296
+ if n_rep == 1:
297
+ return hidden_states
298
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
299
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
300
+
301
+
302
+ def eager_attention_forward(
303
+ module: nn.Module,
304
+ query: torch.Tensor,
305
+ key: torch.Tensor,
306
+ value: torch.Tensor,
307
+ attention_mask: Optional[torch.Tensor],
308
+ scaling: float,
309
+ dropout: float = 0.0,
310
+ **kwargs: Unpack[TransformersKwargs],
311
+ ):
312
+ key_states = repeat_kv(key, module.num_key_value_groups)
313
+ value_states = repeat_kv(value, module.num_key_value_groups)
314
+
315
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
316
+ if attention_mask is not None:
317
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
318
+ attn_weights = attn_weights + causal_mask
319
+
320
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
321
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
322
+ attn_output = torch.matmul(attn_weights, value_states)
323
+ attn_output = attn_output.transpose(1, 2).contiguous()
324
+
325
+ return attn_output, attn_weights
326
+
327
+
328
+ def stack_freqs(cos: torch.Tensor, sin: torch.Tensor):
329
+ dim = cos.size(-1)
330
+ cos = cos.narrow(-1, 0, dim // 2)
331
+ sin = sin.narrow(-1, 0, dim // 2)
332
+ freqs_cis = torch.stack((cos, -sin, sin, cos), dim=-1).view(*cos.size(), 2, 2)
333
+ return freqs_cis
334
+
335
+
336
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
337
+ freqs_cis = stack_freqs(cos, sin)
338
+ freqs_cis = freqs_cis.unsqueeze(unsqueeze_dim)
339
+ q_ = q.reshape(*q.shape[:-1], -1, 1, 2)
340
+ k_ = k.reshape(*k.shape[:-1], -1, 1, 2)
341
+ return (q_ * freqs_cis).sum(5).flatten(3), (k_ * freqs_cis).sum(5).flatten(3)
342
+
343
+
344
+ @use_kernel_forward_from_hub("RMSNorm")
345
+ class PeAudioEncoderRMSNorm(nn.Module):
346
+ def __init__(self, hidden_size, eps: float = 1e-6) -> None:
347
+ """
348
+ PeAudioEncoderRMSNorm is equivalent to T5LayerNorm
349
+ """
350
+ super().__init__()
351
+ self.weight = nn.Parameter(torch.ones(hidden_size))
352
+ self.variance_epsilon = eps
353
+
354
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
355
+ input_dtype = hidden_states.dtype
356
+ hidden_states = hidden_states.to(torch.float32)
357
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
358
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
359
+ return self.weight * hidden_states.to(input_dtype)
360
+
361
+ def extra_repr(self):
362
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
363
+
364
+
365
+ @use_kernelized_func(apply_rotary_pos_emb)
366
+ class PeAudioEncoderAttention(nn.Module):
367
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
368
+
369
+ def __init__(self, config, layer_idx):
370
+ super().__init__()
371
+ self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
372
+ self.config = config
373
+ self.layer_idx = layer_idx
374
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
375
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
376
+ self.scaling = self.head_dim**-0.5
377
+ self.attention_dropout = config.attention_dropout
378
+ self.is_causal = False
379
+
380
+ self.q_proj = nn.Linear(
381
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
382
+ )
383
+ self.k_proj = nn.Linear(
384
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
385
+ )
386
+ self.v_proj = nn.Linear(
387
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
388
+ )
389
+ self.o_proj = nn.Linear(
390
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
391
+ )
392
+ self.q_norm = PeAudioEncoderRMSNorm(
393
+ self.head_dim, eps=config.rms_norm_eps
394
+ ) # unlike olmo, only on the head dim!
395
+ self.k_norm = PeAudioEncoderRMSNorm(
396
+ self.head_dim, eps=config.rms_norm_eps
397
+ ) # thus post q_norm does not need reshape
398
+
399
+ def forward(
400
+ self,
401
+ hidden_states: torch.Tensor,
402
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
403
+ attention_mask: Optional[torch.Tensor] = None,
404
+ **kwargs: Unpack[TransformersKwargs],
405
+ ) -> tuple[torch.Tensor, torch.Tensor]:
406
+ input_shape = hidden_states.shape[:-1]
407
+ hidden_shape = (*input_shape, -1, self.head_dim)
408
+
409
+ query_states = self.q_norm(self.q_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
410
+ key_states = self.k_norm(self.k_proj(hidden_states).view(hidden_shape)).transpose(1, 2)
411
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
412
+
413
+ cos, sin = position_embeddings
414
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
415
+
416
+ attention_interface: Callable = eager_attention_forward
417
+ if self.config._attn_implementation != "eager":
418
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
419
+
420
+ attn_output, attn_weights = attention_interface(
421
+ self,
422
+ query_states,
423
+ key_states,
424
+ value_states,
425
+ attention_mask,
426
+ dropout=0.0 if not self.training else self.attention_dropout,
427
+ scaling=self.scaling,
428
+ **kwargs,
429
+ )
430
+
431
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
432
+ attn_output = self.o_proj(attn_output)
433
+ return attn_output, attn_weights
434
+
435
+
436
+ class PeAudioEncoderMLP(nn.Module):
437
+ def __init__(self, config):
438
+ super().__init__()
439
+ self.config = config
440
+ self.hidden_size = config.hidden_size
441
+ self.intermediate_size = config.intermediate_size
442
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
443
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
444
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
445
+ self.act_fn = ACT2FN[config.hidden_act]
446
+
447
+ def forward(self, x):
448
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
449
+ return down_proj
450
+
451
+
452
+ class PeAudioEncoderLayer(GradientCheckpointingLayer):
453
+ def __init__(self, config, layer_idx):
454
+ super().__init__()
455
+ self.hidden_size = config.hidden_size
456
+
457
+ self.self_attn = PeAudioEncoderAttention(config=config, layer_idx=layer_idx)
458
+
459
+ self.mlp = PeAudioEncoderMLP(config)
460
+ self.input_layernorm = PeAudioEncoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
461
+ self.post_attention_layernorm = PeAudioEncoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
462
+
463
+ def forward(
464
+ self,
465
+ hidden_states: torch.Tensor,
466
+ attention_mask: Optional[torch.Tensor] = None,
467
+ position_ids: Optional[torch.LongTensor] = None,
468
+ past_key_values: Optional[Cache] = None,
469
+ use_cache: Optional[bool] = False,
470
+ cache_position: Optional[torch.LongTensor] = None,
471
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
472
+ **kwargs: Unpack[TransformersKwargs],
473
+ ) -> torch.Tensor:
474
+ residual = hidden_states
475
+ hidden_states = self.input_layernorm(hidden_states)
476
+ # Self Attention
477
+ hidden_states, _ = self.self_attn(
478
+ hidden_states=hidden_states,
479
+ attention_mask=attention_mask,
480
+ position_ids=position_ids,
481
+ past_key_values=past_key_values,
482
+ use_cache=use_cache,
483
+ cache_position=cache_position,
484
+ position_embeddings=position_embeddings,
485
+ **kwargs,
486
+ )
487
+ hidden_states = residual + hidden_states
488
+
489
+ # Fully Connected
490
+ residual = hidden_states
491
+ hidden_states = self.post_attention_layernorm(hidden_states)
492
+ hidden_states = self.mlp(hidden_states)
493
+ hidden_states = residual + hidden_states
494
+ return hidden_states
495
+
496
+
497
+ @auto_docstring
498
+ class PeAudioPreTrainedModel(PreTrainedModel):
499
+ config: PeAudioConfig
500
+ base_model_prefix = "audio_model"
501
+ supports_gradient_checkpointing = True
502
+ _no_split_modules = ["PeAudioEncoderLayer"]
503
+ _skip_keys_device_placement = ["past_key_values"]
504
+ _supports_flash_attn = True
505
+ _supports_sdpa = True
506
+ _supports_flex_attn = True
507
+
508
+ _can_compile_fullgraph = True
509
+ _supports_attention_backend = True
510
+ _can_record_outputs = {
511
+ "hidden_states": PeAudioEncoderLayer,
512
+ "attentions": PeAudioEncoderAttention,
513
+ }
514
+
515
+ @torch.no_grad()
516
+ def _init_weights(self, module):
517
+ super()._init_weights(module)
518
+
519
+ if hasattr(self.config, "initializer_range"):
520
+ std = self.config.initializer_range
521
+ else:
522
+ # 0.02 is the standard default value across the library
523
+ std = getattr(self.config.get_text_config(), "initializer_range", 0.02)
524
+
525
+ if isinstance(module, PeAudioEncoderPatchEmbedder):
526
+ embed_dim = module.class_embedding.shape[-1]
527
+ init.normal_(module.class_embedding, mean=0.0, std=embed_dim**-0.5 * std)
528
+ if isinstance(module, nn.Conv1d):
529
+ init.trunc_normal_(module.weight, std=0.02)
530
+ init.constant_(module.bias, 0)
531
+ elif isinstance(module, Snake1d):
532
+ init.ones_(module.alpha)
533
+ elif isinstance(module, nn.ConvTranspose1d):
534
+ module.reset_parameters()
535
+ elif isinstance(module, nn.Embedding):
536
+ init.normal_(module.weight, mean=0.0, std=0.02)
537
+
538
+
539
+ @dataclass
540
+ @auto_docstring(
541
+ custom_intro="""
542
+ Class for outputs of [`PeAudioEncoder`].
543
+ """
544
+ )
545
+ class PeAudioEncoderOutput(BaseModelOutputWithPooling):
546
+ codec_features: Optional[torch.FloatTensor] = None
547
+ output_mask: Optional[tuple[torch.FloatTensor]] = None
548
+
549
+
550
+ class PeAudioEncoderRotaryEmbedding(nn.Module):
551
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
552
+
553
+ def __init__(self, config: PeAudioEncoderConfig, device=None):
554
+ super().__init__()
555
+ self.max_seq_len_cached = config.max_position_embeddings
556
+ self.original_max_seq_len = config.max_position_embeddings
557
+
558
+ self.config = config
559
+
560
+ self.rope_type = self.config.rope_parameters["rope_type"]
561
+ rope_init_fn: Callable = self.compute_default_rope_parameters
562
+ if self.rope_type != "default":
563
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
564
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
565
+
566
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
567
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
568
+
569
+ @staticmethod
570
+ def compute_default_rope_parameters(
571
+ config: Optional[PeAudioEncoderConfig] = None,
572
+ device: Optional["torch.device"] = None,
573
+ seq_len: Optional[int] = None,
574
+ ) -> tuple["torch.Tensor", float]:
575
+ """
576
+ Computes the inverse frequencies according to the original RoPE implementation
577
+ Args:
578
+ config ([`~transformers.PreTrainedConfig`]):
579
+ The model configuration.
580
+ device (`torch.device`):
581
+ The device to use for initialization of the inverse frequencies.
582
+ seq_len (`int`, *optional*):
583
+ The current sequence length. Unused for this type of RoPE.
584
+ Returns:
585
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
586
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
587
+ """
588
+ base = config.rope_parameters["rope_theta"]
589
+ dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
590
+
591
+ attention_factor = 1.0 # Unused in this type of RoPE
592
+
593
+ # Compute the inverse frequencies
594
+ inv_freq = 1.0 / (
595
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
596
+ )
597
+ return inv_freq, attention_factor
598
+
599
+ @torch.no_grad()
600
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
601
+ def forward(self, x, position_ids):
602
+ inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1).to(x.device)
603
+ position_ids_expanded = position_ids[:, None, :].float()
604
+
605
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
606
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
607
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
608
+ emb = torch.cat((freqs, freqs), dim=-1)
609
+ cos = emb.cos() * self.attention_scaling
610
+ sin = emb.sin() * self.attention_scaling
611
+
612
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
613
+
614
+
615
+ @auto_docstring(
616
+ custom_intro="""
617
+ The PeAudio Encoder model.
618
+ """
619
+ )
620
+ class PeAudioEncoder(PeAudioPreTrainedModel):
621
+ config: PeAudioEncoderConfig
622
+ main_input_name = "input_values"
623
+ base_model_prefix = "audio_model.audio_encoder"
624
+
625
+ def __init__(self, config: PeAudioEncoderConfig):
626
+ super().__init__(config)
627
+ self.embedder = PeAudioEncoderEmbedder(config)
628
+ self.patch_embedder = PeAudioEncoderPatchEmbedder(config)
629
+ self.layers = nn.ModuleList(
630
+ [PeAudioEncoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
631
+ )
632
+ self.norm = PeAudioEncoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
633
+ self.rotary_emb = PeAudioEncoderRotaryEmbedding(config=config)
634
+ self.output = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
635
+ self.gradient_checkpointing = False
636
+
637
+ self.post_init()
638
+
639
+ @can_return_tuple
640
+ @check_model_inputs
641
+ def forward(
642
+ self,
643
+ input_values: torch.Tensor,
644
+ padding_mask: Optional[torch.Tensor] = None,
645
+ **kwargs,
646
+ ) -> BaseModelOutputWithPooling:
647
+ inputs_embeds, padding_mask = self.embedder(input_values, padding_mask=padding_mask)
648
+ inputs_embeds, attention_mask = self.patch_embedder(inputs_embeds, padding_mask=padding_mask)
649
+
650
+ if attention_mask is not None:
651
+ attention_mask = _prepare_4d_attention_mask(attention_mask, inputs_embeds.dtype)
652
+
653
+ position_ids = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device).unsqueeze(0)
654
+ position_embeddings = self.rotary_emb(inputs_embeds, position_ids)
655
+
656
+ hidden_states = inputs_embeds
657
+ for encoder_layer in self.layers[: self.config.num_hidden_layers]:
658
+ hidden_states = encoder_layer(
659
+ hidden_states,
660
+ attention_mask=attention_mask,
661
+ position_embeddings=position_embeddings,
662
+ **kwargs,
663
+ )
664
+
665
+ hidden_states = self.norm(hidden_states)
666
+ hidden_states = self.output(hidden_states)
667
+
668
+ return PeAudioEncoderOutput(
669
+ last_hidden_state=hidden_states[:, 1:],
670
+ pooler_output=hidden_states[:, 0],
671
+ output_mask=padding_mask,
672
+ )
673
+
674
+
675
+ # TODO: not sure about the typing for text_model_output
676
+ @dataclass
677
+ # @auto_docstring
678
+ class PeAudioOutput(ModelOutput):
679
+ loss: Optional[torch.FloatTensor] = None
680
+ logits_audio_text: Optional[torch.FloatTensor] = None
681
+ text_audio_embeds: Optional[torch.FloatTensor] = None
682
+ audio_embeds: Optional[torch.FloatTensor] = None
683
+ text_outputs: BaseModelOutputWithPooling = None
684
+ audio_outputs: BaseModelOutputWithPooling = None
685
+
686
+ def to_tuple(self) -> tuple[Any]:
687
+ return tuple(
688
+ self[k] if k not in ["text_outputs", "audio_outputs"] else getattr(self, k).to_tuple() for k in self.keys()
689
+ )
690
+
691
+
692
+ class PeAudioModel(PeAudioPreTrainedModel):
693
+ def __init__(self, config: PeAudioConfig):
694
+ super().__init__(config)
695
+ self.text_model = AutoModel.from_config(config.text_config)
696
+ self.audio_encoder = PeAudioEncoder(config.audio_config)
697
+
698
+ self.text_audio_head = PeAudioContrastiveHead(config.text_config.hidden_size, config.text_config.hidden_size)
699
+ self.audio_head = PeAudioContrastiveHead(config.audio_config.hidden_size, config.text_config.hidden_size)
700
+
701
+ self.text_audio_logit_scale = nn.Parameter(torch.zeros(1))
702
+ self.text_audio_logit_bias = nn.Parameter(torch.zeros(1))
703
+
704
+ self.post_init()
705
+
706
+ def get_text_audio_embeds(self, input_ids, attention_mask=None):
707
+ # TODO: naming can be improved here...
708
+ text_outputs: MaskedLMOutput = self.text_model(
709
+ input_ids=input_ids,
710
+ attention_mask=attention_mask,
711
+ return_dict=True,
712
+ )
713
+ text_audio_embeds = text_outputs.hidden_states[-1][:, 0]
714
+ return self.text_audio_head(text_audio_embeds)
715
+
716
+ def get_audio_embeds(self, input_values, padding_mask=None):
717
+ audio_outputs: BaseModelOutputWithPooling = self.audio_encoder(
718
+ input_values=input_values,
719
+ padding_mask=padding_mask,
720
+ return_dict=True,
721
+ )
722
+ audio_embeds = audio_outputs.pooler_output
723
+ return self.audio_head(audio_embeds)
724
+
725
+ @can_return_tuple
726
+ def forward(
727
+ self,
728
+ input_ids: torch.Tensor,
729
+ input_values: torch.Tensor,
730
+ attention_mask: Optional[torch.Tensor] = None,
731
+ padding_mask: Optional[torch.Tensor] = None,
732
+ return_loss: Optional[bool] = None,
733
+ **kwargs,
734
+ ) -> PeAudioOutput:
735
+ audio_outputs: BaseModelOutputWithPooling = self.audio_encoder(
736
+ input_values=input_values, padding_mask=padding_mask, **kwargs
737
+ )
738
+
739
+ kwargs["output_hidden_states"] = True
740
+ text_outputs: MaskedLMOutput = self.text_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
741
+
742
+ audio_embeds = audio_outputs.pooler_output
743
+ audio_embeds = self.audio_head(audio_embeds)
744
+
745
+ text_audio_embeds = text_outputs.hidden_states[-1][:, 0]
746
+ text_audio_embeds = self.text_audio_head(text_audio_embeds)
747
+
748
+ logits_audio_text = audio_embeds @ text_audio_embeds.T
749
+ logits_audio_text = logits_audio_text * self.text_audio_logit_scale + self.text_audio_logit_bias
750
+
751
+ loss = None
752
+ if return_loss:
753
+ labels = torch.eye(logits_audio_text.shape[0], device=logits_audio_text.device)
754
+ loss = -F.logsigmoid(labels * logits_audio_text).sum() / logits_audio_text.shape[0]
755
+
756
+ return PeAudioOutput(
757
+ logits_audio_text=logits_audio_text,
758
+ text_audio_embeds=text_audio_embeds,
759
+ audio_embeds=audio_embeds,
760
+ text_outputs=text_outputs,
761
+ audio_outputs=audio_outputs,
762
+ loss=loss,
763
+ )
764
+
765
+
766
+ # TODO: underline in documentation that logits output shape is
767
+ # 1. Model: (n_audio, n_text)
768
+ # 2. Frame-level: (n_audio, n_text, n_frames)
769
+ class PeAudioFrameLevelModel(PeAudioModel):
770
+ def get_audio_embeds(self, input_values, padding_mask=None):
771
+ audio_outputs: BaseModelOutputWithPooling = self.audio_encoder(
772
+ input_values=input_values,
773
+ padding_mask=padding_mask,
774
+ return_dict=True,
775
+ )
776
+ audio_embeds = audio_outputs.last_hidden_state
777
+ audio_embeds = self.audio_head(audio_embeds)
778
+ return audio_embeds
779
+
780
+ @can_return_tuple
781
+ def forward(
782
+ self,
783
+ input_ids: torch.Tensor,
784
+ input_values: torch.Tensor,
785
+ attention_mask: Optional[torch.Tensor] = None,
786
+ padding_mask: Optional[torch.Tensor] = None,
787
+ return_loss: Optional[bool] = None,
788
+ **kwargs,
789
+ ) -> PeAudioOutput:
790
+ audio_outputs: BaseModelOutputWithPooling = self.audio_encoder(
791
+ input_values=input_values, padding_mask=padding_mask, **kwargs
792
+ )
793
+ kwargs["output_hidden_states"] = True
794
+ text_outputs: MaskedLMOutput = self.text_model(input_ids=input_ids, attention_mask=attention_mask, **kwargs)
795
+
796
+ audio_embeds = audio_outputs.last_hidden_state
797
+ audio_embeds = self.audio_head(audio_embeds)
798
+
799
+ text_audio_embeds = text_outputs.hidden_states[-1][:, 0]
800
+ text_audio_embeds = self.text_audio_head(text_audio_embeds)
801
+
802
+ logits_audio_text = (audio_embeds @ text_audio_embeds.T).transpose(1, 2)
803
+ logits_audio_text = logits_audio_text * self.text_audio_logit_scale + self.text_audio_logit_bias
804
+
805
+ loss = None
806
+ if return_loss:
807
+ labels = torch.eye(logits_audio_text.shape[0], device=logits_audio_text.device)
808
+ loss = -F.logsigmoid(labels * logits_audio_text).sum() / logits_audio_text.shape[0]
809
+
810
+ return PeAudioOutput(
811
+ logits_audio_text=logits_audio_text,
812
+ text_audio_embeds=text_audio_embeds,
813
+ audio_embeds=audio_embeds,
814
+ text_outputs=text_outputs,
815
+ audio_outputs=audio_outputs,
816
+ loss=loss,
817
+ )
818
+
819
+
820
+ __all__ = ["PeAudioFrameLevelModel", "PeAudioModel", "PeAudioEncoder"]