transformers 5.0.0rc0__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (835) hide show
  1. transformers/__init__.py +49 -3
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/cli/serve.py +47 -17
  6. transformers/configuration_utils.py +114 -70
  7. transformers/conversion_mapping.py +83 -7
  8. transformers/convert_slow_tokenizer.py +225 -10
  9. transformers/core_model_loading.py +374 -147
  10. transformers/data/data_collator.py +12 -4
  11. transformers/dependency_versions_table.py +2 -3
  12. transformers/dynamic_module_utils.py +1 -2
  13. transformers/feature_extraction_utils.py +55 -24
  14. transformers/file_utils.py +0 -1
  15. transformers/generation/__init__.py +11 -1
  16. transformers/generation/candidate_generator.py +79 -31
  17. transformers/generation/configuration_utils.py +165 -124
  18. transformers/generation/continuous_batching/__init__.py +4 -0
  19. transformers/generation/continuous_batching/cache.py +47 -18
  20. transformers/generation/continuous_batching/cache_manager.py +131 -34
  21. transformers/generation/continuous_batching/continuous_api.py +228 -136
  22. transformers/generation/continuous_batching/requests.py +28 -1
  23. transformers/generation/continuous_batching/scheduler.py +11 -4
  24. transformers/generation/stopping_criteria.py +1 -1
  25. transformers/generation/utils.py +108 -110
  26. transformers/generation/watermarking.py +8 -5
  27. transformers/image_processing_base.py +3 -14
  28. transformers/image_processing_utils_fast.py +15 -4
  29. transformers/initialization.py +37 -0
  30. transformers/integrations/__init__.py +16 -2
  31. transformers/integrations/accelerate.py +58 -113
  32. transformers/integrations/aqlm.py +36 -66
  33. transformers/integrations/awq.py +46 -515
  34. transformers/integrations/bitnet.py +47 -105
  35. transformers/integrations/bitsandbytes.py +91 -202
  36. transformers/integrations/deepspeed.py +18 -2
  37. transformers/integrations/eetq.py +84 -81
  38. transformers/integrations/fbgemm_fp8.py +191 -145
  39. transformers/integrations/finegrained_fp8.py +241 -208
  40. transformers/integrations/flash_attention.py +2 -2
  41. transformers/integrations/fp_quant.py +92 -0
  42. transformers/integrations/ggml.py +11 -1
  43. transformers/integrations/higgs.py +37 -62
  44. transformers/integrations/hub_kernels.py +65 -8
  45. transformers/integrations/integration_utils.py +45 -0
  46. transformers/integrations/mistral.py +12 -0
  47. transformers/integrations/moe.py +240 -0
  48. transformers/integrations/mxfp4.py +28 -74
  49. transformers/integrations/peft.py +12 -29
  50. transformers/integrations/quanto.py +77 -56
  51. transformers/integrations/quark.py +55 -0
  52. transformers/integrations/spqr.py +42 -90
  53. transformers/integrations/tensor_parallel.py +167 -221
  54. transformers/integrations/torchao.py +32 -38
  55. transformers/integrations/vptq.py +40 -59
  56. transformers/modelcard.py +1 -2
  57. transformers/modeling_gguf_pytorch_utils.py +74 -19
  58. transformers/modeling_rope_utils.py +107 -86
  59. transformers/modeling_utils.py +611 -527
  60. transformers/models/__init__.py +22 -0
  61. transformers/models/afmoe/modeling_afmoe.py +10 -19
  62. transformers/models/afmoe/modular_afmoe.py +5 -13
  63. transformers/models/aimv2/modeling_aimv2.py +4 -0
  64. transformers/models/aimv2/modular_aimv2.py +4 -0
  65. transformers/models/albert/modeling_albert.py +3 -0
  66. transformers/models/albert/tokenization_albert.py +6 -12
  67. transformers/models/align/modeling_align.py +14 -6
  68. transformers/models/altclip/modeling_altclip.py +11 -3
  69. transformers/models/apertus/modeling_apertus.py +8 -6
  70. transformers/models/apertus/modular_apertus.py +4 -1
  71. transformers/models/arcee/modeling_arcee.py +5 -5
  72. transformers/models/aria/modeling_aria.py +12 -8
  73. transformers/models/aria/modular_aria.py +7 -3
  74. transformers/models/audioflamingo3/modeling_audioflamingo3.py +1 -0
  75. transformers/models/audioflamingo3/modular_audioflamingo3.py +1 -0
  76. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  77. transformers/models/auto/auto_factory.py +1 -1
  78. transformers/models/auto/configuration_auto.py +38 -0
  79. transformers/models/auto/feature_extraction_auto.py +9 -3
  80. transformers/models/auto/image_processing_auto.py +5 -2
  81. transformers/models/auto/modeling_auto.py +37 -0
  82. transformers/models/auto/processing_auto.py +22 -10
  83. transformers/models/auto/tokenization_auto.py +147 -566
  84. transformers/models/auto/video_processing_auto.py +5 -2
  85. transformers/models/autoformer/modeling_autoformer.py +4 -0
  86. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  87. transformers/models/bamba/modeling_bamba.py +21 -21
  88. transformers/models/bamba/modular_bamba.py +17 -16
  89. transformers/models/bark/modeling_bark.py +11 -0
  90. transformers/models/bart/configuration_bart.py +0 -1
  91. transformers/models/bart/modeling_bart.py +14 -0
  92. transformers/models/barthez/tokenization_barthez.py +5 -10
  93. transformers/models/beit/image_processing_beit_fast.py +0 -1
  94. transformers/models/beit/modeling_beit.py +6 -1
  95. transformers/models/bert/modeling_bert.py +3 -0
  96. transformers/models/bert/tokenization_bert.py +8 -21
  97. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  98. transformers/models/big_bird/modeling_big_bird.py +9 -0
  99. transformers/models/big_bird/tokenization_big_bird.py +18 -42
  100. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +15 -2
  101. transformers/models/biogpt/modeling_biogpt.py +2 -0
  102. transformers/models/biogpt/modular_biogpt.py +2 -0
  103. transformers/models/bit/modeling_bit.py +16 -3
  104. transformers/models/bitnet/modeling_bitnet.py +5 -5
  105. transformers/models/blenderbot/modeling_blenderbot.py +12 -0
  106. transformers/models/blenderbot/tokenization_blenderbot.py +18 -23
  107. transformers/models/blenderbot_small/modeling_blenderbot_small.py +12 -0
  108. transformers/models/blip/modeling_blip.py +2 -0
  109. transformers/models/blip/modeling_blip_text.py +10 -0
  110. transformers/models/blip_2/modeling_blip_2.py +4 -1
  111. transformers/models/bloom/modeling_bloom.py +17 -44
  112. transformers/models/blt/modeling_blt.py +164 -4
  113. transformers/models/blt/modular_blt.py +170 -5
  114. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  115. transformers/models/bridgetower/modeling_bridgetower.py +11 -1
  116. transformers/models/bros/modeling_bros.py +12 -0
  117. transformers/models/camembert/modeling_camembert.py +109 -106
  118. transformers/models/camembert/tokenization_camembert.py +8 -12
  119. transformers/models/canine/modeling_canine.py +11 -0
  120. transformers/models/canine/tokenization_canine.py +2 -0
  121. transformers/models/chameleon/modeling_chameleon.py +11 -5
  122. transformers/models/chinese_clip/modeling_chinese_clip.py +9 -3
  123. transformers/models/clap/feature_extraction_clap.py +2 -2
  124. transformers/models/clap/modeling_clap.py +30 -15
  125. transformers/models/clip/modeling_clip.py +2 -0
  126. transformers/models/clip/tokenization_clip.py +22 -44
  127. transformers/models/clipseg/modeling_clipseg.py +9 -0
  128. transformers/models/clvp/modeling_clvp.py +19 -3
  129. transformers/models/clvp/tokenization_clvp.py +1 -63
  130. transformers/models/code_llama/tokenization_code_llama.py +20 -43
  131. transformers/models/codegen/modeling_codegen.py +13 -4
  132. transformers/models/codegen/tokenization_codegen.py +14 -43
  133. transformers/models/cohere/modeling_cohere.py +5 -4
  134. transformers/models/cohere/modular_cohere.py +2 -1
  135. transformers/models/cohere/tokenization_cohere.py +12 -42
  136. transformers/models/cohere2/modeling_cohere2.py +8 -7
  137. transformers/models/cohere2/modular_cohere2.py +5 -5
  138. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +4 -4
  139. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  140. transformers/models/cohere2_vision/modular_cohere2_vision.py +4 -3
  141. transformers/models/colqwen2/modeling_colqwen2.py +1 -0
  142. transformers/models/colqwen2/modular_colqwen2.py +1 -0
  143. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  144. transformers/models/conditional_detr/modeling_conditional_detr.py +9 -1
  145. transformers/models/convbert/modeling_convbert.py +9 -0
  146. transformers/models/convnext/image_processing_convnext.py +2 -2
  147. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  148. transformers/models/convnext/modeling_convnext.py +2 -4
  149. transformers/models/convnextv2/modeling_convnextv2.py +2 -4
  150. transformers/models/csm/generation_csm.py +19 -22
  151. transformers/models/csm/modeling_csm.py +7 -4
  152. transformers/models/csm/modular_csm.py +2 -0
  153. transformers/models/ctrl/modeling_ctrl.py +15 -2
  154. transformers/models/cvt/modeling_cvt.py +7 -1
  155. transformers/models/cwm/modeling_cwm.py +5 -5
  156. transformers/models/d_fine/configuration_d_fine.py +3 -4
  157. transformers/models/d_fine/modeling_d_fine.py +48 -39
  158. transformers/models/d_fine/modular_d_fine.py +16 -4
  159. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  160. transformers/models/dab_detr/modeling_dab_detr.py +5 -1
  161. transformers/models/dac/modeling_dac.py +6 -6
  162. transformers/models/data2vec/modeling_data2vec_audio.py +5 -0
  163. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  164. transformers/models/data2vec/modeling_data2vec_vision.py +4 -1
  165. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  166. transformers/models/dbrx/configuration_dbrx.py +9 -1
  167. transformers/models/dbrx/modeling_dbrx.py +3 -3
  168. transformers/models/deberta/modeling_deberta.py +7 -0
  169. transformers/models/deberta/tokenization_deberta.py +11 -20
  170. transformers/models/deberta_v2/modeling_deberta_v2.py +8 -0
  171. transformers/models/deberta_v2/tokenization_deberta_v2.py +13 -28
  172. transformers/models/decision_transformer/modeling_decision_transformer.py +12 -6
  173. transformers/models/deepseek_v2/modeling_deepseek_v2.py +9 -7
  174. transformers/models/deepseek_v2/modular_deepseek_v2.py +6 -4
  175. transformers/models/deepseek_v3/modeling_deepseek_v3.py +12 -7
  176. transformers/models/deepseek_v3/modular_deepseek_v3.py +7 -2
  177. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  178. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  179. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  180. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  181. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  182. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  183. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  184. transformers/models/deformable_detr/modeling_deformable_detr.py +5 -1
  185. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  186. transformers/models/depth_anything/modeling_depth_anything.py +1 -0
  187. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  188. transformers/models/depth_pro/modeling_depth_pro.py +2 -0
  189. transformers/models/detr/configuration_detr.py +1 -1
  190. transformers/models/detr/modeling_detr.py +13 -1
  191. transformers/models/dia/generation_dia.py +3 -10
  192. transformers/models/dia/modeling_dia.py +16 -4
  193. transformers/models/dia/modular_dia.py +11 -1
  194. transformers/models/dia/processing_dia.py +1 -1
  195. transformers/models/diffllama/modeling_diffllama.py +5 -5
  196. transformers/models/diffllama/modular_diffllama.py +2 -2
  197. transformers/models/dinat/modeling_dinat.py +3 -0
  198. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +1 -1
  199. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  200. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -2
  201. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -2
  202. transformers/models/distilbert/modeling_distilbert.py +11 -9
  203. transformers/models/distilbert/tokenization_distilbert.py +13 -0
  204. transformers/models/doge/modeling_doge.py +3 -4
  205. transformers/models/doge/modular_doge.py +0 -1
  206. transformers/models/donut/image_processing_donut_fast.py +0 -1
  207. transformers/models/donut/modeling_donut_swin.py +18 -12
  208. transformers/models/dots1/modeling_dots1.py +23 -11
  209. transformers/models/dots1/modular_dots1.py +5 -3
  210. transformers/models/dpr/modeling_dpr.py +5 -0
  211. transformers/models/dpr/tokenization_dpr.py +12 -0
  212. transformers/models/dpt/configuration_dpt.py +1 -1
  213. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  214. transformers/models/dpt/modular_dpt.py +1 -2
  215. transformers/models/edgetam/configuration_edgetam.py +1 -1
  216. transformers/models/edgetam/modeling_edgetam.py +6 -3
  217. transformers/models/edgetam/modular_edgetam.py +15 -14
  218. transformers/models/edgetam_video/modeling_edgetam_video.py +56 -43
  219. transformers/models/edgetam_video/modular_edgetam_video.py +14 -19
  220. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  221. transformers/models/efficientloftr/modeling_efficientloftr.py +16 -3
  222. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  223. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  224. transformers/models/efficientnet/modeling_efficientnet.py +7 -1
  225. transformers/models/electra/modeling_electra.py +7 -0
  226. transformers/models/emu3/modeling_emu3.py +12 -6
  227. transformers/models/emu3/modular_emu3.py +7 -1
  228. transformers/models/encodec/modeling_encodec.py +14 -0
  229. transformers/models/eomt/image_processing_eomt.py +13 -1
  230. transformers/models/eomt/image_processing_eomt_fast.py +60 -16
  231. transformers/models/eomt/modeling_eomt.py +7 -0
  232. transformers/models/eomt/modular_eomt.py +7 -0
  233. transformers/models/ernie/modeling_ernie.py +6 -0
  234. transformers/models/ernie/modular_ernie.py +6 -0
  235. transformers/models/ernie4_5/modeling_ernie4_5.py +5 -5
  236. transformers/models/ernie4_5/modular_ernie4_5.py +2 -1
  237. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +20 -17
  238. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +11 -37
  239. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  240. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  241. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  242. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  243. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  244. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  245. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  246. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  247. transformers/models/esm/modeling_esm.py +6 -0
  248. transformers/models/esm/modeling_esmfold.py +11 -5
  249. transformers/models/evolla/modeling_evolla.py +13 -5
  250. transformers/models/evolla/modular_evolla.py +8 -0
  251. transformers/models/exaone4/modeling_exaone4.py +3 -3
  252. transformers/models/exaone4/modular_exaone4.py +0 -1
  253. transformers/models/falcon/modeling_falcon.py +9 -4
  254. transformers/models/falcon_h1/modeling_falcon_h1.py +32 -26
  255. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  256. transformers/models/falcon_mamba/modeling_falcon_mamba.py +31 -37
  257. transformers/models/falcon_mamba/modular_falcon_mamba.py +19 -33
  258. transformers/models/fast_vlm/__init__.py +27 -0
  259. transformers/models/fast_vlm/configuration_fast_vlm.py +137 -0
  260. transformers/models/fast_vlm/modeling_fast_vlm.py +459 -0
  261. transformers/models/fast_vlm/modular_fast_vlm.py +273 -0
  262. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +31 -13
  263. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  264. transformers/models/flaubert/modeling_flaubert.py +21 -15
  265. transformers/models/flava/image_processing_flava_fast.py +0 -2
  266. transformers/models/flava/modeling_flava.py +10 -2
  267. transformers/models/flex_olmo/modeling_flex_olmo.py +10 -8
  268. transformers/models/florence2/modeling_florence2.py +22 -4
  269. transformers/models/florence2/modular_florence2.py +15 -1
  270. transformers/models/fnet/modeling_fnet.py +14 -0
  271. transformers/models/focalnet/modeling_focalnet.py +4 -0
  272. transformers/models/fsmt/modeling_fsmt.py +2 -0
  273. transformers/models/funnel/modeling_funnel.py +8 -0
  274. transformers/models/funnel/tokenization_funnel.py +17 -24
  275. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  276. transformers/models/fuyu/modeling_fuyu.py +3 -1
  277. transformers/models/fuyu/processing_fuyu.py +19 -3
  278. transformers/models/gemma/modeling_gemma.py +14 -16
  279. transformers/models/gemma/modular_gemma.py +9 -11
  280. transformers/models/gemma/tokenization_gemma.py +10 -27
  281. transformers/models/gemma2/modeling_gemma2.py +5 -5
  282. transformers/models/gemma2/modular_gemma2.py +3 -2
  283. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  284. transformers/models/gemma3/modeling_gemma3.py +42 -91
  285. transformers/models/gemma3/modular_gemma3.py +38 -87
  286. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  287. transformers/models/gemma3n/modeling_gemma3n.py +65 -218
  288. transformers/models/gemma3n/modular_gemma3n.py +68 -68
  289. transformers/models/git/modeling_git.py +183 -126
  290. transformers/models/glm/modeling_glm.py +5 -5
  291. transformers/models/glm4/modeling_glm4.py +5 -5
  292. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  293. transformers/models/glm46v/modeling_glm46v.py +3 -1
  294. transformers/models/glm46v/modular_glm46v.py +3 -0
  295. transformers/models/glm4_moe/modeling_glm4_moe.py +13 -7
  296. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  297. transformers/models/glm4v/configuration_glm4v.py +3 -1
  298. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  299. transformers/models/glm4v/modeling_glm4v.py +18 -8
  300. transformers/models/glm4v/modular_glm4v.py +17 -7
  301. transformers/models/glm4v_moe/configuration_glm4v_moe.py +3 -1
  302. transformers/models/glm4v_moe/modeling_glm4v_moe.py +44 -27
  303. transformers/models/glm4v_moe/modular_glm4v_moe.py +13 -1
  304. transformers/models/glmasr/__init__.py +30 -0
  305. transformers/models/glmasr/configuration_glmasr.py +197 -0
  306. transformers/models/glmasr/modeling_glmasr.py +512 -0
  307. transformers/models/glmasr/modular_glmasr.py +433 -0
  308. transformers/models/glmasr/processing_glmasr.py +332 -0
  309. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  310. transformers/models/glpn/modeling_glpn.py +2 -0
  311. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  312. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  313. transformers/models/gpt2/modeling_gpt2.py +13 -6
  314. transformers/models/gpt2/tokenization_gpt2.py +16 -44
  315. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +4 -8
  316. transformers/models/gpt_neo/modeling_gpt_neo.py +19 -3
  317. transformers/models/gpt_neox/modeling_gpt_neox.py +6 -3
  318. transformers/models/gpt_neox/modular_gpt_neox.py +3 -0
  319. transformers/models/gpt_neox/tokenization_gpt_neox.py +10 -49
  320. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +4 -2
  321. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  322. transformers/models/gpt_oss/modeling_gpt_oss.py +10 -14
  323. transformers/models/gpt_oss/modular_gpt_oss.py +8 -12
  324. transformers/models/gptj/modeling_gptj.py +18 -6
  325. transformers/models/granite/modeling_granite.py +5 -5
  326. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  327. transformers/models/granitemoe/modeling_granitemoe.py +6 -9
  328. transformers/models/granitemoe/modular_granitemoe.py +1 -4
  329. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  330. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +36 -28
  331. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  332. transformers/models/granitemoeshared/modeling_granitemoeshared.py +6 -9
  333. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  334. transformers/models/grounding_dino/modeling_grounding_dino.py +8 -4
  335. transformers/models/groupvit/modeling_groupvit.py +9 -1
  336. transformers/models/helium/modeling_helium.py +5 -4
  337. transformers/models/herbert/tokenization_herbert.py +9 -25
  338. transformers/models/hgnet_v2/modeling_hgnet_v2.py +16 -1
  339. transformers/models/hgnet_v2/modular_hgnet_v2.py +16 -1
  340. transformers/models/hiera/modeling_hiera.py +4 -0
  341. transformers/models/hubert/modeling_hubert.py +7 -0
  342. transformers/models/hubert/modular_hubert.py +5 -0
  343. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +5 -5
  344. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  345. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  346. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +15 -7
  347. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  348. transformers/models/ibert/modeling_ibert.py +22 -0
  349. transformers/models/idefics/modeling_idefics.py +15 -21
  350. transformers/models/idefics2/modeling_idefics2.py +7 -1
  351. transformers/models/idefics3/modeling_idefics3.py +5 -1
  352. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  353. transformers/models/imagegpt/modeling_imagegpt.py +11 -3
  354. transformers/models/informer/modeling_informer.py +4 -0
  355. transformers/models/informer/modular_informer.py +1 -0
  356. transformers/models/instructblip/modeling_instructblip.py +2 -0
  357. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  358. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  359. transformers/models/internvl/modeling_internvl.py +13 -12
  360. transformers/models/internvl/modular_internvl.py +7 -13
  361. transformers/models/internvl/video_processing_internvl.py +0 -1
  362. transformers/models/jais2/__init__.py +27 -0
  363. transformers/models/jais2/configuration_jais2.py +152 -0
  364. transformers/models/jais2/modeling_jais2.py +486 -0
  365. transformers/models/jais2/modular_jais2.py +196 -0
  366. transformers/models/jamba/modeling_jamba.py +25 -20
  367. transformers/models/jamba/modular_jamba.py +17 -17
  368. transformers/models/janus/image_processing_janus_fast.py +0 -1
  369. transformers/models/janus/modeling_janus.py +16 -7
  370. transformers/models/janus/modular_janus.py +17 -7
  371. transformers/models/jetmoe/modeling_jetmoe.py +4 -4
  372. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  373. transformers/models/kosmos2/modeling_kosmos2.py +15 -2
  374. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  375. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  376. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +12 -4
  377. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  378. transformers/models/lasr/__init__.py +29 -0
  379. transformers/models/lasr/configuration_lasr.py +248 -0
  380. transformers/models/lasr/feature_extraction_lasr.py +277 -0
  381. transformers/models/lasr/modeling_lasr.py +730 -0
  382. transformers/models/lasr/modular_lasr.py +576 -0
  383. transformers/models/lasr/processing_lasr.py +94 -0
  384. transformers/models/lasr/tokenization_lasr.py +186 -0
  385. transformers/models/layoutlm/modeling_layoutlm.py +10 -3
  386. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  387. transformers/models/layoutlmv2/modeling_layoutlmv2.py +16 -0
  388. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +11 -53
  389. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  390. transformers/models/layoutlmv3/modeling_layoutlmv3.py +33 -5
  391. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +12 -61
  392. transformers/models/layoutxlm/tokenization_layoutxlm.py +13 -38
  393. transformers/models/led/modeling_led.py +12 -0
  394. transformers/models/levit/modeling_levit.py +21 -0
  395. transformers/models/lfm2/modeling_lfm2.py +5 -6
  396. transformers/models/lfm2/modular_lfm2.py +0 -1
  397. transformers/models/lfm2_moe/modeling_lfm2_moe.py +17 -8
  398. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  399. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  400. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  401. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  402. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  403. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  404. transformers/models/lightglue/modeling_lightglue.py +3 -1
  405. transformers/models/lightglue/modular_lightglue.py +1 -0
  406. transformers/models/lilt/modeling_lilt.py +23 -15
  407. transformers/models/llama/modeling_llama.py +5 -5
  408. transformers/models/llama/tokenization_llama.py +15 -43
  409. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  410. transformers/models/llama4/modeling_llama4.py +11 -6
  411. transformers/models/llava/image_processing_llava_fast.py +0 -1
  412. transformers/models/llava/modeling_llava.py +12 -7
  413. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  414. transformers/models/llava_next/modeling_llava_next.py +7 -3
  415. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  416. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  417. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  418. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  419. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  420. transformers/models/longcat_flash/modeling_longcat_flash.py +6 -5
  421. transformers/models/longcat_flash/modular_longcat_flash.py +3 -2
  422. transformers/models/longformer/modeling_longformer.py +6 -0
  423. transformers/models/longt5/modeling_longt5.py +4 -4
  424. transformers/models/luke/modeling_luke.py +9 -0
  425. transformers/models/luke/tokenization_luke.py +11 -38
  426. transformers/models/lxmert/modeling_lxmert.py +2 -0
  427. transformers/models/m2m_100/modeling_m2m_100.py +14 -0
  428. transformers/models/mamba/modeling_mamba.py +16 -23
  429. transformers/models/mamba2/modeling_mamba2.py +24 -23
  430. transformers/models/marian/configuration_marian.py +1 -1
  431. transformers/models/marian/modeling_marian.py +8 -0
  432. transformers/models/markuplm/modeling_markuplm.py +9 -8
  433. transformers/models/markuplm/tokenization_markuplm.py +28 -61
  434. transformers/models/mask2former/configuration_mask2former.py +3 -3
  435. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  436. transformers/models/mask2former/modeling_mask2former.py +11 -0
  437. transformers/models/maskformer/configuration_maskformer.py +3 -3
  438. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  439. transformers/models/maskformer/modeling_maskformer.py +11 -1
  440. transformers/models/maskformer/modeling_maskformer_swin.py +21 -15
  441. transformers/models/mbart/configuration_mbart.py +1 -0
  442. transformers/models/mbart/modeling_mbart.py +14 -0
  443. transformers/models/mbart/tokenization_mbart.py +11 -52
  444. transformers/models/mbart50/tokenization_mbart50.py +7 -10
  445. transformers/models/megatron_bert/modeling_megatron_bert.py +9 -0
  446. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  447. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  448. transformers/models/mgp_str/modeling_mgp_str.py +2 -0
  449. transformers/models/mimi/modeling_mimi.py +28 -5
  450. transformers/models/minimax/modeling_minimax.py +19 -6
  451. transformers/models/minimax/modular_minimax.py +12 -1
  452. transformers/models/ministral/modeling_ministral.py +5 -5
  453. transformers/models/ministral3/configuration_ministral3.py +1 -1
  454. transformers/models/ministral3/modeling_ministral3.py +5 -4
  455. transformers/models/mistral/modeling_mistral.py +5 -4
  456. transformers/models/mistral3/modeling_mistral3.py +10 -4
  457. transformers/models/mistral3/modular_mistral3.py +3 -1
  458. transformers/models/mixtral/modeling_mixtral.py +15 -7
  459. transformers/models/mixtral/modular_mixtral.py +6 -2
  460. transformers/models/mlcd/modeling_mlcd.py +6 -0
  461. transformers/models/mlcd/modular_mlcd.py +4 -0
  462. transformers/models/mllama/modeling_mllama.py +15 -4
  463. transformers/models/mluke/tokenization_mluke.py +6 -6
  464. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  465. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +8 -4
  466. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  467. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  468. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +2 -0
  469. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  470. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +3 -0
  471. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  472. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  473. transformers/models/mobilevit/modeling_mobilevit.py +7 -0
  474. transformers/models/mobilevitv2/modeling_mobilevitv2.py +7 -0
  475. transformers/models/modernbert/modeling_modernbert.py +16 -2
  476. transformers/models/modernbert/modular_modernbert.py +14 -1
  477. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +17 -10
  478. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +15 -8
  479. transformers/models/moonshine/modeling_moonshine.py +5 -3
  480. transformers/models/moshi/modeling_moshi.py +26 -53
  481. transformers/models/mpnet/modeling_mpnet.py +7 -0
  482. transformers/models/mpnet/tokenization_mpnet.py +5 -13
  483. transformers/models/mpt/modeling_mpt.py +2 -0
  484. transformers/models/mra/modeling_mra.py +10 -1
  485. transformers/models/mt5/configuration_mt5.py +2 -3
  486. transformers/models/mt5/modeling_mt5.py +7 -10
  487. transformers/models/musicgen/modeling_musicgen.py +7 -9
  488. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -0
  489. transformers/models/mvp/modeling_mvp.py +14 -0
  490. transformers/models/nanochat/modeling_nanochat.py +5 -5
  491. transformers/models/nemotron/modeling_nemotron.py +7 -5
  492. transformers/models/nllb/tokenization_nllb.py +8 -22
  493. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  494. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  495. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  496. transformers/models/nougat/tokenization_nougat.py +15 -68
  497. transformers/models/nystromformer/modeling_nystromformer.py +13 -0
  498. transformers/models/olmo/modeling_olmo.py +5 -5
  499. transformers/models/olmo/modular_olmo.py +2 -2
  500. transformers/models/olmo2/modeling_olmo2.py +5 -6
  501. transformers/models/olmo2/modular_olmo2.py +0 -1
  502. transformers/models/olmo3/modeling_olmo3.py +5 -5
  503. transformers/models/olmoe/modeling_olmoe.py +15 -7
  504. transformers/models/olmoe/modular_olmoe.py +4 -2
  505. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  506. transformers/models/omdet_turbo/modeling_omdet_turbo.py +6 -0
  507. transformers/models/oneformer/configuration_oneformer.py +3 -3
  508. transformers/models/oneformer/modeling_oneformer.py +11 -39
  509. transformers/models/openai/modeling_openai.py +15 -0
  510. transformers/models/openai/tokenization_openai.py +10 -46
  511. transformers/models/opt/modeling_opt.py +2 -0
  512. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  513. transformers/models/ovis2/modeling_ovis2.py +15 -3
  514. transformers/models/ovis2/modular_ovis2.py +8 -0
  515. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  516. transformers/models/owlv2/modeling_owlv2.py +11 -3
  517. transformers/models/owlv2/modular_owlv2.py +0 -2
  518. transformers/models/owlvit/modeling_owlvit.py +11 -3
  519. transformers/models/paddleocr_vl/__init__.py +32 -0
  520. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +336 -0
  521. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +504 -0
  522. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +209 -0
  523. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +1682 -0
  524. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +1359 -0
  525. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +135 -0
  526. transformers/models/paligemma/modeling_paligemma.py +25 -17
  527. transformers/models/parakeet/configuration_parakeet.py +4 -6
  528. transformers/models/parakeet/modeling_parakeet.py +14 -6
  529. transformers/models/parakeet/modular_parakeet.py +7 -2
  530. transformers/models/parakeet/processing_parakeet.py +1 -0
  531. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  532. transformers/models/patchtsmixer/modeling_patchtsmixer.py +10 -0
  533. transformers/models/patchtst/modeling_patchtst.py +25 -6
  534. transformers/models/pe_audio/__init__.py +30 -0
  535. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  536. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  537. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  538. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  539. transformers/{kernels/falcon_mamba/__init__.py → models/pe_audio/processing_pe_audio.py} +11 -2
  540. transformers/models/pe_audio_video/__init__.py +29 -0
  541. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  542. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  543. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  544. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  545. transformers/models/pe_video/__init__.py +30 -0
  546. transformers/models/pe_video/configuration_pe_video.py +211 -0
  547. transformers/models/pe_video/modeling_pe_video.py +636 -0
  548. transformers/models/pe_video/modular_pe_video.py +219 -0
  549. transformers/models/pe_video/processing_pe_video.py +10 -0
  550. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  551. transformers/models/pegasus/configuration_pegasus.py +1 -0
  552. transformers/models/pegasus/modeling_pegasus.py +8 -0
  553. transformers/models/pegasus/tokenization_pegasus.py +17 -44
  554. transformers/models/pegasus_x/modeling_pegasus_x.py +5 -0
  555. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  556. transformers/models/perceiver/modeling_perceiver.py +13 -1
  557. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  558. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  559. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  560. transformers/models/persimmon/modeling_persimmon.py +3 -2
  561. transformers/models/phi/modeling_phi.py +5 -6
  562. transformers/models/phi/modular_phi.py +0 -1
  563. transformers/models/phi3/modeling_phi3.py +3 -2
  564. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +9 -6
  565. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +7 -4
  566. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  567. transformers/models/phimoe/modeling_phimoe.py +15 -7
  568. transformers/models/phimoe/modular_phimoe.py +3 -3
  569. transformers/models/pix2struct/modeling_pix2struct.py +2 -0
  570. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  571. transformers/models/pixio/__init__.py +30 -0
  572. transformers/models/pixio/configuration_pixio.py +151 -0
  573. transformers/models/pixio/modeling_pixio.py +507 -0
  574. transformers/models/pixio/modular_pixio.py +404 -0
  575. transformers/models/pixtral/modeling_pixtral.py +3 -2
  576. transformers/models/pixtral/processing_pixtral.py +3 -1
  577. transformers/models/plbart/configuration_plbart.py +1 -0
  578. transformers/models/plbart/modeling_plbart.py +13 -0
  579. transformers/models/plbart/modular_plbart.py +8 -0
  580. transformers/models/plbart/tokenization_plbart.py +0 -2
  581. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  582. transformers/models/poolformer/modeling_poolformer.py +13 -1
  583. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  584. transformers/models/pop2piano/modeling_pop2piano.py +2 -0
  585. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  586. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +1 -0
  587. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +1 -0
  588. transformers/models/prophetnet/modeling_prophetnet.py +5 -1
  589. transformers/models/pvt/modeling_pvt.py +2 -0
  590. transformers/models/pvt_v2/modeling_pvt_v2.py +3 -0
  591. transformers/models/qwen2/modeling_qwen2.py +5 -5
  592. transformers/models/qwen2/tokenization_qwen2.py +14 -18
  593. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +4 -2
  594. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +116 -79
  595. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +71 -33
  596. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +1 -1
  597. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +23 -11
  598. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +29 -27
  599. transformers/models/qwen2_audio/modeling_qwen2_audio.py +4 -2
  600. transformers/models/qwen2_moe/modeling_qwen2_moe.py +15 -7
  601. transformers/models/qwen2_vl/configuration_qwen2_vl.py +1 -1
  602. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  603. transformers/models/qwen2_vl/modeling_qwen2_vl.py +23 -20
  604. transformers/models/qwen3/modeling_qwen3.py +5 -5
  605. transformers/models/qwen3_moe/modeling_qwen3_moe.py +15 -7
  606. transformers/models/qwen3_next/modeling_qwen3_next.py +7 -8
  607. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  608. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +112 -68
  609. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +62 -20
  610. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  611. transformers/models/qwen3_vl/modeling_qwen3_vl.py +57 -42
  612. transformers/models/qwen3_vl/modular_qwen3_vl.py +59 -46
  613. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  614. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +132 -148
  615. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +36 -82
  616. transformers/models/rag/configuration_rag.py +0 -8
  617. transformers/models/rag/modeling_rag.py +8 -9
  618. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +18 -3
  619. transformers/models/reformer/modeling_reformer.py +13 -1
  620. transformers/models/reformer/tokenization_reformer.py +11 -28
  621. transformers/models/regnet/modeling_regnet.py +10 -1
  622. transformers/models/rembert/modeling_rembert.py +13 -1
  623. transformers/models/rembert/tokenization_rembert.py +3 -10
  624. transformers/models/resnet/modeling_resnet.py +19 -5
  625. transformers/models/roberta/modeling_roberta.py +3 -0
  626. transformers/models/roberta/modular_roberta.py +3 -0
  627. transformers/models/roberta/tokenization_roberta.py +18 -27
  628. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  629. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  630. transformers/models/roformer/modeling_roformer.py +6 -0
  631. transformers/models/roformer/tokenization_roformer.py +77 -412
  632. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  633. transformers/models/rt_detr/modeling_rt_detr.py +6 -0
  634. transformers/models/rt_detr/modeling_rt_detr_resnet.py +13 -4
  635. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  636. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +9 -0
  637. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  638. transformers/models/rwkv/modeling_rwkv.py +2 -1
  639. transformers/models/sam/configuration_sam.py +1 -0
  640. transformers/models/sam/image_processing_sam_fast.py +0 -1
  641. transformers/models/sam/modeling_sam.py +4 -1
  642. transformers/models/sam2/configuration_sam2.py +1 -1
  643. transformers/models/sam2/modeling_sam2.py +7 -3
  644. transformers/models/sam2/modular_sam2.py +7 -3
  645. transformers/models/sam2_video/modeling_sam2_video.py +52 -43
  646. transformers/models/sam2_video/modular_sam2_video.py +32 -18
  647. transformers/models/sam3/configuration_sam3.py +21 -1
  648. transformers/models/sam3/modeling_sam3.py +100 -80
  649. transformers/models/sam3_tracker/modeling_sam3_tracker.py +8 -1
  650. transformers/models/sam3_tracker/modular_sam3_tracker.py +8 -1
  651. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  652. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +27 -15
  653. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  654. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  655. transformers/models/sam3_video/modeling_sam3_video.py +4 -3
  656. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  657. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  658. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  659. transformers/models/seamless_m4t/modeling_seamless_m4t.py +32 -12
  660. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +27 -59
  661. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +11 -1
  662. transformers/models/seed_oss/modeling_seed_oss.py +3 -3
  663. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  664. transformers/models/segformer/modeling_segformer.py +6 -3
  665. transformers/models/segformer/modular_segformer.py +0 -1
  666. transformers/models/seggpt/modeling_seggpt.py +2 -0
  667. transformers/models/sew/modeling_sew.py +3 -0
  668. transformers/models/sew/modular_sew.py +1 -0
  669. transformers/models/sew_d/modeling_sew_d.py +3 -0
  670. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  671. transformers/models/siglip/modeling_siglip.py +24 -2
  672. transformers/models/siglip2/modeling_siglip2.py +67 -41
  673. transformers/models/siglip2/modular_siglip2.py +4 -0
  674. transformers/models/smollm3/modeling_smollm3.py +5 -5
  675. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  676. transformers/models/smolvlm/processing_smolvlm.py +0 -7
  677. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  678. transformers/models/speech_to_text/modeling_speech_to_text.py +14 -0
  679. transformers/models/speecht5/modeling_speecht5.py +41 -1
  680. transformers/models/splinter/modeling_splinter.py +12 -3
  681. transformers/models/splinter/tokenization_splinter.py +9 -28
  682. transformers/models/squeezebert/modeling_squeezebert.py +8 -0
  683. transformers/models/stablelm/modeling_stablelm.py +4 -2
  684. transformers/models/starcoder2/modeling_starcoder2.py +5 -4
  685. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  686. transformers/models/superglue/modeling_superglue.py +1 -0
  687. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  688. transformers/models/superpoint/modeling_superpoint.py +1 -0
  689. transformers/models/swiftformer/modeling_swiftformer.py +6 -0
  690. transformers/models/swin/modeling_swin.py +20 -12
  691. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  692. transformers/models/swin2sr/modeling_swin2sr.py +51 -33
  693. transformers/models/swinv2/modeling_swinv2.py +45 -33
  694. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  695. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  696. transformers/models/t5/configuration_t5.py +7 -1
  697. transformers/models/t5/modeling_t5.py +8 -7
  698. transformers/models/t5/tokenization_t5.py +4 -8
  699. transformers/models/t5gemma/modeling_t5gemma.py +6 -6
  700. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  701. transformers/models/t5gemma2/modeling_t5gemma2.py +19 -10
  702. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  703. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  704. transformers/models/table_transformer/modeling_table_transformer.py +5 -1
  705. transformers/models/tapas/modeling_tapas.py +3 -0
  706. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  707. transformers/models/textnet/modeling_textnet.py +11 -2
  708. transformers/models/time_series_transformer/modeling_time_series_transformer.py +4 -0
  709. transformers/models/timesfm/modeling_timesfm.py +14 -0
  710. transformers/models/timesfm/modular_timesfm.py +14 -0
  711. transformers/models/timesformer/modeling_timesformer.py +2 -0
  712. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  713. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  714. transformers/models/timm_wrapper/modeling_timm_wrapper.py +20 -14
  715. transformers/models/trocr/modeling_trocr.py +3 -2
  716. transformers/models/tvp/configuration_tvp.py +5 -1
  717. transformers/models/tvp/modeling_tvp.py +6 -4
  718. transformers/models/udop/configuration_udop.py +1 -0
  719. transformers/models/udop/modeling_udop.py +7 -7
  720. transformers/models/udop/tokenization_udop.py +5 -13
  721. transformers/models/umt5/configuration_umt5.py +2 -2
  722. transformers/models/umt5/modeling_umt5.py +7 -6
  723. transformers/models/unispeech/modeling_unispeech.py +4 -0
  724. transformers/models/unispeech/modular_unispeech.py +2 -0
  725. transformers/models/unispeech_sat/modeling_unispeech_sat.py +6 -0
  726. transformers/models/unispeech_sat/modular_unispeech_sat.py +2 -0
  727. transformers/models/univnet/modeling_univnet.py +1 -0
  728. transformers/models/upernet/modeling_upernet.py +1 -0
  729. transformers/models/vaultgemma/modeling_vaultgemma.py +5 -5
  730. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  731. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  732. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  733. transformers/models/video_llava/modeling_video_llava.py +7 -3
  734. transformers/models/vilt/configuration_vilt.py +2 -2
  735. transformers/models/vilt/modeling_vilt.py +13 -0
  736. transformers/models/vipllava/modeling_vipllava.py +7 -3
  737. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +1 -0
  738. transformers/models/visual_bert/modeling_visual_bert.py +8 -0
  739. transformers/models/vitdet/modeling_vitdet.py +2 -0
  740. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  741. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  742. transformers/models/vitmatte/modeling_vitmatte.py +5 -0
  743. transformers/models/vitpose/configuration_vitpose.py +1 -1
  744. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  745. transformers/models/vits/modeling_vits.py +1 -0
  746. transformers/models/vjepa2/modeling_vjepa2.py +1 -0
  747. transformers/models/voxtral/modeling_voxtral.py +2 -2
  748. transformers/models/voxtral/modular_voxtral.py +2 -2
  749. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -0
  750. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +21 -10
  751. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +12 -0
  752. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +27 -11
  753. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  754. transformers/models/wavlm/modeling_wavlm.py +5 -0
  755. transformers/models/whisper/generation_whisper.py +1 -0
  756. transformers/models/whisper/modeling_whisper.py +11 -3
  757. transformers/models/whisper/tokenization_whisper.py +4 -15
  758. transformers/models/x_clip/modeling_x_clip.py +5 -0
  759. transformers/models/xcodec/modeling_xcodec.py +5 -0
  760. transformers/models/xglm/modeling_xglm.py +11 -0
  761. transformers/models/xglm/tokenization_xglm.py +4 -9
  762. transformers/models/xlm/modeling_xlm.py +18 -14
  763. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  764. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +9 -16
  765. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  766. transformers/models/xlnet/modeling_xlnet.py +3 -1
  767. transformers/models/xlnet/tokenization_xlnet.py +3 -7
  768. transformers/models/xmod/modeling_xmod.py +3 -0
  769. transformers/models/yoso/modeling_yoso.py +10 -1
  770. transformers/models/zamba/modeling_zamba.py +4 -1
  771. transformers/models/zamba2/modeling_zamba2.py +7 -4
  772. transformers/models/zamba2/modular_zamba2.py +1 -1
  773. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  774. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  775. transformers/models/zoedepth/modeling_zoedepth.py +8 -0
  776. transformers/pipelines/__init__.py +11 -9
  777. transformers/pipelines/automatic_speech_recognition.py +20 -12
  778. transformers/pipelines/base.py +2 -10
  779. transformers/pipelines/document_question_answering.py +4 -2
  780. transformers/pipelines/question_answering.py +1 -1
  781. transformers/pipelines/text_generation.py +1 -1
  782. transformers/pipelines/text_to_audio.py +2 -2
  783. transformers/processing_utils.py +133 -50
  784. transformers/quantizers/auto.py +2 -4
  785. transformers/quantizers/base.py +44 -174
  786. transformers/quantizers/quantizer_aqlm.py +2 -23
  787. transformers/quantizers/quantizer_auto_round.py +2 -12
  788. transformers/quantizers/quantizer_awq.py +20 -89
  789. transformers/quantizers/quantizer_bitnet.py +4 -14
  790. transformers/quantizers/quantizer_bnb_4bit.py +18 -155
  791. transformers/quantizers/quantizer_bnb_8bit.py +24 -110
  792. transformers/quantizers/quantizer_compressed_tensors.py +2 -9
  793. transformers/quantizers/quantizer_eetq.py +16 -74
  794. transformers/quantizers/quantizer_fbgemm_fp8.py +38 -138
  795. transformers/quantizers/quantizer_finegrained_fp8.py +26 -113
  796. transformers/quantizers/quantizer_fp_quant.py +52 -82
  797. transformers/quantizers/quantizer_gptq.py +8 -28
  798. transformers/quantizers/quantizer_higgs.py +42 -60
  799. transformers/quantizers/quantizer_hqq.py +144 -153
  800. transformers/quantizers/quantizer_mxfp4.py +14 -194
  801. transformers/quantizers/quantizer_quanto.py +35 -79
  802. transformers/quantizers/quantizer_quark.py +36 -17
  803. transformers/quantizers/quantizer_spqr.py +4 -12
  804. transformers/quantizers/quantizer_torchao.py +50 -325
  805. transformers/quantizers/quantizer_vptq.py +4 -27
  806. transformers/quantizers/quantizers_utils.py +20 -0
  807. transformers/testing_utils.py +324 -47
  808. transformers/tokenization_mistral_common.py +7 -2
  809. transformers/tokenization_utils_base.py +116 -224
  810. transformers/tokenization_utils_tokenizers.py +190 -106
  811. transformers/trainer.py +51 -32
  812. transformers/trainer_callback.py +8 -0
  813. transformers/trainer_jit_checkpoint.py +126 -0
  814. transformers/trainer_seq2seq.py +4 -0
  815. transformers/trainer_utils.py +1 -1
  816. transformers/training_args.py +74 -38
  817. transformers/utils/__init__.py +7 -4
  818. transformers/utils/attention_visualizer.py +4 -4
  819. transformers/utils/auto_docstring.py +35 -25
  820. transformers/utils/generic.py +47 -1
  821. transformers/utils/hub.py +5 -15
  822. transformers/utils/import_utils.py +112 -25
  823. transformers/utils/kernel_config.py +74 -19
  824. transformers/utils/loading_report.py +19 -10
  825. transformers/utils/quantization_config.py +78 -245
  826. transformers/video_processing_utils.py +17 -14
  827. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +275 -229
  828. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +832 -777
  829. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +1 -1
  830. transformers/kernels/__init__.py +0 -0
  831. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +0 -529
  832. transformers/models/roformer/tokenization_roformer_fast.py +0 -160
  833. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  834. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info/licenses}/LICENSE +0 -0
  835. {transformers-5.0.0rc0.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ import torch
21
21
  from torch import Tensor, device, nn
22
22
  from torch.nn import CrossEntropyLoss
23
23
 
24
+ from ... import initialization as init
24
25
  from ...activations import ACT2FN
25
26
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
26
27
  from ...generation import GenerationMixin
@@ -504,6 +505,11 @@ class BlipTextPreTrainedModel(PreTrainedModel):
504
505
  base_model_prefix = "bert"
505
506
  _no_split_modules = []
506
507
 
508
+ def _init_weights(self, module):
509
+ super()._init_weights(module)
510
+ if isinstance(module, BlipTextEmbeddings):
511
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
512
+
507
513
 
508
514
  # Adapted from https://github.com/salesforce/BLIP/blob/3a29b7410476bf5f2ba0955827390eb6ea1f4f9d/models/med.py#L571
509
515
  class BlipTextModel(BlipTextPreTrainedModel):
@@ -609,6 +615,7 @@ class BlipTextModel(BlipTextPreTrainedModel):
609
615
  return_dict: Optional[bool] = None,
610
616
  is_decoder: Optional[bool] = False,
611
617
  cache_position: Optional[torch.Tensor] = None,
618
+ **kwargs,
612
619
  ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
613
620
  r"""
614
621
  encoder_hidden_states (`torch.FloatTensor`, *optional*):
@@ -739,6 +746,8 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
739
746
  self.cls = BlipTextOnlyMLMHead(config)
740
747
  self.label_smoothing = config.label_smoothing
741
748
 
749
+ self.post_init()
750
+
742
751
  def get_input_embeddings(self):
743
752
  return self.bert.get_input_embeddings()
744
753
 
@@ -771,6 +780,7 @@ class BlipTextLMHeadModel(BlipTextPreTrainedModel, GenerationMixin):
771
780
  reduction: Optional[str] = "mean",
772
781
  cache_position: Optional[torch.Tensor] = None,
773
782
  logits_to_keep: Union[int, torch.Tensor] = 0,
783
+ **kwargs,
774
784
  ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
775
785
  r"""
776
786
  encoder_hidden_states (`torch.FloatTensor`, *optional*): Sequence of
@@ -428,6 +428,8 @@ class Blip2PreTrainedModel(PreTrainedModel):
428
428
  ),
