transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,273 @@
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ...activations import ACT2FN
21
+ from ...configuration_utils import PreTrainedConfig
22
+ from ...utils import auto_docstring
23
+ from ..auto import CONFIG_MAPPING
24
+ from ..llava.configuration_llava import LlavaConfig
25
+ from ..llava.modeling_llava import (
26
+ LlavaForConditionalGeneration,
27
+ LlavaModel,
28
+ LlavaMultiModalProjector,
29
+ LlavaPreTrainedModel,
30
+ )
31
+
32
+
33
+ class FastVlmConfig(LlavaConfig):
34
+ r"""
35
+ This is the configuration class to store the configuration of a [`FastVlmForConditionalGeneration`]. It is used to instantiate a
36
+ FastVLM model according to the specified arguments, defining the model architecture. Instantiating a configuration
37
+ with the defaults will yield the same configuration as the one of FastVLM-7B.
38
+
39
+ e.g. [KamilaMila/FastVLM-7B](https://huggingface.co/KamilaMila/FastVLM-7B)
40
+
41
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
42
+ documentation from [`PretrainedConfig`] for more information.
43
+
44
+ Args:
45
+ vision_config (`Union[AutoConfig, dict]`, *optional*, defaults to `TimmWrapperConfig` for `fastvit_mci3`):
46
+ The config object or dictionary of the vision backbone.
47
+ text_config (`Union[AutoConfig, dict]`, *optional*, defaults to `Qwen2Config`):
48
+ The config object or dictionary of the text backbone.
49
+ image_token_id (`int`, *optional*, defaults to 151646):
50
+ The image token index to encode the image prompt.
51
+ projector_hidden_act (`str`, *optional*, defaults to `"gelu"`):
52
+ The activation function used by the multimodal projector.
53
+ vision_feature_select_strategy (`str`, *optional*, defaults to `"full"`):
54
+ The feature selection strategy used to select the vision feature from the vision backbone.
55
+ Only "full" supported.
56
+ vision_feature_layer (`Union[int, list[int]]`, *optional*, defaults to -1):
57
+ The index of the layer to select the vision feature. If multiple indices are provided,
58
+ the vision feature of the corresponding indices will be concatenated to form the
59
+ vision features. Only -1 supported.
60
+ multimodal_projector_bias (`bool`, *optional*, defaults to `True`):
61
+ Whether to use bias in the multimodal projector.
62
+
63
+ Example:
64
+
65
+ ```python
66
+ >>> from transformers import FastVlmForConditionalGeneration, FastVlmConfig
67
+
68
+ >>> # Initializing a FastVLM-7B style configuration
69
+ >>> configuration = FastVlmConfig()
70
+
71
+ >>> # Initializing a model from the FastVLM-7B style configuration
72
+ >>> model = FastVlmForConditionalGeneration(configuration)
73
+
74
+ >>> # Accessing the model configuration
75
+ >>> configuration = model.config
76
+ ```"""
77
+
78
+ model_type = "fast_vlm"
79
+
80
+ def __init__(
81
+ self,
82
+ vision_config=None,
83
+ text_config=None,
84
+ image_token_id=151646,
85
+ projector_hidden_act="gelu",
86
+ vision_feature_select_strategy="full",
87
+ vision_feature_layer=-1,
88
+ multimodal_projector_bias=True,
89
+ **kwargs,
90
+ ):
91
+ self.image_token_id = image_token_id
92
+ self.projector_hidden_act = projector_hidden_act
93
+
94
+ if vision_feature_select_strategy != "full":
95
+ raise ValueError(
96
+ f"Unexpected select feature strategy: {vision_feature_select_strategy}. Only 'full' is supported in FastVLM."
97
+ )
98
+
99
+ if vision_feature_layer != -1:
100
+ raise ValueError(
101
+ f"Unexpected vision feature layer: {vision_feature_layer}. Only -1 is supported in FastVLM."
102
+ )
103
+
104
+ self.vision_feature_select_strategy = vision_feature_select_strategy
105
+ self.vision_feature_layer = vision_feature_layer
106
+
107
+ if isinstance(vision_config, dict):
108
+ vision_config["model_type"] = vision_config.get("model_type", "timm_wrapper")
109
+ vision_config = CONFIG_MAPPING[vision_config["model_type"]](**vision_config)
110
+ elif vision_config is None:
111
+ vision_config = CONFIG_MAPPING["timm_wrapper"](
112
+ architecture="fastvit_mci3",
113
+ do_pooling=True,
114
+ global_pool="avg",
115
+ hidden_size=3072,
116
+ initializer_range=0.02,
117
+ model_args={"inference_mode": True},
118
+ )
119
+
120
+ self.vision_config = vision_config
121
+
122
+ if isinstance(text_config, dict):
123
+ text_config["model_type"] = text_config.get("model_type", "qwen2")
124
+ text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
125
+ elif text_config is None:
126
+ text_config = CONFIG_MAPPING["qwen2"](
127
+ hidden_size=3584,
128
+ vocab_size=152128,
129
+ intermediate_size=18944,
130
+ num_attention_heads=28,
131
+ num_key_value_heads=4,
132
+ num_hidden_layers=28,
133
+ )
134
+
135
+ self.text_config = text_config
136
+ self.multimodal_projector_bias = multimodal_projector_bias
137
+
138
+ PreTrainedConfig.__init__(**kwargs)
139
+
140
+
141
+ class FastVlmMultiModalProjector(LlavaMultiModalProjector):
142
+ def __init__(self, config: FastVlmConfig):
143
+ nn.Module.__init__()
144
+ self.linear_1 = nn.Linear(
145
+ config.vision_config.hidden_size,
146
+ config.text_config.hidden_size,
147
+ bias=config.multimodal_projector_bias,
148
+ )
149
+ self.act = ACT2FN[config.projector_hidden_act]
150
+ self.linear_2 = nn.Linear(
151
+ config.text_config.hidden_size, config.text_config.hidden_size, bias=config.multimodal_projector_bias
152
+ )
153
+
154
+
155
+ class FastVlmPreTrainedModel(LlavaPreTrainedModel):
156
+ pass
157
+
158
+
159
+ class FastVlmModel(LlavaModel):
160
+ _checkpoint_conversion_mapping = {}
161
+
162
+ def __init__(self, config: FastVlmConfig):
163
+ super().__init__(config)
164
+
165
+ def get_image_features(
166
+ self,
167
+ pixel_values: torch.FloatTensor,
168
+ vision_feature_layer: Optional[Union[int, list[int]]] = None,
169
+ vision_feature_select_strategy: Optional[str] = None,
170
+ **kwargs,
171
+ ):
172
+ """
173
+ Obtains image last hidden states from the vision tower and apply multimodal projection.
174
+
175
+ Args:
176
+ pixel_values (`torch.FloatTensor]` of shape `(batch_size, channels, height, width)`):
177
+ The tensors corresponding to the input images.
178
+ vision_feature_layer (`Union[int, list[int]]`, *optional*):
179
+ The index/indices of the layer to select the vision feature. Only -1 supported.
180
+ vision_feature_select_strategy (`str`, *optional*):
181
+ The feature selection strategy used to select the vision feature from the vision backbone.
182
+ Only "full" supported.
183
+ Returns:
184
+ image_features (`torch.Tensor`): Image feature tensor of shape `(num_images, image_length, embed_dim)`).
185
+ """
186
+ vision_feature_layer = (
187
+ vision_feature_layer if vision_feature_layer is not None else self.config.vision_feature_layer
188
+ )
189
+ vision_feature_select_strategy = (
190
+ vision_feature_select_strategy
191
+ if vision_feature_select_strategy is not None
192
+ else self.config.vision_feature_select_strategy
193
+ )
194
+
195
+ kwargs = {k: v for k, v in kwargs.items() if v is not None}
196
+ image_outputs = self.vision_tower(pixel_values, **kwargs)
197
+
198
+ # since the vision tower is hybrid in FastVLM, its output needs to be handled differently from Llava
199
+ selected_image_feature = image_outputs.last_hidden_state
200
+ selected_image_feature = selected_image_feature.flatten(2).permute(0, 2, 1)
201
+ image_features = self.multi_modal_projector(selected_image_feature)
202
+ image_features = list(image_features)
203
+ return image_features
204
+
205
+ def forward(self, **super_kwargs):
206
+ r"""
207
+ vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*):
208
+ The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the
209
+ corresponding indices will be concatenated to form the vision features. Only -1 supported.
210
+ vision_feature_select_strategy (`str`, *optional*):
211
+ The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported.
212
+ """
213
+ super().forward(**super_kwargs)
214
+
215
+
216
+ @auto_docstring(
217
+ custom_intro="""
218
+ The FastVlm model which consists of a vision backbone and a language model.
219
+ """
220
+ )
221
+ class FastVlmForConditionalGeneration(LlavaForConditionalGeneration):
222
+ _checkpoint_conversion_mapping = {}
223
+
224
+ def forward(self, **super_kwargs):
225
+ r"""
226
+ vision_feature_layer (`Union[int, list[int], NoneType]`, *optional*):
227
+ The index of the layer to select the vision feature. If multiple indices are provided, the vision feature of the
228
+ corresponding indices will be concatenated to form the vision features. Only -1 supported.
229
+ vision_feature_select_strategy (`str`, *optional*):
230
+ The feature selection strategy used to select the vision feature from the vision backbone. Only "full" supported.
231
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
232
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
233
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
234
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
235
+
236
+ Example:
237
+
238
+ ```python
239
+ >>> from PIL import Image
240
+ >>> import requests
241
+ >>> from transformers import AutoProcessor, AutoModelForImageTextToText
242
+ >>> import torch
243
+
244
+ >>> device = "cuda" if torch.cuda.is_available() else "cpu"
245
+
246
+ >>> model = AutoModelForImageTextToText.from_pretrained("KamilaMila/FastVLM-0.5B").to(device)
247
+ >>> processor = AutoProcessor.from_pretrained("KamilaMila/FastVLM-0.5B")
248
+
249
+ >>> conversation = [
250
+ {
251
+ "role": "user",
252
+ "content": [
253
+ {"type": "text", "text": "What are these?"},
254
+ {"type": "image"}
255
+ ]
256
+ }
257
+ ]
258
+
259
+ >>> prompt = processor.apply_chat_template(conversation, add_generation_prompt=True)
260
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
261
+ >>> image = Image.open(requests.get(url, stream=True).raw)
262
+
263
+ >>> inputs = processor(images=image, text=prompt, return_tensors="pt").to(device)
264
+
265
+ >>> # Generate
266
+ >>> generated_ids = model.generate(**inputs, max_new_tokens=15)
267
+ >>> print(processor.batch_decode(generated_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0])
268
+ system\n You are a helpful assistant.\n user\n What are these?\n assistant\n The image depicts a traditional Chinese street...
269
+ ```"""
270
+ super().forward(**super_kwargs)
271
+
272
+
273
+ __all__ = ["FastVlmForConditionalGeneration", "FastVlmModel", "FastVlmPreTrainedModel", "FastVlmConfig"]
@@ -514,7 +514,7 @@ class FastSpeech2ConformerConvolutionModule(nn.Module):
514
514
 
