transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (539) hide show
  1. transformers/__init__.py +30 -3
  2. transformers/cli/serve.py +47 -17
  3. transformers/conversion_mapping.py +15 -2
  4. transformers/convert_slow_tokenizer.py +225 -10
  5. transformers/core_model_loading.py +196 -135
  6. transformers/data/data_collator.py +12 -4
  7. transformers/dependency_versions_table.py +1 -2
  8. transformers/dynamic_module_utils.py +1 -2
  9. transformers/feature_extraction_utils.py +1 -2
  10. transformers/file_utils.py +0 -1
  11. transformers/generation/__init__.py +11 -1
  12. transformers/generation/configuration_utils.py +3 -2
  13. transformers/generation/continuous_batching/__init__.py +4 -0
  14. transformers/generation/continuous_batching/continuous_api.py +134 -79
  15. transformers/image_processing_base.py +1 -2
  16. transformers/integrations/__init__.py +4 -2
  17. transformers/integrations/accelerate.py +15 -3
  18. transformers/integrations/aqlm.py +38 -66
  19. transformers/integrations/awq.py +48 -514
  20. transformers/integrations/bitnet.py +45 -100
  21. transformers/integrations/bitsandbytes.py +79 -191
  22. transformers/integrations/deepspeed.py +1 -0
  23. transformers/integrations/eetq.py +84 -79
  24. transformers/integrations/fbgemm_fp8.py +191 -145
  25. transformers/integrations/finegrained_fp8.py +236 -193
  26. transformers/integrations/fp_quant.py +92 -0
  27. transformers/integrations/ggml.py +11 -1
  28. transformers/integrations/higgs.py +40 -62
  29. transformers/integrations/hub_kernels.py +42 -3
  30. transformers/integrations/integration_utils.py +10 -0
  31. transformers/integrations/mxfp4.py +25 -65
  32. transformers/integrations/peft.py +7 -29
  33. transformers/integrations/quanto.py +73 -55
  34. transformers/integrations/quark.py +55 -0
  35. transformers/integrations/spqr.py +44 -90
  36. transformers/integrations/torchao.py +32 -38
  37. transformers/integrations/vptq.py +42 -59
  38. transformers/modelcard.py +1 -2
  39. transformers/modeling_gguf_pytorch_utils.py +8 -0
  40. transformers/modeling_rope_utils.py +30 -6
  41. transformers/modeling_utils.py +116 -112
  42. transformers/models/__init__.py +3 -0
  43. transformers/models/afmoe/modeling_afmoe.py +4 -4
  44. transformers/models/albert/tokenization_albert.py +6 -12
  45. transformers/models/align/modeling_align.py +2 -0
  46. transformers/models/altclip/modeling_altclip.py +4 -0
  47. transformers/models/apertus/modeling_apertus.py +4 -4
  48. transformers/models/arcee/modeling_arcee.py +4 -4
  49. transformers/models/aria/modeling_aria.py +4 -4
  50. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  51. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  52. transformers/models/auto/configuration_auto.py +11 -0
  53. transformers/models/auto/feature_extraction_auto.py +2 -0
  54. transformers/models/auto/image_processing_auto.py +1 -0
  55. transformers/models/auto/modeling_auto.py +6 -0
  56. transformers/models/auto/processing_auto.py +18 -10
  57. transformers/models/auto/tokenization_auto.py +74 -472
  58. transformers/models/autoformer/modeling_autoformer.py +4 -0
  59. transformers/models/bamba/modeling_bamba.py +4 -3
  60. transformers/models/bark/modeling_bark.py +2 -0
  61. transformers/models/bart/modeling_bart.py +7 -0
  62. transformers/models/barthez/tokenization_barthez.py +5 -10
  63. transformers/models/beit/modeling_beit.py +6 -1
  64. transformers/models/bert/tokenization_bert.py +8 -21
  65. transformers/models/big_bird/modeling_big_bird.py +6 -0
  66. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  67. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +8 -2
  68. transformers/models/biogpt/modeling_biogpt.py +2 -0
  69. transformers/models/biogpt/modular_biogpt.py +2 -0
  70. transformers/models/bit/modeling_bit.py +11 -2
  71. transformers/models/bitnet/modeling_bitnet.py +4 -4
  72. transformers/models/blenderbot/modeling_blenderbot.py +5 -0
  73. transformers/models/blenderbot/tokenization_blenderbot.py +12 -16
  74. transformers/models/blenderbot_small/modeling_blenderbot_small.py +5 -0
  75. transformers/models/blip/modeling_blip_text.py +2 -0
  76. transformers/models/blip_2/modeling_blip_2.py +2 -1
  77. transformers/models/bloom/modeling_bloom.py +4 -0
  78. transformers/models/blt/modeling_blt.py +2 -2
  79. transformers/models/blt/modular_blt.py +2 -2
  80. transformers/models/bridgetower/modeling_bridgetower.py +5 -1
  81. transformers/models/bros/modeling_bros.py +4 -0
  82. transformers/models/camembert/tokenization_camembert.py +8 -12
  83. transformers/models/canine/modeling_canine.py +5 -0
  84. transformers/models/chameleon/modeling_chameleon.py +2 -1
  85. transformers/models/chinese_clip/modeling_chinese_clip.py +3 -0
  86. transformers/models/clap/modeling_clap.py +5 -0
  87. transformers/models/clip/tokenization_clip.py +22 -44
  88. transformers/models/clipseg/modeling_clipseg.py +5 -0
  89. transformers/models/clvp/modeling_clvp.py +5 -0
  90. transformers/models/clvp/tokenization_clvp.py +1 -63
  91. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  92. transformers/models/codegen/tokenization_codegen.py +14 -43
  93. transformers/models/cohere/modeling_cohere.py +4 -3
  94. transformers/models/cohere/modular_cohere.py +2 -1
  95. transformers/models/cohere/tokenization_cohere.py +12 -42
  96. transformers/models/cohere2/modeling_cohere2.py +7 -6
  97. transformers/models/cohere2/modular_cohere2.py +5 -5
  98. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -3
  99. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  100. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  101. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  102. transformers/models/conditional_detr/modeling_conditional_detr.py +5 -0
  103. transformers/models/convbert/modeling_convbert.py +6 -0
  104. transformers/models/convnext/modeling_convnext.py +2 -4
  105. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  106. transformers/models/csm/modeling_csm.py +4 -3
  107. transformers/models/ctrl/modeling_ctrl.py +1 -0
  108. transformers/models/cvt/modeling_cvt.py +2 -0
  109. transformers/models/cwm/modeling_cwm.py +4 -4
  110. transformers/models/d_fine/modeling_d_fine.py +2 -0
  111. transformers/models/d_fine/modular_d_fine.py +1 -0
  112. transformers/models/dab_detr/modeling_dab_detr.py +4 -0
  113. transformers/models/dac/modeling_dac.py +2 -2
  114. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  115. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  116. transformers/models/dbrx/modeling_dbrx.py +2 -2
  117. transformers/models/deberta/modeling_deberta.py +5 -0
  118. transformers/models/deberta/tokenization_deberta.py +11 -20
  119. transformers/models/deberta_v2/modeling_deberta_v2.py +6 -0
  120. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  121. transformers/models/decision_transformer/modeling_decision_transformer.py +4 -1
  122. transformers/models/deepseek_v2/modeling_deepseek_v2.py +2 -3
  123. transformers/models/deepseek_v2/modular_deepseek_v2.py +2 -2
  124. transformers/models/deepseek_v3/modeling_deepseek_v3.py +3 -2
  125. transformers/models/deepseek_v3/modular_deepseek_v3.py +1 -0
  126. transformers/models/deformable_detr/modeling_deformable_detr.py +4 -0
  127. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  128. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  129. transformers/models/detr/modeling_detr.py +5 -0
  130. transformers/models/dia/modeling_dia.py +4 -3
  131. transformers/models/dia/modular_dia.py +0 -1
  132. transformers/models/diffllama/modeling_diffllama.py +2 -2
  133. transformers/models/dinat/modeling_dinat.py +3 -0
  134. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  135. transformers/models/dinov3_vit/modeling_dinov3_vit.py +2 -2
  136. transformers/models/dinov3_vit/modular_dinov3_vit.py +2 -2
  137. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  138. transformers/models/doge/modeling_doge.py +2 -3
  139. transformers/models/doge/modular_doge.py +0 -1
  140. transformers/models/donut/modeling_donut_swin.py +2 -0
  141. transformers/models/dots1/modeling_dots1.py +10 -7
  142. transformers/models/dots1/modular_dots1.py +5 -3
  143. transformers/models/dpr/modeling_dpr.py +5 -0
  144. transformers/models/dpr/tokenization_dpr.py +12 -0
  145. transformers/models/edgetam/modeling_edgetam.py +1 -1
  146. transformers/models/edgetam_video/modeling_edgetam_video.py +1 -0
  147. transformers/models/edgetam_video/modular_edgetam_video.py +1 -0
  148. transformers/models/efficientloftr/modeling_efficientloftr.py +2 -2
  149. transformers/models/efficientnet/modeling_efficientnet.py +2 -0
  150. transformers/models/emu3/modeling_emu3.py +4 -4
  151. transformers/models/eomt/image_processing_eomt.py +13 -1
  152. transformers/models/eomt/image_processing_eomt_fast.py +14 -2
  153. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  154. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  155. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +5 -5
  156. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +2 -2
  157. transformers/models/esm/modeling_esmfold.py +5 -4
  158. transformers/models/evolla/modeling_evolla.py +4 -4
  159. transformers/models/exaone4/modeling_exaone4.py +2 -2
  160. transformers/models/exaone4/modular_exaone4.py +0 -1
  161. transformers/models/falcon/modeling_falcon.py +6 -1
  162. transformers/models/falcon_h1/modeling_falcon_h1.py +4 -3
  163. transformers/models/falcon_mamba/modeling_falcon_mamba.py +25 -35
  164. transformers/models/falcon_mamba/modular_falcon_mamba.py +12 -31
  165. transformers/{kernels/falcon_mamba → models/fast_vlm}/__init__.py +15 -3
  166. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  167. transformers/models/fast_vlm/modeling_fast_vlm.py +455 -0
  168. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  169. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +8 -3
  170. transformers/models/flaubert/modeling_flaubert.py +7 -0
  171. transformers/models/flava/modeling_flava.py +6 -1
  172. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -5
  173. transformers/models/florence2/modeling_florence2.py +2 -1
  174. transformers/models/florence2/modular_florence2.py +2 -1
  175. transformers/models/fnet/modeling_fnet.py +7 -0
  176. transformers/models/focalnet/modeling_focalnet.py +4 -0
  177. transformers/models/fsmt/modeling_fsmt.py +2 -0
  178. transformers/models/funnel/modeling_funnel.py +8 -0
  179. transformers/models/funnel/tokenization_funnel.py +17 -24
  180. transformers/models/fuyu/processing_fuyu.py +3 -3
  181. transformers/models/gemma/modeling_gemma.py +4 -4
  182. transformers/models/gemma/tokenization_gemma.py +10 -27
  183. transformers/models/gemma2/modeling_gemma2.py +4 -4
  184. transformers/models/gemma2/modular_gemma2.py +2 -1
  185. transformers/models/gemma3/modeling_gemma3.py +14 -84
  186. transformers/models/gemma3/modular_gemma3.py +12 -81
  187. transformers/models/gemma3n/modeling_gemma3n.py +18 -209
  188. transformers/models/gemma3n/modular_gemma3n.py +17 -59
  189. transformers/models/git/modeling_git.py +2 -0
  190. transformers/models/glm/modeling_glm.py +4 -4
  191. transformers/models/glm4/modeling_glm4.py +4 -4
  192. transformers/models/glm4_moe/modeling_glm4_moe.py +5 -3
  193. transformers/models/glm4v/configuration_glm4v.py +3 -1
  194. transformers/models/glm4v/modeling_glm4v.py +3 -3
  195. transformers/models/glm4v/modular_glm4v.py +6 -4
  196. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  197. transformers/models/glm4v_moe/modeling_glm4v_moe.py +6 -5
  198. transformers/models/glm4v_moe/modular_glm4v_moe.py +1 -1
  199. transformers/models/glpn/modeling_glpn.py +2 -0
  200. transformers/models/gpt2/modeling_gpt2.py +5 -1
  201. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  202. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +1 -0
  203. transformers/models/gpt_neo/modeling_gpt_neo.py +4 -0
  204. transformers/models/gpt_neox/modeling_gpt_neox.py +5 -2
  205. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  206. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  207. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +3 -1
  208. transformers/models/gpt_oss/modeling_gpt_oss.py +5 -6
  209. transformers/models/gpt_oss/modular_gpt_oss.py +3 -5
  210. transformers/models/gptj/modeling_gptj.py +3 -0
  211. transformers/models/granite/modeling_granite.py +4 -4
  212. transformers/models/granitemoe/modeling_granitemoe.py +4 -6
  213. transformers/models/granitemoe/modular_granitemoe.py +0 -2
  214. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +4 -6
  215. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -6
  216. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -0
  217. transformers/models/groupvit/modeling_groupvit.py +3 -0
  218. transformers/models/helium/modeling_helium.py +4 -3
  219. transformers/models/herbert/tokenization_herbert.py +9 -25
  220. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -1
  221. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -1
  222. transformers/models/hiera/modeling_hiera.py +4 -0
  223. transformers/models/hubert/modeling_hubert.py +3 -0
  224. transformers/models/hubert/modular_hubert.py +1 -0
  225. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +4 -4
  226. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +4 -4
  227. transformers/models/ibert/modeling_ibert.py +6 -0
  228. transformers/models/idefics/modeling_idefics.py +5 -21
  229. transformers/models/imagegpt/modeling_imagegpt.py +2 -1
  230. transformers/models/informer/modeling_informer.py +4 -0
  231. transformers/models/informer/modular_informer.py +1 -0
  232. transformers/models/internvl/modeling_internvl.py +2 -4
  233. transformers/models/internvl/modular_internvl.py +2 -4
  234. transformers/models/jamba/modeling_jamba.py +2 -2
  235. transformers/models/janus/modeling_janus.py +1 -0
  236. transformers/models/janus/modular_janus.py +1 -0
  237. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  238. transformers/models/kosmos2/modeling_kosmos2.py +1 -0
  239. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +3 -1
  240. transformers/models/lasr/__init__.py +29 -0
  241. transformers/models/lasr/configuration_lasr.py +244 -0
  242. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  243. transformers/models/lasr/modeling_lasr.py +729 -0
  244. transformers/models/lasr/modular_lasr.py +569 -0
  245. transformers/models/lasr/processing_lasr.py +96 -0
  246. transformers/models/lasr/tokenization_lasr.py +186 -0
  247. transformers/models/layoutlm/modeling_layoutlm.py +5 -0
  248. transformers/models/layoutlmv2/modeling_layoutlmv2.py +4 -0
  249. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +10 -53
  250. transformers/models/layoutlmv3/modeling_layoutlmv3.py +4 -0
  251. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  252. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  253. transformers/models/led/modeling_led.py +6 -0
  254. transformers/models/levit/modeling_levit.py +3 -0
  255. transformers/models/lfm2/modeling_lfm2.py +4 -5
  256. transformers/models/lfm2/modular_lfm2.py +0 -1
  257. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -5
  258. transformers/models/lightglue/modeling_lightglue.py +3 -1
  259. transformers/models/lightglue/modular_lightglue.py +1 -0
  260. transformers/models/lilt/modeling_lilt.py +4 -0
  261. transformers/models/llama/modeling_llama.py +4 -4
  262. transformers/models/llama/tokenization_llama.py +15 -43
  263. transformers/models/llama4/modeling_llama4.py +3 -2
  264. transformers/models/longcat_flash/modeling_longcat_flash.py +4 -4
  265. transformers/models/longcat_flash/modular_longcat_flash.py +2 -2
  266. transformers/models/longformer/modeling_longformer.py +6 -0
  267. transformers/models/longt5/modeling_longt5.py +4 -0
  268. transformers/models/luke/modeling_luke.py +9 -0
  269. transformers/models/luke/tokenization_luke.py +11 -38
  270. transformers/models/lxmert/modeling_lxmert.py +2 -0
  271. transformers/models/m2m_100/modeling_m2m_100.py +4 -0
  272. transformers/models/mamba/modeling_mamba.py +14 -22
  273. transformers/models/marian/modeling_marian.py +5 -0
  274. transformers/models/markuplm/modeling_markuplm.py +4 -0
  275. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  276. transformers/models/mask2former/modeling_mask2former.py +2 -0
  277. transformers/models/maskformer/modeling_maskformer.py +2 -0
  278. transformers/models/maskformer/modeling_maskformer_swin.py +2 -0
  279. transformers/models/mbart/modeling_mbart.py +7 -0
  280. transformers/models/mbart/tokenization_mbart.py +11 -52
  281. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  282. transformers/models/megatron_bert/modeling_megatron_bert.py +7 -0
  283. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  284. transformers/models/mimi/modeling_mimi.py +3 -1
  285. transformers/models/minimax/modeling_minimax.py +4 -4
  286. transformers/models/ministral/modeling_ministral.py +4 -4
  287. transformers/models/ministral3/configuration_ministral3.py +1 -1
  288. transformers/models/ministral3/modeling_ministral3.py +4 -3
  289. transformers/models/mistral/modeling_mistral.py +4 -3
  290. transformers/models/mixtral/modeling_mixtral.py +4 -4
  291. transformers/models/mllama/modeling_mllama.py +2 -2
  292. transformers/models/mluke/tokenization_mluke.py +6 -6
  293. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -0
  294. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  295. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  296. transformers/models/mobilevit/modeling_mobilevit.py +3 -0
  297. transformers/models/mobilevitv2/modeling_mobilevitv2.py +3 -0
  298. transformers/models/modernbert/modeling_modernbert.py +4 -1
  299. transformers/models/modernbert/modular_modernbert.py +2 -0
  300. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +8 -9
  301. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +6 -7
  302. transformers/models/moonshine/modeling_moonshine.py +4 -2
  303. transformers/models/moshi/modeling_moshi.py +5 -2
  304. transformers/models/mpnet/modeling_mpnet.py +5 -0
  305. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  306. transformers/models/mpt/modeling_mpt.py +2 -0
  307. transformers/models/mra/modeling_mra.py +6 -0
  308. transformers/models/mt5/modeling_mt5.py +7 -0
  309. transformers/models/musicgen/modeling_musicgen.py +2 -0
  310. transformers/models/musicgen_melody/modeling_musicgen_melody.py +3 -0
  311. transformers/models/mvp/modeling_mvp.py +7 -0
  312. transformers/models/nanochat/modeling_nanochat.py +4 -4
  313. transformers/models/nemotron/modeling_nemotron.py +4 -2
  314. transformers/models/nllb/tokenization_nllb.py +8 -22
  315. transformers/models/nougat/tokenization_nougat.py +11 -59
  316. transformers/models/nystromformer/modeling_nystromformer.py +6 -0
  317. transformers/models/olmo/modeling_olmo.py +4 -4
  318. transformers/models/olmo/modular_olmo.py +2 -2
  319. transformers/models/olmo2/modeling_olmo2.py +4 -5
  320. transformers/models/olmo2/modular_olmo2.py +0 -1
  321. transformers/models/olmo3/modeling_olmo3.py +4 -4
  322. transformers/models/olmoe/modeling_olmoe.py +4 -4
  323. transformers/models/omdet_turbo/modeling_omdet_turbo.py +2 -0
  324. transformers/models/oneformer/modeling_oneformer.py +4 -1
  325. transformers/models/openai/modeling_openai.py +3 -0
  326. transformers/models/openai/tokenization_openai.py +10 -46
  327. transformers/models/opt/modeling_opt.py +2 -0
  328. transformers/models/owlv2/modeling_owlv2.py +4 -0
  329. transformers/models/owlvit/modeling_owlvit.py +4 -0
  330. transformers/models/paddleocr_vl/__init__.py +32 -0
  331. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  332. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +503 -0
  333. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  334. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1668 -0
  335. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1349 -0
  336. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  337. transformers/models/parakeet/configuration_parakeet.py +4 -6
  338. transformers/models/parakeet/modeling_parakeet.py +9 -6
  339. transformers/models/parakeet/modular_parakeet.py +2 -2
  340. transformers/models/parakeet/processing_parakeet.py +1 -0
  341. transformers/models/patchtsmixer/modeling_patchtsmixer.py +6 -0
  342. transformers/models/patchtst/modeling_patchtst.py +20 -2
  343. transformers/models/pegasus/modeling_pegasus.py +5 -0
  344. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  345. transformers/models/pegasus_x/modeling_pegasus_x.py +4 -0
  346. transformers/models/perceiver/modeling_perceiver.py +8 -0
  347. transformers/models/persimmon/modeling_persimmon.py +2 -1
  348. transformers/models/phi/modeling_phi.py +4 -5
  349. transformers/models/phi/modular_phi.py +0 -1
  350. transformers/models/phi3/modeling_phi3.py +2 -1
  351. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +5 -5
  352. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +4 -4
  353. transformers/models/phimoe/modeling_phimoe.py +4 -4
  354. transformers/models/phimoe/modular_phimoe.py +2 -2
  355. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  356. transformers/models/pixtral/modeling_pixtral.py +2 -1
  357. transformers/models/plbart/modeling_plbart.py +6 -0
  358. transformers/models/plbart/modular_plbart.py +2 -0
  359. transformers/models/plbart/tokenization_plbart.py +0 -2
  360. transformers/models/poolformer/modeling_poolformer.py +2 -0
  361. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  362. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  363. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  364. transformers/models/prophetnet/modeling_prophetnet.py +3 -0
  365. transformers/models/pvt/modeling_pvt.py +2 -0
  366. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  367. transformers/models/qwen2/modeling_qwen2.py +4 -4
  368. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  369. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  370. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +13 -16
  371. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +14 -16
  372. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  373. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +5 -6
  374. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +3 -5
  375. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -0
  376. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  377. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  378. transformers/models/qwen2_vl/modeling_qwen2_vl.py +6 -16
  379. transformers/models/qwen3/modeling_qwen3.py +4 -4
  380. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  381. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -3
  382. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +21 -23
  383. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +14 -16
  384. transformers/models/qwen3_vl/modeling_qwen3_vl.py +39 -37
  385. transformers/models/qwen3_vl/modular_qwen3_vl.py +37 -35
  386. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +39 -37
  387. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +4 -1
  388. transformers/models/rag/modeling_rag.py +1 -0
  389. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +15 -1
  390. transformers/models/reformer/modeling_reformer.py +4 -0
  391. transformers/models/reformer/tokenization_reformer.py +11 -28
  392. transformers/models/regnet/modeling_regnet.py +6 -1
  393. transformers/models/rembert/modeling_rembert.py +6 -0
  394. transformers/models/rembert/tokenization_rembert.py +3 -10
  395. transformers/models/resnet/modeling_resnet.py +11 -2
  396. transformers/models/roberta/tokenization_roberta.py +18 -27
  397. transformers/models/roformer/modeling_roformer.py +6 -0
  398. transformers/models/roformer/tokenization_roformer.py +77 -412
  399. transformers/models/rt_detr/modeling_rt_detr.py +2 -0
  400. transformers/models/rt_detr/modeling_rt_detr_resnet.py +5 -1
  401. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +2 -0
  402. transformers/models/rwkv/modeling_rwkv.py +1 -0
  403. transformers/models/sam2/modeling_sam2.py +2 -2
  404. transformers/models/sam2/modular_sam2.py +2 -2
  405. transformers/models/sam2_video/modeling_sam2_video.py +1 -0
  406. transformers/models/sam2_video/modular_sam2_video.py +1 -0
  407. transformers/models/sam3/modeling_sam3.py +77 -80
  408. transformers/models/sam3_tracker/modeling_sam3_tracker.py +6 -1
  409. transformers/models/sam3_tracker/modular_sam3_tracker.py +6 -1
  410. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +1 -0
  411. transformers/models/sam3_video/modeling_sam3_video.py +1 -0
  412. transformers/models/seamless_m4t/modeling_seamless_m4t.py +5 -1
  413. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  414. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +5 -1
  415. transformers/models/seed_oss/modeling_seed_oss.py +2 -2
  416. transformers/models/segformer/modeling_segformer.py +4 -1
  417. transformers/models/seggpt/modeling_seggpt.py +2 -0
  418. transformers/models/sew/modeling_sew.py +3 -0
  419. transformers/models/sew/modular_sew.py +1 -0
  420. transformers/models/sew_d/modeling_sew_d.py +3 -0
  421. transformers/models/siglip2/modeling_siglip2.py +4 -0
  422. transformers/models/siglip2/modular_siglip2.py +4 -0
  423. transformers/models/smollm3/modeling_smollm3.py +4 -4
  424. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  425. transformers/models/speech_to_text/modeling_speech_to_text.py +4 -0
  426. transformers/models/speecht5/modeling_speecht5.py +13 -1
  427. transformers/models/splinter/modeling_splinter.py +3 -0
  428. transformers/models/splinter/tokenization_splinter.py +9 -28
  429. transformers/models/squeezebert/modeling_squeezebert.py +6 -0
  430. transformers/models/stablelm/modeling_stablelm.py +3 -1
  431. transformers/models/starcoder2/modeling_starcoder2.py +4 -3
  432. transformers/models/superglue/modeling_superglue.py +1 -0
  433. transformers/models/superpoint/modeling_superpoint.py +1 -0
  434. transformers/models/swiftformer/modeling_swiftformer.py +2 -0
  435. transformers/models/swin/modeling_swin.py +4 -0
  436. transformers/models/swin2sr/modeling_swin2sr.py +2 -0
  437. transformers/models/swinv2/modeling_swinv2.py +4 -0
  438. transformers/models/t5/modeling_t5.py +7 -0
  439. transformers/models/t5/tokenization_t5.py +4 -8
  440. transformers/models/t5gemma/modeling_t5gemma.py +5 -5
  441. transformers/models/t5gemma2/modeling_t5gemma2.py +6 -6
  442. transformers/models/table_transformer/modeling_table_transformer.py +4 -0
  443. transformers/models/tapas/modeling_tapas.py +3 -0
  444. transformers/models/textnet/modeling_textnet.py +11 -2
  445. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  446. transformers/models/timesfm/modeling_timesfm.py +2 -0
  447. transformers/models/timesfm/modular_timesfm.py +2 -0
  448. transformers/models/timesformer/modeling_timesformer.py +2 -0
  449. transformers/models/timm_wrapper/modeling_timm_wrapper.py +1 -1
  450. transformers/models/trocr/modeling_trocr.py +2 -0
  451. transformers/models/tvp/modeling_tvp.py +2 -0
  452. transformers/models/udop/modeling_udop.py +4 -0
  453. transformers/models/udop/tokenization_udop.py +5 -13
  454. transformers/models/umt5/modeling_umt5.py +7 -0
  455. transformers/models/unispeech/modeling_unispeech.py +4 -0
  456. transformers/models/unispeech/modular_unispeech.py +2 -0
  457. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  458. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  459. transformers/models/univnet/modeling_univnet.py +1 -0
  460. transformers/models/upernet/modeling_upernet.py +1 -0
  461. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  462. transformers/models/vilt/modeling_vilt.py +6 -0
  463. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  464. transformers/models/visual_bert/modeling_visual_bert.py +6 -0
  465. transformers/models/vitdet/modeling_vitdet.py +2 -0
  466. transformers/models/vitmatte/modeling_vitmatte.py +1 -0
  467. transformers/models/vits/modeling_vits.py +1 -0
  468. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  469. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  470. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +5 -0
  471. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +5 -0
  472. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +6 -0
  473. transformers/models/wavlm/modeling_wavlm.py +5 -0
  474. transformers/models/whisper/modeling_whisper.py +6 -0
  475. transformers/models/whisper/tokenization_whisper.py +4 -15
  476. transformers/models/x_clip/modeling_x_clip.py +3 -0
  477. transformers/models/xglm/modeling_xglm.py +1 -0
  478. transformers/models/xglm/tokenization_xglm.py +4 -9
  479. transformers/models/xlm/modeling_xlm.py +5 -0
  480. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  481. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  482. transformers/models/yoso/modeling_yoso.py +6 -0
  483. transformers/models/zamba/modeling_zamba.py +2 -0
  484. transformers/models/zamba2/modeling_zamba2.py +4 -2
  485. transformers/models/zamba2/modular_zamba2.py +1 -1
  486. transformers/models/zoedepth/modeling_zoedepth.py +1 -0
  487. transformers/pipelines/__init__.py +2 -3
  488. transformers/pipelines/base.py +1 -9
  489. transformers/pipelines/document_question_answering.py +3 -1
  490. transformers/pipelines/text_generation.py +1 -1
  491. transformers/processing_utils.py +23 -11
  492. transformers/quantizers/base.py +35 -110
  493. transformers/quantizers/quantizer_aqlm.py +1 -5
  494. transformers/quantizers/quantizer_auto_round.py +1 -2
  495. transformers/quantizers/quantizer_awq.py +17 -81
  496. transformers/quantizers/quantizer_bitnet.py +3 -8
  497. transformers/quantizers/quantizer_bnb_4bit.py +13 -110
  498. transformers/quantizers/quantizer_bnb_8bit.py +16 -92
  499. transformers/quantizers/quantizer_compressed_tensors.py +1 -5
  500. transformers/quantizers/quantizer_eetq.py +14 -62
  501. transformers/quantizers/quantizer_fbgemm_fp8.py +34 -125
  502. transformers/quantizers/quantizer_finegrained_fp8.py +13 -105
  503. transformers/quantizers/quantizer_fp_quant.py +48 -78
  504. transformers/quantizers/quantizer_gptq.py +7 -24
  505. transformers/quantizers/quantizer_higgs.py +40 -54
  506. transformers/quantizers/quantizer_hqq.py +144 -153
  507. transformers/quantizers/quantizer_mxfp4.py +13 -167
  508. transformers/quantizers/quantizer_quanto.py +20 -64
  509. transformers/quantizers/quantizer_quark.py +36 -17
  510. transformers/quantizers/quantizer_spqr.py +1 -4
  511. transformers/quantizers/quantizer_torchao.py +23 -202
  512. transformers/quantizers/quantizer_vptq.py +8 -22
  513. transformers/quantizers/quantizers_utils.py +20 -0
  514. transformers/testing_utils.py +297 -36
  515. transformers/tokenization_mistral_common.py +4 -0
  516. transformers/tokenization_utils_base.py +113 -222
  517. transformers/tokenization_utils_tokenizers.py +168 -107
  518. transformers/trainer.py +28 -31
  519. transformers/trainer_jit_checkpoint.py +126 -0
  520. transformers/trainer_utils.py +1 -1
  521. transformers/training_args.py +66 -28
  522. transformers/utils/__init__.py +3 -4
  523. transformers/utils/auto_docstring.py +1 -0
  524. transformers/utils/generic.py +27 -1
  525. transformers/utils/hub.py +5 -15
  526. transformers/utils/import_utils.py +61 -16
  527. transformers/utils/kernel_config.py +4 -2
  528. transformers/utils/loading_report.py +19 -10
  529. transformers/utils/quantization_config.py +75 -242
  530. transformers/video_processing_utils.py +1 -2
  531. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/METADATA +274 -227
  532. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/RECORD +536 -520
  533. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/WHEEL +1 -1
  534. transformers/kernels/__init__.py +0 -0
  535. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  536. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  537. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/entry_points.txt +0 -0
  538. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info/licenses}/LICENSE +0 -0
  539. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc1.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: bdist_wheel (0.45.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
File without changes
@@ -1,529 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2024 Tri Dao, Albert Gu, Technological Innovation Institute and HuggingFace Inc. team.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- # Original code from: https://github.com/state-spaces/mamba/blob/main/mamba_ssm/ops/selective_scan_interface.py
16
-
17
- import torch
18
- import torch.nn.functional as F
19
- from einops import rearrange, repeat
20
- from torch.cuda.amp import custom_bwd, custom_fwd
21
-
22
-
23
- try:
24
- import causal_conv1d_cuda
25
- except ImportError:
26
- causal_conv1d_cuda = None
27
-
28
- import mamba_ssm
29
- import selective_scan_cuda
30
-
31
-
32
- # For BC for old mamba-ssm versions: https://github.com/huggingface/transformers/pull/33195#discussion_r1736401127
33
- if hasattr(mamba_ssm.ops.triton, "layernorm"):
34
- from mamba_ssm.ops.triton.layernorm import _layer_norm_fwd
35
- else:
36
- from mamba_ssm.ops.triton.layer_norm import _layer_norm_fwd
37
-
38
-
39
- class SelectiveScanFn(torch.autograd.Function):
40
- @staticmethod
41
- def forward(
42
- ctx, u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
43
- ):
44
- if u.stride(-1) != 1:
45
- u = u.contiguous()
46
- if delta.stride(-1) != 1:
47
- delta = delta.contiguous()
48
- if D is not None:
49
- D = D.contiguous()
50
- if B.stride(-1) != 1:
51
- B = B.contiguous()
52
- if C.stride(-1) != 1:
53
- C = C.contiguous()
54
- if z is not None and z.stride(-1) != 1:
55
- z = z.contiguous()
56
- if B.dim() == 3:
57
- B = rearrange(B, "b dstate l -> b 1 dstate l")
58
- ctx.squeeze_B = True
59
- if C.dim() == 3:
60
- C = rearrange(C, "b dstate l -> b 1 dstate l")
61
- ctx.squeeze_C = True
62
- out, x, *rest = selective_scan_cuda.fwd(u, delta, A, B, C, D, z, delta_bias, delta_softplus)
63
- ctx.delta_softplus = delta_softplus
64
- ctx.has_z = z is not None
65
- last_state = x[:, :, -1, 1::2] # (batch, dim, dstate)
66
- if not ctx.has_z:
67
- ctx.save_for_backward(u, delta, A, B, C, D, delta_bias, x)
68
- return out if not return_last_state else (out, last_state)
69
- else:
70
- ctx.save_for_backward(u, delta, A, B, C, D, z, delta_bias, x, out)
71
- out_z = rest[0]
72
- return out_z if not return_last_state else (out_z, last_state)
73
-
74
- @staticmethod
75
- def backward(ctx, dout, *args):
76
- if not ctx.has_z:
77
- u, delta, A, B, C, D, delta_bias, x = ctx.saved_tensors
78
- z = None
79
- out = None
80
- else:
81
- u, delta, A, B, C, D, z, delta_bias, x, out = ctx.saved_tensors
82
- if dout.stride(-1) != 1:
83
- dout = dout.contiguous()
84
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
85
- # backward of selective_scan_cuda with the backward of chunk).
86
- # Here we just pass in None and dz will be allocated in the C++ code.
87
- du, ddelta, dA, dB, dC, dD, ddelta_bias, *rest = selective_scan_cuda.bwd(
88
- u,
89
- delta,
90
- A,
91
- B,
92
- C,
93
- D,
94
- z,
95
- delta_bias,
96
- dout,
97
- x,
98
- out,
99
- None,
100
- ctx.delta_softplus,
101
- False, # option to recompute out_z, not used here
102
- )
103
- dz = rest[0] if ctx.has_z else None
104
- dB = dB.squeeze(1) if getattr(ctx, "squeeze_B", False) else dB
105
- dC = dC.squeeze(1) if getattr(ctx, "squeeze_C", False) else dC
106
- return (
107
- du,
108
- ddelta,
109
- dA,
110
- dB,
111
- dC,
112
- dD if D is not None else None,
113
- dz,
114
- ddelta_bias if delta_bias is not None else None,
115
- None,
116
- None,
117
- )
118
-
119
-
120
- def rms_norm_forward(
121
- x,
122
- weight,
123
- bias,
124
- eps=1e-6,
125
- is_rms_norm=True,
126
- ):
127
- # x (b l) d
128
- if x.stride(-1) != 1:
129
- x = x.contiguous()
130
- weight = weight.contiguous()
131
- if bias is not None:
132
- bias = bias.contiguous()
133
- y = _layer_norm_fwd(x, weight, bias, eps, None, residual_dtype=None, is_rms_norm=is_rms_norm)[0]
134
- # y (b l) d
135
- return y
136
-
137
-
138
- def selective_scan_fn(
139
- u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
140
- ):
141
- """if return_last_state is True, returns (out, last_state)
142
- last_state has shape (batch, dim, dstate). Note that the gradient of the last state is
143
- not considered in the backward pass.
144
- """
145
- return SelectiveScanFn.apply(u, delta, A, B, C, D, z, delta_bias, delta_softplus, return_last_state)
146
-
147
-
148
- def selective_scan_ref(
149
- u, delta, A, B, C, D=None, z=None, delta_bias=None, delta_softplus=False, return_last_state=False
150
- ):
151
- """
152
- u: r(B D L)
153
- delta: r(B D L)
154
- A: c(D N) or r(D N)
155
- B: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
156
- C: c(D N) or r(B N L) or r(B N 2L) or r(B G N L) or (B G N L)
157
- D: r(D)
158
- z: r(B D L)
159
- delta_bias: r(D), fp32
160
-
161
- out: r(B D L)
162
- last_state (optional): r(B D dstate) or c(B D dstate)
163
- """
164
- dtype_in = u.dtype
165
- u = u.float()
166
- delta = delta.float()
167
- if delta_bias is not None:
168
- delta = delta + delta_bias[..., None].float()
169
- if delta_softplus:
170
- delta = F.softplus(delta)
171
- batch, dim, dstate = u.shape[0], A.shape[0], A.shape[1]
172
- is_variable_B = B.dim() >= 3
173
- is_variable_C = C.dim() >= 3
174
- if A.is_complex():
175
- if is_variable_B:
176
- B = torch.view_as_complex(rearrange(B.float(), "... (L two) -> ... L two", two=2))
177
- if is_variable_C:
178
- C = torch.view_as_complex(rearrange(C.float(), "... (L two) -> ... L two", two=2))
179
- else:
180
- B = B.float()
181
- C = C.float()
182
- x = A.new_zeros((batch, dim, dstate))
183
- ys = []
184
- deltaA = torch.exp(torch.einsum("bdl,dn->bdln", delta, A))
185
- if not is_variable_B:
186
- deltaB_u = torch.einsum("bdl,dn,bdl->bdln", delta, B, u)
187
- else:
188
- if B.dim() == 3:
189
- deltaB_u = torch.einsum("bdl,bnl,bdl->bdln", delta, B, u)
190
- else:
191
- B = repeat(B, "B G N L -> B (G H) N L", H=dim // B.shape[1])
192
- deltaB_u = torch.einsum("bdl,bdnl,bdl->bdln", delta, B, u)
193
- if is_variable_C and C.dim() == 4:
194
- C = repeat(C, "B G N L -> B (G H) N L", H=dim // C.shape[1])
195
- last_state = None
196
- for i in range(u.shape[2]):
197
- x = deltaA[:, :, i] * x + deltaB_u[:, :, i]
198
- if not is_variable_C:
199
- y = torch.einsum("bdn,dn->bd", x, C)
200
- else:
201
- if C.dim() == 3:
202
- y = torch.einsum("bdn,bn->bd", x, C[:, :, i])
203
- else:
204
- y = torch.einsum("bdn,bdn->bd", x, C[:, :, :, i])
205
- if i == u.shape[2] - 1:
206
- last_state = x
207
- if y.is_complex():
208
- y = y.real * 2
209
- ys.append(y)
210
- y = torch.stack(ys, dim=2) # (batch dim L)
211
- out = y if D is None else y + u * rearrange(D, "d -> d 1")
212
- if z is not None:
213
- out = out * F.silu(z)
214
- out = out.to(dtype=dtype_in)
215
- return out if not return_last_state else (out, last_state)
216
-
217
-
218
- class MambaInnerFn(torch.autograd.Function):
219
- @staticmethod
220
- @custom_fwd
221
- def forward(
222
- ctx,
223
- xz,
224
- conv1d_weight,
225
- conv1d_bias,
226
- x_proj_weight,
227
- delta_proj_weight,
228
- out_proj_weight,
229
- out_proj_bias,
230
- A,
231
- B=None,
232
- C=None,
233
- D=None,
234
- delta_bias=None,
235
- B_proj_bias=None,
236
- C_proj_bias=None,
237
- delta_softplus=True,
238
- checkpoint_lvl=1,
239
- b_rms_weight=None,
240
- c_rms_weight=None,
241
- dt_rms_weight=None,
242
- b_c_dt_rms_eps=1e-6,
243
- ):
244
- """
245
- xz: (batch, dim, seqlen)
246
- """
247
- assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
248
- assert checkpoint_lvl in [0, 1]
249
- L = xz.shape[-1]
250
- delta_rank = delta_proj_weight.shape[1]
251
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
252
- if torch.is_autocast_enabled():
253
- # NOTE: `torch.get_autocast_dtype` is there starting from PyTorch 2.4
254
- target_dtype = (
255
- torch.get_autocast_dtype("cuda")
256
- if hasattr(torch, "get_autocast_dtype")
257
- else torch.get_autocast_gpu_dtype()
258
- )
259
- x_proj_weight = x_proj_weight.to(dtype=target_dtype)
260
- delta_proj_weight = delta_proj_weight.to(dtype=target_dtype)
261
- out_proj_weight = out_proj_weight.to(dtype=target_dtype)
262
- out_proj_bias = out_proj_bias.to(dtype=target_dtype) if out_proj_bias is not None else None
263
- if xz.stride(-1) != 1:
264
- xz = xz.contiguous()
265
- conv1d_weight = rearrange(conv1d_weight, "d 1 w -> d w")
266
- x, z = xz.chunk(2, dim=1)
267
- conv1d_bias = conv1d_bias.contiguous() if conv1d_bias is not None else None
268
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
269
- # We're being very careful here about the layout, to avoid extra transposes.
270
- # We want delta to have d as the slowest moving dimension
271
- # and L as the fastest moving dimension, since those are what the ssm_scan kernel expects.
272
- x_dbl = F.linear(rearrange(conv1d_out, "b d l -> (b l) d"), x_proj_weight) # (bl d)
273
- delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
274
- ctx.is_variable_B = B is None
275
- ctx.is_variable_C = C is None
276
- ctx.B_proj_bias_is_None = B_proj_bias is None
277
- ctx.C_proj_bias_is_None = C_proj_bias is None
278
- if B is None: # variable B
279
- B = x_dbl[:, delta_rank : delta_rank + d_state] # (bl dstate)
280
- if B_proj_bias is not None:
281
- B = B + B_proj_bias.to(dtype=B.dtype)
282
- if not A.is_complex():
283
- # B = rearrange(B, "(b l) dstate -> b dstate l", l=L).contiguous()
284
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
285
- else:
286
- B = rearrange(B, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
287
- else:
288
- if B.stride(-1) != 1:
289
- B = B.contiguous()
290
- if C is None: # variable C
291
- C = x_dbl[:, -d_state:] # (bl dstate)
292
- if C_proj_bias is not None:
293
- C = C + C_proj_bias.to(dtype=C.dtype)
294
- if not A.is_complex():
295
- # C = rearrange(C, "(b l) dstate -> b dstate l", l=L).contiguous()
296
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
297
- else:
298
- C = rearrange(C, "(b l) (dstate two) -> b 1 dstate (l two)", l=L, two=2).contiguous()
299
- else:
300
- if C.stride(-1) != 1:
301
- C = C.contiguous()
302
- if D is not None:
303
- D = D.contiguous()
304
-
305
- if b_rms_weight is not None:
306
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
307
- B = rms_norm_forward(B, b_rms_weight, bias=None, eps=b_c_dt_rms_eps)
308
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
309
- if c_rms_weight is not None:
310
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
311
- C = rms_norm_forward(C, c_rms_weight, bias=None, eps=b_c_dt_rms_eps)
312
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
313
- if dt_rms_weight is not None:
314
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
315
- delta = rms_norm_forward(delta, dt_rms_weight, bias=None, eps=b_c_dt_rms_eps)
316
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
317
-
318
- out, scan_intermediates, out_z = selective_scan_cuda.fwd(
319
- conv1d_out, delta, A, B, C, D, z, delta_bias, delta_softplus
320
- )
321
- ctx.delta_softplus = delta_softplus
322
- ctx.out_proj_bias_is_None = out_proj_bias is None
323
- ctx.checkpoint_lvl = checkpoint_lvl
324
- ctx.b_rms_weight = b_rms_weight
325
- ctx.c_rms_weight = c_rms_weight
326
- ctx.dt_rms_weight = dt_rms_weight
327
- ctx.b_c_dt_rms_eps = b_c_dt_rms_eps
328
- if checkpoint_lvl >= 1: # Will recompute conv1d_out and delta in the backward pass
329
- conv1d_out, delta = None, None
330
- ctx.save_for_backward(
331
- xz,
332
- conv1d_weight,
333
- conv1d_bias,
334
- x_dbl,
335
- x_proj_weight,
336
- delta_proj_weight,
337
- out_proj_weight,
338
- conv1d_out,
339
- delta,
340
- A,
341
- B,
342
- C,
343
- D,
344
- delta_bias,
345
- scan_intermediates,
346
- b_rms_weight,
347
- c_rms_weight,
348
- dt_rms_weight,
349
- out,
350
- )
351
- return F.linear(rearrange(out_z, "b d l -> b l d"), out_proj_weight, out_proj_bias)
352
-
353
- @staticmethod
354
- @custom_bwd
355
- def backward(ctx, dout):
356
- # dout: (batch, seqlen, dim)
357
- assert causal_conv1d_cuda is not None, "causal_conv1d_cuda is not available. Please install causal-conv1d."
358
- (
359
- xz,
360
- conv1d_weight,
361
- conv1d_bias,
362
- x_dbl,
363
- x_proj_weight,
364
- delta_proj_weight,
365
- out_proj_weight,
366
- conv1d_out,
367
- delta,
368
- A,
369
- B,
370
- C,
371
- D,
372
- delta_bias,
373
- scan_intermediates,
374
- b_rms_weight,
375
- c_rms_weight,
376
- dt_rms_weight,
377
- out,
378
- ) = ctx.saved_tensors
379
- L = xz.shape[-1]
380
- delta_rank = delta_proj_weight.shape[1]
381
- d_state = A.shape[-1] * (1 if not A.is_complex() else 2)
382
- x, z = xz.chunk(2, dim=1)
383
- if dout.stride(-1) != 1:
384
- dout = dout.contiguous()
385
- if ctx.checkpoint_lvl == 1:
386
- conv1d_out = causal_conv1d_cuda.causal_conv1d_fwd(x, conv1d_weight, conv1d_bias, None, None, None, True)
387
- delta = rearrange(delta_proj_weight @ x_dbl[:, :delta_rank].t(), "d (b l) -> b d l", l=L)
388
- if dt_rms_weight is not None:
389
- delta = rearrange(delta, "b d l -> (b l) d", l=L).contiguous()
390
- delta = rms_norm_forward(delta, ctx.dt_rms_weight, None, ctx.b_c_dt_rms_eps)
391
- delta = rearrange(delta, "(b l) d -> b d l", l=L).contiguous()
392
- if b_rms_weight is not None:
393
- # Recompute & RMSNorm B
394
- B = rearrange(B, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
395
- B = rms_norm_forward(B, ctx.b_rms_weight, None, ctx.b_c_dt_rms_eps)
396
- B = rearrange(B, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
397
- if c_rms_weight is not None:
398
- # Recompute & RMSNorm C
399
- C = rearrange(C, "b 1 dstate l -> (b l) dstate", l=L).contiguous()
400
- C = rms_norm_forward(C, ctx.c_rms_weight, None, ctx.b_c_dt_rms_eps)
401
- C = rearrange(C, "(b l) dstate -> b 1 dstate l", l=L).contiguous()
402
-
403
- # The kernel supports passing in a pre-allocated dz (e.g., in case we want to fuse the
404
- # backward of selective_scan_cuda with the backward of chunk).
405
- dxz = torch.empty_like(xz) # (batch, dim, seqlen)
406
- dx, dz = dxz.chunk(2, dim=1)
407
- dout = rearrange(dout, "b l e -> e (b l)")
408
- dout_y = rearrange(out_proj_weight.t() @ dout, "d (b l) -> b d l", l=L)
409
- dconv1d_out, ddelta, dA, dB, dC, dD, ddelta_bias, dz, out_z = selective_scan_cuda.bwd(
410
- conv1d_out,
411
- delta,
412
- A,
413
- B,
414
- C,
415
- D,
416
- z,
417
- delta_bias,
418
- dout_y,
419
- scan_intermediates,
420
- out,
421
- dz,
422
- ctx.delta_softplus,
423
- True, # option to recompute out_z
424
- )
425
- dout_proj_weight = torch.einsum("eB,dB->ed", dout, rearrange(out_z, "b d l -> d (b l)"))
426
- dout_proj_bias = dout.sum(dim=(0, 1)) if not ctx.out_proj_bias_is_None else None
427
- dD = dD if D is not None else None
428
- dx_dbl = torch.empty_like(x_dbl)
429
- dB_proj_bias = None
430
- if ctx.is_variable_B:
431
- if not A.is_complex():
432
- dB = rearrange(dB, "b 1 dstate l -> (b l) dstate").contiguous()
433
- else:
434
- dB = rearrange(dB, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
435
- dB_proj_bias = dB.sum(0) if not ctx.B_proj_bias_is_None else None
436
- dx_dbl[:, delta_rank : delta_rank + d_state] = dB # (bl d)
437
- dB = None
438
- dC_proj_bias = None
439
- if ctx.is_variable_C:
440
- if not A.is_complex():
441
- dC = rearrange(dC, "b 1 dstate l -> (b l) dstate").contiguous()
442
- else:
443
- dC = rearrange(dC, "b 1 dstate (l two) -> (b l) (dstate two)", two=2).contiguous()
444
- dC_proj_bias = dC.sum(0) if not ctx.C_proj_bias_is_None else None
445
- dx_dbl[:, -d_state:] = dC # (bl d)
446
- dC = None
447
- ddelta = rearrange(ddelta, "b d l -> d (b l)")
448
- ddelta_proj_weight = torch.einsum("dB,Br->dr", ddelta, x_dbl[:, :delta_rank])
449
- dx_dbl[:, :delta_rank] = torch.einsum("dB,dr->Br", ddelta, delta_proj_weight)
450
- dconv1d_out = rearrange(dconv1d_out, "b d l -> d (b l)")
451
- dx_proj_weight = torch.einsum("Br,Bd->rd", dx_dbl, rearrange(conv1d_out, "b d l -> (b l) d"))
452
- dconv1d_out = torch.addmm(dconv1d_out, x_proj_weight.t(), dx_dbl.t(), out=dconv1d_out)
453
- dconv1d_out = rearrange(dconv1d_out, "d (b l) -> b d l", b=x.shape[0], l=x.shape[-1])
454
- # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
455
- # backward of conv1d with the backward of chunk).
456
- dx, dconv1d_weight, dconv1d_bias, *_ = causal_conv1d_cuda.causal_conv1d_bwd(
457
- x, conv1d_weight, conv1d_bias, dconv1d_out, None, None, None, dx, False, True
458
- )
459
- dconv1d_bias = dconv1d_bias if conv1d_bias is not None else None
460
- dconv1d_weight = rearrange(dconv1d_weight, "d w -> d 1 w")
461
- return (
462
- dxz,
463
- dconv1d_weight,
464
- dconv1d_bias,
465
- dx_proj_weight,
466
- ddelta_proj_weight,
467
- dout_proj_weight,
468
- dout_proj_bias,
469
- dA,
470
- dB,
471
- dC,
472
- dD,
473
- ddelta_bias if delta_bias is not None else None,
474
- # 6-None are delta_softplus, checkpoint_lvl, b_rms_weight, c_rms_weight, dt_rms_weight, b_c_dt_rms_eps
475
- dB_proj_bias,
476
- dC_proj_bias,
477
- None,
478
- None,
479
- None,
480
- None,
481
- None,
482
- None,
483
- )
484
-
485
-
486
- def mamba_inner_fn(
487
- xz,
488
- conv1d_weight,
489
- conv1d_bias,
490
- x_proj_weight,
491
- delta_proj_weight,
492
- out_proj_weight,
493
- out_proj_bias,
494
- A,
495
- B=None,
496
- C=None,
497
- D=None,
498
- delta_bias=None,
499
- B_proj_bias=None,
500
- C_proj_bias=None,
501
- delta_softplus=True,
502
- checkpoint_lvl=1,
503
- b_rms_weight=None,
504
- c_rms_weight=None,
505
- dt_rms_weight=None,
506
- b_c_dt_rms_eps=1e-6,
507
- ):
508
- return MambaInnerFn.apply(
509
- xz,
510
- conv1d_weight,
511
- conv1d_bias,
512
- x_proj_weight,
513
- delta_proj_weight,
514
- out_proj_weight,
515
- out_proj_bias,
516
- A,
517
- B,
518
- C,
519
- D,
520
- delta_bias,
521
- B_proj_bias,
522
- C_proj_bias,
523
- delta_softplus,
524
- checkpoint_lvl,
525
- b_rms_weight,
526
- c_rms_weight,
527
- dt_rms_weight,
528
- b_c_dt_rms_eps,
529
- )
@@ -1,160 +0,0 @@
1
- # coding=utf-8
2
- # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
3
- #
4
- # Licensed under the Apache License, Version 2.0 (the "License");
5
- # you may not use this file except in compliance with the License.
6
- # You may obtain a copy of the License at
7
- #
8
- # http://www.apache.org/licenses/LICENSE-2.0
9
- #
10
- # Unless required by applicable law or agreed to in writing, software
11
- # distributed under the License is distributed on an "AS IS" BASIS,
12
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
- # See the License for the specific language governing permissions and
14
- # limitations under the License.
15
- """Tokenization classes for RoFormer."""
16
-
17
- import json
18
- from typing import Optional
19
-
20
- from tokenizers import normalizers
21
- from tokenizers.pre_tokenizers import BertPreTokenizer, PreTokenizer
22
-
23
- from ...tokenization_utils_tokenizers import PreTrainedTokenizerFast
24
- from ...utils import logging
25
- from .tokenization_roformer import RoFormerTokenizer
26
- from .tokenization_utils import JiebaPreTokenizer
27
-
28
-
29
- logger = logging.get_logger(__name__)
30
-
31
- VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt", "tokenizer_file": "tokenizer.json"}
32
-
33
-
34
- class RoFormerTokenizerFast(PreTrainedTokenizerFast):
35
- r"""
36
- Construct a "fast" RoFormer tokenizer (backed by HuggingFace's *tokenizers* library).
37
-
38
- [`RoFormerTokenizerFast`] is almost identical to [`BertTokenizerFast`] and runs end-to-end tokenization:
39
- punctuation splitting and wordpiece. There are some difference between them when tokenizing Chinese.
40
-
41
- This tokenizer inherits from [`PreTrainedTokenizerFast`] which contains most of the main methods. Users should
42
- refer to this superclass for more information regarding those methods.
43
-
44
- Example:
45
-
46
- ```python
47
- >>> from transformers import RoFormerTokenizerFast
48
-
49
- >>> tokenizer = RoFormerTokenizerFast.from_pretrained("junnyu/roformer_chinese_base")
50
- >>> tokenizer.tokenize("今天天气非常好。")
51
- ['今', '天', '天', '气', '非常', '好', '。']
52
- ```"""
53
-
54
- vocab_files_names = VOCAB_FILES_NAMES
55
- slow_tokenizer_class = RoFormerTokenizer
56
-
57
- def __init__(
58
- self,
59
- vocab_file=None,
60
- tokenizer_file=None,
61
- do_lower_case=True,
62
- unk_token="[UNK]",
63
- sep_token="[SEP]",
64
- pad_token="[PAD]",
65
- cls_token="[CLS]",
66
- mask_token="[MASK]",
67
- tokenize_chinese_chars=True,
68
- strip_accents=None,
69
- **kwargs,
70
- ):
71
- super().__init__(
72
- vocab_file,
73
- tokenizer_file=tokenizer_file,
74
- do_lower_case=do_lower_case,
75
- unk_token=unk_token,
76
- sep_token=sep_token,
77
- pad_token=pad_token,
78
- cls_token=cls_token,
79
- mask_token=mask_token,
80
- tokenize_chinese_chars=tokenize_chinese_chars,
81
- strip_accents=strip_accents,
82
- **kwargs,
83
- )
84
-
85
- normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
86
- normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
87
- normalizer_state["lowercase"] = do_lower_case
88
- normalizer_state["strip_accents"] = strip_accents
89
- self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
90
-
91
- vocab = self.backend_tokenizer.get_vocab()
92
- self.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
93
-
94
- self.do_lower_case = do_lower_case
95
- self.strip_accents = strip_accents
96
-
97
- def _post_init(self):
98
- super()._post_init()
99
- normalizer_state = json.loads(self.backend_tokenizer.normalizer.__getstate__())
100
- normalizer_class = getattr(normalizers, normalizer_state.pop("type"))
101
- normalizer_state["lowercase"] = self.do_lower_case
102
- normalizer_state["strip_accents"] = getattr(self, "strip_accents", None)
103
- self.backend_tokenizer.normalizer = normalizer_class(**normalizer_state)
104
- vocab = self.backend_tokenizer.get_vocab()
105
- self.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
106
-
107
- def __getstate__(self):
108
- state = self.__dict__.copy()
109
- state["_tokenizer"].pre_tokenizer = BertPreTokenizer()
110
- return state
111
-
112
- def __setstate__(self, d):
113
- self.__dict__ = d
114
- vocab = self.__dict__["_tokenizer"].get_vocab()
115
- self.__dict__["_tokenizer"].pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
116
-
117
- def build_inputs_with_special_tokens(self, token_ids_0, token_ids_1=None):
118
- """
119
- Build model inputs from a sequence or a pair of sequence for sequence classification tasks by concatenating and
120
- adding special tokens. A RoFormer sequence has the following format:
121
-
122
- - single sequence: `[CLS] X [SEP]`
123
- - pair of sequences: `[CLS] A [SEP] B [SEP]`
124
-
125
- Args:
126
- token_ids_0 (`List[int]`):
127
- List of IDs to which the special tokens will be added.
128
- token_ids_1 (`List[int]`, *optional*):
129
- Optional second list of IDs for sequence pairs.
130
-
131
- Returns:
132
- `List[int]`: List of [input IDs](../glossary#input-ids) with the appropriate special tokens.
133
- """
134
- output = [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
135
-
136
- if token_ids_1 is not None:
137
- output += token_ids_1 + [self.sep_token_id]
138
-
139
- return output
140
-
141
- def save_vocabulary(self, save_directory: str, filename_prefix: Optional[str] = None) -> tuple[str]:
142
- files = self._tokenizer.model.save(save_directory, name=filename_prefix)
143
- return tuple(files)
144
-
145
- def save_pretrained(
146
- self,
147
- save_directory,
148
- legacy_format=None,
149
- filename_prefix=None,
150
- push_to_hub=False,
151
- **kwargs,
152
- ):
153
- self.backend_tokenizer.pre_tokenizer = BertPreTokenizer()
154
- result = super().save_pretrained(save_directory, legacy_format, filename_prefix, push_to_hub, **kwargs)
155
- vocab = self.backend_tokenizer.get_vocab()
156
- self.backend_tokenizer.pre_tokenizer = PreTokenizer.custom(JiebaPreTokenizer(vocab))
157
- return result
158
-
159
-
160
- __all__ = ["RoFormerTokenizerFast"]