transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,135 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/paddleocr_vl/modular_paddleocr_vl.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_paddleocr_vl.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2025 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
10
+ # and OPT implementations in this library. It has been modified from its
11
+ # original forms to accommodate minor architectural differences compared
12
+ # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
13
+ #
14
+ # Licensed under the Apache License, Version 2.0 (the "License");
15
+ # you may not use this file except in compliance with the License.
16
+ # You may obtain a copy of the License at
17
+ #
18
+ # http://www.apache.org/licenses/LICENSE-2.0
19
+ #
20
+ # Unless required by applicable law or agreed to in writing, software
21
+ # distributed under the License is distributed on an "AS IS" BASIS,
22
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23
+ # See the License for the specific language governing permissions and
24
+ # limitations under the License.
25
+
26
+ from typing import Union
27
+
28
+ from ...image_processing_utils import BatchFeature
29
+ from ...image_utils import ImageInput
30
+ from ...processing_utils import ProcessingKwargs, ProcessorMixin, Unpack
31
+ from ...tokenization_utils_base import PreTokenizedInput, TextInput
32
+
33
+
34
+ class PaddleOCRVLProcessorKwargs(ProcessingKwargs, total=False):
35
+ _defaults = {
36
+ "text_kwargs": {
37
+ "padding": False,
38
+ },
39
+ }
40
+
41
+
42
+ class PaddleOCRVLProcessor(ProcessorMixin):
43
+ r"""
44
+ [`PaddleOCRVLProcessor`] offers all the functionalities of [`PaddleOCRVLImageProcessor`] and [`LLamaTokenizerFast`]. See the
45
+ [`~PaddleOCRVLProcessor.__call__`] and [`~PaddleOCRVLProcessor.decode`] for more information.
46
+ Args:
47
+ image_processor ([`PaddleOCRVLImageProcessor`], *optional*):
48
+ The image processor is a required input.
49
+ tokenizer ([`LLamaTokenizerFast`], *optional*):
50
+ The tokenizer is a required input.
51
+ chat_template (`str`, *optional*): A Jinja template which will be used to convert lists of messages
52
+ in a chat into a tokenizable string.
53
+ """
54
+
55
+ image_processor_class = "AutoImageProcessor"
56
+ tokenizer_class = "AutoTokenizer"
57
+
58
+ def __init__(self, image_processor=None, tokenizer=None, chat_template=None, **kwargs):
59
+ self.image_token = tokenizer.image_token
60
+ super().__init__(image_processor, tokenizer, chat_template=chat_template)
61
+
62
+ def __call__(
63
+ self,
64
+ images: ImageInput = None,
65
+ text: Union[TextInput, PreTokenizedInput, list[TextInput], list[PreTokenizedInput]] = None,
66
+ **kwargs: Unpack[PaddleOCRVLProcessorKwargs],
67
+ ) -> BatchFeature:
68
+ """
69
+ Args:
70
+ images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
71
+ The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
72
+ tensor. Both channels-first and channels-last formats are supported.
73
+ text (`str`, `List[str]`, `List[List[str]]`):
74
+ The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
75
+ (pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
76
+ `is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
77
+ return_tensors (`str` or [`~utils.TensorType`], *optional*):
78
+ If set, will return tensors of a particular framework. Acceptable values are:
79
+ - `'tf'`: Return TensorFlow `tf.constant` objects.
80
+ - `'pt'`: Return PyTorch `torch.Tensor` objects.
81
+ - `'np'`: Return NumPy `np.ndarray` objects.
82
+ - `'jax'`: Return JAX `jnp.ndarray` objects.
83
+
84
+ Returns:
85
+ [`BatchFeature`]: A [`BatchFeature`] with the following fields:
86
+
87
+ - **input_ids** -- List of token ids to be fed to a model. Returned when `text` is not `None`.
88
+ - **attention_mask** -- List of indices specifying which tokens should be attended to by the model (when
89
+ `return_attention_mask=True` or if *"attention_mask"* is in `self.model_input_names` and if `text` is not
90
+ `None`).
91
+ - **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
92
+ - **image_grid_thw** -- List of image 3D grid in LLM. Returned when `images` is not `None`.
93
+ """
94
+ output_kwargs = self._merge_kwargs(
95
+ PaddleOCRVLProcessorKwargs,
96
+ tokenizer_init_kwargs=self.tokenizer.init_kwargs,
97
+ **kwargs,
98
+ )
99
+
100
+ if images is not None:
101
+ image_inputs = self.image_processor(images=images, **output_kwargs["images_kwargs"])
102
+ image_grid_thw = image_inputs["image_grid_thw"]
103
+
104
+ else:
105
+ image_inputs = {}
106
+ image_grid_thw = None
107
+
108
+ if not isinstance(text, list):
109
+ text = [text]
110
+
111
+ text = text.copy()
112
+
113
+ if image_grid_thw is not None:
114
+ index = 0
115
+ for i in range(len(text)):
116
+ while self.image_token in text[i]:
117
+ text[i] = text[i].replace(
118
+ self.image_token,
119
+ "<|placeholder|>"
120
+ * (
121
+ image_grid_thw[index].prod()
122
+ // self.image_processor.merge_size
123
+ // self.image_processor.merge_size
124
+ ),
125
+ 1,
126
+ )
127
+ index += 1
128
+ text[i] = text[i].replace("<|placeholder|>", self.image_token)
129
+
130
+ text_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
131
+
132
+ return BatchFeature(data={**text_inputs, **image_inputs})
133
+
134
+
135
+ __all__ = ["PaddleOCRVLProcessor"]
@@ -121,9 +121,6 @@ class ParakeetEncoderConfig(PreTrainedConfig):
121
121
  initializer_range=0.02,