515
515
  Args:
516
516
  hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
517
- attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
517
+ attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
518
518
 
519
519
  Returns:
520
520
  `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
@@ -530,7 +530,10 @@ class FastSpeech2ConformerConvolutionModule(nn.Module):
530
530
 
531
531
  # Apply padding mask before convolution
532
532
  if attention_mask is not None:
533
- all_masked_rows = torch.all(~attention_mask, dim=-1)
533
+ if attention_mask.dtype == torch.bool:
534
+ all_masked_rows = torch.all(~attention_mask, dim=2)
535
+ else:
536
+ all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
534
537
  hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
535
538
 
536
539
  # 1D Depthwise Conv
@@ -724,19 +727,20 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
724
727
  self.embed_dim = config.hidden_size
725
728
  self.input_scale = math.sqrt(self.embed_dim)
726
729
  self.dropout = nn.Dropout(p=module_config["positional_dropout_rate"])
727
- self.pos_enc = None
728
730
  self.max_len = 5000
729
- self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len))
731
+ self.register_buffer(
732
+ "pos_enc", self.extend_pos_enc(torch.tensor(0.0).expand(1, self.max_len)), persistent=False
733
+ )
730
734
 
731
- def extend_pos_enc(self, x):
735
+ def extend_pos_enc(self, x, pos_enc=None):
732
736
  """Reset the positional encodings."""
733
- if self.pos_enc is not None:
737
+ if pos_enc is not None:
734
738
  # self.pos_enc contains both positive and negative parts
735
739
  # the length of self.pos_enc is 2 * input_len - 1
736
- if self.pos_enc.size(1) >= x.size(1) * 2 - 1:
737
- if self.pos_enc.dtype != x.dtype or self.pos_enc.device != x.device:
738
- self.pos_enc = self.pos_enc.to(dtype=x.dtype, device=x.device)
739
- return
740
+ if pos_enc.size(1) >= x.size(1) * 2 - 1:
741
+ if pos_enc.dtype != x.dtype or pos_enc.device != x.device:
742
+ pos_enc = pos_enc.to(dtype=x.dtype, device=x.device)
743
+ return pos_enc
740
744
  # Suppose `i` means to the position of query vector and `j` means the
741
745
  # position of key vector. We use position relative positions when keys
742
746
  # are to the left (i>j) and negative relative positions otherwise (i<j).
@@ -757,7 +761,7 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
757
761
  pos_enc_positive = torch.flip(pos_enc_positive, [0]).unsqueeze(0)
758
762
  pos_enc_negative = pos_enc_negative[1:].unsqueeze(0)
759
763
  pos_enc = torch.cat([pos_enc_positive, pos_enc_negative], dim=1)
760
- self.pos_enc = pos_enc.to(device=x.device, dtype=x.dtype)
764
+ return pos_enc.to(device=x.device, dtype=x.dtype)
761
765
 
762
766
  def forward(self, feature_representation):
763
767
  """
@@ -768,7 +772,7 @@ class FastSpeech2ConformerRelPositionalEncoding(nn.Module):
768
772
  Returns:
769
773
  `torch.Tensor`: Encoded tensor (batch_size, time, `*`).
770
774
  """
771
- self.extend_pos_enc(feature_representation)
775
+ self.pos_enc = self.extend_pos_enc(feature_representation, self.pos_enc)
772
776
  hidden_states = feature_representation * self.input_scale
773
777
  center_idx = self.pos_enc.size(1) // 2
774
778
  pos_emb = self.pos_enc[:, center_idx - hidden_states.size(1) + 1 : center_idx + hidden_states.size(1)]
@@ -1007,6 +1011,10 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
1007
1011
  elif isinstance(module, (nn.LayerNorm, nn.BatchNorm1d)):
1008
1012
  init.zeros_(module.bias)
1009
1013
  init.ones_(module.weight)
1014
+ if getattr(module, "running_mean", None) is not None:
1015
+ init.zeros_(module.running_mean)
1016
+ init.ones_(module.running_var)
1017
+ init.zeros_(module.num_batches_tracked)
1010
1018
  elif isinstance(module, nn.Embedding):
1011
1019
  init.normal_(module.weight)
1012
1020
  # Here we need the check explicitly, as we slice the weight in the `zeros_` call, so it looses the flag
@@ -1015,6 +1023,8 @@ class FastSpeech2ConformerPreTrainedModel(PreTrainedModel):
1015
1023
  elif isinstance(module, FastSpeech2ConformerAttention):
1016
1024
  init.xavier_uniform_(module.pos_bias_u)
1017
1025
  init.xavier_uniform_(module.pos_bias_v)
1026
+ elif isinstance(module, FastSpeech2ConformerRelPositionalEncoding):
1027
+ init.copy_(module.pos_enc, module.extend_pos_enc(torch.tensor(0.0).expand(1, module.max_len)))
1018
1028
 
1019
1029
  def _set_gradient_checkpointing(self, module, value=False):
1020
1030
  if isinstance(module, FastSpeech2ConformerEncoder):
@@ -1118,6 +1128,7 @@ class FastSpeech2ConformerModel(FastSpeech2ConformerPreTrainedModel):
1118
1128
  return_dict: Optional[bool] = None,
1119
1129
  output_attentions: Optional[bool] = None,
1120
1130
  output_hidden_states: Optional[bool] = None,
1131
+ **kwargs,
1121
1132
  ) -> Union[tuple, FastSpeech2ConformerModelOutput]:
1122
1133
  r"""
1123
1134
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -1406,6 +1417,12 @@ class FastSpeech2ConformerHifiGan(PreTrainedModel):
1406
1417
  # Initialize weights and apply final processing
1407
1418
  self.post_init()
1408
1419
 
1420
+ def _init_weights(self, module):
1421
+ super()._init_weights(module)
1422
+ if isinstance(module, FastSpeech2ConformerHifiGan):
1423
+ init.zeros_(module.mean)
1424
+ init.ones_(module.scale)
1425
+
1409
1426
  def apply_weight_norm(self):
1410
1427
  weight_norm = nn.utils.weight_norm
1411
1428
  if hasattr(nn.utils.parametrizations, "weight_norm"):
@@ -1433,7 +1450,7 @@ class FastSpeech2ConformerHifiGan(PreTrainedModel):
1433
1450
  waveform.
1434
1451
  """
1435
1452
  )
1436
- def forward(self, spectrogram: torch.FloatTensor) -> torch.FloatTensor:
1453
+ def forward(self, spectrogram: torch.FloatTensor, **kwargs) -> torch.FloatTensor:
1437
1454
  r"""
1438
1455
  spectrogram (`torch.FloatTensor`):
1439
1456
  Tensor containing the log-mel spectrograms. Can be batched and of shape `(batch_size, sequence_length,
@@ -1509,6 +1526,7 @@ class FastSpeech2ConformerWithHifiGan(PreTrainedModel):
1509
1526
  return_dict: Optional[bool] = None,
1510
1527
  output_attentions: Optional[bool] = None,
1511
1528
  output_hidden_states: Optional[bool] = None,
1529
+ **kwargs,
1512
1530
  ) -> Union[tuple, FastSpeech2ConformerModelOutput]:
