transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -14,7 +14,7 @@
14
14
  # limitations under the License.
15
15
  """Tokenization class for Funnel Transformer."""
16
16
 
17
- from typing import Optional
17
+ from typing import Optional, Union
18
18
 
19
19
  from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers, processors
20
20
  from tokenizers.models import WordPiece
@@ -83,16 +83,17 @@ class FunnelTokenizer(TokenizersBackend):
83
83
  value for `lowercase` (as in the original BERT).
84
84
  wordpieces_prefix (`str`, *optional*, defaults to `"##"`):
85
85
  The prefix for subwords.
86
- vocab (`dict`, *optional*):
86
+ vocab (`str` or `dict[str, int]`, *optional*):
87
87
  Custom vocabulary dictionary.
88
88
  """
89
89
 
90
90
  vocab_files_names = VOCAB_FILES_NAMES
91
- slow_tokenizer_class = None
91
+ model = WordPiece
92
92
  cls_token_type_id: int = 2
93
93
 
94
94
  def __init__(
95
95
  self,
96
+ vocab: Optional[Union[str, dict[str, int]]] = None,
96
97
  do_lower_case: bool = True,
97
98
  unk_token: str = "<unk>",
98
99
  sep_token: str = "<sep>",
@@ -105,23 +106,18 @@ class FunnelTokenizer(TokenizersBackend):
105
106
  tokenize_chinese_chars: bool = True,
106
107
  strip_accents: Optional[bool] = None,
107
108
  wordpieces_prefix: str = "##",
108
- vocab: Optional[dict] = None,
109
- vocab_file: Optional[str] = None,
110
109
  **kwargs,
111
110
  ):
112
- self.vocab_file = vocab_file
113
111
  self.do_lower_case = do_lower_case
114
112
  self.tokenize_chinese_chars = tokenize_chinese_chars
115
113
  self.strip_accents = strip_accents
116
114
  self.clean_text = clean_text
117
115
  self.wordpieces_prefix = wordpieces_prefix
118
116
 
119
- if vocab is not None:
120
- self._vocab = (
121
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
122
- )
123
- else:
124
- self._vocab = {
117
+ self._vocab = (
118
+ vocab
119
+ if vocab is not None
120
+ else {
125
121
  str(pad_token): 0,
126
122
  str(unk_token): 1,
127
123
  str(cls_token): 2,
@@ -130,6 +126,7 @@ class FunnelTokenizer(TokenizersBackend):
130
126
  str(bos_token): 5,
131
127
  str(eos_token): 6,
132
128
  }
129
+ )
133
130
 
134
131
  self._tokenizer = Tokenizer(WordPiece(self._vocab, unk_token=str(unk_token)))
135
132
 
@@ -142,19 +139,7 @@ class FunnelTokenizer(TokenizersBackend):
142
139
  self._tokenizer.pre_tokenizer = pre_tokenizers.BertPreTokenizer()
143
140
  self._tokenizer.decoder = decoders.WordPiece(prefix=wordpieces_prefix)
144
141
 
145
- self._tokenizer.post_processor = processors.TemplateProcessing(
146
- single=f"{cls_token}:2 $A:0 {sep_token}:0", # token_type_id is 2 for Funnel transformer
147
- pair=f"{cls_token}:2 $A:0 {sep_token}:0 $B:1 {sep_token}:1",
148
- special_tokens=[
149
- (str(cls_token), self._vocab.get(str(cls_token), 2)),
150
- (str(sep_token), self._vocab.get(str(sep_token), 3)),
151
- ],
152
- )
153
-
154
- tokenizer_object = self._tokenizer
155
-
156
142
  super().__init__(
157
- tokenizer_object=tokenizer_object,
158
143
  do_lower_case=do_lower_case,
159
144
  unk_token=unk_token,
160
145
  sep_token=sep_token,
@@ -169,6 +154,14 @@ class FunnelTokenizer(TokenizersBackend):
169
154
  wordpieces_prefix=wordpieces_prefix,
170
155
  **kwargs,
171
156
  )
157
+ self._tokenizer.post_processor = processors.TemplateProcessing(
158
+ single=f"{cls_token}:2 $A:0 {sep_token}:0", # token_type_id is 2 for Funnel transformer
159
+ pair=f"{cls_token}:2 $A:0 {sep_token}:0 $B:1 {sep_token}:1",
160
+ special_tokens=[
161
+ (str(cls_token), self.cls_token_id),
162
+ (str(sep_token), self.sep_token_id),
163
+ ],
164
+ )
172
165
 
173
166
 
174
167
  __all__ = ["FunnelTokenizer"]
@@ -337,13 +337,13 @@ class FuyuProcessor(ProcessorMixin):
337
337
  r"""