429
429
  ):
430
430
  init.zeros_(module.query_tokens)
431
+ elif isinstance(module, Blip2TextEmbeddings):
432
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
431
433
 
432
434
 
433
435
  # Copied from transformers.models.blip.modeling_blip.BlipEncoder with Blip->Blip2
@@ -603,7 +605,7 @@ class Blip2QFormerMultiHeadAttention(nn.Module):
603
605
 
604
606
  # This is actually dropping out entire tokens to attend to, which might
605
607
  # seem a bit unusual, but is taken from the original Transformer paper.
606
- attention_probs_dropped = self.dropout(attention_probs)
608
+ attention_probs_dropped = self.dropout(attention_probs).to(value_layer.dtype)
607
609
 
608
610
  context_layer = torch.matmul(attention_probs_dropped, value_layer)
609
611
 
@@ -1948,6 +1950,7 @@ class Blip2ForImageTextRetrieval(Blip2PreTrainedModel):
1948
1950
  output_attentions: Optional[bool] = None,
1949
1951
  output_hidden_states: Optional[bool] = None,
1950
1952
  return_dict: Optional[bool] = None,
1953
+ **kwargs,
1951
1954
  ) -> Union[tuple, Blip2ImageTextMatchingModelOutput]:
1952
1955
  r"""
1953
1956
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
@@ -465,6 +465,7 @@ class BloomModel(BloomPreTrainedModel):
465
465
  output_hidden_states: Optional[bool] = None,
466
466
  return_dict: Optional[bool] = None,
467
467
  cache_position: Optional[torch.LongTensor] = None,
468
+ **kwargs,
468
469
  ) -> Union[tuple[torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]:
469
470
  r"""