122
122
  **kwargs,
123
123
  ):
124
- super().__init__(
125
- **kwargs,
126
- )
127
124
  self.hidden_size = hidden_size
128
125
  self.num_hidden_layers = num_hidden_layers
129
126
  self.num_attention_heads = num_attention_heads
@@ -133,10 +130,7 @@ class ParakeetEncoderConfig(PreTrainedConfig):
133
130
  self.attention_bias = attention_bias
134
131
  self.convolution_bias = convolution_bias
135
132
 
136
- if (conv_kernel_size - 1) % 2 != 0:
137
- raise ValueError(f"conv_kernel_size must be odd, got {conv_kernel_size}")
138
133
  self.conv_kernel_size = conv_kernel_size
139
-
140
134
  self.subsampling_conv_kernel_size = subsampling_conv_kernel_size
141
135
  self.subsampling_conv_stride = subsampling_conv_stride
142
136
 
@@ -153,6 +147,10 @@ class ParakeetEncoderConfig(PreTrainedConfig):
153
147
  self.scale_input = scale_input
154
148
  self.initializer_range = initializer_range
155
149
 
150
+ super().__init__(
151
+ **kwargs,
152
+ )
153
+
156
154
 
157
155
  class ParakeetCTCConfig(PreTrainedConfig):
