transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -91,6 +91,7 @@ class EvollaSaProtRotaryEmbedding(nn.Module):
91
91
 
92
92
  def __init__(self, dim: int):
93
93
  super().__init__()
94
+ self.dim = dim
94
95
  # Generate and save the inverse frequency buffer (non trainable)
95
96
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
96
97
  self.register_buffer("inv_freq", inv_freq)
@@ -203,12 +204,19 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel):
203
204
  ],
204
205
  }
205
206
 
207
+ def _init_weights(self, module):
208
+ super()._init_weights(module)
209
+ if isinstance(module, EvollaSaProtRotaryEmbedding):
210
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
211
+ init.copy_(module.inv_freq, inv_freq)
212
+
206
213
 
207
214
  class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
208
215
  def __init__(self, config: SaProtConfig):
209
216
  super().__init__(config)
210
217
  self.embeddings = EvollaSaProtEmbeddings(config)
211
218
  self.encoder = EvollaSaProtEncoder(config)
219
+ self.post_init()
212
220
 
213
221
  def get_input_embeddings(self):
214
222
  return self.embeddings.word_embeddings
@@ -86,7 +86,7 @@ class Exaone4RotaryEmbedding(nn.Module):
86
86
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
87
87
 
88
88
  self.register_buffer("inv_freq", inv_freq, persistent=False)
89
- self.original_inv_freq = inv_freq
89
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
90
90
 
91
91
  @staticmethod