470
471
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -713,36 +714,21 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
713
714
  inputs_embeds=None,
714
715
  cache_position=None,
715
716
  use_cache=True,
717
+ is_first_iteration=False,
716
718
  **kwargs,
717
719
  ):
718
720
  # Overwritten because of the fixed-shape attention mask creation
719
721
 
720
- # If we have cache: let's slice `input_ids` through `cache_position`, to keep only the unprocessed tokens
721
- # Exception 1: when passing input_embeds, input_ids may be missing entries
722
- # Exception 2: some generation methods do special slicing of input_ids, so we don't need to do it here
723
- # Exception 3: with synced GPUs cache_position may go out of bounds, but we only want dummy token in that case.
724
- # (we can't check exception 3 while compiling)
725
- # Exception 4: If input_embeds are passed then slice it through `cache_position`, to keep only the unprocessed tokens and
726
- # generate the first token for each sequence. Later use the generated Input ids for continuation.
727
- if past_key_values is not None:
728
- if inputs_embeds is not None and input_ids.shape[1] == 0: # Exception 4
729
- inputs_embeds = inputs_embeds[:, -cache_position.shape[0] :]
730
- elif (
731
- inputs_embeds is not None # Exception 1
732
- or cache_position[-1] >= input_ids.shape[1] # Exception 3
733
- ):
734
- input_ids = input_ids[:, -cache_position.shape[0] :]
735
- elif input_ids.shape[1] != cache_position.shape[0]: # Default case (the "else", a no op, is Exception 2)
736
- input_ids = input_ids[:, cache_position]
737
-
738
- # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
739
- if inputs_embeds is not None and len(cache_position) == inputs_embeds.shape[1]:
740
- model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None}
741
- else:
742
- # This `clone` call is needed to avoid recapturing cuda graphs with `torch.compile`'s `mode="reduce-overhead`, as otherwise the
743
- # input `position_ids` would have various stride during the decoding. Here, simply using `.contiguous()` is not sufficient as in
744
- # the batch size = 1 case, `position_ids` is already contiguous but with varying stride which retriggers a capture.
745
- model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None}
722
+ model_inputs = super().prepare_inputs_for_generation(
723
+ input_ids,
724
+ past_key_values=past_key_values,
725
+ attention_mask=attention_mask,
726
+ inputs_embeds=inputs_embeds,
727
+ cache_position=cache_position,
728
+ use_cache=use_cache,
729
+ is_first_iteration=is_first_iteration,
730
+ **kwargs,
731
+ )
746
732
 
