transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -12,12 +12,11 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- from typing import Optional
15
+ from typing import Optional, Union
16
16
 
17
17
  from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
18
18
  from tokenizers.models import BPE
19
19
 
20
- from ...tokenization_utils_base import generate_merges
21
20
  from ...tokenization_utils_tokenizers import TokenizersBackend
22
21
  from ...utils import logging
23
22
 
@@ -30,7 +29,7 @@ class GemmaTokenizer(TokenizersBackend):
30
29
  """
31
30
  Construct a fast Gemma tokenizer (backed by HuggingFace's tokenizers library).
32
31
 
33
- This tokenizer uses a Unigram model with ByteFallback, no prefix space, and a normalizer that replaces
32
+ This tokenizer uses a BPE model with byte fallback, no prefix space, and a normalizer that replaces
34
33
  spaces with "▁".
35
34
 
36
35
  Args:
@@ -50,48 +49,37 @@ class GemmaTokenizer(TokenizersBackend):
50
49
  Whether or not to add a `bos_token` at the start of sequences.
51
50
  add_eos_token (`bool`, optional, defaults to False):
52
51
  Whether or not to add an `eos_token` at the end of sequences.
53
- vocab (`dict`, optional):
52
+ vocab (`str` or `dict[str, int]`, optional):
54
53
  Custom vocabulary dict. If not provided, a minimal vocabulary is created using the special tokens.
55
54
  """
56
55
 
57
56
  vocab_files_names = VOCAB_FILES_NAMES
58
- slow_tokenizer_class = None
59
57
  padding_side = "left"
60
58
  model_input_names = ["input_ids", "attention_mask"]
59
+ model = BPE
61
60
 
62
61
  def __init__(
63
62
  self,
63
+ vocab: Optional[Union[str, dict[str, int]]] = None,
64
+ merges: Optional[Union[str, list[str]]] = None,
64
65
  unk_token: str = "<unk>",
65
66
  bos_token: str = "<bos>",
66
67
  eos_token: str = "<eos>",
67
68
  pad_token: str = "<pad>",
68
69
  mask_token: str = "<mask>",
69
- add_bos_token: bool = True,
70
- add_eos_token: bool = False,
71
- vocab: Optional[dict] = None,
72
- merges: Optional[list[tuple[str, str]]] = None,
73
70
  **kwargs,
74
71
  ):