158
156
  r"""
@@ -29,13 +29,13 @@ from torch import nn
29
29
 
30
30
  from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
- from ...integrations import use_kernel_func_from_hub
32
+ from ...integrations import use_kernel_func_from_hub, use_kernelized_func
33
33
  from ...modeling_layers import GradientCheckpointingLayer
34
34
  from ...modeling_outputs import BaseModelOutput, CausalLMOutput
35
35
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
36
  from ...processing_utils import Unpack
37
37
  from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
38
- from ...utils.generic import check_model_inputs
38
+ from ...utils.generic import check_model_inputs, maybe_autocast
39
39
  from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
40
40
 
41
41
 
@@ -88,7 +88,7 @@ class ParakeetEncoderRelPositionalEncoding(nn.Module):
88
88
  if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
89
89
  else "cpu"
90
90
  )
91
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
91
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
92
92
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
93
93
  sin = freqs.sin()
94
94
  cos = freqs.cos()
@@ -155,7 +155,7 @@ class ParakeetEncoderConvolutionModule(nn.Module):
155
155
 
156
156
  Args:
157
157
  hidden_states (`torch.Tensor` of shape `(batch, time, channels)`): Input tensor.
158
- attention_mask (`torch.Tensor` of shape `(batch, 1, time)`): Attention mask.
158
+ attention_mask (`torch.Tensor` of shape `(batch, 1, time, time)`): Attention mask.
159
159
 
160
160
  Returns:
161
161
  `torch.Tensor`: Output tensor of shape `(batch, time, channels)`.
@@ -171,7 +171,10 @@ class ParakeetEncoderConvolutionModule(nn.Module):
171
171
 
172
172
  # Apply padding mask before convolution
173
173
  if attention_mask is not None:
174
- all_masked_rows = torch.all(~attention_mask, dim=-1)
174
+ if attention_mask.dtype == torch.bool:
175
+ all_masked_rows = torch.all(~attention_mask, dim=2)
176
+ else:
177
+ all_masked_rows = torch.all(~(attention_mask == 0.0), dim=2)
175
178
  hidden_states = hidden_states.masked_fill(all_masked_rows, 0.0)
176
179
 
177
180
  # 1D Depthwise Conv
@@ -256,6 +259,7 @@ def eager_attention_forward(
256
259
  return attn_output, attn_weights
257
260
 
258
261
 
262
+ @use_kernelized_func(apply_rotary_pos_emb)
259
263
  class ParakeetEncoderAttention(nn.Module):
260
264
  """Multi-head attention with relative positional encoding. See section 3.3 of https://huggingface.co/papers/1901.02860."""
261
265
 
@@ -281,7 +285,6 @@ class ParakeetEncoderAttention(nn.Module):
281
285
  self.o_proj = nn.Linear(
282
286
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
283
287
  )
284
- self.rotary_fn = apply_rotary_pos_emb
285
288
  # W_{k,R} projection
286
289
  self.relative_k_proj = nn.Linear(config.hidden_size, config.num_attention_heads * self.head_dim, bias=False)
287
290
  # global content bias
@@ -29,7 +29,7 @@ from ...modeling_outputs import BaseModelOutput, CausalLMOutput
29
29
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
30
30
  from ...processing_utils import Unpack
31
31
  from ...utils import ModelOutput, TransformersKwargs, auto_docstring, can_return_tuple
32
- from ...utils.generic import check_model_inputs
32
+ from ...utils.generic import check_model_inputs, maybe_autocast
33
33
  from ..fastspeech2_conformer.modeling_fastspeech2_conformer import FastSpeech2ConformerConvolutionModule
34
34
  from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
35
35
  from .configuration_parakeet import ParakeetCTCConfig, ParakeetEncoderConfig
@@ -84,7 +84,7 @@ class ParakeetEncoderRelPositionalEncoding(nn.Module):
84
84
  if isinstance(hidden_states.device.type, str) and hidden_states.device.type != "mps"
85
85
  else "cpu"
86
86
  )
87
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
87
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
88
88
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
89
89
  sin = freqs.sin()
90
90
  cos = freqs.cos()
@@ -28,6 +28,7 @@ class ParakeetProcessorKwargs(ProcessingKwargs, total=False):
28
28
  "audio_kwargs": {
29
29
  "sampling_rate": 16000,
30
30
  "padding": "longest",
31
+ "return_attention_mask": True,
31
32
  },
32
33
  "text_kwargs": {
33
34
  "padding": True,
@@ -1141,6 +1141,7 @@ class PatchTSMixerEncoder(PatchTSMixerPreTrainedModel):
1141
1141
  past_values: torch.Tensor,
1142
1142
  output_hidden_states: Optional[bool] = False,
1143
1143
  return_dict: Optional[bool] = None,
1144
+ **kwargs,
1144
1145
  ) -> Union[tuple, PatchTSMixerEncoderOutput]:
1145
1146
  r"""
1146
1147
  past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
@@ -1251,6 +1252,7 @@ class PatchTSMixerModel(PatchTSMixerPreTrainedModel):
1251
1252
  observed_mask: Optional[torch.Tensor] = None,
1252
1253
  output_hidden_states: Optional[bool] = False,
1253
1254
  return_dict: Optional[bool] = None,
1255
+ **kwargs,
1254
1256
  ) -> PatchTSMixerModelOutput:
1255
1257
  r"""
1256
1258
  past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