338
338
  Constructs a Fuyu processor which wraps a Fuyu image processor and a Llama tokenizer into a single processor.
339
339
 
340
- [`FuyuProcessor`] offers all the functionalities of [`FuyuImageProcessor`] and [`LlamaTokenizerFast`]. See the
340
+ [`FuyuProcessor`] offers all the functionalities of [`FuyuImageProcessor`] and [`TokenizersBackend`]. See the
341
341
  [`~FuyuProcessor.__call__`] and [`~FuyuProcessor.decode`] for more information.
342
342
 
343
343
  Args:
344
344
  image_processor ([`FuyuImageProcessor`]):
345
345
  The image processor is a required input.
346
- tokenizer ([`LlamaTokenizerFast`]):
346
+ tokenizer ([`TokenizersBackend`]):
347
347
  The tokenizer is a required input.
348
348
  """
349
349
 
@@ -486,7 +486,7 @@ class FuyuProcessor(ProcessorMixin):
486
486
  ) -> "FuyuBatchFeature":
487
487
  """
488
488
  Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
489
- and `kwargs` arguments to LlamaTokenizerFast's [`~LlamaTokenizerFast.__call__`] if `text` is not `None` to
489
+ and `kwargs` arguments to TokenizersBackend's [`~TokenizersBackend.__call__`] if `text` is not `None` to
490
490
  encode the text. To prepare the image(s), this method forwards the `images` and `kwargs` arguments to
491
491
  FuyuImageProcessor's [`~FuyuImageProcessor.__call__`] if `images` is not `None`. Please refer to the docstring
492
492
  of the above two methods for more information.
@@ -29,7 +29,7 @@ from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
- from ...integrations import use_kernel_func_from_hub
32
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
33
33
  from ...masking_utils import create_causal_mask