747
733
  # This part differs from other models because BLOOM needs a 2D mask to construct alibi tensor
748
734
  # The only difference is the usage of 2D instead of 4D mask, but the shape will be static
@@ -752,24 +738,8 @@ class BloomForCausalLM(BloomPreTrainedModel, GenerationMixin):
752
738
  diff = target_length - seq_length
753
739
 
754
740
  new_attn_mask = torch.zeros(batch_size, diff, device=attention_mask.device, dtype=attention_mask.dtype)
755
- attention_mask = torch.cat(
756
- [attention_mask, new_attn_mask],
757
- dim=-1,
758
- )
759
-
760
- model_inputs.update(
761
- {
762
- "cache_position": cache_position,
763
- "past_key_values": past_key_values,
764
- "use_cache": use_cache,
765
- "attention_mask": attention_mask,
766
- }
767
- )
768
-
769
- # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
770
- for key, value in kwargs.items():
771
- if key not in model_inputs:
772
- model_inputs[key] = value
741
+ attention_mask = torch.cat([attention_mask, new_attn_mask], dim=-1)
742
+ model_inputs["attention_mask"] = attention_mask
773
743
 
774
744
  return model_inputs
775
745
 
@@ -883,6 +853,7 @@ class BloomForSequenceClassification(BloomPreTrainedModel):
883
853
  output_attentions: Optional[bool] = None,