@@ -1362,6 +1364,7 @@ class PatchTSMixerForPretraining(PatchTSMixerPreTrainedModel):
1362
1364
  output_hidden_states: Optional[bool] = False,
1363
1365
  return_loss: bool = True,
1364
1366
  return_dict: Optional[bool] = None,
1367
+ **kwargs,
1365
1368
  ) -> PatchTSMixerForPreTrainingOutput:
1366
1369
  r"""
1367
1370
  past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
@@ -1574,6 +1577,7 @@ class PatchTSMixerForPrediction(PatchTSMixerPreTrainedModel):
1574
1577
  output_hidden_states: Optional[bool] = False,
1575
1578
  return_loss: bool = True,
1576
1579
  return_dict: Optional[bool] = None,
1580
+ **kwargs,
1577
1581
  ) -> PatchTSMixerForPredictionOutput:
1578
1582
  r"""
1579
1583
  past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
@@ -1797,6 +1801,7 @@ class PatchTSMixerForTimeSeriesClassification(PatchTSMixerPreTrainedModel):
1797
1801
  output_hidden_states: Optional[bool] = False,
1798
1802
  return_loss: bool = True,
1799
1803
  return_dict: Optional[bool] = None,
1804
+ **kwargs,
1800
1805
  ) -> PatchTSMixerForTimeSeriesClassificationOutput:
1801
1806
  r"""
1802
1807
  past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
@@ -1987,6 +1992,7 @@ class PatchTSMixerForRegression(PatchTSMixerPreTrainedModel):
1987
1992
  output_hidden_states: Optional[bool] = False,
1988
1993
  return_loss: bool = True,
1989
1994
  return_dict: Optional[bool] = None,
1995
+ **kwargs,
1990
1996
  ) -> PatchTSMixerForRegressionOutput:
1991
1997
  r"""
1992
1998
  past_values (`torch.FloatTensor` of shape `(batch_size, seq_length, num_input_channels)`):
@@ -24,6 +24,7 @@ from torch import nn
24
24
 
25
25
  from ... import initialization as init
26
26
  from ...activations import ACT2CLS
27
+ from ...integrations.deepspeed import is_deepspeed_zero3_enabled
27
28
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
28
29
  from ...modeling_outputs import BaseModelOutput
29
30
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
@@ -418,7 +419,7 @@ class PatchTSTEncoderLayer(nn.Module):
418
419
  super().__init__()
419
420
 
420
421
  self.channel_attention = config.channel_attention
421
- # Multi-Head attention
422
+
422
423
  self.self_attn = PatchTSTAttention(
423
424
  embed_dim=config.d_model,
424
425
  num_heads=config.num_attention_heads,
@@ -555,6 +556,9 @@ class PatchTSTPreTrainedModel(PreTrainedModel):
555
556
  main_input_name = "past_values"
556
557
  input_modalities = ("time",)
557
558
  supports_gradient_checkpointing = False
559
+ _supports_flash_attn = True
560
+ _supports_sdpa = True
561
+ _supports_flex_attn = True
558
562
 
559
563
  @torch.no_grad()
560
564
  def _init_weights(self, module: nn.Module):
@@ -571,7 +575,15 @@ class PatchTSTPreTrainedModel(PreTrainedModel):
571
575
  init.normal_(module.cls_token, std=0.02)
572
576
  num_patches += 1
573
577
  # initialize positional encoding
574
- init.copy_(module.position_enc, module._init_pe(self.config, num_patches))
578
+ position_enc = module._init_pe(self.config, num_patches)
579
+ if is_deepspeed_zero3_enabled():
580
+ import deepspeed
581
+
582
+ with deepspeed.zero.GatheredParameters(module.position_enc, modifier_rank=None):
583
+ if module.position_enc.numel() > 0:
584
+ init.copy_(module.position_enc, position_enc)
585
+ else:
586
+ init.copy_(module.position_enc, position_enc)
575
587
  elif isinstance(module, nn.LayerNorm):
576
588
  init.zeros_(module.bias)
577
589
  init.ones_(module.weight)
@@ -704,6 +716,7 @@ class PatchTSTEncoder(PatchTSTPreTrainedModel):
704
716
  patch_input: torch.Tensor,
705
717
  output_hidden_states: Optional[bool] = None,
706
718
  output_attentions: Optional[bool] = None,
719
+ **kwargs,
707
720
  ) -> BaseModelOutput:
