transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -23,7 +23,6 @@ from ... import initialization as init
23
23
  from ...cache_utils import Cache, DynamicCache
24
24
  from ...configuration_utils import PreTrainedConfig, layer_type_validation
25
25
  from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
26
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
27
26
  from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
28
27
  from ...modeling_outputs import BaseModelOutputWithPast, SequenceClassifierOutputWithPast
29
28
  from ...modeling_rope_utils import (
@@ -34,6 +33,7 @@ from ...modeling_rope_utils import (
34
33
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
35
34
  from ...processing_utils import Unpack
36
35
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, logging
36
+ from ...utils.generic import maybe_autocast
37
37
  from ..gemma2.configuration_gemma2 import Gemma2Config
38
38
  from ..gemma2.modeling_gemma2 import (
39
39
  Gemma2Attention,
@@ -438,7 +438,7 @@ class Gemma3RotaryEmbedding(Gemma2RotaryEmbedding):
438
438
  position_ids_expanded = position_ids[:, None, :].float()
439
439
 
440
440
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
441
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
441
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
442
442
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
443
443
  emb = torch.cat((freqs, freqs), dim=-1)
444
444
  cos = emb.cos() * attention_scaling
@@ -465,7 +465,7 @@ class Gemma3Attention(Gemma2Attention):
465
465
  attention_mask: Optional[torch.Tensor] = None,
466
466
  past_key_values: Optional[Cache] = None,
467
467
  cache_position: Optional[torch.LongTensor] = None,
468
- **kwargs: Unpack[FlashAttentionKwargs],
468
+ **kwargs: Unpack[TransformersKwargs],
469
469
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
470
470
  input_shape = hidden_states.shape[:-1]
471
471
  hidden_shape = (*input_shape, -1, self.head_dim)
@@ -527,23 +527,19 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
527
527
  attention_mask: Optional[torch.Tensor] = None,
528
528
  position_ids: Optional[torch.LongTensor] = None,
529
529
  past_key_values: Optional[Cache] = None,
530
- output_attentions: Optional[bool] = False,
531
- use_cache: Optional[bool] = False,
532
530
  cache_position: Optional[torch.LongTensor] = None,
533
- **kwargs,
531
+ **kwargs: Unpack[TransformersKwargs],
534
532
  ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
535
533
  residual = hidden_states
536
534
 
537
535
  hidden_states = self.input_layernorm(hidden_states)
538
536
 
539
- hidden_states, self_attn_weights = self.self_attn(
537
+ hidden_states, _ = self.self_attn(
540
538
  hidden_states=hidden_states,
541
539
  position_embeddings=position_embeddings,
542
540
  attention_mask=attention_mask,
543
541
  position_ids=position_ids,
544
542
  past_key_values=past_key_values,
545
- output_attentions=output_attentions,
546
- use_cache=use_cache,
547
543
  cache_position=cache_position,
548
544
  **kwargs,
549
545
  )
@@ -556,12 +552,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
556
552
  hidden_states = self.post_feedforward_layernorm(hidden_states)
557
553
  hidden_states = residual + hidden_states
558
554
 
559
- outputs = (hidden_states,)
560
-
561
- if output_attentions:
562
- outputs += (self_attn_weights,)
563
-
564
- return outputs
555
+ return hidden_states
565
556
 
566
557
 
567
558
  GEMMA3_START_DOCSTRING = None
@@ -620,30 +611,16 @@ class Gemma3TextModel(Gemma2Model):
620
611
  past_key_values: Optional[Cache] = None,
621
612
  inputs_embeds: Optional[torch.FloatTensor] = None,
622
613
  use_cache: Optional[bool] = None,
623
- output_attentions: Optional[bool] = None,
624
- output_hidden_states: Optional[bool] = None,
625
614
  cache_position: Optional[torch.LongTensor] = None,
626
615
  **kwargs: Unpack[TransformersKwargs],
627
616
  ) -> BaseModelOutputWithPast:
628
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
629
- output_hidden_states = (
630
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
631
- )
632
- use_cache = use_cache if use_cache is not None else self.config.use_cache
633
-
634
617
  if (input_ids is None) ^ (inputs_embeds is not None):
635
618
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
636
619
 
637
- if self.gradient_checkpointing and self.training and use_cache:
638
- logger.warning_once(
639
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
640
- )
641
- use_cache = False
642
-
643
620
  if inputs_embeds is None:
644
621
  inputs_embeds = self.embed_tokens(input_ids)
645
622
 
646
- if use_cache and past_key_values is None and not self.training:
623
+ if use_cache and past_key_values is None:
647
624
  past_key_values = DynamicCache(config=self.config)
648
625
 
649
626
  if cache_position is None:
@@ -684,41 +661,22 @@ class Gemma3TextModel(Gemma2Model):
684
661
  for layer_type in self.config.layer_types:
685
662
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
686
663
 
687
- # decoder layers
688
- all_hidden_states = () if output_hidden_states else None
689
- all_self_attns = () if output_attentions else None
690
-
691
664
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
692
- if output_hidden_states:
693
- all_hidden_states += (hidden_states,)
694
-
695
- layer_outputs = decoder_layer(
665
+ hidden_states = decoder_layer(
696
666
  hidden_states,
697
667
  attention_mask=causal_mask_mapping[decoder_layer.attention_type],
698
668
  position_embeddings=position_embeddings[decoder_layer.attention_type],
699
669
  position_ids=position_ids,
700
670
  past_key_values=past_key_values,
701
- output_attentions=output_attentions,
702
- use_cache=use_cache,
703
671
  cache_position=cache_position,
704
672
  **kwargs,
705
673
  )
706
674
 
707
- hidden_states = layer_outputs[0]
708
-
709
- if output_attentions:
710
- all_self_attns += (layer_outputs[1],)
711
-
712
675
  hidden_states = self.norm(hidden_states)
713
676
 
714
- if output_hidden_states:
715
- all_hidden_states += (hidden_states,)
716
-
717
677
  return BaseModelOutputWithPast(
718
678
  last_hidden_state=hidden_states,
719
679
  past_key_values=past_key_values,
720
- hidden_states=all_hidden_states,
721
- attentions=all_self_attns,
722
680
  )
723
681
 
724
682
 
@@ -853,20 +811,11 @@ class Gemma3Model(PaliGemmaModel):
853
811
  inputs_embeds: Optional[torch.FloatTensor] = None,
854
812
  labels: Optional[torch.LongTensor] = None,
855
813
  use_cache: Optional[bool] = None,
856
- output_attentions: Optional[bool] = None,
857
- output_hidden_states: Optional[bool] = None,
858
- return_dict: Optional[bool] = None,
859
- **lm_kwargs,
814
+ **lm_kwargs: Unpack[TransformersKwargs],
860
815
  ) -> Union[tuple, Gemma3ModelOutputWithPast]:
861
816
  if (input_ids is None) ^ (inputs_embeds is not None):
862
817
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
863
818
 
864
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
865
- output_hidden_states = (
866
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
867
- )
868
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
869
-
870
819
  # Replace image id with PAD if the image token if OOV, to avoid index-errors
871
820
  if input_ids is not None and self.config.image_token_id >= self.vocab_size:
872
821
  special_image_mask = input_ids == self.config.image_token_id
@@ -913,8 +862,6 @@ class Gemma3Model(PaliGemmaModel):
913
862
  past_key_values=past_key_values,
914
863
  inputs_embeds=inputs_embeds,
915
864
  use_cache=use_cache,
916
- output_attentions=output_attentions,
917
- output_hidden_states=output_hidden_states,
918
865
  return_dict=True,
919
866
  cache_position=cache_position,
920
867
  **lm_kwargs,
@@ -922,7 +869,7 @@ class Gemma3Model(PaliGemmaModel):
922
869
 
923
870
  return Gemma3ModelOutputWithPast(
924
871
  last_hidden_state=outputs.last_hidden_state,
925
- past_key_values=outputs.past_key_values if use_cache else None,
872
+ past_key_values=outputs.past_key_values,
926
873
  hidden_states=outputs.hidden_states,
927
874
  attentions=outputs.attentions,
928
875
  image_hidden_states=image_features if pixel_values is not None else None,
@@ -934,6 +881,7 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
934
881
  # Fix: https://github.com/huggingface/transformers/issues/40564
935
882
  accepts_loss_kwargs = False
936
883
 
884
+ @can_return_tuple
937
885
  @auto_docstring
938
886
  def forward(
939
887
  self,
@@ -947,11 +895,8 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
947
895
  inputs_embeds: Optional[torch.FloatTensor] = None,
948
896
  labels: Optional[torch.LongTensor] = None,
949
897
  use_cache: Optional[bool] = None,
950
- output_attentions: Optional[bool] = None,
951
- output_hidden_states: Optional[bool] = None,
952
- return_dict: Optional[bool] = None,
953
898
  logits_to_keep: Union[int, torch.Tensor] = 0,
954
- **lm_kwargs,
899
+ **lm_kwargs: Unpack[TransformersKwargs],
955
900
  ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
956
901
  r"""
957
902
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -997,13 +942,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
997
942
  "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
998
943
  ```
999
944
  """
1000
-
1001
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1002
- output_hidden_states = (
1003
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1004
- )
1005
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1006
-
1007
945
  outputs = self.model(
1008
946
  input_ids=input_ids,
1009
947
  pixel_values=pixel_values,
@@ -1014,9 +952,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
1014
952
  inputs_embeds=inputs_embeds,
1015
953
  use_cache=use_cache,
1016
954
  labels=labels,
1017
- output_attentions=output_attentions,
1018
- output_hidden_states=output_hidden_states,
1019
- return_dict=return_dict,
1020
955
  cache_position=cache_position,
1021
956
  **lm_kwargs,
1022
957
  )
@@ -1048,10 +983,6 @@ class Gemma3ForConditionalGeneration(PaliGemmaForConditionalGeneration):
1048
983
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
1049
984
  loss = loss_fct(flat_logits, flat_labels)
1050
985
 
1051
- if not return_dict:
1052
- output = (logits,) + outputs[1:]
1053
- return (loss,) + output if loss is not None else output
1054
-
1055
986
  return Gemma3CausalLMOutputWithPast(
1056
987
  loss=loss,
1057
988
  logits=logits,
@@ -32,21 +32,19 @@ from ... import initialization as init
32
32
  from ...activations import ACT2FN
33
33
  from ...cache_utils import Cache, DynamicCache
34
34
  from ...generation import GenerationMixin
35
+ from ...integrations import use_kernelized_func
35
36
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
36
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
37
37
  from ...modeling_layers import GradientCheckpointingLayer
38
38
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
39
39
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
40
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
41
  from ...processing_utils import Unpack
42
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
42
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
43
+ from ...utils.generic import check_model_inputs, maybe_autocast
43
44
  from ..auto import AutoModel
44
45
  from .configuration_gemma3n import Gemma3nAudioConfig, Gemma3nConfig, Gemma3nTextConfig, Gemma3nVisionConfig
45
46
 
46
47
 
47
- logger = logging.get_logger(__name__)
48
-
49
-
50
48
  @dataclass
51
49
  @auto_docstring(
52
50
  custom_intro="""
@@ -923,7 +921,7 @@ class Gemma3nAudioEncoder(PreTrainedModel):
923
921
  )
924
922
 
925
923
  def forward(
926
- self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor
924
+ self, audio_mel: torch.Tensor, audio_mel_mask: torch.BoolTensor, **kwargs
927
925
  ) -> tuple[torch.Tensor, torch.BoolTensor]:
928
926
  """Encodes a batch of MELs.
929
927
 
@@ -1228,6 +1226,7 @@ def apply_rotary_pos_emb(
1228
1226
  return (x * cos) + (rotate_half(x) * sin)
1229
1227
 
1230
1228
 
1229
+ @use_kernelized_func(apply_rotary_pos_emb)
1231
1230
  class Gemma3nTextAttention(nn.Module):
1232
1231
  """Multi-headed attention from 'Attention Is All You Need' paper"""
1233
1232
 
@@ -1254,7 +1253,6 @@ class Gemma3nTextAttention(nn.Module):
1254
1253
  self.o_proj = nn.Linear(
1255
1254
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
1256
1255
  )
1257
- self.rotary_fn = apply_rotary_pos_emb
1258
1256
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
1259
1257
  self.is_sliding = self.layer_type == "sliding_attention"
1260
1258
 
@@ -1283,7 +1281,7 @@ class Gemma3nTextAttention(nn.Module):
1283
1281
  attention_mask: Optional[torch.Tensor] = None,
1284
1282
  past_key_values: Optional[Cache] = None,
1285
1283
  cache_position: Optional[torch.LongTensor] = None,
1286
- **kwargs: Unpack[FlashAttentionKwargs],
1284
+ **kwargs: Unpack[TransformersKwargs],
1287
1285
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
1288
1286
  input_shape = hidden_states.shape[:-1]
1289
1287
  hidden_shape = (*input_shape, -1, self.config.head_dim)
@@ -1379,10 +1377,8 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
1379
1377
  attention_mask: Optional[torch.Tensor] = None,
1380
1378
  position_ids: Optional[torch.LongTensor] = None,
1381
1379
  past_key_values: Optional[Cache] = None,
1382
- output_attentions: Optional[bool] = False,
1383
- use_cache: Optional[bool] = False,
1384
1380
  cache_position: Optional[torch.LongTensor] = None,
1385
- **kwargs,
1381
+ **kwargs: Unpack[TransformersKwargs],
1386
1382
  ) -> tuple[torch.Tensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
1387
1383
  predictions = self.altup.predict(hidden_states)
1388
1384
  active_prediction = predictions[self.config.altup_active_idx]
@@ -1390,14 +1386,12 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
1390
1386
  active_prediction_normed = self.input_layernorm(active_prediction)
1391
1387
  laurel_output = self.laurel(active_prediction_normed)
1392
1388
 
1393
- attn, self_attn_weights = self.self_attn(
1389
+ attn, _ = self.self_attn(
1394
1390
  hidden_states=active_prediction_normed,
1395
1391
  attention_mask=attention_mask,
1396
1392
  position_ids=position_ids,
1397
1393
  position_embeddings=position_embeddings,
1398
1394
  past_key_values=past_key_values,
1399
- output_attentions=output_attentions,
1400
- use_cache=use_cache,
1401
1395
  cache_position=cache_position,
1402
1396
  **kwargs,
1403
1397
  )
@@ -1426,154 +1420,7 @@ class Gemma3nTextDecoderLayer(GradientCheckpointingLayer):
1426
1420
  first_prediction = self.post_per_layer_input_norm(first_prediction)
1427
1421
  corrected_predictions[1:] += first_prediction
1428
1422
 
1429
- outputs = (corrected_predictions,)
1430
-
1431
- if output_attentions:
1432
- outputs += (self_attn_weights,)
1433
-
1434
- return outputs
1435
-
1436
-
1437
- class Gemma3nMLP(nn.Module):
1438
- def __init__(self, config):
1439
- super().__init__()
1440
- self.config = config
1441
- self.hidden_size = config.hidden_size
1442
- self.intermediate_size = config.intermediate_size
1443
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
1444
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
1445
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
1446
- self.act_fn = ACT2FN[config.hidden_activation]
1447
-
1448
- def forward(self, x):
1449
- down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
1450
- return down_proj
1451
-
1452
-
1453
- class Gemma3nAttention(nn.Module):
1454
- """Multi-headed attention from 'Attention Is All You Need' paper"""
1455
-
1456
- def __init__(self, config: Gemma3nConfig, layer_idx: int):
1457
- super().__init__()
1458
- self.layer_type = config.layer_types[layer_idx] if hasattr(config, "layer_types") else None
1459
- self.config = config
1460
- self.layer_idx = layer_idx
1461
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
1462
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
1463
- self.scaling = config.query_pre_attn_scalar**-0.5
1464
- self.attention_dropout = self.config.attention_dropout
1465
- self.is_causal = not getattr(config, "use_bidirectional_attention", False)
1466
-
1467
- self.q_proj = nn.Linear(
1468
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
1469
- )
1470
- self.k_proj = nn.Linear(
1471
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
1472
- )
1473
- self.v_proj = nn.Linear(
1474
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
1475
- )
1476
- self.o_proj = nn.Linear(
1477
- config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
1478
- )
1479
- self.rotary_fn = apply_rotary_pos_emb
1480
- self.attn_logit_softcapping = self.config.attn_logit_softcapping
1481
- self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
1482
-
1483
- def forward(
1484
- self,
1485
- hidden_states: torch.Tensor,
1486
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
1487
- attention_mask: Optional[torch.Tensor] = None,
1488
- past_key_values: Optional[Cache] = None,
1489
- cache_position: Optional[torch.LongTensor] = None,
1490
- **kwargs: Unpack[FlashAttentionKwargs],
1491
- ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
1492
- input_shape = hidden_states.shape[:-1]
1493
- hidden_shape = (*input_shape, -1, self.head_dim)
1494
-
1495
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
1496
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
1497
- value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
1498
-
1499
- cos, sin = position_embeddings
1500
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
1501
-
1502
- if past_key_values is not None:
1503
- # sin and cos are specific to RoPE models; cache_position needed for the static cache
1504
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
1505
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
1506
-
1507
- attention_interface: Callable = eager_attention_forward
1508
- if self.config._attn_implementation != "eager":
1509
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
1510
-
1511
- attn_output, attn_weights = attention_interface(
1512
- self,
1513
- query_states,
1514
- key_states,
1515
- value_states,
1516
- attention_mask,
1517
- dropout=self.attention_dropout if self.training else 0.0,
1518
- scaling=self.scaling,
1519
- sliding_window=self.sliding_window,
1520
- softcap=self.attn_logit_softcapping,
1521
- **kwargs,
1522
- )
1523
-
1524
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
1525
- attn_output = self.o_proj(attn_output)
1526
- return attn_output, attn_weights
1527
-
1528
-
1529
- class Gemma3nDecoderLayer(GradientCheckpointingLayer):
1530
- def __init__(self, config: Gemma3nConfig, layer_idx: int):
1531
- super().__init__()
1532
- self.hidden_size = config.hidden_size
1533
- self.config = config
1534
- self.attention_type = config.layer_types[layer_idx]
1535
- self.self_attn = Gemma3nAttention(config=config, layer_idx=layer_idx)
1536
- self.mlp = Gemma3nMLP(config)
1537
- self.input_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1538
- self.post_attention_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1539
-
1540
- self.pre_feedforward_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1541
- self.post_feedforward_layernorm = Gemma3nRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1542
-
1543
- def forward(
1544
- self,
1545
- hidden_states: torch.Tensor,
1546
- position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
1547
- attention_mask: Optional[torch.Tensor] = None,
1548
- position_ids: Optional[torch.LongTensor] = None,
1549
- past_key_values: Optional[Cache] = None,
1550
- cache_position: Optional[torch.LongTensor] = None,
1551
- **kwargs,
1552
- ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
1553
- residual = hidden_states
1554
-
1555
- hidden_states = self.input_layernorm(hidden_states)
1556
-
1557
- # Self Attention
1558
- hidden_states, _ = self.self_attn(
1559
- hidden_states=hidden_states,
1560
- position_embeddings=position_embeddings,
1561
- attention_mask=attention_mask,
1562
- position_ids=position_ids,
1563
- past_key_values=past_key_values,
1564
- cache_position=cache_position,
1565
- **kwargs,
1566
- )
1567
- hidden_states = self.post_attention_layernorm(hidden_states)
1568
- hidden_states = residual + hidden_states
1569
-
1570
- residual = hidden_states
1571
- hidden_states = self.pre_feedforward_layernorm(hidden_states)
1572
- hidden_states = self.mlp(hidden_states)
1573
- hidden_states = self.post_feedforward_layernorm(hidden_states)
1574
- hidden_states = residual + hidden_states
1575
-
1576
- return hidden_states
1423
+ return corrected_predictions
1577
1424
 
1578
1425
 
1579
1426
  @auto_docstring
@@ -1590,8 +1437,8 @@ class Gemma3nPreTrainedModel(PreTrainedModel):
1590
1437
  _can_compile_fullgraph = True
1591
1438
  _supports_attention_backend = True
1592
1439
  _can_record_outputs = {
1593
- "hidden_states": Gemma3nDecoderLayer,
1594
- "attentions": Gemma3nAttention,
1440
+ "hidden_states": Gemma3nTextDecoderLayer,
1441
+ "attentions": Gemma3nTextAttention,
1595
1442
  }
1596
1443
  input_modalities = ("image", "text", "audio")
1597
1444
 
@@ -1678,7 +1525,7 @@ class Gemma3nRotaryEmbedding(nn.Module):
1678
1525
  position_ids_expanded = position_ids[:, None, :].float()
1679
1526
 
1680
1527
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1681
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
1528
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
1682
1529
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1683
1530
  emb = torch.cat((freqs, freqs), dim=-1)
1684
1531
  cos = emb.cos() * attention_scaling
@@ -1741,7 +1588,7 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1741
1588
  # Initialize weights and apply final processing
1742
1589
  self.post_init()
1743
1590
 
1744
- @can_return_tuple
1591
+ @check_model_inputs(tie_last_hidden_states=False)
1745
1592
  @auto_docstring
1746
1593
  def forward(
1747
1594
  self,
@@ -1752,8 +1599,6 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1752
1599
  past_key_values: Optional[Cache] = None,
1753
1600
  inputs_embeds: Optional[torch.FloatTensor] = None,
1754
1601
  use_cache: Optional[bool] = None,
1755
- output_attentions: Optional[bool] = None,
1756
- output_hidden_states: Optional[bool] = None,
1757
1602
  cache_position: Optional[torch.LongTensor] = None,
1758
1603
  **kwargs: Unpack[TransformersKwargs],
1759
1604
  ) -> BaseModelOutputWithPast:
@@ -1761,37 +1606,21 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1761
1606
  per_layer_inputs (torch.Tensor, *optional*, defaults to None):
1762
1607
  Pre-computed per-layer embeddings. If None, they are derived from input_ids if provided.
1763
1608
  """
1764
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1765
- output_hidden_states = (
1766
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1767
- )
1768
- use_cache = use_cache if use_cache is not None else self.config.use_cache
1769
-
1770
1609
  if (input_ids is None) ^ (inputs_embeds is not None):
1771
1610
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1772
1611
 
1773
- if self.gradient_checkpointing and self.training and use_cache:
1774
- logger.warning_once(
1775
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
1776
- )
1777
- use_cache = False
1778
-
1779
1612
  if input_ids is not None:
1780
1613
  inputs_embeds = self.embed_tokens(input_ids)
1781
1614
  per_layer_inputs = self.get_per_layer_inputs(input_ids)
1782
1615
 
1783
1616
  per_layer_inputs = self.project_per_layer_inputs(inputs_embeds, per_layer_inputs)
1784
1617
 
1785
- if use_cache and past_key_values is None and not self.training:
1618
+ if use_cache and past_key_values is None:
1786
1619
  past_key_values = DynamicCache(config=self.config)
1787
1620
 
1788
1621
  if cache_position is None:
1789
1622
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1790
- cache_position = torch.arange(
1791
- past_seen_tokens,
1792
- past_seen_tokens + inputs_embeds.shape[1],
1793
- device=inputs_embeds.device,
1794
- )
1623
+ cache_position = torch.arange(inputs_embeds.shape[1], device=inputs_embeds.device) + past_seen_tokens
1795
1624
 
1796
1625
  if position_ids is None:
1797
1626
  position_ids = cache_position.unsqueeze(0)
@@ -1835,39 +1664,21 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1835
1664
  for layer_type in self.config.layer_types:
1836
1665
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
1837
1666
 
1838
- # decoder layers
1839
- all_hidden_states = () if output_hidden_states else None
1840
- all_self_attns = () if output_attentions else None
1841
-
1842
1667
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
1843
- if output_hidden_states:
1844
- all_hidden_states += (hidden_states,)
1845
-
1846
1668
  causal_mask = causal_mask_mapping[decoder_layer.attention_type]
1847
1669
  per_layer_input = per_layer_inputs[:, :, decoder_layer.layer_idx, :]
1848
1670
 
1849
- layer_outputs = decoder_layer(
1671
+ hidden_states = decoder_layer(
1850
1672
  hidden_states,
1851
1673
  position_embeddings[decoder_layer.attention_type],
1852
1674
  per_layer_input,
1853
1675
  attention_mask=causal_mask,
1854
1676
  position_ids=position_ids,
1855
1677
  past_key_values=past_key_values,
1856
- output_attentions=output_attentions,
1857
- use_cache=use_cache,
1858
1678
  cache_position=cache_position,
1859
1679
  **kwargs,
1860
1680
  )
1861
1681
 
1862
- hidden_states = layer_outputs[0]
1863
-
1864
- if output_attentions:
1865
- all_self_attns += (layer_outputs[1],)
1866
-
1867
- # add hidden states from the last decoder layer (but before reprojecting to stay consistent with layer output)
1868
- if output_hidden_states:
1869
- all_hidden_states += (hidden_states,)
1870
-
1871
1682
  # Per-layer inputs to single output
1872
1683
  target_magnitude = torch.mean(hidden_states[0] ** 2, dim=-1, keepdim=True) ** 0.5
1873
1684
  temp_hidden_states = [hidden_states[0]]
@@ -1887,8 +1698,6 @@ class Gemma3nTextModel(Gemma3nPreTrainedModel):
1887
1698
  return BaseModelOutputWithPast(
1888
1699
  last_hidden_state=hidden_states,
1889
1700
  past_key_values=past_key_values,
1890
- hidden_states=all_hidden_states,
1891
- attentions=all_self_attns,
1892
1701
  )
1893
1702
 
1894
1703
  def get_per_layer_inputs(self, input_ids: torch.LongTensor) -> torch.Tensor:
@@ -2175,7 +1984,7 @@ class Gemma3nModel(Gemma3nPreTrainedModel):
2175
1984
  use_cache: Optional[bool] = None,
2176
1985
  output_attentions: Optional[bool] = None,
2177
1986
  output_hidden_states: Optional[bool] = None,
2178
- **lm_kwargs,
1987
+ **lm_kwargs: Unpack[TransformersKwargs],
2179
1988
  ) -> Gemma3nCausalLMOutputWithPast:
2180
1989
  r"""
2181
1990
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -2363,7 +2172,7 @@ class Gemma3nForConditionalGeneration(Gemma3nPreTrainedModel, GenerationMixin):
2363
2172
  output_attentions: Optional[bool] = None,
2364
2173
  output_hidden_states: Optional[bool] = None,
2365
2174
  logits_to_keep: Union[int, torch.Tensor] = 0,
2366
- **lm_kwargs,
2175
+ **lm_kwargs: Unpack[TransformersKwargs],
2367
2176
  ) -> Gemma3nCausalLMOutputWithPast:
2368
2177
  r"""
2369
2178
  input_features_mask (torch.Tensor, *optional*, defaults to None):