92
92
  def compute_default_rope_parameters(
@@ -122,7 +122,7 @@ class FalconRotaryEmbedding(nn.Module):
122
122
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
123
123
 
124
124
  self.register_buffer("inv_freq", inv_freq, persistent=False)
125
- self.original_inv_freq = inv_freq
125
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
126
126
 
127
127
  @staticmethod
128
128
  def compute_default_rope_parameters(
@@ -521,8 +521,8 @@ class FalconFlashAttention2(FalconAttention):
521
521
  else torch.get_autocast_gpu_dtype()
522
522
  )
523
523
  # Handle the case where the model is quantized
524
- elif hasattr(self.config, "_pre_quantization_dtype"):
525
- target_dtype = self.config._pre_quantization_dtype
524
+ elif hasattr(self.config, "quantization_config"):
525
+ target_dtype = self.config.dtype
526
526
  else:
527
527
  target_dtype = self.query_key_value.weight.dtype
528
528
 
@@ -241,7 +241,7 @@ class FalconH1RotaryEmbedding(nn.Module):
241
241
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
242
242
 
243
243
  self.register_buffer("inv_freq", inv_freq, persistent=False)
244
- self.original_inv_freq = inv_freq
244
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
245
245
 
246
246
  @staticmethod
247
247
  def compute_default_rope_parameters(
@@ -1187,26 +1187,6 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer):
1187
1187
  return outputs
1188
1188
 
1189
1189
 
1190
- @auto_docstring
1191
- class FalconH1PreTrainedModel(PreTrainedModel):
1192
- config: FalconH1Config
1193
- base_model_prefix = "model"
1194
- supports_gradient_checkpointing = True
1195
- _no_split_modules = ["FalconH1DecoderLayer"]
1196
- _skip_keys_device_placement = "past_key_values"
1197
- _supports_flash_attn = True
1198
- _supports_sdpa = True
1199
- _is_stateful = True
1200
-
1201
- @torch.no_grad()
1202
- def _init_weights(self, module):
1203
- super()._init_weights(module)
1204
- if isinstance(module, FalconH1Mixer):
1205
- init.ones_(module.dt_bias)
1206
- init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
1207
- init.ones_(module.D)
1208
-
1209
-
1210
1190
  def compute_mup_vector(config):
1211
1191
  """
1212
1192
  Computes the MuP vector based on model configuration.
@@ -1244,6 +1224,30 @@ def compute_mup_vector(config):
1244
1224
  return mup_vector
1245
1225
 
1246
1226
 
1227
+ @auto_docstring
1228
+ class FalconH1PreTrainedModel(PreTrainedModel):
1229
+ config: FalconH1Config
1230
+ base_model_prefix = "model"
1231
+ supports_gradient_checkpointing = True
1232
+ _no_split_modules = ["FalconH1DecoderLayer"]
1233
+ _skip_keys_device_placement = "past_key_values"
1234
+ _supports_flash_attn = True
1235
+ _supports_sdpa = True
1236
+ _is_stateful = True
1237
+
1238
+ @torch.no_grad()
1239
+ def _init_weights(self, module):
1240
+ super()._init_weights(module)
1241
+ if isinstance(module, FalconH1Mixer):
1242
+ init.ones_(module.dt_bias)
1243
+ init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
1244
+ init.ones_(module.D)
1245
+ elif isinstance(module, FalconH1Model):
1246
+ mup_vector = compute_mup_vector(module.config)
1247
+ for layer in module.layers:
1248
+ init.copy_(layer.mamba.mup_vector, mup_vector)
1249
+
1250
+
1247
1251
  @auto_docstring
1248
1252
  # Adapted from transformers.models.jamba.modeling_jamba.JambaModel
1249
1253
  class FalconH1Model(FalconH1PreTrainedModel):
@@ -1269,7 +1273,7 @@ class FalconH1Model(FalconH1PreTrainedModel):
1269
1273
  # Compute the MuP vector once and register it for all layers
1270
1274
  mup_vector = compute_mup_vector(config)
1271
1275
  for layer in self.layers:
1272
- layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False)
1276
+ layer.mamba.register_buffer("mup_vector", mup_vector.clone(), persistent=False)
1273
1277
 
1274
1278
  # Initialize weights and apply final processing
1275
1279
  self.post_init()
@@ -1591,6 +1595,7 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
1591
1595
  cache_position=None,
1592
1596
  position_ids=None,
1593
1597
  use_cache=True,
1598
+ is_first_iteration=False,
1594
1599
  **kwargs,
1595
1600
  ):
1596
1601
  # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
@@ -1628,7 +1633,7 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
1628
1633
  position_ids = position_ids[:, -input_ids.shape[1] :]
1629
1634
 
1630
1635
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1631
- if inputs_embeds is not None and empty_past_kv:
1636
+ if inputs_embeds is not None and is_first_iteration:
1632
1637
  model_inputs = {"inputs_embeds": inputs_embeds}
1633
1638
  else:
1634
1639
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -928,6 +928,10 @@ class FalconH1PreTrainedModel(PreTrainedModel):
928
928
  init.ones_(module.dt_bias)
929
929
  init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
930
930
  init.ones_(module.D)
931
+ elif isinstance(module, FalconH1Model):
932
+ mup_vector = compute_mup_vector(module.config)
933
+ for layer in module.layers:
934
+ init.copy_(layer.mamba.mup_vector, mup_vector)
931
935
 
932
936
 
933
937
  def compute_mup_vector(config):
@@ -992,7 +996,7 @@ class FalconH1Model(FalconH1PreTrainedModel):
992
996
  # Compute the MuP vector once and register it for all layers
993
997
  mup_vector = compute_mup_vector(config)
994
998
  for layer in self.layers:
995
- layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False)
999
+ layer.mamba.register_buffer("mup_vector", mup_vector.clone(), persistent=False)
996
1000
 
997
1001
  # Initialize weights and apply final processing
998
1002
  self.post_init()
@@ -1298,6 +1302,7 @@ class FalconH1ForCausalLM(LlamaForCausalLM):
1298
1302
  cache_position=None,
1299
1303
  position_ids=None,
1300
1304
  use_cache=True,
1305
+ is_first_iteration=False,
1301
1306
  **kwargs,
1302
1307
  ):
1303
1308
  # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
@@ -1335,7 +1340,7 @@ class FalconH1ForCausalLM(LlamaForCausalLM):
1335
1340
  position_ids = position_ids[:, -input_ids.shape[1] :]
1336
1341
 
1337
1342
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1338
- if inputs_embeds is not None and empty_past_kv:
1343
+ if inputs_embeds is not None and is_first_iteration:
1339
1344
  model_inputs = {"inputs_embeds": inputs_embeds}
1340
1345
  else:
1341
1346
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -31,7 +31,7 @@ from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
32
  from ...configuration_utils import PreTrainedConfig
33
33
  from ...generation import GenerationMixin
34
- from ...integrations.hub_kernels import lazy_load_kernel
34
+ from ...integrations import lazy_load_kernel
35
35
  from ...modeling_layers import GradientCheckpointingLayer
36
36
  from ...modeling_utils import PreTrainedModel
37
37
  from ...utils import ModelOutput, auto_docstring, logging
@@ -345,7 +345,7 @@ class FalconMambaMixer(nn.Module):
345
345
 
346
346
  # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
347
347
  # at the price of a small overhead.
348
- if hasattr(self.config, "_pre_quantization_dtype"):
348
+ if hasattr(self.config, "quantization_config"):
349
349
  discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
350
350
  else:
351
351
  discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
@@ -613,6 +613,9 @@ class FalconMambaPreTrainedModel(PreTrainedModel):
613
613
  init.ones_(module.weight)
614
614
  elif isinstance(module, nn.Embedding):
615
615
  init.normal_(module.weight, std=std)
616
+ if isinstance(module, FalconMambaMixer):
617
+ init.ones_(module.b_c_rms)
618
+ init.ones_(module.dt_rms)
616
619
 
617
620
 
618
621
  @dataclass
@@ -811,6 +814,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
811
814
  cache_params: Optional[FalconMambaCache] = None,
812
815
  cache_position: Optional[torch.LongTensor] = None,
813
816
  attention_mask: Optional[torch.LongTensor] = None,
817
+ is_first_iteration: Optional[bool] = False,
814
818
  **kwargs,
815
819
  ):
816
820
  # Overwritten -- uses `cache_params` as opposed to `past_key_values`
@@ -19,6 +19,7 @@ from typing import Optional
19
19
  import torch
20
20
  from torch import nn
21
21
 
22
+ from ... import initialization as init
22
23
  from ...utils import auto_docstring, logging
23
24
  from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling
24
25
  from ..mamba.configuration_mamba import MambaConfig
@@ -357,7 +358,7 @@ class FalconMambaMixer(MambaMixer):
357
358
 
358
359
  # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
359
360
  # at the price of a small overhead.
360
- if hasattr(self.config, "_pre_quantization_dtype"):
361
+ if hasattr(self.config, "quantization_config"):
361
362
  discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
362
363
  else:
363
364
  discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
@@ -529,7 +530,11 @@ class FalconMambaBlock(MambaBlock):
529
530
 
530
531
  @auto_docstring
531
532
  class FalconMambaPreTrainedModel(MambaPreTrainedModel):
532
- pass
533
+ def _init_weights(self, module):
534
+ super()._init_weights(module)
535
+ if isinstance(module, FalconMambaMixer):
536
+ init.ones_(module.b_c_rms)
537
+ init.ones_(module.dt_rms)
533
538
 
534
539
 
535
540
  class FalconMambaOutput(MambaOutput):
@@ -430,6 +430,7 @@ class FastVlmForConditionalGeneration(FastVlmPreTrainedModel, GenerationMixin):
430
430
  attention_mask=None,
431
431
  cache_position=None,
432
432
  logits_to_keep=None,
433
+ is_first_iteration=False,
433
434
  **kwargs,
434
435
  ):
435
436
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -441,12 +442,15 @@ class FastVlmForConditionalGeneration(FastVlmPreTrainedModel, GenerationMixin):
441
442
  attention_mask=attention_mask,
442
443
  cache_position=cache_position,
443
444
  logits_to_keep=logits_to_keep,
445
+ is_first_iteration=is_first_iteration,
444
446
  **kwargs,
445
447
  )
446
448
 
447
- if cache_position[0] == 0:
448
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
449
- # Otherwise we need pixel values to be passed to model
449
+ if is_first_iteration or not kwargs.get("use_cache", True):
450
+ # Pixel values are used only in the first iteration if available
451
+ # In subsquent iterations, they are already merged with text and cached
452
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
453
+ # iteration with a question and cached system prompt (continue generate from cache)
450
454
  model_inputs["pixel_values"] = pixel_values
451
455
 
452
456
  return model_inputs
@@ -727,19 +727,20 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
727
727
  self.embed_dim = config.hidden_size
728
728
  self.input_scale = math.sqrt(self.embed_dim)
729
729
  self.dropout = nn.Dropout(p=module_config["positional_dropout_rate"])
730
- self.pos_enc = None
731
730
  self.max_len = 5000
732
- self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len))
731
+ self.register_buffer(
732
+ "pos_enc", self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len)), persistent=False
733
+ )
733
734
 
734
- def extend_pos_enc(self, x):
735
+ def extend_pos_enc(self, x, pos_enc=None):
735
736
  """Reset the positional encodings."""
736
- if self.pos_enc is not None:
737
+ if pos_enc is not None:
737
738
  # self.pos_enc contains both positive and negative parts
738
739
  # the length of self.pos_enc is 2 * input_len - 1
739
- if self.pos_enc.size(1) >= x.size(1) * 2 - 1:
740
- if self.pos_enc.dtype != x.dtype or self.pos_enc.device != x.device:
741
- self.pos_enc = self.pos_enc.to(dtype=x.dtype, device=x.device)
742
- return
740
+ if pos_enc.size(1) >= x.size(1) * 2 - 1:
741
+ if pos_enc.dtype != x.dtype or pos_enc.device != x.device:
742
+ pos_enc = pos_enc.to(dtype=x.dtype, device=x.device)
743
+ return pos_enc
743
744
  # Suppose `i` means to the position of query vector and `j` means the
744
745
  # position of key vector. We use position relative positions when keys
745
746
  # are to the left (i>j) and negative relative positions otherwise (i<j).
@@ -760,7 +761,7 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
760
761
  pos_enc_positive = torch.flip(pos_enc_positive, [0]).unsqueeze(0)
761
762
  pos_enc_negative = pos_enc_negative[1:].unsqueeze(0)
762
763
  pos_enc = torch.cat([pos_enc_positive, pos_enc_negative], dim=1)
763
- self.pos_enc = pos_enc.to(device=x.device, dtype=x.dtype)
764
+ return pos_enc.to(device=x.device, dtype=x.dtype)
764
765
 
765
766
  def forward(self, feature_representation):
766
767
  """
@@ -771,7 +772,7 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
771
772
  Returns:
772
773
  `torch.Tensor`: Encoded tensor (batch_size, time, `*`).
773
774
  """
774
- self.extend_pos_enc(feature_representation)
775
+ self.pos_enc = self.extend_pos_enc(feature_representation, self.pos_enc)
775
776
  hidden_states = feature_representation * self.input_scale
776
777
  center_idx = self.pos_enc.size(1) // 2
777
778
  pos_emb = self.pos_enc[:, center_idx - hidden_states.size(1) + 1 : center_idx + hidden_states.size(1)]
@@ -1010,6 +1011,10 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
1010
1011
  elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
1011
1012
  init.zeros_(module.bias)
1012
1013
  init.ones_(module.weight)
1014
+ if getattr(module, "running_mean", None) is not None:
1015
+ init.zeros_(module.running_mean)
1016
+ init.ones_(module.running_var)
1017
+ init.zeros_(module.num_batches_tracked)
1013
1018
  elif isinstance(module, nn.Embedding):
1014
1019
  init.normal_(module.weight)
1015
1020
  # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
@@ -1018,6 +1023,8 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
1018
1023
  elif isinstance(module, FastSpeech2ConformerAttention):
1019
1024
  init.xavier_uniform_(module.pos_bias_u)
1020
1025
  init.xavier_uniform_(module.pos_bias_v)
1026
+ elif isinstance(module, FastSpeech2ConformerRelPositionalEncoding):
1027
+ init.copy_(module.pos_enc, module.extend_pos_enc(torch.tensor(0.0).expand(1, module.max_len)))
1021
1028
 
1022
1029
  def _set_gradient_checkpointing(self, module, value=False):
1023
1030
  if isinstance(module, FastSpeech2ConformerEncoder):
@@ -1410,6 +1417,12 @@ class FastSpeech2ConformerHifiGan(PreTrainedModel):
1410
1417
  # Initialize weights and apply final processing
1411
1418
  self.post_init()
1412
1419
 
1420
+ def _init_weights(self, module):
1421
+ super()._init_weights(module)
1422
+ if isinstance(module, FastSpeech2ConformerHifiGan):
1423
+ init.zeros_(module.mean)
1424
+ init.ones_(module.scale)
1425
+
1413
1426
  def apply_weight_norm(self):
1414
1427
  weight_norm = nn.utils.weight_norm
1415
1428
  if hasattr(nn.utils.parametrizations, "weight_norm"):
@@ -79,6 +79,7 @@ class FastSpeech2ConformerTokenizer(PreTrainedTokenizer):
79
79
  unk_token=unk_token,
80
80
  pad_token=pad_token,
81
81
  should_strip_spaces=should_strip_spaces,
82
+ special_tokens_pattern="none",
82
83
  **kwargs,
83
84
  )
84
85
 
@@ -660,9 +660,6 @@ class FlaubertPreTrainedModel(PreTrainedModel):
660
660
  config: FlaubertConfig
661
661
  base_model_prefix = "transformer"
662
662
 
663
- def __init__(self, *inputs, **kwargs):
664
- super().__init__(*inputs, **kwargs)
665
-
666
663
  @property
667
664
  def dummy_inputs(self):
668
665
  inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
@@ -690,15 +687,17 @@ class FlaubertPreTrainedModel(PreTrainedModel):
690
687
  if isinstance(module, nn.LayerNorm):
691
688
  init.zeros_(module.bias)
692
689
  init.ones_(module.weight)
693
- if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings:
694
- init.copy_(
695
- module.position_embeddings.weight,
696
- create_sinusoidal_embeddings(
697
- self.config.max_position_embeddings,
698
- self.config.emb_dim,
699
- out=torch.empty_like(module.position_embeddings.weight),
700
- ),
701
- )
690
+ if isinstance(module, FlaubertModel):
691
+ if self.config.sinusoidal_embeddings:
692
+ init.copy_(
693
+ module.position_embeddings.weight,
694
+ create_sinusoidal_embeddings(
695
+ self.config.max_position_embeddings,
696
+ self.config.emb_dim,
697
+ out=torch.empty_like(module.position_embeddings.weight),
698
+ ),
699
+ )
700
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
702
701
 
703
702
 
704
703
  @auto_docstring
@@ -760,15 +759,15 @@ class FlaubertModel(FlaubertPreTrainedModel):
760
759
  self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
761
760
  self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
762
761
 
763
- # Initialize weights and apply final processing
764
- self.post_init()
765
-
766
762
  self.layerdrop = getattr(config, "layerdrop", 0.0)
767
763
  self.pre_norm = getattr(config, "pre_norm", False)
768
764
  self.register_buffer(
769
765
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
770
766
  )
771
767
 
768
+ # Initialize weights and apply final processing
769
+ self.post_init()
770
+
772
771
  # Copied from transformers.models.xlm.modeling_xlm.XLMModel.get_input_embeddings
773
772
  def get_input_embeddings(self):
774
773
  return self.embeddings
@@ -306,7 +306,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
306
306
  processed_images_grouped[shape] = stacked_images
307
307
 
308
308
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
309
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
310
309
 
311
310
  return processed_images
312
311
 
@@ -397,7 +396,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
397
396
  mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
398
397
  )
399
398
  masks = [mask_generator() for _ in range(len(images))]
400
- masks = torch.stack(masks, dim=0) if return_tensors else masks
401
399
  data["bool_masked_pos"] = masks
402
400
 
403
401
  return BatchFeature(data=data, tensor_type=return_tensors)
@@ -677,6 +677,9 @@ class FlavaPreTrainedModel(PreTrainedModel):
677
677
  init.zeros_(module.position_embeddings)
678
678
  if module.mask_token is not None:
679
679
  init.zeros_(module.mask_token)
680
+ elif isinstance(module, FlavaTextEmbeddings):
681
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
682
+ init.zeros_(module.token_type_ids)
680
683
  elif isinstance(module, FlavaMultimodalModel):
681
684
  if module.use_cls_token:
682
685
  init.zeros_(module.cls_token)
@@ -1107,7 +1110,7 @@ class FlavaModel(FlavaPreTrainedModel):
1107
1110
  output_hidden_states: bool = True,
1108
1111
  return_dict: Optional[bool] = None,
1109
1112
  **kwargs,
1110
- ) -> Union[tuple, FlavaOutput]:
1113
+ ) -> Union[tuple, FlavaModelOutput]:
1111
1114
  r"""
1112
1115
  input_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`):
1113
1116
  Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
@@ -30,14 +30,14 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
33
+ from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GradientCheckpointingLayer
36
36
  from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
37
37
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
40
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
41
41
  from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
42
42
  from .configuration_flex_olmo import FlexOlmoConfig
43
43
 
@@ -80,7 +80,7 @@ class FlexOlmoRotaryEmbedding(nn.Module):
80
80
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
81
81
 
82
82
  self.register_buffer("inv_freq", inv_freq, persistent=False)
83
- self.original_inv_freq = inv_freq
83
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
84
84
 
85
85
  @staticmethod
86
86
  def compute_default_rope_parameters(
@@ -293,6 +293,7 @@ class FlexOlmoAttention(nn.Module):
293
293
  return attn_output, attn_weights
294
294
 
295
295
 
296
+ @use_experts_implementation
296
297
  class FlexOlmoExperts(nn.Module):
297
298
  """Collection of expert weights stored as 3D tensors."""
298
299
 
@@ -421,7 +422,9 @@ class FlexOlmoPreTrainedModel(PreTrainedModel):
421
422
  _supports_flash_attn = True
422
423
  _supports_sdpa = True
423
424
  _supports_flex_attn = True
424
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
425
+ _can_compile_fullgraph = (
426
+ is_grouped_mm_available()
427
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
425
428
  _supports_attention_backend = True
426
429
  _can_record_outputs = {
427
430
  "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),
@@ -26,6 +26,7 @@ from typing import Any, Optional, Union
26
26
  import torch.nn as nn
27
27
  import torch.nn.functional as F
28
28
 
29
+ from ... import initialization as init
29
30
  from ...activations import ACT2FN
30
31
  from ...cache_utils import Cache
31
32
  from ...generation import GenerationMixin
@@ -629,6 +630,18 @@ class Florence2PreTrainedModel(PreTrainedModel):
629
630
  _supports_attention_backend = False
630
631
  config_class = Florence2Config
631
632
 
633
+ def _init_weights(self, module):
634
+ super()._init_weights(module)
635
+ if isinstance(module, Florence2VisionPositionalEmbeddingCosine1D):
636
+ pos_idx_to_embed = torch.empty((module.max_seq_len, module.embed_dim))
637
+ sine, cosine = module.get_sinusoid_embeddings(
638
+ max_positions=module.max_seq_len,
639
+ embed_dim=module.embed_dim,
640
+ )
641
+ pos_idx_to_embed[:, 0::2] = sine
642
+ pos_idx_to_embed[:, 1::2] = cosine
643
+ init.copy_(module.pos_idx_to_embed, pos_idx_to_embed)
644
+
632
645
 
633
646
  @auto_docstring(
634
647
  custom_intro="""
@@ -937,6 +950,7 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
937
950
  attention_mask=None,
938
951
  cache_position=None,
939
952
  logits_to_keep=None,
953
+ is_first_iteration=False,
940
954
  **kwargs,
941
955
  ):
942
956
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -948,12 +962,15 @@ class Florence2ForConditionalGeneration(Florence2PreTrainedModel, GenerationMixi
948
962
  attention_mask=attention_mask,
949
963
  cache_position=cache_position,
950
964
  logits_to_keep=logits_to_keep,
965
+ is_first_iteration=is_first_iteration,
951
966
  **kwargs,
952
967
  )
953
968
 
954
- if cache_position[0] == 0:
955
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
956
- # Otherwise we need pixel values to be passed to model
969
+ if is_first_iteration or not kwargs.get("use_cache", True):
970
+ # Pixel values are used only in the first iteration if available
971
+ # In subsquent iterations, they are already merged with text and cached
972
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
973
+ # iteration with a question and cached system prompt (continue generate from cache)
957
974
  model_inputs["pixel_values"] = pixel_values
958
975
 
959
976
  return model_inputs
@@ -22,6 +22,7 @@ import numpy as np
22
22
  import torch.nn as nn
23
23
  import torch.nn.functional as F
24
24
 
25
+ from ... import initialization as init
25
26
  from ...activations import ACT2FN
26
27
  from ...cache_utils import Cache
27
28
  from ...configuration_utils import PreTrainedConfig
@@ -1500,6 +1501,18 @@ class Florence2PreTrainedModel(LlavaPreTrainedModel):
1500
1501
 
1501
1502
  _supports_attention_backend = False
1502
1503
 
1504
+ def _init_weights(self, module):
1505
+ PreTrainedModel._init_weights(self, module)
1506
+ if isinstance(module, Florence2VisionPositionalEmbeddingCosine1D):
1507
+ pos_idx_to_embed = torch.empty((module.max_seq_len, module.embed_dim))
1508
+ sine, cosine = module.get_sinusoid_embeddings(
1509
+ max_positions=module.max_seq_len,
1510
+ embed_dim=module.embed_dim,
1511
+ )
1512
+ pos_idx_to_embed[:, 0::2] = sine
1513
+ pos_idx_to_embed[:, 1::2] = cosine
1514
+ init.copy_(module.pos_idx_to_embed, pos_idx_to_embed)
1515
+
1503
1516
 
1504
1517
  @auto_docstring(
1505
1518
  custom_intro="""
@@ -23,6 +23,7 @@ import torch
23
23
  from torch import nn
24
24
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
25
25
 
26
+ from ... import initialization as init
26
27
  from ...utils import auto_docstring, is_scipy_available
27
28
 
28
29
 
@@ -374,6 +375,12 @@ class FNetPreTrainedModel(PreTrainedModel):
374
375
  base_model_prefix = "fnet"
375
376
  supports_gradient_checkpointing = True
376
377
 
378
+ def _init_weights(self, module):
379
+ super()._init_weights(module)
380
+ if isinstance(module, FNetEmbeddings):
381
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
382
+ init.zeros_(module.token_type_ids)
383
+
377
384
 
378
385
  @dataclass
379
386
  @auto_docstring(
@@ -94,7 +94,7 @@ class FuyuBatchFeature(BatchFeature):
94
94
  The outputs dictionary from the processors contains a mix of tensors and lists of tensors.
95
95
  """
96
96
 
97
- def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None):
97
+ def convert_to_tensors(self, tensor_type: Optional[Union[str, TensorType]] = None, **kwargs):
98
98
  """
99
99
  Convert the inner content to tensors.
100
100
 
@@ -359,6 +359,7 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
359
359
  image_patches=None,
360
360
  image_patches_indices=None,
361
361
  cache_position=None,
362
+ is_first_iteration=False,
362
363
  **kwargs,
363
364
  ):
364
365
  # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
@@ -371,10 +372,11 @@ class FuyuForCausalLM(FuyuPreTrainedModel, GenerationMixin):
371
372
  image_patches=image_patches,
372
373
  image_patches_indices=image_patches_indices,
373
374
  cache_position=cache_position,
375
+ is_first_iteration=is_first_iteration,
374
376
  **kwargs,
375
377
  )
376
378
 
377
- if cache_position[0] != 0:
379
+ if not is_first_iteration and kwargs.get("use_cache", True):
378
380
  # set image_patches and image_patches_indices to `None` for decoding stage
379
381
  model_inputs["image_patches_indices"] = None
380
382
  model_inputs["image_patches"] = None