708
721
  """
709
722
  Parameters:
@@ -1092,6 +1105,7 @@ class PatchTSTModel(PatchTSTPreTrainedModel):
1092
1105
  output_hidden_states: Optional[bool] = None,
1093
1106
  output_attentions: Optional[bool] = None,
1094
1107
  return_dict: Optional[bool] = None,
1108
+ **kwargs,
1095
1109
  ) -> Union[tuple, PatchTSTModelOutput]:
1096
1110
  r"""
1097
1111
  Parameters:
@@ -1228,6 +1242,7 @@ class PatchTSTForPretraining(PatchTSTPreTrainedModel):
1228
1242
  output_hidden_states: Optional[bool] = None,
1229
1243
  output_attentions: Optional[bool] = None,
1230
1244
  return_dict: Optional[bool] = None,
1245
+ **kwargs,
1231
1246
  ) -> Union[tuple, PatchTSTForPretrainingOutput]:
1232
1247
  r"""
1233
1248
  Parameters:
@@ -1387,6 +1402,7 @@ class PatchTSTForClassification(PatchTSTPreTrainedModel):
1387
1402
  output_hidden_states: Optional[bool] = None,
1388
1403
  output_attentions: Optional[bool] = None,
1389
1404
  return_dict: Optional[bool] = None,
1405
+ **kwargs,
1390
1406
  ) -> Union[tuple, PatchTSTForClassificationOutput]:
1391
1407
  r"""
1392
1408
  past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
@@ -1594,6 +1610,7 @@ class PatchTSTForPrediction(PatchTSTPreTrainedModel):
1594
1610
  output_hidden_states: Optional[bool] = None,
1595
1611
  output_attentions: Optional[bool] = None,
1596
1612
  return_dict: Optional[bool] = None,
1613
+ **kwargs,
1597
1614
  ) -> Union[tuple, PatchTSTForPredictionOutput]:
1598
1615
  r"""
1599
1616
  Parameters:
@@ -1840,6 +1857,7 @@ class PatchTSTForRegression(PatchTSTPreTrainedModel):
1840
1857
  output_hidden_states: Optional[bool] = None,
1841
1858
  output_attentions: Optional[bool] = None,
1842
1859
  return_dict: Optional[bool] = None,
1860
+ **kwargs,
1843
1861
  ) -> Union[tuple, PatchTSTForRegressionOutput]:
1844
1862
  r"""
1845
1863
  past_values (`torch.Tensor` of shape `(bs, sequence_length, num_input_channels)`, *required*):
@@ -518,6 +518,7 @@ class PegasusEncoder(PegasusPreTrainedModel):
518
518
  output_attentions=None,
519
519
  output_hidden_states=None,
520
520
  return_dict=None,
521
+ **kwargs,
521
522
  ):
522
523
  r"""
523
524
  Args:
@@ -695,6 +696,7 @@ class PegasusDecoder(PegasusPreTrainedModel):
695
696
  output_hidden_states=None,
696
697
  return_dict=None,
697
698
  cache_position=None,
699
+ **kwargs,
698
700
  ):
699
701
  r"""
700
702
  Args:
@@ -946,6 +948,7 @@ class PegasusModel(PegasusPreTrainedModel):
946
948
  output_hidden_states: Optional[bool] = None,
947
949
  return_dict: Optional[bool] = None,
948
950
  cache_position: Optional[torch.Tensor] = None,
951
+ **kwargs,
949
952
  ) -> Union[tuple, Seq2SeqModelOutput]:
950
953
  r"""
951
954
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1111,6 +1114,7 @@ class PegasusForConditionalGeneration(PegasusPreTrainedModel, GenerationMixin):
1111
1114
  output_hidden_states: Optional[bool] = None,
1112
1115
  return_dict: Optional[bool] = None,
1113
1116
  cache_position: Optional[torch.Tensor] = None,
1117
+ **kwargs,
1114
1118
  ) -> Union[tuple, Seq2SeqLMOutput]:
1115
1119
  r"""