884
854
  output_hidden_states: Optional[bool] = None,
885
855
  return_dict: Optional[bool] = None,
856
+ **kwargs,
886
857
  ) -> Union[tuple[torch.Tensor], SequenceClassifierOutputWithPast]:
887
858
  r"""
888
859
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -1006,6 +977,7 @@ class BloomForTokenClassification(BloomPreTrainedModel):
1006
977
  output_attentions: Optional[bool] = None,
1007
978
  output_hidden_states: Optional[bool] = None,
1008
979
  return_dict: Optional[bool] = None,
980
+ **kwargs,
1009
981
  ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
1010
982
  r"""
1011
983
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -1084,6 +1056,7 @@ class BloomForQuestionAnswering(BloomPreTrainedModel):
1084
1056
  output_attentions: Optional[bool] = None,
1085
1057
  output_hidden_states: Optional[bool] = None,
1086
1058
  return_dict: Optional[bool] = None,
1059
+ **kwargs,
1087
1060
  ) -> Union[tuple, QuestionAnsweringModelOutput]:
1088
1061
  r"""
1089
1062
  input_ids (`torch.LongTensor` of shape `(batch_size, input_ids_length)`):
@@ -27,6 +27,7 @@ import torch.distributions
27
27
  import torch.nn as nn
28
28
  import torch.nn.functional as F
29
29
 
30
+ from ... import initialization as init
30
31
  from ...activations import ACT2FN
31
32
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
32
33
  from ...generation import GenerationMixin
@@ -38,7 +39,7 @@ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
38
39
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
39
40
  from ...processing_utils import Unpack
40
41
  from ...utils import TransformersKwargs, auto_docstring, can_return_tuple
41
- from ...utils.generic import OutputRecorder, check_model_inputs
42
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
42
43
  from .configuration_blt import (
43
44
  BltConfig,
44
45
  BltGlobalTransformerConfig,
@@ -102,7 +103,7 @@ class BltRotaryEmbedding(nn.Module):
102
103
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
103
104
 
104
105
  self.register_buffer("inv_freq", inv_freq, persistent=False)
105
- self.original_inv_freq = inv_freq
106
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
106
107
 
107
108
  @staticmethod
108
109
  def compute_default_rope_parameters(
@@ -141,7 +142,7 @@ class BltRotaryEmbedding(nn.Module):
141
142
  position_ids_expanded = position_ids[:, None, :].float()
142
143
 
143
144
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
144
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
145
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
145
146
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
146
147
  emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
147
148
  cos = emb.cos() * self.attention_scaling
@@ -444,6 +445,163 @@ class BltPreTrainedModel(PreTrainedModel):
444
445
  "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
445
446
  }
446
447
 
448
+ @torch.no_grad()
449
+ def _init_weights(self, module):
450
+ """
451
+ Initialize BLT weights following the original ByteLatentTransformer:
452
+
453
+ - Most weights are drawn from a truncated normal.
454
+ - Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs).
455
+ - Norm layers are set to weight = 1, bias = 0.
456
+ """
457
+ class_name = module.__class__.__name__
458
+
459
+ # Norms: RMSNorm / LayerNorm
460
+ if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name:
461
+ if getattr(module, "weight", None) is not None:
462
+ init.ones_(module.weight)
463
+ if getattr(module, "bias", None) is not None:
464
+ init.zeros_(module.bias)
465
+ return
466
+
467
+ # Embeddings (encoder / patcher / hash embeddings)
468
+ if isinstance(module, nn.Embedding):
469
+ hidden_size = getattr(self.config, "hidden_size", None)
470
+ if hidden_size is None and hasattr(self.config, "encoder_config"):
471
+ hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
472
+ if hidden_size is None:
473
+ hidden_size = module.embedding_dim
474
+
475
+ std = hidden_size**-0.5
476
+ init.trunc_normal_(
477
+ module.weight,
478
+ mean=0.0,
479
+ std=std,
480
+ a=-3 * std,
481
+ b=3 * std,
482
+ )
483
+ if module.padding_idx is not None:
484
+ init.zeros_(module.weight[module.padding_idx])
485
+ return
486
+
487
+ # Self-attention / cross-attention projections
488
+ if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in (
489
+ "MllamaTextSelfAttention",
490
+ "MllamaTextCrossAttention",
491
+ ):
492
+ dim = getattr(self.config, "hidden_size", None)
493
+ if dim is None and hasattr(module, "hidden_size"):
494
+ dim = module.hidden_size
495
+ if dim is None:
496
+ for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"):
497
+ proj = getattr(module, name, None)
498
+ if proj is not None and hasattr(proj, "weight"):
499
+ dim = proj.weight.shape[-1]
500
+ break
501
+ if dim is None:
502
+ return
503
+
504
+ std = dim**-0.5
505
+
506
+ # Input projections (q, k, v)
507
+ for proj_name in ("q_proj", "k_proj", "v_proj"):
508
+ proj = getattr(module, proj_name, None)
509
+ if proj is not None and hasattr(proj, "weight"):
510
+ init.trunc_normal_(
511
+ proj.weight,
512
+ mean=0.0,
513
+ std=std,
514
+ a=-3 * std,
515
+ b=3 * std,
516
+ )
517
+ if getattr(proj, "bias", None) is not None:
518
+ init.zeros_(proj.bias)
519
+
520
+ # Output projection: o_proj or dense
521
+ o_proj = getattr(module, "o_proj", getattr(module, "dense", None))
522
+ if o_proj is not None and hasattr(o_proj, "weight"):
523
+ init.trunc_normal_(
524
+ o_proj.weight,
525
+ mean=0.0,
526
+ std=std,
527
+ a=-3 * std,
528
+ b=3 * std,
529
+ )
530
+ if getattr(o_proj, "bias", None) is not None:
531
+ init.zeros_(o_proj.bias)
532
+ return
533
+
534
+ # MLP / FFN blocks
535
+ if isinstance(module, BltMLP) or class_name == "MllamaTextMLP":
536
+ hidden_size = getattr(self.config, "hidden_size", None)
537
+ if hidden_size is None and hasattr(self.config, "decoder_config"):
538
+ hidden_size = getattr(self.config.decoder_config, "hidden_size", None)
539
+ if hidden_size is None and hasattr(self.config, "encoder_config"):
540
+ hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
541
+
542
+ # Input-side std
543
+ in_std = None
544
+ if hidden_size is not None:
545
+ in_std = hidden_size**-0.5
546
+
547
+ gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None))
548
+ up_proj = getattr(module, "up_proj", None)
549
+ down_proj = getattr(module, "down_proj", getattr(module, "fc2", None))
550
+
551
+ # gate / input projections
552
+ for proj in (gate_proj, up_proj):
553
+ if proj is not None and hasattr(proj, "weight"):
554
+ std = in_std or (proj.weight.shape[1] ** -0.5)
555
+ init.trunc_normal_(
556
+ proj.weight,
557
+ mean=0.0,
558
+ std=std,
559
+ a=-3 * std,
560
+ b=3 * std,
561
+ )
562
+ if getattr(proj, "bias", None) is not None:
563
+ init.zeros_(proj.bias)
564
+
565
+ # output/ down projections
566
+ if down_proj is not None and hasattr(down_proj, "weight"):
567
+ hidden_dim = down_proj.weight.shape[1]
568
+ out_std = hidden_dim**-0.5
569
+ init.trunc_normal_(
570
+ down_proj.weight,
571
+ mean=0.0,
572
+ std=out_std,
573
+ a=-3 * out_std,
574
+ b=3 * out_std,
575
+ )
576
+ if getattr(down_proj, "bias", None) is not None:
577
+ init.zeros_(down_proj.bias)
578
+ return
579
+
580
+ # Generic Linear layers (projections, lm_head, etc.)
581
+ if isinstance(module, nn.Linear):
582
+ fan_in = module.in_features
583
+ std = fan_in**-0.5
584
+ init.trunc_normal_(
585
+ module.weight,
586
+ mean=0.0,
587
+ std=std,
588
+ a=-3 * std,
589
+ b=3 * std,
590
+ )
591
+ if module.bias is not None:
592
+ init.zeros_(module.bias)
593
+ return
594
+
595
+ if isinstance(module, BltRotaryEmbedding):
596
+ rope_fn = (
597
+ ROPE_INIT_FUNCTIONS[module.rope_type]
598
+ if module.rope_type != "default"
599
+ else module.compute_default_rope_parameters
600
+ )
601
+ buffer_value, _ = rope_fn(module.config)
602
+ init.copy_(module.inv_freq, buffer_value)
603
+ init.copy_(module.original_inv_freq, buffer_value)
604
+
447
605
 
448
606
  class BltLocalEncoder(BltPreTrainedModel):
449
607
  config: BltLocalEncoderConfig
@@ -753,6 +911,8 @@ class BltPatcher(BltPreTrainedModel):
753
911
  bias=False,
754
912
  )
755
913
 
914
+ self.post_init()
915
+
756
916
  def forward(
757
917
  self,
758
918
  input_ids: Optional[torch.LongTensor] = None,
@@ -952,7 +1112,7 @@ def compute_hash_embeddings(
952
1112
  hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
953
1113
  # Apply offset to get the correct slice of the fused embedding
954
1114
  offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
955
- embeddings += encoder_hash_tok_embedding(offset_hash_ids)
1115
+ embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
956
1116
  embedding_idx += 1
957
1117
 
958
1118
  return embeddings
@@ -22,14 +22,15 @@ import torch.distributions
22
22
  import torch.nn as nn
23
23
  import torch.nn.functional as F
24
24
 
25
+ from ... import initialization as init
25
26
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
26
27
  from ...masking_utils import create_causal_mask
27
28
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
28
- from ...modeling_rope_utils import dynamic_rope_update
29
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
29
30
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
30
31
  from ...processing_utils import Unpack
31
32
  from ...utils import TransformersKwargs, auto_docstring, logging
32
- from ...utils.generic import OutputRecorder, check_model_inputs
33
+ from ...utils.generic import OutputRecorder, check_model_inputs, maybe_autocast
33
34
  from ..cohere2.modeling_cohere2 import rotate_half # noqa: F401
34
35
  from ..llama.modeling_llama import LlamaRotaryEmbedding
35
36
  from ..mllama.modeling_mllama import (
@@ -133,7 +134,7 @@ def compute_hash_embeddings(
133
134
  hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
134
135
  # Apply offset to get the correct slice of the fused embedding
135
136
  offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
136
- embeddings += encoder_hash_tok_embedding(offset_hash_ids)
137
+ embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
137
138
  embedding_idx += 1
138
139
 
139
140
  return embeddings
@@ -277,7 +278,7 @@ class BltRotaryEmbedding(LlamaRotaryEmbedding):
277
278
  position_ids_expanded = position_ids[:, None, :].float()
278
279
 
279
280
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
280
- with torch.autocast(device_type=device_type, enabled=False): # Force float32
281
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
281
282
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
282
283
  emb = torch.repeat_interleave(freqs, 2, dim=-1) # diff from Llama: we interleave() instead of cat()
283
284
  cos = emb.cos() * self.attention_scaling
@@ -360,8 +361,170 @@ class BltPreTrainedModel(MllamaPreTrainedModel):
360
361
  "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
361
362
  }
362
363
 
364
+ # Weight initialization is adapted from:
365
+ # - https://github.com/facebookresearch/blt/blob/main/bytelatent/model/blt.py
366
+ # - https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/transformers_modeling_backend/model/model.py
367
+ #
368
+ # Both implementations use truncated normal initialization with std ~ 1 / sqrt(d_model)
369
+ # (or 1 / sqrt(hidden_dim) for FFN outputs), and unit initialization for normalization layers.
370
+ # We follow the same scheme here, but expressed in the Transformers APIs.
371
+
372
+ @torch.no_grad()
363
373
  def _init_weights(self, module):
364
- raise AttributeError("No need to inherit it!")
374
+ """
375
+ Initialize BLT weights following the original ByteLatentTransformer:
376
+
377
+ - Most weights are drawn from a truncated normal.
378
+ - Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs).
379
+ - Norm layers are set to weight = 1, bias = 0.
380
+ """
381
+ class_name = module.__class__.__name__
382
+
383
+ # Norms: RMSNorm / LayerNorm
384
+ if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name:
385
+ if getattr(module, "weight", None) is not None:
386
+ init.ones_(module.weight)
387
+ if getattr(module, "bias", None) is not None:
388
+ init.zeros_(module.bias)
389
+ return
390
+
391
+ # Embeddings (encoder / patcher / hash embeddings)
392
+ if isinstance(module, nn.Embedding):
393
+ hidden_size = getattr(self.config, "hidden_size", None)
394
+ if hidden_size is None and hasattr(self.config, "encoder_config"):
395
+ hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
396
+ if hidden_size is None:
397
+ hidden_size = module.embedding_dim
398
+
399
+ std = hidden_size**-0.5
400
+ init.trunc_normal_(
401
+ module.weight,
402
+ mean=0.0,
403
+ std=std,
404
+ a=-3 * std,
405
+ b=3 * std,
406
+ )
407
+ if module.padding_idx is not None:
408
+ init.zeros_(module.weight[module.padding_idx])
409
+ return
410
+
411
+ # Self-attention / cross-attention projections
412
+ if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in (
413
+ "MllamaTextSelfAttention",
414
+ "MllamaTextCrossAttention",
415
+ ):
416
+ dim = getattr(self.config, "hidden_size", None)
417
+ if dim is None and hasattr(module, "hidden_size"):
418
+ dim = module.hidden_size
419
+ if dim is None:
420
+ for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"):
421
+ proj = getattr(module, name, None)
422
+ if proj is not None and hasattr(proj, "weight"):
423
+ dim = proj.weight.shape[-1]
424
+ break
425
+ if dim is None:
426
+ return
427
+
428
+ std = dim**-0.5
429
+
430
+ # Input projections (q, k, v)
431
+ for proj_name in ("q_proj", "k_proj", "v_proj"):
432
+ proj = getattr(module, proj_name, None)
433
+ if proj is not None and hasattr(proj, "weight"):
434
+ init.trunc_normal_(
435
+ proj.weight,
436
+ mean=0.0,
437
+ std=std,
438
+ a=-3 * std,
439
+ b=3 * std,
440
+ )
441
+ if getattr(proj, "bias", None) is not None:
442
+ init.zeros_(proj.bias)
443
+
444
+ # Output projection: o_proj or dense
445
+ o_proj = getattr(module, "o_proj", getattr(module, "dense", None))
446
+ if o_proj is not None and hasattr(o_proj, "weight"):
447
+ init.trunc_normal_(
448
+ o_proj.weight,
449
+ mean=0.0,
450
+ std=std,
451
+ a=-3 * std,
452
+ b=3 * std,
453
+ )
454
+ if getattr(o_proj, "bias", None) is not None:
455
+ init.zeros_(o_proj.bias)
456
+ return
457
+
458
+ # MLP / FFN blocks
459
+ if isinstance(module, BltMLP) or class_name == "MllamaTextMLP":
460
+ hidden_size = getattr(self.config, "hidden_size", None)
461
+ if hidden_size is None and hasattr(self.config, "decoder_config"):
462
+ hidden_size = getattr(self.config.decoder_config, "hidden_size", None)
463
+ if hidden_size is None and hasattr(self.config, "encoder_config"):
464
+ hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
465
+
466
+ # Input-side std
467
+ in_std = None
468
+ if hidden_size is not None:
469
+ in_std = hidden_size**-0.5
470
+
471
+ gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None))
472
+ up_proj = getattr(module, "up_proj", None)
473
+ down_proj = getattr(module, "down_proj", getattr(module, "fc2", None))
474
+
475
+ # gate / input projections
476
+ for proj in (gate_proj, up_proj):
477
+ if proj is not None and hasattr(proj, "weight"):
478
+ std = in_std or (proj.weight.shape[1] ** -0.5)
479
+ init.trunc_normal_(
480
+ proj.weight,
481
+ mean=0.0,
482
+ std=std,
483
+ a=-3 * std,
484
+ b=3 * std,
485
+ )
486
+ if getattr(proj, "bias", None) is not None:
487
+ init.zeros_(proj.bias)
488
+
489
+ # output/ down projections
490
+ if down_proj is not None and hasattr(down_proj, "weight"):
491
+ hidden_dim = down_proj.weight.shape[1]
492
+ out_std = hidden_dim**-0.5
493
+ init.trunc_normal_(
494
+ down_proj.weight,
495
+ mean=0.0,
496
+ std=out_std,
497
+ a=-3 * out_std,
498
+ b=3 * out_std,
499
+ )
500
+ if getattr(down_proj, "bias", None) is not None:
501
+ init.zeros_(down_proj.bias)
502
+ return
503
+
504
+ # Generic Linear layers (projections, lm_head, etc.)
505
+ if isinstance(module, nn.Linear):
506
+ fan_in = module.in_features
507
+ std = fan_in**-0.5
508
+ init.trunc_normal_(
509
+ module.weight,
510
+ mean=0.0,
511
+ std=std,
512
+ a=-3 * std,
513
+ b=3 * std,
514
+ )
515
+ if module.bias is not None:
516
+ init.zeros_(module.bias)
517
+ return
518
+
519
+ if isinstance(module, BltRotaryEmbedding):
520
+ rope_fn = (
521
+ ROPE_INIT_FUNCTIONS[module.rope_type]
522
+ if module.rope_type != "default"
523
+ else module.compute_default_rope_parameters
524
+ )
525
+ buffer_value, _ = rope_fn(module.config)
526
+ init.copy_(module.inv_freq, buffer_value)
527
+ init.copy_(module.original_inv_freq, buffer_value)
365
528
 
366
529
  def _update_causal_mask(self, module):
367
530
  raise AttributeError("No need to inherit it!")
@@ -634,6 +797,8 @@ class BltPatcher(BltPreTrainedModel):
634
797
  bias=False,
635
798
  )
636
799
 
800
+ self.post_init()
801
+
637
802
  def forward(
638
803
  self,
639
804
  input_ids: Optional[torch.LongTensor] = None,
@@ -251,10 +251,8 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
251
251
  processed_images, processed_masks = self.pad(
252
252
  processed_images, return_mask=True, disable_grouping=disable_grouping
253
253
  )
254
- processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
255
254
  data["pixel_mask"] = processed_masks
256
255
 
257
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
258
256
  data["pixel_values"] = processed_images
259
257
 
260
258
  return BatchFeature(data=data, tensor_type=return_tensors)
@@ -943,6 +943,11 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
943
943
  init.ones_(module.weight)
944
944
  elif isinstance(module, BridgeTowerForContrastiveLearning):
945
945
  init.constant_(module.logit_scale, self.config.logit_scale_init_value)
946
+ elif isinstance(module, BridgeTowerVisionEmbeddings):
947
+ init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
948
+ elif isinstance(module, BridgeTowerTextEmbeddings):
949
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
950
+ init.zeros_(module.token_type_ids)
946
951
 
947
952
  if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None:
948
953
  init.zeros_(module.bias)
@@ -955,12 +960,13 @@ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
955
960
  def __init__(self, config):
956
961
  super().__init__(config)
957
962
  self.visual = BridgeTowerVisionTransformer(config)
963
+ self.post_init()
958
964
 
959
965
  @property
960
966
  def dtype(self):
961
967
  return self.visual.embeddings.patch_embedding.weight.dtype
962
968
 
963
- def forward(self, image, image_mask=None, interpolate_pos_encoding=False):
969
+ def forward(self, image, image_mask=None, interpolate_pos_encoding=False, **kwargs):
964
970
  return self.visual(image.type(self.dtype), image_mask, interpolate_pos_encoding)
965
971
 
966
972
 
@@ -1223,6 +1229,7 @@ class BridgeTowerModel(BridgeTowerPreTrainedModel):
1223
1229
  return_dict: Optional[bool] = None,
1224
1230
  labels: Optional[torch.LongTensor] = None,
1225
1231
  interpolate_pos_encoding: bool = False,
1232
+ **kwargs,
1226
1233
  ) -> Union[tuple[torch.Tensor], BridgeTowerModelOutput]:
1227
1234
  r"""