1513
1531
  r"""
1514
1532
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
@@ -79,6 +79,7 @@ class FastSpeech2ConformerTokenizer(PreTrainedTokenizer):
79
79
  unk_token=unk_token,
80
80
  pad_token=pad_token,
81
81
  should_strip_spaces=should_strip_spaces,
82
+ special_tokens_pattern="none",
82
83
  **kwargs,
83
84
  )
84
85
 
@@ -660,9 +660,6 @@ class FlaubertPreTrainedModel(PreTrainedModel):
660
660
  config: FlaubertConfig
661
661
  base_model_prefix = "transformer"
662
662
 
663
- def __init__(self, *inputs, **kwargs):
664
- super().__init__(*inputs, **kwargs)
665
-
666
663
  @property
667
664
  def dummy_inputs(self):
668
665
  inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
@@ -690,15 +687,17 @@ class FlaubertPreTrainedModel(PreTrainedModel):
690
687
  if isinstance(module, nn.LayerNorm):
691
688
  init.zeros_(module.bias)
692
689
  init.ones_(module.weight)
693
- if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings:
694
- init.copy_(
695
- module.position_embeddings.weight,
696
- create_sinusoidal_embeddings(
697
- self.config.max_position_embeddings,
698
- self.config.emb_dim,
699
- out=torch.empty_like(module.position_embeddings.weight),
700
- ),
701
- )
690
+ if isinstance(module, FlaubertModel):
691
+ if self.config.sinusoidal_embeddings:
692
+ init.copy_(
693
+ module.position_embeddings.weight,
694
+ create_sinusoidal_embeddings(
695
+ self.config.max_position_embeddings,
696
+ self.config.emb_dim,
697
+ out=torch.empty_like(module.position_embeddings.weight),
698
+ ),
699
+ )
700
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
702
701
 
703
702
 
704
703
  @auto_docstring
@@ -760,15 +759,15 @@ class FlaubertModel(FlaubertPreTrainedModel):
760
759
  self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
761
760
  self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
762
761
 
763
- # Initialize weights and apply final processing
764
- self.post_init()
765
-
766
762
  self.layerdrop = getattr(config, "layerdrop", 0.0)
767
763
  self.pre_norm = getattr(config, "pre_norm", False)
768
764
  self.register_buffer(
769
765
  "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
770
766
  )
771
767
 
768
+ # Initialize weights and apply final processing
769
+ self.post_init()
770
+
772
771
  # Copied from transformers.models.xlm.modeling_xlm.XLMModel.get_input_embeddings
773
772
  def get_input_embeddings(self):
774
773
  return self.embeddings
@@ -792,6 +791,7 @@ class FlaubertModel(FlaubertPreTrainedModel):
792
791
  output_hidden_states: Optional[bool] = None,
793
792
  return_dict: Optional[bool] = None,
794
793
  cache_position: Optional[torch.Tensor] = None,
794
+ **kwargs,
795
795
  ) -> Union[tuple, BaseModelOutput]:
796
796
  r"""