75
- self._add_bos_token = add_bos_token
76
- self._add_eos_token = add_eos_token
77
-
78
- special_tokens = {str(pad_token), str(eos_token), str(bos_token), str(unk_token)}
79
-
80
- if vocab is not None:
81
- self._vocab = (
82
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
83
- )
84
- else:
85
- self._vocab = {
72
+ if vocab is None:
73
+ vocab = {
86
74
  str(pad_token): 0,
87
75
  str(eos_token): 1,
88
76
  str(bos_token): 2,
89
77
  str(unk_token): 3,
90
78
  str(mask_token): 4,
91
79
  }
80
+ self._vocab = vocab
81
+ self._merges = merges or []
92
82
 
93
- filtered_vocab = {t: i for t, i in self._vocab.items() if t not in special_tokens}
94
- self._merges = merges if merges is not None else generate_merges(filtered_vocab)
95
83
  self._tokenizer = Tokenizer(
96
84
  BPE(
97
85
  vocab=self._vocab,
@@ -108,17 +96,12 @@ class GemmaTokenizer(TokenizersBackend):
108
96
  )
109
97
  self._tokenizer.normalizer = normalizers.Replace(" ", "▁")
110
98
  self._tokenizer.pre_tokenizer = pre_tokenizers.Split(" ", "merged_with_previous")
111
- tokenizer_object = self._tokenizer
112
-
113
99
  super().__init__(
114
- tokenizer_object=tokenizer_object,
115
100
  unk_token=unk_token,
116
101
  bos_token=bos_token,
117
102
  eos_token=eos_token,
118
103
  pad_token=pad_token,
119
104
  mask_token=mask_token,
120
- add_bos_token=add_bos_token,
121
- add_eos_token=add_eos_token,
122
105
  **kwargs,
123
106
  )
124
107
 
@@ -29,7 +29,7 @@ from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
- from ...integrations import use_kernel_func_from_hub
32
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
33
33
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
34
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
35
35
  from ...modeling_layers import (
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
42
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
43
  from ...processing_utils import Unpack
44
44
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
45
- from ...utils.generic import check_model_inputs
45
+ from ...utils.generic import check_model_inputs, maybe_autocast
46
46
  from .configuration_gemma2 import Gemma2Config
47
47
 
48
48
 
@@ -99,7 +99,7 @@ class Gemma2RotaryEmbedding(nn.Module):
99
99
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
100
100
 
101
101
  self.register_buffer("inv_freq", inv_freq, persistent=False)
102
- self.original_inv_freq = inv_freq
102
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
103
103
 
104
104
  @staticmethod
105
105
  def compute_default_rope_parameters(
@@ -138,7 +138,7 @@ class Gemma2RotaryEmbedding(nn.Module):
138
138
  position_ids_expanded = position_ids[:, None, :].float()
139
139
 
140
140
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
141
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
141
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
142
142
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
143
143
  emb = torch.cat((freqs, freqs), dim=-1)
144
144
  cos = emb.cos() * self.attention_scaling
@@ -229,6 +229,7 @@ def eager_attention_forward(
229
229
  return attn_output, attn_weights
230
230
 
231
231
 
232
+ @use_kernelized_func(apply_rotary_pos_emb)
232
233
  class Gemma2Attention(nn.Module):
233
234
  """Multi-headed attention from 'Attention Is All You Need' paper"""
234
235
 
@@ -255,7 +256,6 @@ class Gemma2Attention(nn.Module):
255
256
  self.o_proj = nn.Linear(
256
257
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
257
258
  )
258
- self.rotary_fn = apply_rotary_pos_emb
259
259
  self.attn_logit_softcapping = self.config.attn_logit_softcapping
260
260
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
261
261
 
@@ -34,6 +34,7 @@ from ...modeling_rope_utils import (
34
34
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
35
35
  from ...processing_utils import Unpack
36
36
  from ...utils import TransformersKwargs, logging
37
+ from ...utils.generic import maybe_autocast
37
38
  from ..gemma.modeling_gemma import (
38
39
  GemmaAttention,
39
40
  GemmaForCausalLM,
@@ -243,7 +244,7 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
243
244
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
244
245
 
245
246
  self.register_buffer("inv_freq", inv_freq, persistent=False)
246
- self.original_inv_freq = inv_freq
247
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
247
248
 
248
249
  @torch.no_grad()
249
250
  @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
@@ -252,7 +253,7 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
252
253
  position_ids_expanded = position_ids[:, None, :].float()
253
254
 
254
255
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
255
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
256
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
256
257
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
257
258
  emb = torch.cat((freqs, freqs), dim=-1)
258
259
  cos = emb.cos() * self.attention_scaling
@@ -231,7 +231,6 @@ class Gemma3ImageProcessorFast(BaseImageProcessorFast):
231
231
  processed_images_grouped[shape] = stacked_images
232
232
 
233
233
  processed_images = reorder_images(processed_images_grouped, grouped_images_index)
234
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
235
234
  return BatchFeature(
236
235
  data={"pixel_values": processed_images, "num_crops": num_crops}, tensor_type=return_tensors
237
236
  )
@@ -31,16 +31,15 @@ from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...configuration_utils import PreTrainedConfig
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_func_from_hub
34
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
35
35
  from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
36
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
37
36
  from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
38
37
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
39
38
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
40
  from ...processing_utils import Unpack
42
41
  from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
43
- from ...utils.generic import check_model_inputs
42
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
43
  from ..auto import AutoModel
45
44
  from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
46
45
 
@@ -101,6 +100,7 @@ class Gemma3TextScaledWordEmbedding(nn.Embedding):
101
100
 
102
101
  def __init__(self, num_embeddings: int, embedding_dim: int, padding_idx: int, embed_scale: float = 1.0):
103
102
  super().__init__(num_embeddings, embedding_dim, padding_idx)
103
+ self.scalar_embed_scale = embed_scale
104
104
  self.register_buffer("embed_scale", torch.tensor(embed_scale), persistent=False)
105
105
 
106
106
  def forward(self, input_ids: torch.Tensor):
@@ -166,7 +166,7 @@ class Gemma3RotaryEmbedding(nn.Module):
166
166
  rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type[layer_type]]
167
167
  curr_inv_freq, curr_attention_scaling = rope_init_fn(self.config, device, layer_type=layer_type)
168
168
  self.register_buffer(f"{layer_type}_inv_freq", curr_inv_freq, persistent=False)
169
- setattr(self, f"{layer_type}_original_inv_freq", curr_inv_freq)
169
+ self.register_buffer(f"{layer_type}_original_inv_freq", curr_inv_freq.clone(), persistent=False)
170
170
  setattr(self, f"{layer_type}_attention_scaling", curr_attention_scaling)
171
171
 
172
172
  @staticmethod
@@ -215,7 +215,7 @@ class Gemma3RotaryEmbedding(nn.Module):
215
215
  position_ids_expanded = position_ids[:, None, :].float()
216
216
 
217
217
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
218
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
218
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
219
219
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
220
220
  emb = torch.cat((freqs, freqs), dim=-1)
221
221
  cos = emb.cos() * attention_scaling
@@ -306,6 +306,7 @@ def eager_attention_forward(
306
306
  return attn_output, attn_weights
307
307
 
308
308
 
309
+ @use_kernelized_func(apply_rotary_pos_emb)
309
310
  class Gemma3Attention(nn.Module):
310
311
  """Multi-headed attention from 'Attention Is All You Need' paper"""
311
312
 
@@ -332,7 +333,6 @@ class Gemma3Attention(nn.Module):
332
333
  self.o_proj = nn.Linear(
333
334
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
334
335
  )
335
- self.rotary_fn = apply_rotary_pos_emb
336
336
  self.attn_logit_softcapping = self.config.attn_logit_softcapping
337
337
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
338
338
  self.is_sliding = self.layer_type == "sliding_attention"
@@ -347,7 +347,7 @@ class Gemma3Attention(nn.Module):
347
347
  attention_mask: Optional[torch.Tensor] = None,
348
348
  past_key_values: Optional[Cache] = None,
349
349
  cache_position: Optional[torch.LongTensor] = None,
350
- **kwargs: Unpack[FlashAttentionKwargs],
350
+ **kwargs: Unpack[TransformersKwargs],
351
351
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
352
352
  input_shape = hidden_states.shape[:-1]
353
353
  hidden_shape = (*input_shape, -1, self.head_dim)
@@ -409,23 +409,19 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
409
409
  attention_mask: Optional[torch.Tensor] = None,
410
410
  position_ids: Optional[torch.LongTensor] = None,
411
411
  past_key_values: Optional[Cache] = None,
412
- output_attentions: Optional[bool] = False,
413
- use_cache: Optional[bool] = False,
414
412
  cache_position: Optional[torch.LongTensor] = None,
415
- **kwargs,
413
+ **kwargs: Unpack[TransformersKwargs],
416
414
  ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
417
415
  residual = hidden_states
418
416
 
419
417
  hidden_states = self.input_layernorm(hidden_states)
420
418
 
421
- hidden_states, self_attn_weights = self.self_attn(
419
+ hidden_states, _ = self.self_attn(
422
420
  hidden_states=hidden_states,
423
421
  position_embeddings=position_embeddings,
424
422
  attention_mask=attention_mask,
425
423
  position_ids=position_ids,
426
424
  past_key_values=past_key_values,
427
- output_attentions=output_attentions,
428
- use_cache=use_cache,
429
425
  cache_position=cache_position,
430
426
  **kwargs,
431
427
  )
@@ -438,12 +434,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
438
434
  hidden_states = self.post_feedforward_layernorm(hidden_states)
439
435
  hidden_states = residual + hidden_states
440
436
 
441
- outputs = (hidden_states,)
442
-
443
- if output_attentions:
444
- outputs += (self_attn_weights,)
445
-
446
- return outputs
437
+ return hidden_states
447
438
 
448
439
 
449
440
  @auto_docstring
@@ -478,6 +469,16 @@ class Gemma3PreTrainedModel(PreTrainedModel):
478
469
  # We initialize with 0s to be 1 centered as the RMSNorm here does (1 + weight)
479
470
  elif "RMSNorm" in module.__class__.__name__:
480
471
  init.zeros_(module.weight)
472
+ elif isinstance(module, Gemma3TextScaledWordEmbedding):
473
+ init.constant_(module.embed_scale, module.scalar_embed_scale)
474
+ elif isinstance(module, Gemma3RotaryEmbedding):
475
+ for layer_type in module.layer_types:
476
+ rope_init_fn = module.compute_default_rope_parameters
477
+ if module.rope_type[layer_type] != "default":
478
+ rope_init_fn = ROPE_INIT_FUNCTIONS[module.rope_type[layer_type]]
479
+ curr_inv_freq, _ = rope_init_fn(module.config, layer_type=layer_type)
480
+ init.copy_(getattr(module, f"{layer_type}_inv_freq"), curr_inv_freq)
481
+ init.copy_(getattr(module, f"{layer_type}_original_inv_freq"), curr_inv_freq)
481
482
 
482
483
 
483
484
  def _bidirectional_window_overlay(sliding_window: int) -> Callable[[int, int, int, int], bool]:
@@ -527,30 +528,16 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
527
528
  past_key_values: Optional[Cache] = None,
528
529
  inputs_embeds: Optional[torch.FloatTensor] = None,
529
530
  use_cache: Optional[bool] = None,
530
- output_attentions: Optional[bool] = None,
531
- output_hidden_states: Optional[bool] = None,
532
531
  cache_position: Optional[torch.LongTensor] = None,
533
532
  **kwargs: Unpack[TransformersKwargs],
534
533
  ) -> BaseModelOutputWithPast:
535
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
536
- output_hidden_states = (
537
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
538
- )
539
- use_cache = use_cache if use_cache is not None else self.config.use_cache
540
-
541
534
  if (input_ids is None) ^ (inputs_embeds is not None):
542
535
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
543
536
 
544
- if self.gradient_checkpointing and self.training and use_cache:
545
- logger.warning_once(
546
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
547
- )
548
- use_cache = False
549
-
550
537
  if inputs_embeds is None:
551
538
  inputs_embeds = self.embed_tokens(input_ids)
552
539
 
553
- if use_cache and past_key_values is None and not self.training:
540
+ if use_cache and past_key_values is None:
554
541
  past_key_values = DynamicCache(config=self.config)
555
542
 
556
543
  if cache_position is None:
@@ -591,41 +578,22 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
591
578
  for layer_type in self.config.layer_types:
592
579
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
593
580
 
594
- # decoder layers
595
- all_hidden_states = () if output_hidden_states else None
596
- all_self_attns = () if output_attentions else None
597
-
598
581
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
599
- if output_hidden_states:
600
- all_hidden_states += (hidden_states,)
601
-
602
- layer_outputs = decoder_layer(
582
+ hidden_states = decoder_layer(
603
583
  hidden_states,
604
584
  attention_mask=causal_mask_mapping[decoder_layer.attention_type],
605
585
  position_embeddings=position_embeddings[decoder_layer.attention_type],
606
586
  position_ids=position_ids,
607
587
  past_key_values=past_key_values,
608
- output_attentions=output_attentions,
609
- use_cache=use_cache,
610
588
  cache_position=cache_position,
611
589
  **kwargs,
612
590
  )
613
591
 
614
- hidden_states = layer_outputs[0]
615
-
616
- if output_attentions:
617
- all_self_attns += (layer_outputs[1],)
618
-
619
592
  hidden_states = self.norm(hidden_states)
620
593
 
621
- if output_hidden_states:
622
- all_hidden_states += (hidden_states,)
623
-
624
594
  return BaseModelOutputWithPast(
625
595
  last_hidden_state=hidden_states,
626
596
  past_key_values=past_key_values,
627
- hidden_states=all_hidden_states,
628
- attentions=all_self_attns,
629
597
  )
630
598
 
631
599
 
@@ -797,6 +765,7 @@ def create_causal_mask_mapping(
797
765
  token_type_ids: Optional[torch.Tensor] = None,
798
766
  pixel_values: Optional[torch.FloatTensor] = None,
799
767
  is_training: bool = False,
768
+ is_first_iteration: Optional[bool] = None,
800
769
  **kwargs,
801
770
  ) -> dict:
802
771
  """
@@ -819,8 +788,12 @@ def create_causal_mask_mapping(
819
788
  # NOTE: this `may_have_image_input` logic is not flawless, it fails when we're using a cache eagerly initialized
820
789
  # (e.g. compiled prefill) AND `pixel_values` are not provided (i.e. the image data is provided through other
821
790
  # means). Determining prefill in that case requires checking data values, which is not compile-compatible.
822
- may_have_image_input = past_key_values is None or not past_key_values.is_initialized or pixel_values is not None
823
- if token_type_ids is not None and may_have_image_input:
791
+ is_first_iteration = (
792
+ is_first_iteration
793
+ if is_first_iteration is not None
794
+ else (past_key_values is None or not past_key_values.is_initialized or pixel_values is not None)
795
+ )
796
+ if token_type_ids is not None and is_first_iteration:
824
797
  # We need to pass an additional mask function to account for token type ids, and it needs to be an `or` (to
825
798
  # undo the causal masking)
826
799
 
@@ -918,10 +891,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
918
891
  inputs_embeds: Optional[torch.FloatTensor] = None,
919
892
  labels: Optional[torch.LongTensor] = None,
920
893
  use_cache: Optional[bool] = None,
921
- output_attentions: Optional[bool] = None,
922
- output_hidden_states: Optional[bool] = None,
923
- return_dict: Optional[bool] = None,
924
- **lm_kwargs,
894
+ **lm_kwargs: Unpack[TransformersKwargs],
925
895
  ) -> Union[tuple, Gemma3ModelOutputWithPast]:
926
896
  r"""
927
897
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -953,12 +923,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
953
923
  if (input_ids is None) ^ (inputs_embeds is not None):
954
924
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
955
925
 
956
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
957
- output_hidden_states = (
958
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
959
- )
960
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
961
-
962
926
  # Replace image id with PAD if the image token if OOV, to avoid index-errors
963
927
  if input_ids is not None and self.config.image_token_id >= self.vocab_size:
964
928
  special_image_mask = input_ids == self.config.image_token_id
@@ -1005,8 +969,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
1005
969
  past_key_values=past_key_values,
1006
970
  inputs_embeds=inputs_embeds,
1007
971
  use_cache=use_cache,
1008
- output_attentions=output_attentions,
1009
- output_hidden_states=output_hidden_states,
1010
972
  return_dict=True,
1011
973
  cache_position=cache_position,
1012
974
  **lm_kwargs,
@@ -1014,7 +976,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
1014
976
 
1015
977
  return Gemma3ModelOutputWithPast(
1016
978
  last_hidden_state=outputs.last_hidden_state,
1017
- past_key_values=outputs.past_key_values if use_cache else None,
979
+ past_key_values=outputs.past_key_values,
1018
980
  hidden_states=outputs.hidden_states,
1019
981
  attentions=outputs.attentions,
1020
982
  image_hidden_states=image_features if pixel_values is not None else None,
@@ -1053,6 +1015,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1053
1015
  def get_image_features(self, pixel_values):
1054
1016
  return self.model.get_image_features(pixel_values)
1055
1017
 
1018
+ @can_return_tuple
1056
1019
  @auto_docstring
1057
1020
  def forward(
1058
1021
  self,
@@ -1066,11 +1029,8 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1066
1029
  inputs_embeds: Optional[torch.FloatTensor] = None,
1067
1030
  labels: Optional[torch.LongTensor] = None,
1068
1031
  use_cache: Optional[bool] = None,
1069
- output_attentions: Optional[bool] = None,
1070
- output_hidden_states: Optional[bool] = None,
1071
- return_dict: Optional[bool] = None,
1072
1032
  logits_to_keep: Union[int, torch.Tensor] = 0,
1073
- **lm_kwargs,
1033
+ **lm_kwargs: Unpack[TransformersKwargs],
1074
1034
  ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
1075
1035
  r"""
1076
1036
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1116,13 +1076,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1116
1076
  "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
1117
1077
  ```
1118
1078
  """
1119
-
1120
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1121
- output_hidden_states = (
1122
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1123
- )
1124
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1125
-
1126
1079
  outputs = self.model(
1127
1080
  input_ids=input_ids,
1128
1081
  pixel_values=pixel_values,
@@ -1133,9 +1086,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1133
1086
  inputs_embeds=inputs_embeds,
1134
1087
  use_cache=use_cache,
1135
1088
  labels=labels,
1136
- output_attentions=output_attentions,
1137
- output_hidden_states=output_hidden_states,
1138
- return_dict=return_dict,
1139
1089
  cache_position=cache_position,
1140
1090
  **lm_kwargs,
1141
1091
  )
@@ -1167,10 +1117,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1167
1117
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
1168
1118
  loss = loss_fct(flat_logits, flat_labels)
1169
1119
 
1170
- if not return_dict:
1171
- output = (logits,) + outputs[1:]
1172
- return (loss,) + output if loss is not None else output
1173
-
1174
1120
  return Gemma3CausalLMOutputWithPast(
1175
1121
  loss=loss,
1176
1122
  logits=logits,
@@ -1193,6 +1139,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1193
1139
  use_cache=True,
1194
1140
  logits_to_keep=None,
1195
1141
  labels=None,
1142
+ is_first_iteration=False,
1196
1143
  **kwargs,
1197
1144
  ):
1198
1145
  # Overwritten -- custom `position_ids` and `pixel_values` handling
@@ -1206,12 +1153,15 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1206
1153
  use_cache=use_cache,
1207
1154
  logits_to_keep=logits_to_keep,
1208
1155
  token_type_ids=token_type_ids,
1156
+ is_first_iteration=is_first_iteration,
1209
1157
  **kwargs,
1210
1158
  )
1211
1159
 
1212
- # If we're in cached decoding stage, pixel values should be None because input ids do not contain special image token anymore
1213
- # Otherwise we need pixel values to be passed to model. NOTE: use_cache=False needs pixel_values always
1214
- if cache_position[0] == 0:
1160
+ # Pixel values are used only in the first iteration if available
1161
+ # In subsquent iterations, they are already merged with text and cached
1162
+ # NOTE: first iteration doesn't have to be prefill, it can be the first
1163
+ # iteration with a question and cached system prompt (continue generate from cache). NOTE: use_cache=False needs pixel_values always
1164
+ if is_first_iteration or not use_cache:
1215
1165
  model_inputs["pixel_values"] = pixel_values
1216
1166
 
1217
1167
  return model_inputs
@@ -1225,6 +1175,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1225
1175
  past_key_values: Optional[Cache],
1226
1176
  position_ids: Optional[torch.Tensor],
1227
1177
  token_type_ids: Optional[torch.Tensor] = None,
1178
+ is_first_iteration: Optional[bool] = False,
1228
1179
  **kwargs,
1229
1180
  ) -> dict:
1230
1181
  # Uses the overwritten `create_masks_for_generate` with `token_type_ids` masking
@@ -1236,7 +1187,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1236
1187
  past_key_values,
1237
1188
  position_ids,
1238
1189
  token_type_ids,
1239
- pixel_values=kwargs.get("pixel_values"),
1190
+ is_first_iteration=is_first_iteration,
1240
1191
  **{k: v for k, v in kwargs.items() if k != "pixel_values"},
1241
1192
  )
1242
1193