transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
32
  from ...cache_utils import Cache, DynamicCache
33
33
  from ...generation import GenerationMixin
34
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
34
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
35
35
  from ...masking_utils import create_bidirectional_mask, create_causal_mask
36
36
  from ...modeling_layers import GradientCheckpointingLayer
37
37
  from ...modeling_outputs import (
@@ -45,7 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
45
45
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
46
  from ...processing_utils import Unpack
47
47
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
48
- from ...utils.generic import OutputRecorder, check_model_inputs
48
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
49
49
  from .configuration_evolla import EvollaConfig, SaProtConfig
50
50
 
51
51
 
@@ -185,6 +185,7 @@ class EvollaSaProtRotaryEmbedding(nn.Module):
185
185
 
186
186
  def __init__(self, dim: int):
187
187
  super().__init__()
188
+ self.dim = dim
188
189
  # Generate and save the inverse frequency buffer (non trainable)
189
190
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
190
191
  self.register_buffer("inv_freq", inv_freq)
@@ -518,12 +519,19 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel):
518
519
  ],
519
520
  }
520
521
 
522
+ def _init_weights(self, module):
523
+ super()._init_weights(module)
524
+ if isinstance(module, EvollaSaProtRotaryEmbedding):
525
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
526
+ init.copy_(module.inv_freq, inv_freq)
527
+
521
528
 
522
529
  class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
523
530
  def __init__(self, config: SaProtConfig):
524
531
  super().__init__(config)
525
532
  self.embeddings = EvollaSaProtEmbeddings(config)
526
533
  self.encoder = EvollaSaProtEncoder(config)
534
+ self.post_init()
527
535
 
528
536
  def get_input_embeddings(self):
529
537
  return self.embeddings.word_embeddings
@@ -980,7 +988,7 @@ class EvollaRotaryEmbedding(nn.Module):
980
988
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
981
989
 
982
990
  self.register_buffer("inv_freq", inv_freq, persistent=False)
983
- self.original_inv_freq = inv_freq
991
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
984
992
 
985
993
  @staticmethod
986
994
  def compute_default_rope_parameters(
@@ -1019,7 +1027,7 @@ class EvollaRotaryEmbedding(nn.Module):
1019
1027
  position_ids_expanded = position_ids[:, None, :].float()
1020
1028
 
1021
1029
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1022
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
1030
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
1023
1031
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
1024
1032
  emb = torch.cat((freqs, freqs), dim=-1)
1025
1033
  cos = emb.cos() * self.attention_scaling
@@ -1091,6 +1099,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
1091
1099
  return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
1092
1100
 
1093
1101
 
1102
+ @use_kernelized_func(apply_rotary_pos_emb)
1094
1103
  class EvollaAttention(nn.Module):
1095
1104
  """Multi-headed attention from 'Attention Is All You Need' paper"""
1096
1105
 
@@ -1116,7 +1125,6 @@ class EvollaAttention(nn.Module):
1116
1125
  self.o_proj = nn.Linear(
1117
1126
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
1118
1127
  )
1119
- self.rotary_fn = apply_rotary_pos_emb
1120
1128
 
1121
1129
  def forward(
1122
1130
  self,
@@ -91,6 +91,7 @@ class EvollaSaProtRotaryEmbedding(nn.Module):
91
91
 
92
92
  def __init__(self, dim: int):
93
93
  super().__init__()
94
+ self.dim = dim
94
95
  # Generate and save the inverse frequency buffer (non trainable)
95
96
  inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2, dtype=torch.int64).float() / dim))
96
97
  self.register_buffer("inv_freq", inv_freq)
@@ -203,12 +204,19 @@ class EvollaSaProtPreTrainedModel(PreTrainedModel):
203
204
  ],
204
205
  }
205
206
 
207
+ def _init_weights(self, module):
208
+ super()._init_weights(module)
209
+ if isinstance(module, EvollaSaProtRotaryEmbedding):
210
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, module.dim, 2, dtype=torch.int64).float() / module.dim))
211
+ init.copy_(module.inv_freq, inv_freq)
212
+
206
213
 
207
214
  class EvollaSaProtProteinEncoder(EvollaSaProtPreTrainedModel):
208
215
  def __init__(self, config: SaProtConfig):
209
216
  super().__init__(config)
210
217
  self.embeddings = EvollaSaProtEmbeddings(config)