797
797
  langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1002,6 +1002,7 @@ class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin):
1002
1002
  output_attentions: Optional[bool] = None,
1003
1003
  output_hidden_states: Optional[bool] = None,
1004
1004
  return_dict: Optional[bool] = None,
1005
+ **kwargs,
1005
1006
  ) -> Union[tuple, MaskedLMOutput]:
1006
1007
  r"""
1007
1008
  langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1090,6 +1091,7 @@ class FlaubertForSequenceClassification(FlaubertPreTrainedModel):
1090
1091
  output_attentions: Optional[bool] = None,
1091
1092
  output_hidden_states: Optional[bool] = None,
1092
1093
  return_dict: Optional[bool] = None,
1094
+ **kwargs,
1093
1095
  ) -> Union[tuple, SequenceClassifierOutput]:
1094
1096
  r"""
1095
1097
  langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1195,6 +1197,7 @@ class FlaubertForTokenClassification(FlaubertPreTrainedModel):
1195
1197
  output_attentions: Optional[bool] = None,
1196
1198
  output_hidden_states: Optional[bool] = None,
1197
1199
  return_dict: Optional[bool] = None,
1200
+ **kwargs,
1198
1201
  ) -> Union[tuple, TokenClassifierOutput]:
1199
1202
  r"""
1200
1203
  langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1286,6 +1289,7 @@ class FlaubertForQuestionAnsweringSimple(FlaubertPreTrainedModel):
1286
1289
  output_attentions: Optional[bool] = None,
1287
1290
  output_hidden_states: Optional[bool] = None,
1288
1291
  return_dict: Optional[bool] = None,
1292
+ **kwargs,
1289
1293
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1290
1294
  r"""