1228
1235
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
@@ -1530,6 +1537,7 @@ class BridgeTowerForMaskedLM(BridgeTowerPreTrainedModel):
1530
1537
  output_hidden_states: Optional[bool] = None,
1531
1538
  return_dict: Optional[bool] = None,
1532
1539
  labels: Optional[torch.LongTensor] = None,
1540
+ **kwargs,
1533
1541
  ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]:
1534
1542
  r"""
1535
1543
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
@@ -1630,6 +1638,7 @@ class BridgeTowerForImageAndTextRetrieval(BridgeTowerPreTrainedModel):
1630
1638
  output_hidden_states: Optional[bool] = None,
1631
1639
  return_dict: Optional[bool] = None,
1632
1640
  labels: Optional[torch.LongTensor] = None,
1641
+ **kwargs,
1633
1642
  ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
1634
1643
  r"""
1635
1644
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):
@@ -1742,6 +1751,7 @@ class BridgeTowerForContrastiveLearning(BridgeTowerPreTrainedModel):
1742
1751
  output_hidden_states: Optional[bool] = True,
1743
1752
  return_dict: Optional[bool] = None,
1744
1753
  return_loss: Optional[bool] = None,
1754
+ **kwargs,
1745
1755
  ) -> Union[BridgeTowerContrastiveOutput, tuple[torch.FloatTensor]]:
1746
1756
  r"""
1747
1757
  image_embeds (`torch.FloatTensor` of shape `(batch_size, num_patches, hidden_size)`, *optional*):