1116
1120
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1283,6 +1287,7 @@ class PegasusForCausalLM(PegasusPreTrainedModel, GenerationMixin):
1283
1287
  return_dict: Optional[bool] = None,
1284
1288
  cache_position: Optional[torch.LongTensor] = None,
1285
1289
  logits_to_keep: Union[int, torch.Tensor] = 0,
1290
+ **kwargs,
1286
1291
  ) -> Union[tuple, CausalLMOutputWithCrossAttentions]:
1287
1292
  r"""
1288
1293
  labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -14,6 +14,8 @@
14
14
  # limitations under the License.
15
15
  """Tokenization class for model PEGASUS."""
16
16
 
17
+ from typing import Optional, Union
18
+
17
19
  from tokenizers import Regex, Tokenizer, decoders, normalizers, pre_tokenizers, processors
18
20
  from tokenizers.models import Unigram
19
21
 
@@ -70,15 +72,17 @@ class PegasusTokenizer(TokenizersBackend):
70
72
  that uses the tokens 2 - 104 only for pretraining
71
73
  offset (`int`, *optional*, defaults to 103):
72
74
  Offset for additional special tokens.
73
- vocab (`dict`, *optional*):
74
- Custom vocabulary dictionary. If not provided, a blank vocabulary is initialized.
75
+ vocab (`str` or `list[tuple[str, float]]`, *optional*):
76
+ Custom vocabulary with `(token, score)` tuples. If not provided, a blank vocabulary is initialized.
75
77
  """
76
78
 
77
79
  vocab_files_names = VOCAB_FILES_NAMES
78
80
  model_input_names = ["input_ids", "attention_mask"]
81
+ model = Unigram
79
82
 
80
83
  def __init__(
81
84
  self,
85
+ vocab: Optional[Union[str, list[tuple[str, float]]]] = None,
82
86
  pad_token="<pad>",
83
87
  eos_token="</s>",
84
88
  unk_token="<unk>",
@@ -86,60 +90,27 @@ class PegasusTokenizer(TokenizersBackend):
86
90
  mask_token_sent="<mask_1>",
87
91
  additional_special_tokens=None,
88
92
  offset=103,
89
- vocab=None,
90
- vocab_file=None,
91
93
  **kwargs,
92
94
  ):
93
95
  self.offset = offset
94
- self.vocab_file = vocab_file
95
96
 
96
97
  if additional_special_tokens is None:
97
98
  additional_special_tokens = [mask_token_sent] if mask_token_sent is not None else []
98
99
  additional_special_tokens += [f"<unk_{i}>" for i in range(2, self.offset)]
99
100
 
100
- if vocab is not None:
101
- # For Pegasus, insert special tokens at the beginning
102
- special_tokens_set = {pad_token, eos_token, mask_token_sent, mask_token, unk_token}
103
- special_tokens_set.update(additional_special_tokens)
104
-
105
- # Build special tokens in correct order
106
- _vocab_list = [
107
- (str(pad_token), 0.0),
108
- (str(eos_token), 0.0),
109
- ]
110
- if mask_token_sent:
111
- _vocab_list.append((str(mask_token_sent), 0.0))
112
- for token in additional_special_tokens:
113
- if token not in [pad_token, eos_token, mask_token_sent]:
114
- _vocab_list.append((str(token), 0.0))
115
- if mask_token not in [t for t, _ in _vocab_list]:
116
- _vocab_list.append((str(mask_token), 0.0))
117
- _vocab_list.append((str(unk_token), 0.0))
118
-
119
- # Filter out special tokens from main vocab and combine
120
- filtered_vocab = [(t, s) for t, s in vocab if t not in special_tokens_set]
121
- _vocab_list = _vocab_list + filtered_vocab
122
- else:
123
- _vocab_list = [(str(unk_token), 0.0)]
124
-
125
- self._vocab = {token: idx for idx, (token, _) in enumerate(_vocab_list)}
126
-
127
- self._tokenizer = Tokenizer(Unigram(vocab=_vocab_list, unk_id=self._vocab.get(str(unk_token), 0)))
101
+ if vocab is None:
102
+ vocab = [(str(unk_token), 0.0), (str(pad_token), 0.0), (str(eos_token), 0.0), (str(mask_token), 0.0)]
128
103
 