1291
1295
  langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1423,6 +1427,7 @@ class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
1423
1427
  output_attentions: Optional[bool] = None,
1424
1428
  output_hidden_states: Optional[bool] = None,
1425
1429
  return_dict: Optional[bool] = None,
1430
+ **kwargs,
1426
1431
  ) -> Union[tuple, FlaubertForQuestionAnsweringOutput]:
1427
1432
  r"""
1428
1433
  langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1538,6 +1543,7 @@ class FlaubertForMultipleChoice(FlaubertPreTrainedModel):
1538
1543
  output_attentions: Optional[bool] = None,
1539
1544
  output_hidden_states: Optional[bool] = None,
1540
1545
  return_dict: Optional[bool] = None,
1546
+ **kwargs,
1541
1547
  ) -> Union[tuple, MultipleChoiceModelOutput]:
1542
1548
  r"""
1543
1549
  input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
@@ -306,7 +306,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
306
306
  processed_images_grouped[shape] = stacked_images
307
307
 
308
308
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
309
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
310
309
 
311
310
  return processed_images
312
311
 
@@ -397,7 +396,6 @@ class FlavaImageProcessorFast(BaseImageProcessorFast):
397
396
  mask_group_max_aspect_ratio=mask_group_max_aspect_ratio,
398
397
  )