211
218
  self.encoder = EvollaSaProtEncoder(config)
219
+ self.post_init()
212
220
 
213
221
  def get_input_embeddings(self):
214
222
  return self.embeddings.word_embeddings
@@ -44,6 +44,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
44
44
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
45
45
  from ...processing_utils import Unpack
46
46
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
47
+ from ...utils.generic import maybe_autocast
47
48
  from .configuration_exaone4 import Exaone4Config
48
49
 
49
50
 
@@ -85,7 +86,7 @@ class Exaone4RotaryEmbedding(nn.Module):
85
86
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
86
87
 
87
88
  self.register_buffer("inv_freq", inv_freq, persistent=False)
88
- self.original_inv_freq = inv_freq
89
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
89
90
 
90
91
  @staticmethod
91
92
  def compute_default_rope_parameters(
@@ -124,7 +125,7 @@ class Exaone4RotaryEmbedding(nn.Module):
124
125
  position_ids_expanded = position_ids[:, None, :].float()
125
126
 
126
127
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
127
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
128
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
128
129
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
129
130
  emb = torch.cat((freqs, freqs), dim=-1)
130
131
  cos = emb.cos() * self.attention_scaling
@@ -239,7 +240,6 @@ class Exaone4Attention(nn.Module):
239
240
  attention_mask: Optional[torch.Tensor] = None,
240
241
  past_key_values: Optional[Cache] = None,
241
242
  cache_position: Optional[torch.LongTensor] = None,
242
- position_ids: Optional[torch.LongTensor] = None,
243
243
  **kwargs: Unpack[TransformersKwargs],
244
244
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
245
245
  input_shape = hidden_states.shape[:-1]
@@ -260,7 +260,6 @@ class Exaone4Attention(nn.Module):
260
260
  attention_mask: Optional[torch.Tensor] = None,
261
261
  past_key_values: Optional[Cache] = None,
262
262
  cache_position: Optional[torch.LongTensor] = None,
263
- position_ids: Optional[torch.LongTensor] = None,
264
263
  **kwargs: Unpack[TransformersKwargs],
265
264
  ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
266
265
  input_shape = hidden_states.shape[:-1]
@@ -48,6 +48,7 @@ from ...utils import (
48
48
  auto_docstring,
49
49
  logging,
50
50
  )
51
+ from ...utils.generic import maybe_autocast
51
52
  from .configuration_falcon import FalconConfig
52
53
 
53
54
 
@@ -121,7 +122,7 @@ class FalconRotaryEmbedding(nn.Module):
121
122
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
122
123
 
123
124
  self.register_buffer("inv_freq", inv_freq, persistent=False)
124
- self.original_inv_freq = inv_freq
125
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
125
126
 
126
127
  @staticmethod
127
128
  def compute_default_rope_parameters(
@@ -160,7 +161,7 @@ class FalconRotaryEmbedding(nn.Module):
160
161
  position_ids_expanded = position_ids[:, None, :].float()
161
162
 
162
163
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
163
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
164
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
164
165
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
165
166
  emb = torch.cat((freqs, freqs), dim=-1)
166
167
  cos = emb.cos() * self.attention_scaling
@@ -520,8 +521,8 @@ class FalconFlashAttention2(FalconAttention):
520
521
  else torch.get_autocast_gpu_dtype()
521
522
  )
522
523
  # Handle the case where the model is quantized
523
- elif hasattr(self.config, "_pre_quantization_dtype"):
524
- target_dtype = self.config._pre_quantization_dtype
524
+ elif hasattr(self.config, "quantization_config"):
525
+ target_dtype = self.config.dtype
525
526
  else:
526
527
  target_dtype = self.query_key_value.weight.dtype
527
528
 
@@ -739,6 +740,7 @@ class FalconModel(FalconPreTrainedModel):
739
740
  output_hidden_states: Optional[bool] = None,
740
741
  return_dict: Optional[bool] = None,
741
742
  cache_position: Optional[torch.LongTensor] = None,
743
+ **kwargs,
742
744
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
743
745
  r"""
744
746
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -1119,6 +1121,7 @@ class FalconForSequenceClassification(FalconPreTrainedModel):
1119
1121
  output_attentions: Optional[bool] = None,
1120
1122
  output_hidden_states: Optional[bool] = None,
1121
1123
  return_dict: Optional[bool] = None,
1124
+ **kwargs,
1122
1125
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
1123
1126
  r"""
1124
1127
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -1243,6 +1246,7 @@ class FalconForTokenClassification(FalconPreTrainedModel):
1243
1246
  output_attentions: Optional[bool] = None,
1244
1247
  output_hidden_states: Optional[bool] = None,
1245
1248
  return_dict: Optional[bool] = None,
1249
+ **kwargs,
1246
1250
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1247
1251
  r"""
1248
1252
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -1320,6 +1324,7 @@ class FalconForQuestionAnswering(FalconPreTrainedModel):
1320
1324
  output_attentions: Optional[bool] = None,
1321
1325
  output_hidden_states: Optional[bool] = None,
1322
1326
  return_dict: Optional[bool] = None,
1327
+ **kwargs,
1323
1328
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1324
1329
  r"""
1325
1330
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -36,7 +36,7 @@ from transformers.activations import ACT2FN
36
36
  from ... import initialization as init
37
37
  from ...cache_utils import Cache
38
38
  from ...generation import GenerationMixin
39
- from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
39
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub, use_kernelized_func
40
40
  from ...modeling_attn_mask_utils import AttentionMaskConverter
41
41
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
42
42
  from ...modeling_layers import GradientCheckpointingLayer
@@ -45,6 +45,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
45
45
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
46
46
  from ...processing_utils import Unpack
47
47
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torchdynamo_compiling, logging
48
+ from ...utils.generic import maybe_autocast
48
49
  from ...utils.import_utils import is_causal_conv1d_available, is_mamba_2_ssm_available
49
50
  from .configuration_falcon_h1 import FalconH1Config
50
51
 
@@ -240,7 +241,7 @@ class FalconH1RotaryEmbedding(nn.Module):
240
241
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
241
242
 
242
243
  self.register_buffer("inv_freq", inv_freq, persistent=False)
243
- self.original_inv_freq = inv_freq
244
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
244
245
 
245
246
  @staticmethod
246
247
  def compute_default_rope_parameters(
@@ -279,7 +280,7 @@ class FalconH1RotaryEmbedding(nn.Module):
279
280
  position_ids_expanded = position_ids[:, None, :].float()
280
281
 
281
282
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
282
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
283
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
283
284
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
284
285
  emb = torch.cat((freqs, freqs), dim=-1)
285
286
  cos = emb.cos() * self.attention_scaling
@@ -361,6 +362,7 @@ def eager_attention_forward(
361
362
  return attn_output, attn_weights
362
363
 
363
364
 
365
+ @use_kernelized_func(apply_rotary_pos_emb)
364
366
  class FalconH1Attention(nn.Module):
365
367
  """Multi-headed attention from 'Attention Is All You Need' paper"""
366
368
 
@@ -386,7 +388,6 @@ class FalconH1Attention(nn.Module):
386
388
  self.o_proj = nn.Linear(
387
389
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
388
390
  )
389
- self.rotary_fn = apply_rotary_pos_emb
390
391
  self.key_multiplier = config.key_multiplier
391
392
 
392
393
  def forward(
@@ -1186,26 +1187,6 @@ class FalconH1DecoderLayer(GradientCheckpointingLayer):
1186
1187
  return outputs
1187
1188
 
1188
1189
 
1189
- @auto_docstring
1190
- class FalconH1PreTrainedModel(PreTrainedModel):
1191
- config: FalconH1Config
1192
- base_model_prefix = "model"
1193
- supports_gradient_checkpointing = True
1194
- _no_split_modules = ["FalconH1DecoderLayer"]
1195
- _skip_keys_device_placement = "past_key_values"
1196
- _supports_flash_attn = True
1197
- _supports_sdpa = True
1198
- _is_stateful = True
1199
-
1200
- @torch.no_grad()
1201
- def _init_weights(self, module):
1202
- super()._init_weights(module)
1203
- if isinstance(module, FalconH1Mixer):
1204
- init.ones_(module.dt_bias)
1205
- init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
1206
- init.ones_(module.D)
1207
-
1208
-
1209
1190
  def compute_mup_vector(config):
1210
1191
  """
1211
1192
  Computes the MuP vector based on model configuration.
@@ -1243,6 +1224,30 @@ def compute_mup_vector(config):
1243
1224
  return mup_vector
1244
1225
 
1245
1226
 
1227
+ @auto_docstring
1228
+ class FalconH1PreTrainedModel(PreTrainedModel):
1229
+ config: FalconH1Config
1230
+ base_model_prefix = "model"
1231
+ supports_gradient_checkpointing = True
1232
+ _no_split_modules = ["FalconH1DecoderLayer"]
1233
+ _skip_keys_device_placement = "past_key_values"
1234
+ _supports_flash_attn = True
1235
+ _supports_sdpa = True
1236
+ _is_stateful = True
1237
+
1238
+ @torch.no_grad()
1239
+ def _init_weights(self, module):
1240
+ super()._init_weights(module)
1241
+ if isinstance(module, FalconH1Mixer):
1242
+ init.ones_(module.dt_bias)
1243
+ init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
1244
+ init.ones_(module.D)
1245
+ elif isinstance(module, FalconH1Model):
1246
+ mup_vector = compute_mup_vector(module.config)
1247
+ for layer in module.layers:
1248
+ init.copy_(layer.mamba.mup_vector, mup_vector)
1249
+
1250
+
1246
1251
  @auto_docstring
1247
1252
  # Adapted from transformers.models.jamba.modeling_jamba.JambaModel
1248
1253
  class FalconH1Model(FalconH1PreTrainedModel):
@@ -1268,7 +1273,7 @@ class FalconH1Model(FalconH1PreTrainedModel):
1268
1273
  # Compute the MuP vector once and register it for all layers
1269
1274
  mup_vector = compute_mup_vector(config)
1270
1275
  for layer in self.layers:
1271
- layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False)
1276
+ layer.mamba.register_buffer("mup_vector", mup_vector.clone(), persistent=False)
1272
1277
 
1273
1278
  # Initialize weights and apply final processing
1274
1279
  self.post_init()
@@ -1590,6 +1595,7 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
1590
1595
  cache_position=None,
1591
1596
  position_ids=None,
1592
1597
  use_cache=True,
1598
+ is_first_iteration=False,
1593
1599
  **kwargs,
1594
1600
  ):
1595
1601
  # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
@@ -1627,7 +1633,7 @@ class FalconH1ForCausalLM(FalconH1PreTrainedModel, GenerationMixin):
1627
1633
  position_ids = position_ids[:, -input_ids.shape[1] :]
1628
1634
 
1629
1635
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1630
- if inputs_embeds is not None and empty_past_kv:
1636
+ if inputs_embeds is not None and is_first_iteration:
1631
1637
  model_inputs = {"inputs_embeds": inputs_embeds}
1632
1638
  else:
1633
1639
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -928,6 +928,10 @@ class FalconH1PreTrainedModel(PreTrainedModel):
928
928
  init.ones_(module.dt_bias)
929
929
  init.copy_(module.A_log, torch.log(torch.arange(1, module.num_heads + 1)))
930
930
  init.ones_(module.D)
931
+ elif isinstance(module, FalconH1Model):
932
+ mup_vector = compute_mup_vector(module.config)
933
+ for layer in module.layers:
934
+ init.copy_(layer.mamba.mup_vector, mup_vector)
931
935
 
932
936
 
933
937
  def compute_mup_vector(config):
@@ -992,7 +996,7 @@ class FalconH1Model(FalconH1PreTrainedModel):
992
996
  # Compute the MuP vector once and register it for all layers
993
997
  mup_vector = compute_mup_vector(config)
994
998
  for layer in self.layers:
995
- layer.mamba.register_buffer("mup_vector", mup_vector, persistent=False)
999
+ layer.mamba.register_buffer("mup_vector", mup_vector.clone(), persistent=False)
996
1000
 
997
1001
  # Initialize weights and apply final processing
998
1002
  self.post_init()
@@ -1298,6 +1302,7 @@ class FalconH1ForCausalLM(LlamaForCausalLM):
1298
1302
  cache_position=None,
1299
1303
  position_ids=None,
1300
1304
  use_cache=True,
1305
+ is_first_iteration=False,
1301
1306
  **kwargs,
1302
1307
  ):
1303
1308
  # Overwritten -- has a unique cache type, `FalconHybridMambaAttentionDynamicCache`
@@ -1335,7 +1340,7 @@ class FalconH1ForCausalLM(LlamaForCausalLM):
1335
1340
  position_ids = position_ids[:, -input_ids.shape[1] :]
1336
1341
 
1337
1342
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
1338
- if inputs_embeds is not None and empty_past_kv:
1343
+ if inputs_embeds is not None and is_first_iteration:
1339
1344
  model_inputs = {"inputs_embeds": inputs_embeds}
1340
1345
  else:
1341
1346
  model_inputs = {"input_ids": input_ids.contiguous()} # `contiguous()` needed for compilation use cases
@@ -31,15 +31,11 @@ from ... import initialization as init
31
31
  from ...activations import ACT2FN
32
32
  from ...configuration_utils import PreTrainedConfig
33
33
  from ...generation import GenerationMixin
34
- from ...integrations.hub_kernels import lazy_load_kernel
34
+ from ...integrations import lazy_load_kernel
35
35
  from ...modeling_layers import GradientCheckpointingLayer
36
36
  from ...modeling_utils import PreTrainedModel
37
37
  from ...utils import ModelOutput, auto_docstring, logging
38
- from ...utils.import_utils import (
39
- is_mamba_ssm_available,
40
- is_mambapy_available,
41
- is_torchdynamo_compiling,
42
- )
38
+ from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling
43
39
  from .configuration_falcon_mamba import FalconMambaConfig
44
40
 
45
41
 
@@ -48,14 +44,6 @@ if is_mambapy_available():
48
44
  else:
49
45
  pscan = None
50
46
 
51
- if is_mamba_ssm_available():
52
- from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
53
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
54
-
55
- from ...kernels.falcon_mamba import mamba_inner_fn
56
- else:
57
- selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
58
-
59
47
 
60
48
  logger = logging.get_logger(__name__)
61
49
 
@@ -231,7 +219,27 @@ class FalconMambaMixer(nn.Module):
231
219
  self.out_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.use_bias)
232
220
  self.use_bias = config.use_bias
233
221
 
222
+ global causal_conv1d, causal_conv1d_update, causal_conv1d_fn
223
+ causal_conv1d = lazy_load_kernel("causal-conv1d")
224
+ causal_conv1d_update, causal_conv1d_fn = (
225
+ (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
226
+ if causal_conv1d is not None
227
+ else (None, None)
228
+ )
229
+ global falcon_mamba_ssm, selective_state_update, selective_scan_fn, falcon_mamba_inner_fn
230
+ falcon_mamba_ssm = lazy_load_kernel("falcon_mamba-ssm")
231
+ selective_state_update, selective_scan_fn, falcon_mamba_inner_fn = (
232
+ (
233
+ falcon_mamba_ssm.selective_state_update,
234
+ falcon_mamba_ssm.selective_scan_fn,
235
+ falcon_mamba_ssm.falcon_mamba_inner_fn,
236
+ )
237
+ if falcon_mamba_ssm is not None
238
+ else (None, None, None)
239
+ )
240
+
234
241
  self.warn_slow_implementation()
242
+
235
243
  # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
236
244
  self.register_buffer(
237
245
  "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
@@ -242,14 +250,8 @@ class FalconMambaMixer(nn.Module):
242
250
  self.rms_eps = config.mixer_rms_eps
243
251
 
244
252
  def warn_slow_implementation(self):
245
- causal_conv1d = lazy_load_kernel("causal-conv1d")
246
- causal_conv1d_update, causal_conv1d_fn = (
247
- (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
248
- if causal_conv1d is not None
249
- else (None, None)
250
- )
251
253
  is_fast_path_available = all(
252
- (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
254
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn)
253
255
  )
254
256
  if not is_fast_path_available:
255
257
  if self.use_falcon_mambapy:
@@ -279,9 +281,8 @@ class FalconMambaMixer(nn.Module):
279
281
  ):
280
282
  # 1. Gated MLP's linear projection
281
283
  projected_states = self.in_proj(hidden_states).transpose(1, 2)
282
-
283
284
  if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
284
- contextualized_states = mamba_inner_fn(
285
+ contextualized_states = falcon_mamba_inner_fn(
285
286
  projected_states,
286
287
  self.conv1d.weight,
287
288
  self.conv1d.bias if self.use_conv_bias else None,
@@ -302,12 +303,6 @@ class FalconMambaMixer(nn.Module):
302
303
  )
303
304
 
304
305
  else:
305
- causal_conv1d = lazy_load_kernel("causal-conv1d")
306
- causal_conv1d_update, causal_conv1d_fn = (
307
- (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
308
- if causal_conv1d is not None
309
- else (None, None)
310
- )
311
306
  hidden_states, gate = projected_states.chunk(2, dim=1)
312
307
 
313
308
  if attention_mask is not None:
@@ -350,7 +345,7 @@ class FalconMambaMixer(nn.Module):
350
345
 
351
346
  # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
352
347
  # at the price of a small overhead.
353
- if hasattr(self.config, "_pre_quantization_dtype"):
348
+ if hasattr(self.config, "quantization_config"):
354
349
  discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
355
350
  else:
356
351
  discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
@@ -502,14 +497,8 @@ class FalconMambaMixer(nn.Module):
502
497
  cache_position: Optional[torch.LongTensor] = None,
503
498
  attention_mask: Optional[torch.LongTensor] = None,
504
499
  ):
505
- causal_conv1d = lazy_load_kernel("causal-conv1d")
506
- causal_conv1d_update, causal_conv1d_fn = (
507
- (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
508
- if causal_conv1d is not None
509
- else (None, None)
510
- )
511
500
  is_fast_path_available = all(
512
- (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
501
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn)
513
502
  )
514
503
  if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not is_torchdynamo_compiling():
515
504
  return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
@@ -624,6 +613,9 @@ class FalconMambaPreTrainedModel(PreTrainedModel):
624
613
  init.ones_(module.weight)
625
614
  elif isinstance(module, nn.Embedding):
626
615
  init.normal_(module.weight, std=std)
616
+ if isinstance(module, FalconMambaMixer):
617
+ init.ones_(module.b_c_rms)
618
+ init.ones_(module.dt_rms)
627
619
 
628
620
 
629
621
  @dataclass
@@ -703,6 +695,7 @@ class FalconMambaModel(FalconMambaPreTrainedModel):
703
695
  return_dict: Optional[bool] = None,
704
696
  cache_position: Optional[torch.LongTensor] = None,
705
697
  attention_mask: Optional[torch.LongTensor] = None,
698
+ **kwargs,
706
699
  ) -> Union[tuple, FalconMambaOutput]:
707
700
  r"""
708
701
  cache_params (`FalconMambaCache`, *optional*):
@@ -821,6 +814,7 @@ class FalconMambaForCausalLM(FalconMambaPreTrainedModel, GenerationMixin):
821
814
  cache_params: Optional[FalconMambaCache] = None,
822
815
  cache_position: Optional[torch.LongTensor] = None,
823
816
  attention_mask: Optional[torch.LongTensor] = None,
817
+ is_first_iteration: Optional[bool] = False,
824
818
  **kwargs,
825
819
  ):
826
820
  # Overwritten -- uses `cache_params` as opposed to `past_key_values`
@@ -19,9 +19,9 @@ from typing import Optional
19
19
  import torch
20
20
  from torch import nn
21
21
 
22
- from ...integrations.hub_kernels import lazy_load_kernel
22
+ from ... import initialization as init
23
23
  from ...utils import auto_docstring, logging
24
- from ...utils.import_utils import is_mamba_ssm_available, is_mambapy_available, is_torchdynamo_compiling
24
+ from ...utils.import_utils import is_mambapy_available, is_torchdynamo_compiling
25
25
  from ..mamba.configuration_mamba import MambaConfig
26
26
  from ..mamba.modeling_mamba import (
27
27
  MambaBlock,
@@ -43,13 +43,13 @@ if is_mambapy_available():
43
43
  else:
44
44
  pscan = None
45
45
 
46
- if is_mamba_ssm_available():
47
- from mamba_ssm.ops.selective_scan_interface import selective_scan_fn
48
- from mamba_ssm.ops.triton.selective_state_update import selective_state_update
49
-
50
- from ...kernels.falcon_mamba import mamba_inner_fn
51
- else:
52
- selective_state_update, selective_scan_fn, mamba_inner_fn = None, None, None
46
+ selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn = (
47
+ None,
48
+ None,
49
+ None,
50
+ None,
51
+ None,
52
+ )
53
53
 
54
54
 
55
55
  class FalconMambaConfig(MambaConfig):
@@ -251,14 +251,8 @@ def rms_forward(hidden_states, variance_epsilon=1e-6):
251
251
 
252
252
  class FalconMambaMixer(MambaMixer):
253
253
  def warn_slow_implementation(self):
254
- causal_conv1d = lazy_load_kernel("causal-conv1d")
255
- causal_conv1d_update, causal_conv1d_fn = (
256
- (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
257
- if causal_conv1d is not None
258
- else (None, None)
259
- )
260
254
  is_fast_path_available = all(
261
- (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
255
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn)
262
256
  )
263
257
  if not is_fast_path_available:
264
258
  if self.use_falcon_mambapy:
@@ -281,6 +275,7 @@ class FalconMambaMixer(MambaMixer):
281
275
 
282
276
  def __init__(self, config: FalconMambaConfig, layer_idx: int):
283
277
  super().__init__(config, layer_idx)
278
+
284
279
  # Triton expects to pass RMS weights even if they are non learnable, thus we need to create these weights here
285
280
  self.register_buffer(
286
281
  "b_c_rms", torch.nn.Parameter(torch.ones(self.ssm_state_size), requires_grad=False), persistent=False
@@ -299,9 +294,8 @@ class FalconMambaMixer(MambaMixer):
299
294
  ):
300
295
  # 1. Gated MLP's linear projection
301
296
  projected_states = self.in_proj(hidden_states).transpose(1, 2)
302
-
303
297
  if self.training and cache_params is None: # Doesn't support outputting the states -> used for training
304
- contextualized_states = mamba_inner_fn(
298
+ contextualized_states = falcon_mamba_inner_fn(
305
299
  projected_states,
306
300
  self.conv1d.weight,
307
301
  self.conv1d.bias if self.use_conv_bias else None,
@@ -322,12 +316,6 @@ class FalconMambaMixer(MambaMixer):
322
316
  )
323
317
 
324
318
  else:
325
- causal_conv1d = lazy_load_kernel("causal-conv1d")
326
- causal_conv1d_update, causal_conv1d_fn = (
327
- (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
328
- if causal_conv1d is not None
329
- else (None, None)
330
- )
331
319
  hidden_states, gate = projected_states.chunk(2, dim=1)
332
320
 
333
321
  if attention_mask is not None:
@@ -370,7 +358,7 @@ class FalconMambaMixer(MambaMixer):
370
358
 
371
359
  # In case the model has been quantized, we need a hack to properly call the `nn.Linear` module
372
360
  # at the price of a small overhead.
373
- if hasattr(self.config, "_pre_quantization_dtype"):
361
+ if hasattr(self.config, "quantization_config"):
374
362
  discrete_time_step = (self.dt_proj(time_step) - self.dt_proj.bias).transpose(1, 2)
375
363
  else:
376
364
  discrete_time_step = self.dt_proj.weight @ time_step.transpose(1, 2)
@@ -521,14 +509,8 @@ class FalconMambaMixer(MambaMixer):
521
509
  cache_position: Optional[torch.LongTensor] = None,
522
510
  attention_mask: Optional[torch.LongTensor] = None,
523
511
  ):
524
- causal_conv1d = lazy_load_kernel("causal-conv1d")
525
- causal_conv1d_update, causal_conv1d_fn = (
526
- (causal_conv1d.causal_conv1d_update, causal_conv1d.causal_conv1d_fn)
527
- if causal_conv1d is not None
528
- else (None, None)
529
- )
530
512
  is_fast_path_available = all(
531
- (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, mamba_inner_fn)
513
+ (selective_state_update, selective_scan_fn, causal_conv1d_fn, causal_conv1d_update, falcon_mamba_inner_fn)
532
514
  )
533
515
  if is_fast_path_available and "cuda" in self.x_proj.weight.device.type and not is_torchdynamo_compiling():
534
516
  return self.cuda_kernels_forward(hidden_states, cache_params, cache_position, attention_mask)
@@ -548,7 +530,11 @@ class FalconMambaBlock(MambaBlock):
548
530
 
549
531
  @auto_docstring
550
532
  class FalconMambaPreTrainedModel(MambaPreTrainedModel):
551
- pass
533
+ def _init_weights(self, module):
534
+ super()._init_weights(module)
535
+ if isinstance(module, FalconMambaMixer):
536
+ init.ones_(module.b_c_rms)
537
+ init.ones_(module.dt_rms)
552
538
 
553
539
 
554
540
  class FalconMambaOutput(MambaOutput):