104
+ self._vocab = vocab
105
+ self._tokenizer = Tokenizer(Unigram(vocab=vocab, unk_id=self._vocab.index((str(unk_token), 0.0), 1)))
129
106
  self._tokenizer.normalizer = normalizers.Sequence(
130
107
  [normalizers.Replace(Regex(r"\n"), " "), normalizers.Replace(Regex(r" {2,}"), " ")]
131
108
  )
132
109
 
133
- self._tokenizer.post_processor = processors.TemplateProcessing(
134
- single=f"$A {eos_token}",
135
- pair=f"$A $B {eos_token}",
136
- special_tokens=[(str(eos_token), self._vocab.get(str(eos_token), 1))],
137
- )
138
-
139
- tokenizer_object = self._tokenizer
110
+ self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True)
111
+ self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
140
112
 
141
113
  super().__init__(
142
- tokenizer_object=tokenizer_object,
143
114
  pad_token=pad_token,
144
115
  eos_token=eos_token,
145
116
  unk_token=unk_token,
@@ -149,9 +120,11 @@ class PegasusTokenizer(TokenizersBackend):
149
120
  additional_special_tokens=additional_special_tokens,
150
121
  **kwargs,
151
122
  )
152
-
153
- self._tokenizer.pre_tokenizer = pre_tokenizers.Metaspace(replacement="▁", prepend_scheme="always", split=True)
154
- self._tokenizer.decoder = decoders.Metaspace(replacement="▁", prepend_scheme="always", split=True)
123
+ self._tokenizer.post_processor = processors.TemplateProcessing(
124
+ single=f"$A {eos_token}",
125
+ pair=f"$A $B {eos_token}",
126
+ special_tokens=[(str(eos_token), self.convert_tokens_to_ids(str(eos_token)))],
127
+ )
155
128
 
156
129
 
157
130
  __all__ = ["PegasusTokenizer"]
@@ -821,6 +821,7 @@ class PegasusXEncoder(PegasusXPreTrainedModel):
821
821
  output_attentions=None,
822
822
  output_hidden_states=None,
823
823
  return_dict=None,
824
+ **kwargs,
824
825
  ):
825
826
  r"""
826
827
  Args:
@@ -989,6 +990,7 @@ class PegasusXDecoder(PegasusXPreTrainedModel):
989
990
  output_hidden_states=None,
990
991
  return_dict=None,
991
992
  cache_position=None,
993
+ **kwargs,
992
994
  ):
993
995
  r"""
994
996
  Args:
@@ -1241,6 +1243,7 @@ class PegasusXModel(PegasusXPreTrainedModel):
1241
1243
  output_hidden_states: Optional[bool] = None,
1242
1244
  return_dict: Optional[bool] = None,
1243
1245
  cache_position: Optional[torch.Tensor] = None,
1246
+ **kwargs,
1244
1247
  ) -> Union[tuple, Seq2SeqModelOutput]:
1245
1248
  r"""
1246
1249
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -1388,6 +1391,7 @@ class PegasusXForConditionalGeneration(PegasusXPreTrainedModel, GenerationMixin)
1388
1391
  output_hidden_states: Optional[bool] = None,
1389
1392
  return_dict: Optional[bool] = None,
1390
1393
  cache_position: Optional[torch.Tensor] = None,
1394
+ **kwargs,
1391
1395
  ) -> Union[tuple, Seq2SeqLMOutput]:
1392
1396
  r"""
1393
1397
  decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
@@ -615,6 +615,7 @@ class PerceiverModel(PerceiverPreTrainedModel):
615
615
  output_hidden_states: Optional[bool] = None,
616
616
  interpolate_pos_encoding: bool = False,
617
617
  return_dict: Optional[bool] = None,
618
+ **kwargs,
618
619
  ) -> Union[tuple, PerceiverModelOutput]:
619
620
  r"""
620
621
  inputs (`torch.FloatTensor`):
@@ -850,6 +851,7 @@ class PerceiverForMaskedLM(PerceiverPreTrainedModel):
850
851
  labels: Optional[torch.Tensor] = None,
851
852
  return_dict: Optional[bool] = None,
852
853
  input_ids: Optional[torch.Tensor] = None,
854
+ **kwargs,
853
855
  ) -> Union[tuple, PerceiverMaskedLMOutput]:
854
856
  r"""
855
857
  inputs (`torch.FloatTensor`):
@@ -975,6 +977,7 @@ class PerceiverForSequenceClassification(PerceiverPreTrainedModel):
975
977
  labels: Optional[torch.Tensor] = None,
976
978
  return_dict: Optional[bool] = None,
977
979
  input_ids: Optional[torch.Tensor] = None,
980
+ **kwargs,
978
981
  ) -> Union[tuple, PerceiverClassifierOutput]:
979
982
  r"""
980
983
  inputs (`torch.FloatTensor`):
@@ -1107,6 +1110,7 @@ class PerceiverForImageClassificationLearned(PerceiverPreTrainedModel):
1107
1110
  interpolate_pos_encoding: bool = False,
1108
1111
  return_dict: Optional[bool] = None,
1109
1112
  pixel_values: Optional[torch.Tensor] = None,
1113
+ **kwargs,
1110
1114
  ) -> Union[tuple, PerceiverClassifierOutput]:
1111
1115
  r"""
1112
1116
  inputs (`torch.FloatTensor`):
@@ -1229,6 +1233,7 @@ class PerceiverForImageClassificationFourier(PerceiverPreTrainedModel):
1229
1233
  labels: Optional[torch.Tensor] = None,
1230
1234
  return_dict: Optional[bool] = None,
1231
1235
  pixel_values: Optional[torch.Tensor] = None,
1236
+ **kwargs,
1232
1237
  ) -> Union[tuple, PerceiverClassifierOutput]:
1233
1238
  r"""
1234
1239
  inputs (`torch.FloatTensor`):
@@ -1350,6 +1355,7 @@ class PerceiverForImageClassificationConvProcessing(PerceiverPreTrainedModel):
1350
1355
  labels: Optional[torch.Tensor] = None,
1351
1356
  return_dict: Optional[bool] = None,
1352
1357
  pixel_values: Optional[torch.Tensor] = None,
1358
+ **kwargs,
1353
1359
  ) -> Union[tuple, PerceiverClassifierOutput]:
1354
1360
  r"""
1355
1361
  inputs (`torch.FloatTensor`):
@@ -1487,6 +1493,7 @@ class PerceiverForOpticalFlow(PerceiverPreTrainedModel):
1487
1493
  output_hidden_states: Optional[bool] = None,
1488
1494
  labels: Optional[torch.Tensor] = None,
1489
1495
  return_dict: Optional[bool] = None,
1496
+ **kwargs,
1490
1497
  ) -> Union[tuple, PerceiverClassifierOutput]:
1491
1498
  r"""
1492
1499
  inputs (`torch.FloatTensor`):
@@ -1695,6 +1702,7 @@ class PerceiverForMultimodalAutoencoding(PerceiverPreTrainedModel):
1695
1702
  output_hidden_states: Optional[bool] = None,
1696
1703
  labels: Optional[torch.Tensor] = None,
1697
1704
  return_dict: Optional[bool] = None,
1705
+ **kwargs,
1698
1706
  ) -> Union[tuple, PerceiverClassifierOutput]:
1699
1707
  r"""
1700
1708
  inputs (`torch.FloatTensor`):
@@ -46,6 +46,7 @@ from ...modeling_rope_utils import (
46
46
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
47
47
  from ...processing_utils import Unpack
48
48
  from ...utils import auto_docstring, can_return_tuple, is_torch_flex_attn_available, logging
49
+ from ...utils.generic import maybe_autocast
49
50
  from .configuration_persimmon import PersimmonConfig
50
51
 
51
52
 
@@ -118,7 +119,7 @@ class PersimmonRotaryEmbedding(nn.Module):
118
119
  position_ids_expanded = position_ids[:, None, :].float()
119
120
 
120
121
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
121
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
122
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
122
123
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
123
124
  emb = torch.cat((freqs, freqs), dim=-1)
124
125
  cos = emb.cos() * self.attention_scaling