399
398
  masks = [mask_generator() for _ in range(len(images))]
400
- masks = torch.stack(masks, dim=0) if return_tensors else masks
401
399
  data["bool_masked_pos"] = masks
402
400
 
403
401
  return BatchFeature(data=data, tensor_type=return_tensors)
@@ -677,6 +677,9 @@ class FlavaPreTrainedModel(PreTrainedModel):
677
677
  init.zeros_(module.position_embeddings)
678
678
  if module.mask_token is not None:
679
679
  init.zeros_(module.mask_token)
680
+ elif isinstance(module, FlavaTextEmbeddings):
681
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
682
+ init.zeros_(module.token_type_ids)
680
683
  elif isinstance(module, FlavaMultimodalModel):
681
684
  if module.use_cls_token:
682
685
  init.zeros_(module.cls_token)
@@ -725,6 +728,7 @@ class FlavaImageModel(FlavaPreTrainedModel):
725
728
  output_attentions: Optional[bool] = None,
726
729
  output_hidden_states: Optional[bool] = None,
727
730
  return_dict: Optional[bool] = None,
731
+ **kwargs,
728
732
  ) -> Union[tuple, BaseModelOutputWithPooling]:
729
733
  r"""
730
734
  bool_masked_pos (`torch.BoolTensor` of shape `(batch_size, image_num_patches)`):
@@ -804,6 +808,7 @@ class FlavaTextModel(FlavaPreTrainedModel):
804
808
  output_attentions: Optional[bool] = None,
805
809
  output_hidden_states: Optional[bool] = None,
806
810
  return_dict: Optional[bool] = None,
811
+ **kwargs,
807
812
  ) -> Union[tuple, BaseModelOutputWithPooling]:
808
813
  r"""