34
34
  from ...modeling_layers import (
35
35
  GenericForSequenceClassification,
@@ -41,7 +41,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
41
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
42
  from ...processing_utils import Unpack
43
43
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
44
- from ...utils.generic import check_model_inputs
44
+ from ...utils.generic import check_model_inputs, maybe_autocast
45
45
  from .configuration_gemma import GemmaConfig
46
46
 
47
47
 
@@ -137,7 +137,7 @@ class GemmaRotaryEmbedding(nn.Module):
137
137
  position_ids_expanded = position_ids[:, None, :].float()
138
138
 
139
139
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
140
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
140
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
141
141
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
142
142
  emb = torch.cat((freqs, freqs), dim=-1)
143
143
  cos = emb.cos() * self.attention_scaling
@@ -219,6 +219,7 @@ def eager_attention_forward(
219
219
  return attn_output, attn_weights
220
220
 
221
221
 
222
+ @use_kernelized_func(apply_rotary_pos_emb)
222
223
  class GemmaAttention(nn.Module):
223
224
  """Multi-headed attention from 'Attention Is All You Need' paper"""
224
225
 
@@ -244,7 +245,6 @@ class GemmaAttention(nn.Module):
244
245
  self.o_proj = nn.Linear(
245
246
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
246
247
  )
247
- self.rotary_fn = apply_rotary_pos_emb
248
248
 
249
249
  def forward(
250
250
  self,
@@ -12,12 +12,11 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- from typing import Optional
15
+ from typing import Optional, Union
16
16
 
17
17
  from tokenizers import Tokenizer, decoders, normalizers, pre_tokenizers
18
18
  from tokenizers.models import BPE
19
19
 
20
- from ...tokenization_utils_base import generate_merges
21
20
  from ...tokenization_utils_tokenizers import TokenizersBackend
22
21
  from ...utils import logging
23
22
 
@@ -30,7 +29,7 @@ class GemmaTokenizer(TokenizersBackend):
30
29
  """
31
30
  Construct a fast Gemma tokenizer (backed by HuggingFace's tokenizers library).
32
31
 
33
- This tokenizer uses a Unigram model with ByteFallback, no prefix space, and a normalizer that replaces
32
+ This tokenizer uses a BPE model with byte fallback, no prefix space, and a normalizer that replaces
34
33
  spaces with "▁".
35
34
 
36
35
  Args:
@@ -50,48 +49,37 @@ class GemmaTokenizer(TokenizersBackend):
50
49
  Whether or not to add a `bos_token` at the start of sequences.
51
50
  add_eos_token (`bool`, optional, defaults to False):
52
51
  Whether or not to add an `eos_token` at the end of sequences.
53
- vocab (`dict`, optional):
52
+ vocab (`str` or `dict[str, int]`, optional):
54
53
  Custom vocabulary dict. If not provided, a minimal vocabulary is created using the special tokens.
55
54
  """
56
55
 
57
56
  vocab_files_names = VOCAB_FILES_NAMES
58
- slow_tokenizer_class = None
59
57
  padding_side = "left"
60
58
  model_input_names = ["input_ids", "attention_mask"]
59
+ model = BPE
61
60
 
62
61
  def __init__(
63
62
  self,
63
+ vocab: Optional[Union[str, dict[str, int]]] = None,
64
+ merges: Optional[Union[str, list[str]]] = None,
64
65
  unk_token: str = "<unk>",
65
66
  bos_token: str = "<bos>",
66
67
  eos_token: str = "<eos>",
67
68
  pad_token: str = "<pad>",
68
69
  mask_token: str = "<mask>",
69
- add_bos_token: bool = True,
70
- add_eos_token: bool = False,
71
- vocab: Optional[dict] = None,
72
- merges: Optional[list[tuple[str, str]]] = None,
73
70
  **kwargs,
74
71
  ):
75
- self._add_bos_token = add_bos_token
76
- self._add_eos_token = add_eos_token
77
-
78
- special_tokens = {str(pad_token), str(eos_token), str(bos_token), str(unk_token)}
79
-
80
- if vocab is not None:
81
- self._vocab = (
82
- {token: idx for idx, (token, _score) in enumerate(vocab)} if isinstance(vocab, list) else vocab
83
- )
84
- else:
85
- self._vocab = {
72
+ if vocab is None:
73
+ vocab = {
86
74
  str(pad_token): 0,
87
75
  str(eos_token): 1,
88
76
  str(bos_token): 2,
89
77
  str(unk_token): 3,
90
78
  str(mask_token): 4,
91
79
  }
80
+ self._vocab = vocab
81
+ self._merges = merges or []
92
82
 
93
- filtered_vocab = {t: i for t, i in self._vocab.items() if t not in special_tokens}
94
- self._merges = merges if merges is not None else generate_merges(filtered_vocab)
95
83
  self._tokenizer = Tokenizer(
96
84
  BPE(
97
85
  vocab=self._vocab,
@@ -108,17 +96,12 @@ class GemmaTokenizer(TokenizersBackend):
108
96
  )
109
97
  self._tokenizer.normalizer = normalizers.Replace(" ", "▁")
110
98
  self._tokenizer.pre_tokenizer = pre_tokenizers.Split(" ", "merged_with_previous")
111
- tokenizer_object = self._tokenizer
112
-
113
99
  super().__init__(
114
- tokenizer_object=tokenizer_object,
115
100
  unk_token=unk_token,
116
101
  bos_token=bos_token,
117
102
  eos_token=eos_token,
118
103
  pad_token=pad_token,
119
104
  mask_token=mask_token,
120
- add_bos_token=add_bos_token,
121
- add_eos_token=add_eos_token,
122
105
  **kwargs,
123
106
  )
124
107
 
@@ -29,7 +29,7 @@ from ... import initialization as init
29
29
  from ...activations import ACT2FN
30
30
  from ...cache_utils import Cache, DynamicCache
31
31
  from ...generation import GenerationMixin
32
- from ...integrations import use_kernel_func_from_hub
32
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
33
33
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
34
34
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
35
35
  from ...modeling_layers import (
@@ -42,7 +42,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
42
42
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
43
43
  from ...processing_utils import Unpack
44
44
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
45
- from ...utils.generic import check_model_inputs
45
+ from ...utils.generic import check_model_inputs, maybe_autocast
46
46
  from .configuration_gemma2 import Gemma2Config
47
47
 
48
48
 
@@ -138,7 +138,7 @@ class Gemma2RotaryEmbedding(nn.Module):
138
138
  position_ids_expanded = position_ids[:, None, :].float()
139
139
 
140
140
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
141
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
141
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
142
142
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
143
143
  emb = torch.cat((freqs, freqs), dim=-1)
144
144
  cos = emb.cos() * self.attention_scaling
@@ -229,6 +229,7 @@ def eager_attention_forward(
229
229
  return attn_output, attn_weights
230
230
 
231
231
 
232
+ @use_kernelized_func(apply_rotary_pos_emb)
232
233
  class Gemma2Attention(nn.Module):
233
234
  """Multi-headed attention from 'Attention Is All You Need' paper"""
234
235
 
@@ -255,7 +256,6 @@ class Gemma2Attention(nn.Module):
255
256
  self.o_proj = nn.Linear(
256
257
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
257
258
  )
258
- self.rotary_fn = apply_rotary_pos_emb
259
259
  self.attn_logit_softcapping = self.config.attn_logit_softcapping
260
260
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
261
261
 
@@ -34,6 +34,7 @@ from ...modeling_rope_utils import (
34
34
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
35
35
  from ...processing_utils import Unpack
36
36
  from ...utils import TransformersKwargs, logging
37
+ from ...utils.generic import maybe_autocast
37
38
  from ..gemma.modeling_gemma import (
38
39
  GemmaAttention,
39
40
  GemmaForCausalLM,
@@ -252,7 +253,7 @@ class Gemma2RotaryEmbedding(GemmaRotaryEmbedding):
252
253
  position_ids_expanded = position_ids[:, None, :].float()
253
254
 
254
255
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
255
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
256
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
256
257
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
257
258
  emb = torch.cat((freqs, freqs), dim=-1)
258
259
  cos = emb.cos() * self.attention_scaling
@@ -31,16 +31,15 @@ from ...activations import ACT2FN
31
31
  from ...cache_utils import Cache, DynamicCache
32
32
  from ...configuration_utils import PreTrainedConfig
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_func_from_hub
34
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
35
35
  from ...masking_utils import create_causal_mask, create_masks_for_generate, create_sliding_window_causal_mask
36
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
37
36
  from ...modeling_layers import GenericForSequenceClassification, GradientCheckpointingLayer
38
37
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
39
38
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
40
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
41
40
  from ...processing_utils import Unpack
42
41
  from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple, logging
43
- from ...utils.generic import check_model_inputs
42
+ from ...utils.generic import check_model_inputs, maybe_autocast
44
43
  from ..auto import AutoModel
45
44
  from .configuration_gemma3 import Gemma3Config, Gemma3TextConfig
46
45
 
@@ -215,7 +214,7 @@ class Gemma3RotaryEmbedding(nn.Module):
215
214
  position_ids_expanded = position_ids[:, None, :].float()
216
215
 
217
216
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
218
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
217
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
219
218
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
220
219
  emb = torch.cat((freqs, freqs), dim=-1)
221
220
  cos = emb.cos() * attention_scaling
@@ -306,6 +305,7 @@ def eager_attention_forward(
306
305
  return attn_output, attn_weights
307
306
 
308
307
 
308
+ @use_kernelized_func(apply_rotary_pos_emb)
309
309
  class Gemma3Attention(nn.Module):
310
310
  """Multi-headed attention from 'Attention Is All You Need' paper"""
311
311
 
@@ -332,7 +332,6 @@ class Gemma3Attention(nn.Module):
332
332
  self.o_proj = nn.Linear(
333
333
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
334
334
  )
335
- self.rotary_fn = apply_rotary_pos_emb
336
335
  self.attn_logit_softcapping = self.config.attn_logit_softcapping
337
336
  self.sliding_window = config.sliding_window if self.layer_type == "sliding_attention" else None
338
337
  self.is_sliding = self.layer_type == "sliding_attention"
@@ -347,7 +346,7 @@ class Gemma3Attention(nn.Module):
347
346
  attention_mask: Optional[torch.Tensor] = None,
348
347
  past_key_values: Optional[Cache] = None,
349
348
  cache_position: Optional[torch.LongTensor] = None,
350
- **kwargs: Unpack[FlashAttentionKwargs],
349
+ **kwargs: Unpack[TransformersKwargs],
351
350
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
352
351
  input_shape = hidden_states.shape[:-1]
353
352
  hidden_shape = (*input_shape, -1, self.head_dim)
@@ -409,23 +408,19 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
409
408
  attention_mask: Optional[torch.Tensor] = None,
410
409
  position_ids: Optional[torch.LongTensor] = None,
411
410
  past_key_values: Optional[Cache] = None,
412
- output_attentions: Optional[bool] = False,
413
- use_cache: Optional[bool] = False,
414
411
  cache_position: Optional[torch.LongTensor] = None,
415
- **kwargs,
412
+ **kwargs: Unpack[TransformersKwargs],
416
413
  ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
417
414
  residual = hidden_states
418
415
 
419
416
  hidden_states = self.input_layernorm(hidden_states)
420
417
 
421
- hidden_states, self_attn_weights = self.self_attn(
418
+ hidden_states, _ = self.self_attn(
422
419
  hidden_states=hidden_states,
423
420
  position_embeddings=position_embeddings,
424
421
  attention_mask=attention_mask,
425
422
  position_ids=position_ids,
426
423
  past_key_values=past_key_values,
427
- output_attentions=output_attentions,
428
- use_cache=use_cache,
429
424
  cache_position=cache_position,
430
425
  **kwargs,
431
426
  )
@@ -438,12 +433,7 @@ class Gemma3DecoderLayer(GradientCheckpointingLayer):
438
433
  hidden_states = self.post_feedforward_layernorm(hidden_states)
439
434
  hidden_states = residual + hidden_states
440
435
 
441
- outputs = (hidden_states,)
442
-
443
- if output_attentions:
444
- outputs += (self_attn_weights,)
445
-
446
- return outputs
436
+ return hidden_states
447
437
 
448
438
 
449
439
  @auto_docstring
@@ -527,30 +517,16 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
527
517
  past_key_values: Optional[Cache] = None,
528
518
  inputs_embeds: Optional[torch.FloatTensor] = None,
529
519
  use_cache: Optional[bool] = None,
530
- output_attentions: Optional[bool] = None,
531
- output_hidden_states: Optional[bool] = None,
532
520
  cache_position: Optional[torch.LongTensor] = None,
533
521
  **kwargs: Unpack[TransformersKwargs],
534
522
  ) -> BaseModelOutputWithPast:
535
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
536
- output_hidden_states = (
537
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
538
- )
539
- use_cache = use_cache if use_cache is not None else self.config.use_cache
540
-
541
523
  if (input_ids is None) ^ (inputs_embeds is not None):
542
524
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
543
525
 
544
- if self.gradient_checkpointing and self.training and use_cache:
545
- logger.warning_once(
546
- "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`."
547
- )
548
- use_cache = False
549
-
550
526
  if inputs_embeds is None:
551
527
  inputs_embeds = self.embed_tokens(input_ids)
552
528
 
553
- if use_cache and past_key_values is None and not self.training:
529
+ if use_cache and past_key_values is None:
554
530
  past_key_values = DynamicCache(config=self.config)
555
531
 
556
532
  if cache_position is None:
@@ -591,41 +567,22 @@ class Gemma3TextModel(Gemma3PreTrainedModel):
591
567
  for layer_type in self.config.layer_types:
592
568
  position_embeddings[layer_type] = self.rotary_emb(hidden_states, position_ids, layer_type)
593
569
 
594
- # decoder layers
595
- all_hidden_states = () if output_hidden_states else None
596
- all_self_attns = () if output_attentions else None
597
-
598
570
  for decoder_layer in self.layers[: self.config.num_hidden_layers]:
599
- if output_hidden_states:
600
- all_hidden_states += (hidden_states,)
601
-
602
- layer_outputs = decoder_layer(
571
+ hidden_states = decoder_layer(
603
572
  hidden_states,
604
573
  attention_mask=causal_mask_mapping[decoder_layer.attention_type],
605
574
  position_embeddings=position_embeddings[decoder_layer.attention_type],
606
575
  position_ids=position_ids,
607
576
  past_key_values=past_key_values,
608
- output_attentions=output_attentions,
609
- use_cache=use_cache,
610
577
  cache_position=cache_position,
611
578
  **kwargs,
612
579
  )
613
580
 
614
- hidden_states = layer_outputs[0]
615
-
616
- if output_attentions:
617
- all_self_attns += (layer_outputs[1],)
618
-
619
581
  hidden_states = self.norm(hidden_states)
620
582
 
621
- if output_hidden_states:
622
- all_hidden_states += (hidden_states,)
623
-
624
583
  return BaseModelOutputWithPast(
625
584
  last_hidden_state=hidden_states,
626
585
  past_key_values=past_key_values,
627
- hidden_states=all_hidden_states,
628
- attentions=all_self_attns,
629
586
  )
630
587
 
631
588
 
@@ -918,10 +875,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
918
875
  inputs_embeds: Optional[torch.FloatTensor] = None,
919
876
  labels: Optional[torch.LongTensor] = None,
920
877
  use_cache: Optional[bool] = None,
921
- output_attentions: Optional[bool] = None,
922
- output_hidden_states: Optional[bool] = None,
923
- return_dict: Optional[bool] = None,
924
- **lm_kwargs,
878
+ **lm_kwargs: Unpack[TransformersKwargs],
925
879
  ) -> Union[tuple, Gemma3ModelOutputWithPast]:
926
880
  r"""
927
881
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -953,12 +907,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
953
907
  if (input_ids is None) ^ (inputs_embeds is not None):
954
908
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
955
909
 
956
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
957
- output_hidden_states = (
958
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
959
- )
960
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
961
-
962
910
  # Replace image id with PAD if the image token if OOV, to avoid index-errors
963
911
  if input_ids is not None and self.config.image_token_id >= self.vocab_size:
964
912
  special_image_mask = input_ids == self.config.image_token_id
@@ -1005,8 +953,6 @@ class Gemma3Model(Gemma3PreTrainedModel):
1005
953
  past_key_values=past_key_values,
1006
954
  inputs_embeds=inputs_embeds,
1007
955
  use_cache=use_cache,
1008
- output_attentions=output_attentions,
1009
- output_hidden_states=output_hidden_states,
1010
956
  return_dict=True,
1011
957
  cache_position=cache_position,
1012
958
  **lm_kwargs,
@@ -1014,7 +960,7 @@ class Gemma3Model(Gemma3PreTrainedModel):
1014
960
 
1015
961
  return Gemma3ModelOutputWithPast(
1016
962
  last_hidden_state=outputs.last_hidden_state,
1017
- past_key_values=outputs.past_key_values if use_cache else None,
963
+ past_key_values=outputs.past_key_values,
1018
964
  hidden_states=outputs.hidden_states,
1019
965
  attentions=outputs.attentions,
1020
966
  image_hidden_states=image_features if pixel_values is not None else None,
@@ -1053,6 +999,7 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1053
999
  def get_image_features(self, pixel_values):
1054
1000
  return self.model.get_image_features(pixel_values)
1055
1001
 
1002
+ @can_return_tuple
1056
1003
  @auto_docstring
1057
1004
  def forward(
1058
1005
  self,
@@ -1066,11 +1013,8 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1066
1013
  inputs_embeds: Optional[torch.FloatTensor] = None,
1067
1014
  labels: Optional[torch.LongTensor] = None,
1068
1015
  use_cache: Optional[bool] = None,
1069
- output_attentions: Optional[bool] = None,
1070
- output_hidden_states: Optional[bool] = None,
1071
- return_dict: Optional[bool] = None,
1072
1016
  logits_to_keep: Union[int, torch.Tensor] = 0,
1073
- **lm_kwargs,
1017
+ **lm_kwargs: Unpack[TransformersKwargs],
1074
1018
  ) -> Union[tuple, Gemma3CausalLMOutputWithPast]:
1075
1019
  r"""
1076
1020
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -1116,13 +1060,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1116
1060
  "user\nYou are a helpful assistant.\n\n\n\n\n\nWhere is the cat standing?\nmodel\nBased on the image, the cat is standing in a snowy area, likely outdoors. It appears to"
1117
1061
  ```
1118
1062
  """
1119
-
1120
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1121
- output_hidden_states = (
1122
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1123
- )
1124
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1125
-
1126
1063
  outputs = self.model(
1127
1064
  input_ids=input_ids,
1128
1065
  pixel_values=pixel_values,
@@ -1133,9 +1070,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1133
1070
  inputs_embeds=inputs_embeds,
1134
1071
  use_cache=use_cache,
1135
1072
  labels=labels,
1136
- output_attentions=output_attentions,
1137
- output_hidden_states=output_hidden_states,
1138
- return_dict=return_dict,
1139
1073
  cache_position=cache_position,
1140
1074
  **lm_kwargs,
1141
1075
  )
@@ -1167,10 +1101,6 @@ class Gemma3ForConditionalGeneration(Gemma3PreTrainedModel, GenerationMixin):
1167
1101
  flat_labels = shift_labels.view(-1).to(shift_logits.device)
1168
1102
  loss = loss_fct(flat_logits, flat_labels)
1169
1103
 
1170
- if not return_dict:
1171
- output = (logits,) + outputs[1:]
1172
- return (loss,) + output if loss is not None else output
1173
-
1174
1104
  return Gemma3CausalLMOutputWithPast(
1175
1105
  loss=loss,
1176
1106
  logits=logits,