809
814
  input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_length)`):
@@ -896,6 +901,7 @@ class FlavaMultimodalModel(FlavaPreTrainedModel):
896
901
  output_attentions: Optional[bool] = None,
897
902
  output_hidden_states: Optional[bool] = None,
898
903
  return_dict: Optional[bool] = None,
904
+ **kwargs,
899
905
  ) -> Union[tuple, BaseModelOutputWithPooling]:
900
906
  r"""
901
907
  hidden_states (`torch.FloatTensor` of shape `(batch_size, image_num_patches + text_seq_len, hidden_size)`):
@@ -1103,7 +1109,8 @@ class FlavaModel(FlavaPreTrainedModel):
1103
1109
  output_attentions: Optional[bool] = None,
1104
1110
  output_hidden_states: bool = True,
1105
1111
  return_dict: Optional[bool] = None,
1106
- ) -> Union[tuple, FlavaOutput]:
1112
+ **kwargs,
1113
+ ) -> Union[tuple, FlavaModelOutput]:
1107
1114
  r"""
1108
1115
  input_ids (`torch.LongTensor` of shape `(batch_size, image_num_patches + text_seq_len)`):
1109
1116
  Indices of input sequence tokens in the vocabulary. Indices can be obtained using [`AutoTokenizer`]. See
@@ -1380,7 +1387,7 @@ class FlavaImageCodebook(FlavaPreTrainedModel):
1380
1387
  z_logits = self.blocks(pixel_values)
1381
1388
  return nn.Softmax(dim=1)(z_logits)
1382
1389
 
1383
- def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
1390
+ def forward(self, pixel_values: torch.FloatTensor, **kwargs) -> torch.Tensor:
1384
1391
  f"""
1385
1392
  Args:
1386
1393
  pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
@@ -1575,6 +1582,7 @@ class FlavaForPreTraining(FlavaPreTrainedModel):
1575
1582
  output_hidden_states: bool = True,
1576
1583
  return_dict: Optional[bool] = None,
1577
1584
  return_loss: Optional[bool] = None,
1585
+ **kwargs,
1578
1586
  ) -> Union[tuple[torch.Tensor], FlavaForPreTrainingOutput]:
1579
1587
  r"""
1580
1588
  input_ids (`torch.LongTensor` of shape `(batch_size, text_seq_len)`):
@@ -30,15 +30,15 @@ from ... import initialization as init
30
30
  from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...generation import GenerationMixin
33
- from ...integrations import use_kernel_forward_from_hub
33
+ from ...integrations import use_experts_implementation, use_kernel_forward_from_hub, use_kernelized_func
34
34
  from ...masking_utils import create_causal_mask
35
35
  from ...modeling_layers import GradientCheckpointingLayer
36
36
  from ...modeling_outputs import MoeCausalLMOutputWithPast, MoeModelOutputWithPast
37
37
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
38
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
39
  from ...processing_utils import Unpack
40
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
- from ...utils.generic import OutputRecorder, check_model_inputs
40
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_grouped_mm_available
41
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
42
42
  from .configuration_flex_olmo import FlexOlmoConfig
43
43
 
44
44
 
@@ -80,7 +80,7 @@ class FlexOlmoRotaryEmbedding(nn.Module):
80
80
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
81
81
 
82
82
  self.register_buffer("inv_freq", inv_freq, persistent=False)
83
- self.original_inv_freq = inv_freq
83
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
84
84
 
85
85
  @staticmethod
86
86
  def compute_default_rope_parameters(
@@ -119,7 +119,7 @@ class FlexOlmoRotaryEmbedding(nn.Module):
119
119
  position_ids_expanded = position_ids[:, None, :].float()
120
120
 
121
121
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
122
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
122
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
123
123
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
124
124
  emb = torch.cat((freqs, freqs), dim=-1)
125
125
  cos = emb.cos() * self.attention_scaling
@@ -216,6 +216,7 @@ def rotate_half(x):
216
216
  return torch.cat((-x2, x1), dim=-1)
217
217
 
218
218
 
219
+ @use_kernelized_func(apply_rotary_pos_emb)
219
220
  class FlexOlmoAttention(nn.Module):
220
221
  """Multi-headed attention from 'Attention Is All You Need' paper"""
221
222
 
@@ -241,7 +242,6 @@ class FlexOlmoAttention(nn.Module):
241
242
  self.o_proj = nn.Linear(
242
243
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
243
244
  )
244
- self.rotary_fn = apply_rotary_pos_emb
245
245
  self.q_norm = FlexOlmoRMSNorm(config.num_attention_heads * self.head_dim, config.rms_norm_eps)
246
246
  self.k_norm = FlexOlmoRMSNorm(config.num_key_value_heads * self.head_dim, config.rms_norm_eps)
247
247
 
@@ -252,7 +252,6 @@ class FlexOlmoAttention(nn.Module):
252
252
  attention_mask: Optional[torch.Tensor],
253
253
  past_key_values: Optional[Cache] = None,
254
254
  cache_position: Optional[torch.LongTensor] = None,
255
- position_ids: Optional[torch.LongTensor] = None,
256
255
  **kwargs: Unpack[TransformersKwargs],
257
256
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
258
257
  input_shape = hidden_states.shape[:-1]
@@ -294,6 +293,7 @@ class FlexOlmoAttention(nn.Module):
294
293
  return attn_output, attn_weights
295
294
 
296
295
 
296
+ @use_experts_implementation
297
297
  class FlexOlmoExperts(nn.Module):
298
298
  """Collection of expert weights stored as 3D tensors."""
299
299
 
@@ -422,7 +422,9 @@ class FlexOlmoPreTrainedModel(PreTrainedModel):
422
422
  _supports_flash_attn = True
423
423
  _supports_sdpa = True
424
424
  _supports_flex_attn = True
425
- _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
425
+ _can_compile_fullgraph = (
426
+ is_grouped_mm_available()
427
+ ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
426
428
  _supports_attention_backend = True
427
429
  _can_record_outputs = {
428
430
  "router_logits": OutputRecorder(nn.Linear, layer_name="mlp.gate", index=0),