transformers 5.0.0rc1__py3-none-any.whl → 5.0.0rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (671) hide show
  1. transformers/__init__.py +20 -1
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +0 -1
  4. transformers/cache_utils.py +17 -15
  5. transformers/configuration_utils.py +114 -70
  6. transformers/conversion_mapping.py +68 -5
  7. transformers/core_model_loading.py +201 -35
  8. transformers/dependency_versions_table.py +1 -1
  9. transformers/feature_extraction_utils.py +54 -22
  10. transformers/generation/candidate_generator.py +79 -31
  11. transformers/generation/configuration_utils.py +162 -122
  12. transformers/generation/continuous_batching/cache.py +47 -18
  13. transformers/generation/continuous_batching/cache_manager.py +131 -34
  14. transformers/generation/continuous_batching/continuous_api.py +101 -64
  15. transformers/generation/continuous_batching/requests.py +28 -1
  16. transformers/generation/continuous_batching/scheduler.py +11 -4
  17. transformers/generation/stopping_criteria.py +1 -1
  18. transformers/generation/utils.py +108 -110
  19. transformers/generation/watermarking.py +8 -5
  20. transformers/image_processing_base.py +2 -12
  21. transformers/image_processing_utils_fast.py +15 -4
  22. transformers/initialization.py +37 -0
  23. transformers/integrations/__init__.py +12 -0
  24. transformers/integrations/accelerate.py +44 -111
  25. transformers/integrations/aqlm.py +3 -5
  26. transformers/integrations/awq.py +2 -5
  27. transformers/integrations/bitnet.py +5 -8
  28. transformers/integrations/bitsandbytes.py +16 -15
  29. transformers/integrations/deepspeed.py +18 -3
  30. transformers/integrations/eetq.py +3 -5
  31. transformers/integrations/fbgemm_fp8.py +1 -1
  32. transformers/integrations/finegrained_fp8.py +6 -16
  33. transformers/integrations/flash_attention.py +2 -2
  34. transformers/integrations/higgs.py +2 -5
  35. transformers/integrations/hub_kernels.py +23 -5
  36. transformers/integrations/integration_utils.py +35 -0
  37. transformers/integrations/mistral.py +12 -0
  38. transformers/integrations/moe.py +240 -0
  39. transformers/integrations/mxfp4.py +4 -10
  40. transformers/integrations/peft.py +5 -0
  41. transformers/integrations/quanto.py +5 -2
  42. transformers/integrations/spqr.py +3 -5
  43. transformers/integrations/tensor_parallel.py +167 -221
  44. transformers/integrations/vptq.py +3 -5
  45. transformers/modeling_gguf_pytorch_utils.py +66 -19
  46. transformers/modeling_rope_utils.py +78 -81
  47. transformers/modeling_utils.py +583 -503
  48. transformers/models/__init__.py +19 -0
  49. transformers/models/afmoe/modeling_afmoe.py +7 -16
  50. transformers/models/afmoe/modular_afmoe.py +5 -13
  51. transformers/models/aimv2/modeling_aimv2.py +4 -0
  52. transformers/models/aimv2/modular_aimv2.py +4 -0
  53. transformers/models/albert/modeling_albert.py +3 -0
  54. transformers/models/align/modeling_align.py +12 -6
  55. transformers/models/altclip/modeling_altclip.py +7 -3
  56. transformers/models/apertus/modeling_apertus.py +4 -2
  57. transformers/models/apertus/modular_apertus.py +4 -1
  58. transformers/models/arcee/modeling_arcee.py +1 -1
  59. transformers/models/aria/modeling_aria.py +8 -4
  60. transformers/models/aria/modular_aria.py +7 -3
  61. transformers/models/audioflamingo3/processing_audioflamingo3.py +27 -22
  62. transformers/models/auto/auto_factory.py +1 -1
  63. transformers/models/auto/configuration_auto.py +27 -0
  64. transformers/models/auto/feature_extraction_auto.py +7 -3
  65. transformers/models/auto/image_processing_auto.py +4 -2
  66. transformers/models/auto/modeling_auto.py +31 -0
  67. transformers/models/auto/processing_auto.py +4 -0
  68. transformers/models/auto/tokenization_auto.py +132 -153
  69. transformers/models/auto/video_processing_auto.py +5 -2
  70. transformers/models/aya_vision/modeling_aya_vision.py +7 -3
  71. transformers/models/bamba/modeling_bamba.py +18 -19
  72. transformers/models/bamba/modular_bamba.py +17 -16
  73. transformers/models/bark/modeling_bark.py +9 -0
  74. transformers/models/bart/configuration_bart.py +0 -1
  75. transformers/models/bart/modeling_bart.py +7 -0
  76. transformers/models/beit/image_processing_beit_fast.py +0 -1
  77. transformers/models/bert/modeling_bert.py +3 -0
  78. transformers/models/bert_generation/modeling_bert_generation.py +2 -0
  79. transformers/models/big_bird/modeling_big_bird.py +3 -0
  80. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +7 -0
  81. transformers/models/bit/modeling_bit.py +5 -1
  82. transformers/models/bitnet/modeling_bitnet.py +1 -1
  83. transformers/models/blenderbot/modeling_blenderbot.py +7 -0
  84. transformers/models/blenderbot/tokenization_blenderbot.py +6 -7
  85. transformers/models/blenderbot_small/modeling_blenderbot_small.py +7 -0
  86. transformers/models/blip/modeling_blip.py +2 -0
  87. transformers/models/blip/modeling_blip_text.py +8 -0
  88. transformers/models/blip_2/modeling_blip_2.py +2 -0
  89. transformers/models/bloom/modeling_bloom.py +13 -44
  90. transformers/models/blt/modeling_blt.py +162 -2
  91. transformers/models/blt/modular_blt.py +168 -3
  92. transformers/models/bridgetower/image_processing_bridgetower_fast.py +0 -2
  93. transformers/models/bridgetower/modeling_bridgetower.py +6 -0
  94. transformers/models/bros/modeling_bros.py +8 -0
  95. transformers/models/camembert/modeling_camembert.py +109 -106
  96. transformers/models/canine/modeling_canine.py +6 -0
  97. transformers/models/canine/tokenization_canine.py +2 -0
  98. transformers/models/chameleon/modeling_chameleon.py +9 -4
  99. transformers/models/chinese_clip/modeling_chinese_clip.py +6 -3
  100. transformers/models/clap/feature_extraction_clap.py +2 -2
  101. transformers/models/clap/modeling_clap.py +25 -15
  102. transformers/models/clip/modeling_clip.py +2 -0
  103. transformers/models/clipseg/modeling_clipseg.py +4 -0
  104. transformers/models/clvp/modeling_clvp.py +14 -3
  105. transformers/models/code_llama/tokenization_code_llama.py +1 -1
  106. transformers/models/codegen/modeling_codegen.py +13 -4
  107. transformers/models/cohere/modeling_cohere.py +1 -1
  108. transformers/models/cohere2/modeling_cohere2.py +1 -1
  109. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +0 -1
  110. transformers/models/cohere2_vision/modeling_cohere2_vision.py +7 -3
  111. transformers/models/conditional_detr/configuration_conditional_detr.py +1 -1
  112. transformers/models/conditional_detr/modeling_conditional_detr.py +4 -1
  113. transformers/models/convbert/modeling_convbert.py +3 -0
  114. transformers/models/convnext/image_processing_convnext.py +2 -2
  115. transformers/models/convnext/image_processing_convnext_fast.py +9 -13
  116. transformers/models/csm/generation_csm.py +19 -22
  117. transformers/models/csm/modeling_csm.py +3 -1
  118. transformers/models/csm/modular_csm.py +2 -0
  119. transformers/models/ctrl/modeling_ctrl.py +14 -2
  120. transformers/models/cvt/modeling_cvt.py +5 -1
  121. transformers/models/cwm/modeling_cwm.py +1 -1
  122. transformers/models/d_fine/configuration_d_fine.py +3 -4
  123. transformers/models/d_fine/modeling_d_fine.py +46 -39
  124. transformers/models/d_fine/modular_d_fine.py +15 -4
  125. transformers/models/dab_detr/configuration_dab_detr.py +2 -2
  126. transformers/models/dab_detr/modeling_dab_detr.py +1 -1
  127. transformers/models/dac/modeling_dac.py +4 -4
  128. transformers/models/data2vec/modeling_data2vec_text.py +7 -0
  129. transformers/models/data2vec/modular_data2vec_text.py +7 -0
  130. transformers/models/dbrx/configuration_dbrx.py +9 -1
  131. transformers/models/dbrx/modeling_dbrx.py +1 -1
  132. transformers/models/deberta/modeling_deberta.py +2 -0
  133. transformers/models/deberta_v2/modeling_deberta_v2.py +2 -0
  134. transformers/models/decision_transformer/modeling_decision_transformer.py +8 -5
  135. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -4
  136. transformers/models/deepseek_v2/modular_deepseek_v2.py +4 -2
  137. transformers/models/deepseek_v3/modeling_deepseek_v3.py +9 -5
  138. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -2
  139. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +0 -1
  140. transformers/models/deepseek_vl/modeling_deepseek_vl.py +9 -5
  141. transformers/models/deepseek_vl/modular_deepseek_vl.py +3 -0
  142. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +0 -4
  143. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +9 -5
  144. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +9 -9
  145. transformers/models/deformable_detr/configuration_deformable_detr.py +2 -2
  146. transformers/models/deformable_detr/modeling_deformable_detr.py +1 -1
  147. transformers/models/depth_anything/configuration_depth_anything.py +2 -3
  148. transformers/models/depth_pro/image_processing_depth_pro_fast.py +0 -1
  149. transformers/models/detr/configuration_detr.py +1 -1
  150. transformers/models/detr/modeling_detr.py +8 -1
  151. transformers/models/dia/generation_dia.py +3 -10
  152. transformers/models/dia/modeling_dia.py +12 -1
  153. transformers/models/dia/modular_dia.py +11 -0
  154. transformers/models/dia/processing_dia.py +1 -1
  155. transformers/models/diffllama/modeling_diffllama.py +3 -3
  156. transformers/models/diffllama/modular_diffllama.py +2 -2
  157. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +0 -1
  158. transformers/models/dinov3_vit/modeling_dinov3_vit.py +3 -0
  159. transformers/models/dinov3_vit/modular_dinov3_vit.py +3 -0
  160. transformers/models/distilbert/modeling_distilbert.py +11 -9
  161. transformers/models/doge/modeling_doge.py +1 -1
  162. transformers/models/donut/image_processing_donut_fast.py +0 -1
  163. transformers/models/donut/modeling_donut_swin.py +16 -12
  164. transformers/models/dots1/modeling_dots1.py +14 -5
  165. transformers/models/dpt/configuration_dpt.py +1 -1
  166. transformers/models/dpt/image_processing_dpt_fast.py +1 -2
  167. transformers/models/dpt/modular_dpt.py +1 -2
  168. transformers/models/edgetam/configuration_edgetam.py +1 -1
  169. transformers/models/edgetam/modeling_edgetam.py +5 -2
  170. transformers/models/edgetam/modular_edgetam.py +15 -14
  171. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -43
  172. transformers/models/edgetam_video/modular_edgetam_video.py +13 -19
  173. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +1 -2
  174. transformers/models/efficientloftr/modeling_efficientloftr.py +14 -1
  175. transformers/models/efficientnet/image_processing_efficientnet.py +5 -6
  176. transformers/models/efficientnet/image_processing_efficientnet_fast.py +1 -2
  177. transformers/models/efficientnet/modeling_efficientnet.py +5 -1
  178. transformers/models/electra/modeling_electra.py +7 -0
  179. transformers/models/emu3/modeling_emu3.py +8 -2
  180. transformers/models/emu3/modular_emu3.py +7 -1
  181. transformers/models/encodec/modeling_encodec.py +14 -0
  182. transformers/models/eomt/image_processing_eomt_fast.py +46 -14
  183. transformers/models/eomt/modeling_eomt.py +7 -0
  184. transformers/models/eomt/modular_eomt.py +7 -0
  185. transformers/models/ernie/modeling_ernie.py +6 -0
  186. transformers/models/ernie/modular_ernie.py +6 -0
  187. transformers/models/ernie4_5/modeling_ernie4_5.py +1 -1
  188. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +16 -13
  189. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +9 -35
  190. transformers/models/ernie4_5_vl_moe/__init__.py +31 -0
  191. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +330 -0
  192. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +456 -0
  193. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +232 -0
  194. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +1898 -0
  195. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +1904 -0
  196. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +251 -0
  197. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +594 -0
  198. transformers/models/esm/modeling_esm.py +6 -0
  199. transformers/models/esm/modeling_esmfold.py +6 -1
  200. transformers/models/evolla/modeling_evolla.py +9 -1
  201. transformers/models/evolla/modular_evolla.py +8 -0
  202. transformers/models/exaone4/modeling_exaone4.py +1 -1
  203. transformers/models/falcon/modeling_falcon.py +3 -3
  204. transformers/models/falcon_h1/modeling_falcon_h1.py +28 -23
  205. transformers/models/falcon_h1/modular_falcon_h1.py +7 -2
  206. transformers/models/falcon_mamba/modeling_falcon_mamba.py +6 -2
  207. transformers/models/falcon_mamba/modular_falcon_mamba.py +7 -2
  208. transformers/models/fast_vlm/modeling_fast_vlm.py +7 -3
  209. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +23 -10
  210. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -0
  211. transformers/models/flaubert/modeling_flaubert.py +14 -15
  212. transformers/models/flava/image_processing_flava_fast.py +0 -2
  213. transformers/models/flava/modeling_flava.py +4 -1
  214. transformers/models/flex_olmo/modeling_flex_olmo.py +7 -4
  215. transformers/models/florence2/modeling_florence2.py +20 -3
  216. transformers/models/florence2/modular_florence2.py +13 -0
  217. transformers/models/fnet/modeling_fnet.py +7 -0
  218. transformers/models/fuyu/image_processing_fuyu.py +1 -1
  219. transformers/models/fuyu/modeling_fuyu.py +3 -1
  220. transformers/models/fuyu/processing_fuyu.py +16 -0
  221. transformers/models/gemma/modeling_gemma.py +10 -12
  222. transformers/models/gemma/modular_gemma.py +9 -11
  223. transformers/models/gemma2/modeling_gemma2.py +1 -1
  224. transformers/models/gemma2/modular_gemma2.py +1 -1
  225. transformers/models/gemma3/image_processing_gemma3_fast.py +0 -1
  226. transformers/models/gemma3/modeling_gemma3.py +28 -7
  227. transformers/models/gemma3/modular_gemma3.py +26 -6
  228. transformers/models/gemma3n/configuration_gemma3n.py +3 -0
  229. transformers/models/gemma3n/modeling_gemma3n.py +47 -9
  230. transformers/models/gemma3n/modular_gemma3n.py +51 -9
  231. transformers/models/git/modeling_git.py +181 -126
  232. transformers/models/glm/modeling_glm.py +1 -1
  233. transformers/models/glm4/modeling_glm4.py +1 -1
  234. transformers/models/glm46v/image_processing_glm46v.py +0 -4
  235. transformers/models/glm46v/modeling_glm46v.py +3 -1
  236. transformers/models/glm46v/modular_glm46v.py +3 -0
  237. transformers/models/glm4_moe/modeling_glm4_moe.py +9 -5
  238. transformers/models/glm4_moe/modular_glm4_moe.py +1 -1
  239. transformers/models/glm4v/image_processing_glm4v.py +0 -4
  240. transformers/models/glm4v/modeling_glm4v.py +15 -5
  241. transformers/models/glm4v/modular_glm4v.py +11 -3
  242. transformers/models/glm4v_moe/modeling_glm4v_moe.py +39 -23
  243. transformers/models/glm4v_moe/modular_glm4v_moe.py +12 -0
  244. transformers/models/glmasr/__init__.py +30 -0
  245. transformers/models/glmasr/configuration_glmasr.py +197 -0
  246. transformers/models/glmasr/modeling_glmasr.py +512 -0
  247. transformers/models/glmasr/modular_glmasr.py +433 -0
  248. transformers/models/glmasr/processing_glmasr.py +332 -0
  249. transformers/models/glpn/image_processing_glpn_fast.py +0 -1
  250. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +0 -1
  251. transformers/models/got_ocr2/modeling_got_ocr2.py +8 -3
  252. transformers/models/gpt2/modeling_gpt2.py +8 -5
  253. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +3 -8
  254. transformers/models/gpt_neo/modeling_gpt_neo.py +15 -3
  255. transformers/models/gpt_neox/modeling_gpt_neox.py +1 -1
  256. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +1 -1
  257. transformers/models/gpt_oss/configuration_gpt_oss.py +17 -0
  258. transformers/models/gpt_oss/modeling_gpt_oss.py +6 -9
  259. transformers/models/gpt_oss/modular_gpt_oss.py +5 -7
  260. transformers/models/gptj/modeling_gptj.py +15 -6
  261. transformers/models/granite/modeling_granite.py +1 -1
  262. transformers/models/granite_speech/modeling_granite_speech.py +15 -1
  263. transformers/models/granitemoe/modeling_granitemoe.py +2 -3
  264. transformers/models/granitemoe/modular_granitemoe.py +1 -2
  265. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +4 -0
  266. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +33 -23
  267. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +12 -2
  268. transformers/models/granitemoeshared/modeling_granitemoeshared.py +2 -3
  269. transformers/models/grounding_dino/configuration_grounding_dino.py +2 -3
  270. transformers/models/grounding_dino/modeling_grounding_dino.py +4 -4
  271. transformers/models/groupvit/modeling_groupvit.py +6 -1
  272. transformers/models/helium/modeling_helium.py +1 -1
  273. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -0
  274. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -0
  275. transformers/models/hubert/modeling_hubert.py +4 -0
  276. transformers/models/hubert/modular_hubert.py +4 -0
  277. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +1 -1
  278. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +1 -1
  279. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  280. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +12 -4
  281. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +4 -2
  282. transformers/models/ibert/modeling_ibert.py +16 -0
  283. transformers/models/idefics/modeling_idefics.py +10 -0
  284. transformers/models/idefics2/modeling_idefics2.py +7 -1
  285. transformers/models/idefics3/modeling_idefics3.py +5 -1
  286. transformers/models/imagegpt/image_processing_imagegpt_fast.py +1 -5
  287. transformers/models/imagegpt/modeling_imagegpt.py +9 -2
  288. transformers/models/instructblip/modeling_instructblip.py +2 -0
  289. transformers/models/instructblipvideo/modeling_instructblipvideo.py +52 -50
  290. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +0 -1
  291. transformers/models/internvl/modeling_internvl.py +11 -8
  292. transformers/models/internvl/modular_internvl.py +5 -9
  293. transformers/models/internvl/video_processing_internvl.py +0 -1
  294. transformers/models/jais2/__init__.py +27 -0
  295. transformers/models/jais2/configuration_jais2.py +152 -0
  296. transformers/models/jais2/modeling_jais2.py +486 -0
  297. transformers/models/jais2/modular_jais2.py +196 -0
  298. transformers/models/jamba/modeling_jamba.py +24 -19
  299. transformers/models/jamba/modular_jamba.py +17 -17
  300. transformers/models/janus/image_processing_janus_fast.py +0 -1
  301. transformers/models/janus/modeling_janus.py +15 -7
  302. transformers/models/janus/modular_janus.py +16 -7
  303. transformers/models/jetmoe/modeling_jetmoe.py +2 -2
  304. transformers/models/jetmoe/modular_jetmoe.py +1 -0
  305. transformers/models/kosmos2/modeling_kosmos2.py +14 -2
  306. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -2
  307. transformers/models/kosmos2_5/modeling_kosmos2_5.py +10 -1
  308. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +9 -3
  309. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +9 -1
  310. transformers/models/lasr/configuration_lasr.py +4 -0
  311. transformers/models/lasr/modeling_lasr.py +3 -2
  312. transformers/models/lasr/modular_lasr.py +8 -1
  313. transformers/models/lasr/processing_lasr.py +0 -2
  314. transformers/models/layoutlm/modeling_layoutlm.py +5 -3
  315. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +0 -1
  316. transformers/models/layoutlmv2/modeling_layoutlmv2.py +12 -0
  317. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +1 -0
  318. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +0 -1
  319. transformers/models/layoutlmv3/modeling_layoutlmv3.py +29 -5
  320. transformers/models/led/modeling_led.py +6 -0
  321. transformers/models/levit/modeling_levit.py +18 -0
  322. transformers/models/lfm2/modeling_lfm2.py +1 -1
  323. transformers/models/lfm2_moe/modeling_lfm2_moe.py +14 -4
  324. transformers/models/lfm2_moe/modular_lfm2_moe.py +5 -28
  325. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  326. transformers/models/lfm2_vl/modeling_lfm2_vl.py +11 -5
  327. transformers/models/lfm2_vl/modular_lfm2_vl.py +4 -2
  328. transformers/models/lfm2_vl/processing_lfm2_vl.py +82 -42
  329. transformers/models/lightglue/image_processing_lightglue_fast.py +1 -2
  330. transformers/models/lilt/modeling_lilt.py +19 -15
  331. transformers/models/llama/modeling_llama.py +1 -1
  332. transformers/models/llama4/image_processing_llama4_fast.py +1 -2
  333. transformers/models/llama4/modeling_llama4.py +8 -4
  334. transformers/models/llava/image_processing_llava_fast.py +0 -1
  335. transformers/models/llava/modeling_llava.py +12 -7
  336. transformers/models/llava_next/image_processing_llava_next_fast.py +0 -1
  337. transformers/models/llava_next/modeling_llava_next.py +7 -3
  338. transformers/models/llava_next_video/modeling_llava_next_video.py +7 -3
  339. transformers/models/llava_next_video/modular_llava_next_video.py +7 -3
  340. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +0 -1
  341. transformers/models/llava_onevision/modeling_llava_onevision.py +7 -3
  342. transformers/models/llava_onevision/modular_llava_onevision.py +7 -4
  343. transformers/models/longcat_flash/modeling_longcat_flash.py +2 -1
  344. transformers/models/longcat_flash/modular_longcat_flash.py +1 -0
  345. transformers/models/longt5/modeling_longt5.py +0 -4
  346. transformers/models/m2m_100/modeling_m2m_100.py +10 -0
  347. transformers/models/mamba/modeling_mamba.py +2 -1
  348. transformers/models/mamba2/modeling_mamba2.py +24 -23
  349. transformers/models/marian/configuration_marian.py +1 -1
  350. transformers/models/marian/modeling_marian.py +3 -0
  351. transformers/models/markuplm/modeling_markuplm.py +5 -8
  352. transformers/models/mask2former/configuration_mask2former.py +3 -3
  353. transformers/models/mask2former/image_processing_mask2former_fast.py +1 -4
  354. transformers/models/mask2former/modeling_mask2former.py +9 -0
  355. transformers/models/maskformer/configuration_maskformer.py +3 -3
  356. transformers/models/maskformer/image_processing_maskformer_fast.py +1 -4
  357. transformers/models/maskformer/modeling_maskformer.py +9 -1
  358. transformers/models/maskformer/modeling_maskformer_swin.py +19 -15
  359. transformers/models/mbart/configuration_mbart.py +1 -0
  360. transformers/models/mbart/modeling_mbart.py +7 -0
  361. transformers/models/megatron_bert/modeling_megatron_bert.py +2 -0
  362. transformers/models/metaclip_2/modeling_metaclip_2.py +2 -0
  363. transformers/models/metaclip_2/modular_metaclip_2.py +2 -0
  364. transformers/models/mimi/modeling_mimi.py +25 -4
  365. transformers/models/minimax/modeling_minimax.py +16 -3
  366. transformers/models/minimax/modular_minimax.py +12 -1
  367. transformers/models/ministral/modeling_ministral.py +1 -1
  368. transformers/models/ministral3/modeling_ministral3.py +1 -1
  369. transformers/models/mistral/modeling_mistral.py +1 -1
  370. transformers/models/mistral3/modeling_mistral3.py +10 -4
  371. transformers/models/mistral3/modular_mistral3.py +3 -1
  372. transformers/models/mixtral/modeling_mixtral.py +12 -4
  373. transformers/models/mixtral/modular_mixtral.py +6 -2
  374. transformers/models/mlcd/modeling_mlcd.py +6 -0
  375. transformers/models/mlcd/modular_mlcd.py +4 -0
  376. transformers/models/mllama/modeling_mllama.py +13 -2
  377. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +1 -2
  378. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +4 -4
  379. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +1 -2
  380. transformers/models/mobilebert/modeling_mobilebert.py +2 -0
  381. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +0 -1
  382. transformers/models/mobilevit/image_processing_mobilevit.py +5 -5
  383. transformers/models/mobilevit/image_processing_mobilevit_fast.py +1 -2
  384. transformers/models/mobilevit/modeling_mobilevit.py +4 -0
  385. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -0
  386. transformers/models/modernbert/modeling_modernbert.py +12 -1
  387. transformers/models/modernbert/modular_modernbert.py +12 -1
  388. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -1
  389. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +9 -1
  390. transformers/models/moonshine/modeling_moonshine.py +1 -1
  391. transformers/models/moshi/modeling_moshi.py +21 -51
  392. transformers/models/mpnet/modeling_mpnet.py +2 -0
  393. transformers/models/mra/modeling_mra.py +4 -1
  394. transformers/models/mt5/configuration_mt5.py +2 -3
  395. transformers/models/mt5/modeling_mt5.py +0 -10
  396. transformers/models/musicgen/modeling_musicgen.py +5 -9
  397. transformers/models/musicgen_melody/modeling_musicgen_melody.py +4 -0
  398. transformers/models/mvp/modeling_mvp.py +7 -0
  399. transformers/models/nanochat/modeling_nanochat.py +1 -1
  400. transformers/models/nemotron/modeling_nemotron.py +3 -3
  401. transformers/models/nllb_moe/configuration_nllb_moe.py +1 -0
  402. transformers/models/nllb_moe/modeling_nllb_moe.py +10 -0
  403. transformers/models/nougat/image_processing_nougat_fast.py +0 -1
  404. transformers/models/nougat/tokenization_nougat.py +11 -16
  405. transformers/models/nystromformer/modeling_nystromformer.py +7 -0
  406. transformers/models/olmo/modeling_olmo.py +1 -1
  407. transformers/models/olmo2/modeling_olmo2.py +1 -1
  408. transformers/models/olmo3/modeling_olmo3.py +1 -1
  409. transformers/models/olmoe/modeling_olmoe.py +12 -4
  410. transformers/models/olmoe/modular_olmoe.py +4 -2
  411. transformers/models/omdet_turbo/configuration_omdet_turbo.py +2 -2
  412. transformers/models/omdet_turbo/modeling_omdet_turbo.py +4 -0
  413. transformers/models/oneformer/configuration_oneformer.py +3 -3
  414. transformers/models/oneformer/modeling_oneformer.py +7 -38
  415. transformers/models/openai/modeling_openai.py +12 -0
  416. transformers/models/ovis2/image_processing_ovis2_fast.py +0 -1
  417. transformers/models/ovis2/modeling_ovis2.py +15 -3
  418. transformers/models/ovis2/modular_ovis2.py +8 -0
  419. transformers/models/owlv2/image_processing_owlv2_fast.py +0 -2
  420. transformers/models/owlv2/modeling_owlv2.py +7 -3
  421. transformers/models/owlv2/modular_owlv2.py +0 -2
  422. transformers/models/owlvit/modeling_owlvit.py +7 -3
  423. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +3 -2
  424. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +28 -14
  425. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +22 -12
  426. transformers/models/paligemma/modeling_paligemma.py +25 -17
  427. transformers/models/parakeet/modeling_parakeet.py +5 -0
  428. transformers/models/parakeet/modular_parakeet.py +5 -0
  429. transformers/models/parakeet/{tokenization_parakeet_fast.py → tokenization_parakeet.py} +3 -3
  430. transformers/models/patchtsmixer/modeling_patchtsmixer.py +4 -0
  431. transformers/models/patchtst/modeling_patchtst.py +5 -4
  432. transformers/models/pe_audio/__init__.py +30 -0
  433. transformers/models/pe_audio/configuration_pe_audio.py +206 -0
  434. transformers/models/pe_audio/feature_extraction_pe_audio.py +162 -0
  435. transformers/models/pe_audio/modeling_pe_audio.py +820 -0
  436. transformers/models/pe_audio/modular_pe_audio.py +299 -0
  437. transformers/models/pe_audio/processing_pe_audio.py +24 -0
  438. transformers/models/pe_audio_video/__init__.py +29 -0
  439. transformers/models/pe_audio_video/configuration_pe_audio_video.py +225 -0
  440. transformers/models/pe_audio_video/modeling_pe_audio_video.py +972 -0
  441. transformers/models/pe_audio_video/modular_pe_audio_video.py +764 -0
  442. transformers/models/pe_audio_video/processing_pe_audio_video.py +25 -0
  443. transformers/models/pe_video/__init__.py +30 -0
  444. transformers/models/pe_video/configuration_pe_video.py +211 -0
  445. transformers/models/pe_video/modeling_pe_video.py +636 -0
  446. transformers/models/pe_video/modular_pe_video.py +219 -0
  447. transformers/models/pe_video/processing_pe_video.py +10 -0
  448. transformers/models/pe_video/video_processing_pe_video.py +66 -0
  449. transformers/models/pegasus/configuration_pegasus.py +1 -0
  450. transformers/models/pegasus/modeling_pegasus.py +3 -0
  451. transformers/models/pegasus_x/modeling_pegasus_x.py +1 -0
  452. transformers/models/perceiver/image_processing_perceiver_fast.py +0 -1
  453. transformers/models/perceiver/modeling_perceiver.py +5 -1
  454. transformers/models/perception_lm/image_processing_perception_lm_fast.py +0 -1
  455. transformers/models/perception_lm/modeling_perception_lm.py +7 -3
  456. transformers/models/perception_lm/modular_perception_lm.py +7 -3
  457. transformers/models/persimmon/modeling_persimmon.py +1 -1
  458. transformers/models/phi/modeling_phi.py +1 -1
  459. transformers/models/phi3/modeling_phi3.py +1 -1
  460. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +4 -1
  461. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +3 -0
  462. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +0 -2
  463. transformers/models/phimoe/modeling_phimoe.py +12 -4
  464. transformers/models/phimoe/modular_phimoe.py +1 -1
  465. transformers/models/pix2struct/processing_pix2struct.py +0 -4
  466. transformers/models/pixio/__init__.py +30 -0
  467. transformers/models/pixio/configuration_pixio.py +151 -0
  468. transformers/models/pixio/modeling_pixio.py +507 -0
  469. transformers/models/pixio/modular_pixio.py +404 -0
  470. transformers/models/pixtral/modeling_pixtral.py +1 -1
  471. transformers/models/pixtral/processing_pixtral.py +3 -1
  472. transformers/models/plbart/configuration_plbart.py +1 -0
  473. transformers/models/plbart/modeling_plbart.py +7 -0
  474. transformers/models/plbart/modular_plbart.py +6 -0
  475. transformers/models/poolformer/image_processing_poolformer_fast.py +0 -1
  476. transformers/models/poolformer/modeling_poolformer.py +11 -1
  477. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  478. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +2 -3
  479. transformers/models/prophetnet/modeling_prophetnet.py +2 -1
  480. transformers/models/qwen2/modeling_qwen2.py +1 -1
  481. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +104 -64
  482. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +58 -18
  483. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +18 -5
  484. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +26 -22
  485. transformers/models/qwen2_audio/modeling_qwen2_audio.py +2 -2
  486. transformers/models/qwen2_moe/modeling_qwen2_moe.py +12 -4
  487. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +3 -2
  488. transformers/models/qwen2_vl/modeling_qwen2_vl.py +17 -4
  489. transformers/models/qwen3/modeling_qwen3.py +1 -1
  490. transformers/models/qwen3_moe/modeling_qwen3_moe.py +12 -4
  491. transformers/models/qwen3_next/modeling_qwen3_next.py +4 -6
  492. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +4 -0
  493. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +92 -46
  494. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +48 -4
  495. transformers/models/qwen3_vl/configuration_qwen3_vl.py +5 -5
  496. transformers/models/qwen3_vl/modeling_qwen3_vl.py +17 -4
  497. transformers/models/qwen3_vl/modular_qwen3_vl.py +21 -10
  498. transformers/models/qwen3_vl/processing_qwen3_vl.py +3 -3
  499. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +94 -112
  500. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +32 -81
  501. transformers/models/rag/configuration_rag.py +0 -8
  502. transformers/models/rag/modeling_rag.py +7 -9
  503. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +3 -2
  504. transformers/models/reformer/modeling_reformer.py +9 -1
  505. transformers/models/regnet/modeling_regnet.py +4 -0
  506. transformers/models/rembert/modeling_rembert.py +7 -1
  507. transformers/models/resnet/modeling_resnet.py +8 -3
  508. transformers/models/roberta/modeling_roberta.py +3 -0
  509. transformers/models/roberta/modular_roberta.py +3 -0
  510. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +3 -0
  511. transformers/models/roc_bert/modeling_roc_bert.py +3 -0
  512. transformers/models/rt_detr/configuration_rt_detr.py +1 -1
  513. transformers/models/rt_detr/modeling_rt_detr.py +4 -0
  514. transformers/models/rt_detr/modeling_rt_detr_resnet.py +8 -3
  515. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +2 -3
  516. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +7 -0
  517. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +8 -3
  518. transformers/models/rwkv/modeling_rwkv.py +1 -1
  519. transformers/models/sam/configuration_sam.py +1 -0
  520. transformers/models/sam/image_processing_sam_fast.py +0 -1
  521. transformers/models/sam/modeling_sam.py +4 -1
  522. transformers/models/sam2/configuration_sam2.py +1 -1
  523. transformers/models/sam2/modeling_sam2.py +5 -1
  524. transformers/models/sam2/modular_sam2.py +5 -1
  525. transformers/models/sam2_video/modeling_sam2_video.py +51 -43
  526. transformers/models/sam2_video/modular_sam2_video.py +31 -18
  527. transformers/models/sam3/configuration_sam3.py +21 -1
  528. transformers/models/sam3/modeling_sam3.py +23 -0
  529. transformers/models/sam3_tracker/modeling_sam3_tracker.py +2 -0
  530. transformers/models/sam3_tracker/modular_sam3_tracker.py +2 -0
  531. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +25 -0
  532. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +26 -15
  533. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +25 -2
  534. transformers/models/sam3_video/configuration_sam3_video.py +14 -0
  535. transformers/models/sam3_video/modeling_sam3_video.py +3 -3
  536. transformers/models/sam3_video/processing_sam3_video.py +1 -1
  537. transformers/models/sam_hq/configuration_sam_hq.py +1 -0
  538. transformers/models/sam_hq/modeling_sam_hq.py +26 -23
  539. transformers/models/seamless_m4t/modeling_seamless_m4t.py +27 -11
  540. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +6 -0
  541. transformers/models/seed_oss/modeling_seed_oss.py +1 -1
  542. transformers/models/segformer/image_processing_segformer_fast.py +0 -1
  543. transformers/models/segformer/modeling_segformer.py +2 -2
  544. transformers/models/segformer/modular_segformer.py +0 -1
  545. transformers/models/shieldgemma2/modeling_shieldgemma2.py +1 -0
  546. transformers/models/siglip/modeling_siglip.py +24 -2
  547. transformers/models/siglip2/modeling_siglip2.py +63 -41
  548. transformers/models/smollm3/modeling_smollm3.py +1 -1
  549. transformers/models/smolvlm/modeling_smolvlm.py +5 -1
  550. transformers/models/smolvlm/video_processing_smolvlm.py +0 -1
  551. transformers/models/speech_to_text/modeling_speech_to_text.py +10 -0
  552. transformers/models/speecht5/modeling_speecht5.py +28 -0
  553. transformers/models/splinter/modeling_splinter.py +9 -3
  554. transformers/models/squeezebert/modeling_squeezebert.py +2 -0
  555. transformers/models/stablelm/modeling_stablelm.py +1 -1
  556. transformers/models/starcoder2/modeling_starcoder2.py +1 -1
  557. transformers/models/superglue/image_processing_superglue_fast.py +1 -2
  558. transformers/models/superpoint/image_processing_superpoint_fast.py +1 -2
  559. transformers/models/swiftformer/modeling_swiftformer.py +4 -0
  560. transformers/models/swin/modeling_swin.py +16 -12
  561. transformers/models/swin2sr/image_processing_swin2sr_fast.py +0 -1
  562. transformers/models/swin2sr/modeling_swin2sr.py +49 -33
  563. transformers/models/swinv2/modeling_swinv2.py +41 -33
  564. transformers/models/switch_transformers/modeling_switch_transformers.py +2 -8
  565. transformers/models/switch_transformers/modular_switch_transformers.py +2 -8
  566. transformers/models/t5/configuration_t5.py +7 -1
  567. transformers/models/t5/modeling_t5.py +1 -7
  568. transformers/models/t5gemma/modeling_t5gemma.py +1 -1
  569. transformers/models/t5gemma2/configuration_t5gemma2.py +6 -42
  570. transformers/models/t5gemma2/modeling_t5gemma2.py +13 -4
  571. transformers/models/t5gemma2/modular_t5gemma2.py +289 -4
  572. transformers/models/table_transformer/configuration_table_transformer.py +1 -1
  573. transformers/models/table_transformer/modeling_table_transformer.py +1 -1
  574. transformers/models/textnet/image_processing_textnet_fast.py +0 -1
  575. transformers/models/timesfm/modeling_timesfm.py +12 -0
  576. transformers/models/timesfm/modular_timesfm.py +12 -0
  577. transformers/models/timm_backbone/modeling_timm_backbone.py +13 -9
  578. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -0
  579. transformers/models/timm_wrapper/modeling_timm_wrapper.py +19 -13
  580. transformers/models/trocr/modeling_trocr.py +1 -2
  581. transformers/models/tvp/configuration_tvp.py +5 -1
  582. transformers/models/tvp/modeling_tvp.py +4 -4
  583. transformers/models/udop/configuration_udop.py +1 -0
  584. transformers/models/udop/modeling_udop.py +3 -7
  585. transformers/models/umt5/configuration_umt5.py +2 -2
  586. transformers/models/umt5/modeling_umt5.py +0 -6
  587. transformers/models/vaultgemma/modeling_vaultgemma.py +1 -1
  588. transformers/models/video_llama_3/image_processing_video_llama_3.py +3 -2
  589. transformers/models/video_llama_3/modeling_video_llama_3.py +12 -1
  590. transformers/models/video_llama_3/modular_video_llama_3.py +10 -1
  591. transformers/models/video_llava/modeling_video_llava.py +7 -3
  592. transformers/models/vilt/configuration_vilt.py +2 -2
  593. transformers/models/vilt/modeling_vilt.py +7 -0
  594. transformers/models/vipllava/modeling_vipllava.py +7 -3
  595. transformers/models/visual_bert/modeling_visual_bert.py +2 -0
  596. transformers/models/vitmatte/configuration_vitmatte.py +1 -1
  597. transformers/models/vitmatte/image_processing_vitmatte_fast.py +0 -1
  598. transformers/models/vitmatte/modeling_vitmatte.py +4 -0
  599. transformers/models/vitpose/configuration_vitpose.py +1 -1
  600. transformers/models/vitpose/image_processing_vitpose_fast.py +0 -1
  601. transformers/models/voxtral/modeling_voxtral.py +2 -2
  602. transformers/models/voxtral/modular_voxtral.py +2 -2
  603. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +16 -10
  604. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +7 -0
  605. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +21 -11
  606. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +21 -11
  607. transformers/models/whisper/generation_whisper.py +1 -0
  608. transformers/models/whisper/modeling_whisper.py +5 -3
  609. transformers/models/x_clip/modeling_x_clip.py +2 -0
  610. transformers/models/xcodec/modeling_xcodec.py +5 -0
  611. transformers/models/xglm/modeling_xglm.py +10 -0
  612. transformers/models/xlm/modeling_xlm.py +13 -14
  613. transformers/models/xlm_roberta/modeling_xlm_roberta.py +109 -106
  614. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +3 -0
  615. transformers/models/xlnet/modeling_xlnet.py +3 -1
  616. transformers/models/xmod/modeling_xmod.py +3 -0
  617. transformers/models/yoso/modeling_yoso.py +4 -1
  618. transformers/models/zamba/modeling_zamba.py +2 -1
  619. transformers/models/zamba2/modeling_zamba2.py +3 -2
  620. transformers/models/zoedepth/configuration_zoedepth.py +1 -1
  621. transformers/models/zoedepth/image_processing_zoedepth_fast.py +1 -3
  622. transformers/models/zoedepth/modeling_zoedepth.py +7 -0
  623. transformers/pipelines/__init__.py +9 -6
  624. transformers/pipelines/automatic_speech_recognition.py +20 -12
  625. transformers/pipelines/base.py +1 -1
  626. transformers/pipelines/document_question_answering.py +1 -1
  627. transformers/pipelines/question_answering.py +1 -1
  628. transformers/pipelines/text_to_audio.py +2 -2
  629. transformers/processing_utils.py +127 -56
  630. transformers/quantizers/auto.py +2 -4
  631. transformers/quantizers/base.py +9 -64
  632. transformers/quantizers/quantizer_aqlm.py +1 -18
  633. transformers/quantizers/quantizer_auto_round.py +1 -10
  634. transformers/quantizers/quantizer_awq.py +3 -8
  635. transformers/quantizers/quantizer_bitnet.py +1 -6
  636. transformers/quantizers/quantizer_bnb_4bit.py +9 -49
  637. transformers/quantizers/quantizer_bnb_8bit.py +9 -19
  638. transformers/quantizers/quantizer_compressed_tensors.py +1 -4
  639. transformers/quantizers/quantizer_eetq.py +2 -12
  640. transformers/quantizers/quantizer_fbgemm_fp8.py +5 -14
  641. transformers/quantizers/quantizer_finegrained_fp8.py +15 -10
  642. transformers/quantizers/quantizer_fp_quant.py +4 -4
  643. transformers/quantizers/quantizer_gptq.py +1 -4
  644. transformers/quantizers/quantizer_higgs.py +2 -6
  645. transformers/quantizers/quantizer_mxfp4.py +2 -28
  646. transformers/quantizers/quantizer_quanto.py +14 -14
  647. transformers/quantizers/quantizer_spqr.py +3 -8
  648. transformers/quantizers/quantizer_torchao.py +28 -124
  649. transformers/quantizers/quantizer_vptq.py +1 -10
  650. transformers/testing_utils.py +28 -12
  651. transformers/tokenization_mistral_common.py +3 -2
  652. transformers/tokenization_utils_base.py +3 -2
  653. transformers/tokenization_utils_tokenizers.py +25 -2
  654. transformers/trainer.py +24 -2
  655. transformers/trainer_callback.py +8 -0
  656. transformers/trainer_seq2seq.py +4 -0
  657. transformers/training_args.py +8 -10
  658. transformers/utils/__init__.py +4 -0
  659. transformers/utils/attention_visualizer.py +4 -4
  660. transformers/utils/auto_docstring.py +34 -25
  661. transformers/utils/generic.py +20 -0
  662. transformers/utils/import_utils.py +51 -9
  663. transformers/utils/kernel_config.py +71 -18
  664. transformers/utils/quantization_config.py +8 -8
  665. transformers/video_processing_utils.py +16 -12
  666. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/METADATA +5 -6
  667. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/RECORD +671 -632
  668. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/WHEEL +0 -0
  669. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/entry_points.txt +0 -0
  670. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/licenses/LICENSE +0 -0
  671. {transformers-5.0.0rc1.dist-info → transformers-5.0.0rc2.dist-info}/top_level.txt +0 -0
@@ -27,6 +27,7 @@ import torch.distributions
27
27
  import torch.nn as nn
28
28
  import torch.nn.functional as F
29
29
 
30
+ from ... import initialization as init
30
31
  from ...activations import ACT2FN
31
32
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
32
33
  from ...generation import GenerationMixin
@@ -102,7 +103,7 @@ class BltRotaryEmbedding(nn.Module):
102
103
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
103
104
 
104
105
  self.register_buffer("inv_freq", inv_freq, persistent=False)
105
- self.original_inv_freq = inv_freq
106
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
106
107
 
107
108
  @staticmethod
108
109
  def compute_default_rope_parameters(
@@ -444,6 +445,163 @@ class BltPreTrainedModel(PreTrainedModel):
444
445
  "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
445
446
  }
446
447
 
448
+ @torch.no_grad()
449
+ def _init_weights(self, module):
450
+ """
451
+ Initialize BLT weights following the original ByteLatentTransformer:
452
+
453
+ - Most weights are drawn from a truncated normal.
454
+ - Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs).
455
+ - Norm layers are set to weight = 1, bias = 0.
456
+ """
457
+ class_name = module.__class__.__name__
458
+
459
+ # Norms: RMSNorm / LayerNorm
460
+ if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name:
461
+ if getattr(module, "weight", None) is not None:
462
+ init.ones_(module.weight)
463
+ if getattr(module, "bias", None) is not None:
464
+ init.zeros_(module.bias)
465
+ return
466
+
467
+ # Embeddings (encoder / patcher / hash embeddings)
468
+ if isinstance(module, nn.Embedding):
469
+ hidden_size = getattr(self.config, "hidden_size", None)
470
+ if hidden_size is None and hasattr(self.config, "encoder_config"):
471
+ hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
472
+ if hidden_size is None:
473
+ hidden_size = module.embedding_dim
474
+
475
+ std = hidden_size**-0.5
476
+ init.trunc_normal_(
477
+ module.weight,
478
+ mean=0.0,
479
+ std=std,
480
+ a=-3 * std,
481
+ b=3 * std,
482
+ )
483
+ if module.padding_idx is not None:
484
+ init.zeros_(module.weight[module.padding_idx])
485
+ return
486
+
487
+ # Self-attention / cross-attention projections
488
+ if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in (
489
+ "MllamaTextSelfAttention",
490
+ "MllamaTextCrossAttention",
491
+ ):
492
+ dim = getattr(self.config, "hidden_size", None)
493
+ if dim is None and hasattr(module, "hidden_size"):
494
+ dim = module.hidden_size
495
+ if dim is None:
496
+ for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"):
497
+ proj = getattr(module, name, None)
498
+ if proj is not None and hasattr(proj, "weight"):
499
+ dim = proj.weight.shape[-1]
500
+ break
501
+ if dim is None:
502
+ return
503
+
504
+ std = dim**-0.5
505
+
506
+ # Input projections (q, k, v)
507
+ for proj_name in ("q_proj", "k_proj", "v_proj"):
508
+ proj = getattr(module, proj_name, None)
509
+ if proj is not None and hasattr(proj, "weight"):
510
+ init.trunc_normal_(
511
+ proj.weight,
512
+ mean=0.0,
513
+ std=std,
514
+ a=-3 * std,
515
+ b=3 * std,
516
+ )
517
+ if getattr(proj, "bias", None) is not None:
518
+ init.zeros_(proj.bias)
519
+
520
+ # Output projection: o_proj or dense
521
+ o_proj = getattr(module, "o_proj", getattr(module, "dense", None))
522
+ if o_proj is not None and hasattr(o_proj, "weight"):
523
+ init.trunc_normal_(
524
+ o_proj.weight,
525
+ mean=0.0,
526
+ std=std,
527
+ a=-3 * std,
528
+ b=3 * std,
529
+ )
530
+ if getattr(o_proj, "bias", None) is not None:
531
+ init.zeros_(o_proj.bias)
532
+ return
533
+
534
+ # MLP / FFN blocks
535
+ if isinstance(module, BltMLP) or class_name == "MllamaTextMLP":
536
+ hidden_size = getattr(self.config, "hidden_size", None)
537
+ if hidden_size is None and hasattr(self.config, "decoder_config"):
538
+ hidden_size = getattr(self.config.decoder_config, "hidden_size", None)
539
+ if hidden_size is None and hasattr(self.config, "encoder_config"):
540
+ hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
541
+
542
+ # Input-side std
543
+ in_std = None
544
+ if hidden_size is not None:
545
+ in_std = hidden_size**-0.5
546
+
547
+ gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None))
548
+ up_proj = getattr(module, "up_proj", None)
549
+ down_proj = getattr(module, "down_proj", getattr(module, "fc2", None))
550
+
551
+ # gate / input projections
552
+ for proj in (gate_proj, up_proj):
553
+ if proj is not None and hasattr(proj, "weight"):
554
+ std = in_std or (proj.weight.shape[1] ** -0.5)
555
+ init.trunc_normal_(
556
+ proj.weight,
557
+ mean=0.0,
558
+ std=std,
559
+ a=-3 * std,
560
+ b=3 * std,
561
+ )
562
+ if getattr(proj, "bias", None) is not None:
563
+ init.zeros_(proj.bias)
564
+
565
+ # output/ down projections
566
+ if down_proj is not None and hasattr(down_proj, "weight"):
567
+ hidden_dim = down_proj.weight.shape[1]
568
+ out_std = hidden_dim**-0.5
569
+ init.trunc_normal_(
570
+ down_proj.weight,
571
+ mean=0.0,
572
+ std=out_std,
573
+ a=-3 * out_std,
574
+ b=3 * out_std,
575
+ )
576
+ if getattr(down_proj, "bias", None) is not None:
577
+ init.zeros_(down_proj.bias)
578
+ return
579
+
580
+ # Generic Linear layers (projections, lm_head, etc.)
581
+ if isinstance(module, nn.Linear):
582
+ fan_in = module.in_features
583
+ std = fan_in**-0.5
584
+ init.trunc_normal_(
585
+ module.weight,
586
+ mean=0.0,
587
+ std=std,
588
+ a=-3 * std,
589
+ b=3 * std,
590
+ )
591
+ if module.bias is not None:
592
+ init.zeros_(module.bias)
593
+ return
594
+
595
+ if isinstance(module, BltRotaryEmbedding):
596
+ rope_fn = (
597
+ ROPE_INIT_FUNCTIONS[module.rope_type]
598
+ if module.rope_type != "default"
599
+ else module.compute_default_rope_parameters
600
+ )
601
+ buffer_value, _ = rope_fn(module.config)
602
+ init.copy_(module.inv_freq, buffer_value)
603
+ init.copy_(module.original_inv_freq, buffer_value)
604
+
447
605
 
448
606
  class BltLocalEncoder(BltPreTrainedModel):
449
607
  config: BltLocalEncoderConfig
@@ -753,6 +911,8 @@ class BltPatcher(BltPreTrainedModel):
753
911
  bias=False,
754
912
  )
755
913
 
914
+ self.post_init()
915
+
756
916
  def forward(
757
917
  self,
758
918
  input_ids: Optional[torch.LongTensor] = None,
@@ -952,7 +1112,7 @@ def compute_hash_embeddings(
952
1112
  hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
953
1113
  # Apply offset to get the correct slice of the fused embedding
954
1114
  offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
955
- embeddings += encoder_hash_tok_embedding(offset_hash_ids)
1115
+ embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
956
1116
  embedding_idx += 1
957
1117
 
958
1118
  return embeddings
@@ -22,10 +22,11 @@ import torch.distributions
22
22
  import torch.nn as nn
23
23
  import torch.nn.functional as F
24
24
 
25
+ from ... import initialization as init
25
26
  from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
26
27
  from ...masking_utils import create_causal_mask
27
28
  from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast
28
- from ...modeling_rope_utils import dynamic_rope_update
29
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
29
30
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS
30
31
  from ...processing_utils import Unpack
31
32
  from ...utils import TransformersKwargs, auto_docstring, logging
@@ -133,7 +134,7 @@ def compute_hash_embeddings(
133
134
  hash_ids = byte_group_hash_function(local_encoder_tokens, group_size, prime, encoder_hash_byte_group_vocab)
134
135
  # Apply offset to get the correct slice of the fused embedding
135
136
  offset_hash_ids = hash_ids + embedding_idx * encoder_hash_byte_group_vocab
136
- embeddings += encoder_hash_tok_embedding(offset_hash_ids)
137
+ embeddings += encoder_hash_tok_embedding(offset_hash_ids).to(embeddings.device)
137
138
  embedding_idx += 1
138
139
 
139
140
  return embeddings
@@ -360,8 +361,170 @@ class BltPreTrainedModel(MllamaPreTrainedModel):
360
361
  "attentions": OutputRecorder(BltSelfAttention, index=1, layer_name="local_decoder"),
361
362
  }
362
363
 
364
+ # Weight initialization is adapted from:
365
+ # - https://github.com/facebookresearch/blt/blob/main/bytelatent/model/blt.py
366
+ # - https://github.com/pytorch/torchtitan/blob/main/torchtitan/experiments/transformers_modeling_backend/model/model.py
367
+ #
368
+ # Both implementations use truncated normal initialization with std ~ 1 / sqrt(d_model)
369
+ # (or 1 / sqrt(hidden_dim) for FFN outputs), and unit initialization for normalization layers.
370
+ # We follow the same scheme here, but expressed in the Transformers APIs.
371
+
372
+ @torch.no_grad()
363
373
  def _init_weights(self, module):
364
- raise AttributeError("No need to inherit it!")
374
+ """
375
+ Initialize BLT weights following the original ByteLatentTransformer:
376
+
377
+ - Most weights are drawn from a truncated normal.
378
+ - Scale is ~ 1 / sqrt(model_dim) (or 1 / sqrt(hidden_dim) for FFN outputs).
379
+ - Norm layers are set to weight = 1, bias = 0.
380
+ """
381
+ class_name = module.__class__.__name__
382
+
383
+ # Norms: RMSNorm / LayerNorm
384
+ if isinstance(module, (BltRMSNorm, nn.LayerNorm)) or "RMSNorm" in class_name or "LayerNorm" in class_name:
385
+ if getattr(module, "weight", None) is not None:
386
+ init.ones_(module.weight)
387
+ if getattr(module, "bias", None) is not None:
388
+ init.zeros_(module.bias)
389
+ return
390
+
391
+ # Embeddings (encoder / patcher / hash embeddings)
392
+ if isinstance(module, nn.Embedding):
393
+ hidden_size = getattr(self.config, "hidden_size", None)
394
+ if hidden_size is None and hasattr(self.config, "encoder_config"):
395
+ hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
396
+ if hidden_size is None:
397
+ hidden_size = module.embedding_dim
398
+
399
+ std = hidden_size**-0.5
400
+ init.trunc_normal_(
401
+ module.weight,
402
+ mean=0.0,
403
+ std=std,
404
+ a=-3 * std,
405
+ b=3 * std,
406
+ )
407
+ if module.padding_idx is not None:
408
+ init.zeros_(module.weight[module.padding_idx])
409
+ return
410
+
411
+ # Self-attention / cross-attention projections
412
+ if isinstance(module, (BltSelfAttention, BltCrossAttention)) or class_name in (
413
+ "MllamaTextSelfAttention",
414
+ "MllamaTextCrossAttention",
415
+ ):
416
+ dim = getattr(self.config, "hidden_size", None)
417
+ if dim is None and hasattr(module, "hidden_size"):
418
+ dim = module.hidden_size
419
+ if dim is None:
420
+ for name in ("q_proj", "k_proj", "v_proj", "o_proj", "dense"):
421
+ proj = getattr(module, name, None)
422
+ if proj is not None and hasattr(proj, "weight"):
423
+ dim = proj.weight.shape[-1]
424
+ break
425
+ if dim is None:
426
+ return
427
+
428
+ std = dim**-0.5
429
+
430
+ # Input projections (q, k, v)
431
+ for proj_name in ("q_proj", "k_proj", "v_proj"):
432
+ proj = getattr(module, proj_name, None)
433
+ if proj is not None and hasattr(proj, "weight"):
434
+ init.trunc_normal_(
435
+ proj.weight,
436
+ mean=0.0,
437
+ std=std,
438
+ a=-3 * std,
439
+ b=3 * std,
440
+ )
441
+ if getattr(proj, "bias", None) is not None:
442
+ init.zeros_(proj.bias)
443
+
444
+ # Output projection: o_proj or dense
445
+ o_proj = getattr(module, "o_proj", getattr(module, "dense", None))
446
+ if o_proj is not None and hasattr(o_proj, "weight"):
447
+ init.trunc_normal_(
448
+ o_proj.weight,
449
+ mean=0.0,
450
+ std=std,
451
+ a=-3 * std,
452
+ b=3 * std,
453
+ )
454
+ if getattr(o_proj, "bias", None) is not None:
455
+ init.zeros_(o_proj.bias)
456
+ return
457
+
458
+ # MLP / FFN blocks
459
+ if isinstance(module, BltMLP) or class_name == "MllamaTextMLP":
460
+ hidden_size = getattr(self.config, "hidden_size", None)
461
+ if hidden_size is None and hasattr(self.config, "decoder_config"):
462
+ hidden_size = getattr(self.config.decoder_config, "hidden_size", None)
463
+ if hidden_size is None and hasattr(self.config, "encoder_config"):
464
+ hidden_size = getattr(self.config.encoder_config, "hidden_size", None)
465
+
466
+ # Input-side std
467
+ in_std = None
468
+ if hidden_size is not None:
469
+ in_std = hidden_size**-0.5
470
+
471
+ gate_proj = getattr(module, "gate_proj", getattr(module, "fc1", None))
472
+ up_proj = getattr(module, "up_proj", None)
473
+ down_proj = getattr(module, "down_proj", getattr(module, "fc2", None))
474
+
475
+ # gate / input projections
476
+ for proj in (gate_proj, up_proj):
477
+ if proj is not None and hasattr(proj, "weight"):
478
+ std = in_std or (proj.weight.shape[1] ** -0.5)
479
+ init.trunc_normal_(
480
+ proj.weight,
481
+ mean=0.0,
482
+ std=std,
483
+ a=-3 * std,
484
+ b=3 * std,
485
+ )
486
+ if getattr(proj, "bias", None) is not None:
487
+ init.zeros_(proj.bias)
488
+
489
+ # output/ down projections
490
+ if down_proj is not None and hasattr(down_proj, "weight"):
491
+ hidden_dim = down_proj.weight.shape[1]
492
+ out_std = hidden_dim**-0.5
493
+ init.trunc_normal_(
494
+ down_proj.weight,
495
+ mean=0.0,
496
+ std=out_std,
497
+ a=-3 * out_std,
498
+ b=3 * out_std,
499
+ )
500
+ if getattr(down_proj, "bias", None) is not None:
501
+ init.zeros_(down_proj.bias)
502
+ return
503
+
504
+ # Generic Linear layers (projections, lm_head, etc.)
505
+ if isinstance(module, nn.Linear):
506
+ fan_in = module.in_features
507
+ std = fan_in**-0.5
508
+ init.trunc_normal_(
509
+ module.weight,
510
+ mean=0.0,
511
+ std=std,
512
+ a=-3 * std,
513
+ b=3 * std,
514
+ )
515
+ if module.bias is not None:
516
+ init.zeros_(module.bias)
517
+ return
518
+
519
+ if isinstance(module, BltRotaryEmbedding):
520
+ rope_fn = (
521
+ ROPE_INIT_FUNCTIONS[module.rope_type]
522
+ if module.rope_type != "default"
523
+ else module.compute_default_rope_parameters
524
+ )
525
+ buffer_value, _ = rope_fn(module.config)
526
+ init.copy_(module.inv_freq, buffer_value)
527
+ init.copy_(module.original_inv_freq, buffer_value)
365
528
 
366
529
  def _update_causal_mask(self, module):
367
530
  raise AttributeError("No need to inherit it!")
@@ -634,6 +797,8 @@ class BltPatcher(BltPreTrainedModel):
634
797
  bias=False,
635
798
  )
636
799
 
800
+ self.post_init()
801
+
637
802
  def forward(
638
803
  self,
639
804
  input_ids: Optional[torch.LongTensor] = None,
@@ -251,10 +251,8 @@ class BridgeTowerImageProcessorFast(BaseImageProcessorFast):
251
251
  processed_images, processed_masks = self.pad(
252
252
  processed_images, return_mask=True, disable_grouping=disable_grouping
253
253
  )
254
- processed_masks = torch.stack(processed_masks, dim=0) if return_tensors else processed_masks
255
254
  data["pixel_mask"] = processed_masks
256
255
 
257
- processed_images = torch.stack(processed_images, dim=0) if return_tensors else processed_images
258
256
  data["pixel_values"] = processed_images
259
257
 
260
258
  return BatchFeature(data=data, tensor_type=return_tensors)
@@ -943,6 +943,11 @@ class BridgeTowerPreTrainedModel(PreTrainedModel):
943
943
  init.ones_(module.weight)
944
944
  elif isinstance(module, BridgeTowerForContrastiveLearning):
945
945
  init.constant_(module.logit_scale, self.config.logit_scale_init_value)
946
+ elif isinstance(module, BridgeTowerVisionEmbeddings):
947
+ init.copy_(module.position_ids, torch.arange(module.num_positions).expand((1, -1)))
948
+ elif isinstance(module, BridgeTowerTextEmbeddings):
949
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
950
+ init.zeros_(module.token_type_ids)
946
951
 
947
952
  if isinstance(module, (nn.Linear, BridgeTowerMLMHead)) and module.bias is not None:
948
953
  init.zeros_(module.bias)
@@ -955,6 +960,7 @@ class BridgeTowerVisionModel(BridgeTowerPreTrainedModel):
955
960
  def __init__(self, config):
956
961
  super().__init__(config)
957
962
  self.visual = BridgeTowerVisionTransformer(config)
963
+ self.post_init()
958
964
 
959
965
  @property
960
966
  def dtype(self):
@@ -522,6 +522,14 @@ class BrosPreTrainedModel(PreTrainedModel):
522
522
  std = self.config.initializer_range
523
523
  if isinstance(module, BrosRelationExtractor):
524
524
  init.normal_(module.dummy_node, std=std)
525
+ elif isinstance(module, BrosTextEmbeddings):
526
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
527
+ init.zeros_(module.token_type_ids)
528
+ elif isinstance(module, BrosPositionalEmbedding1D):
529
+ inv_freq = 1 / (
530
+ 10000 ** (torch.arange(0.0, module.dim_bbox_sinusoid_emb_1d, 2.0) / module.dim_bbox_sinusoid_emb_1d)
531
+ )
532
+ init.copy_(module.inv_freq, inv_freq)
525
533
 
526
534
 
527
535
  @auto_docstring
@@ -54,6 +54,112 @@ from .configuration_camembert import CamembertConfig
54
54
  logger = logging.get_logger(__name__)
55
55
 
56
56
 
57
+ class CamembertEmbeddings(nn.Module):
58
+ """Construct the embeddings from word, position and token_type embeddings."""
59
+
60
+ def __init__(self, config):
61
+ super().__init__()
62
+ self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
63
+ self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
64
+
65
+ self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
66
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
67
+ # position_ids (1, len position emb) is contiguous in memory and exported when serialized
68
+ self.register_buffer(
69
+ "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
70
+ )
71
+ self.register_buffer(
72
+ "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
73
+ )
74
+
75
+ self.padding_idx = config.pad_token_id
76
+ self.position_embeddings = nn.Embedding(
77
+ config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
78
+ )
79
+
80
+ def forward(
81
+ self,
82
+ input_ids: Optional[torch.LongTensor] = None,
83
+ token_type_ids: Optional[torch.LongTensor] = None,
84
+ position_ids: Optional[torch.LongTensor] = None,
85
+ inputs_embeds: Optional[torch.FloatTensor] = None,
86
+ past_key_values_length: int = 0,
87
+ ) -> torch.Tensor:
88
+ if position_ids is None:
89
+ if input_ids is not None:
90
+ # Create the position ids from the input token ids. Any padded tokens remain padded.
91
+ position_ids = self.create_position_ids_from_input_ids(
92
+ input_ids, self.padding_idx, past_key_values_length
93
+ )
94
+ else:
95
+ position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
96
+
97
+ if input_ids is not None:
98
+ input_shape = input_ids.size()
99
+ else:
100
+ input_shape = inputs_embeds.size()[:-1]
101
+
102
+ batch_size, seq_length = input_shape
103
+
104
+ # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
105
+ # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
106
+ # issue #5664
107
+ if token_type_ids is None:
108
+ if hasattr(self, "token_type_ids"):
109
+ # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
110
+ buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
111
+ buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
112
+ token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
113
+ else:
114
+ token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
115
+
116
+ if inputs_embeds is None:
117
+ inputs_embeds = self.word_embeddings(input_ids)
118
+ token_type_embeddings = self.token_type_embeddings(token_type_ids)
119
+ embeddings = inputs_embeds + token_type_embeddings
120
+
121
+ position_embeddings = self.position_embeddings(position_ids)
122
+ embeddings = embeddings + position_embeddings
123
+
124
+ embeddings = self.LayerNorm(embeddings)
125
+ embeddings = self.dropout(embeddings)
126
+ return embeddings
127
+
128
+ @staticmethod
129
+ def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
130
+ """
131
+ We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
132
+
133
+ Args:
134
+ inputs_embeds: torch.Tensor
135
+
136
+ Returns: torch.Tensor
137
+ """
138
+ input_shape = inputs_embeds.size()[:-1]
139
+ sequence_length = input_shape[1]
140
+
141
+ position_ids = torch.arange(
142
+ padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
143
+ )
144
+ return position_ids.unsqueeze(0).expand(input_shape)
145
+
146
+ @staticmethod
147
+ def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
148
+ """
149
+ Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
150
+ are ignored. This is modified from fairseq's `utils.make_positions`.
151
+
152
+ Args:
153
+ x: torch.Tensor x:
154
+
155
+ Returns: torch.Tensor
156
+ """
157
+ # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
158
+ mask = input_ids.ne(padding_idx).int()
159
+ incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
160
+ return incremental_indices.long() + padding_idx
161
+
162
+
57
163
  def eager_attention_forward(
58
164
  module: nn.Module,
59
165
  query: torch.Tensor,
@@ -417,112 +523,9 @@ class CamembertPreTrainedModel(PreTrainedModel):
417
523
  super()._init_weights(module)
418
524
  if isinstance(module, CamembertLMHead):
419
525
  init.zeros_(module.bias)
420
-
421
-
422
- class CamembertEmbeddings(nn.Module):
423
- """Construct the embeddings from word, position and token_type embeddings."""
424
-
425
- def __init__(self, config):
426
- super().__init__()
427
- self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
428
- self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
429
-
430
- self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
431
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
432
- # position_ids (1, len position emb) is contiguous in memory and exported when serialized
433
- self.register_buffer(
434
- "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
435
- )
436
- self.register_buffer(
437
- "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
438
- )
439
-
440
- self.padding_idx = config.pad_token_id
441
- self.position_embeddings = nn.Embedding(
442
- config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
443
- )
444
-
445
- def forward(
446
- self,
447
- input_ids: Optional[torch.LongTensor] = None,
448
- token_type_ids: Optional[torch.LongTensor] = None,
449
- position_ids: Optional[torch.LongTensor] = None,
450
- inputs_embeds: Optional[torch.FloatTensor] = None,
451
- past_key_values_length: int = 0,
452
- ) -> torch.Tensor:
453
- if position_ids is None:
454
- if input_ids is not None:
455
- # Create the position ids from the input token ids. Any padded tokens remain padded.
456
- position_ids = self.create_position_ids_from_input_ids(
457
- input_ids, self.padding_idx, past_key_values_length
458
- )
459
- else:
460
- position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds, self.padding_idx)
461
-
462
- if input_ids is not None:
463
- input_shape = input_ids.size()
464
- else:
465
- input_shape = inputs_embeds.size()[:-1]
466
-
467
- batch_size, seq_length = input_shape
468
-
469
- # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
470
- # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
471
- # issue #5664
472
- if token_type_ids is None:
473
- if hasattr(self, "token_type_ids"):
474
- # NOTE: We assume either pos ids to have bsz == 1 (broadcastable) or bsz == effective bsz (input_shape[0])
475
- buffered_token_type_ids = self.token_type_ids.expand(position_ids.shape[0], -1)
476
- buffered_token_type_ids = torch.gather(buffered_token_type_ids, dim=1, index=position_ids)
477
- token_type_ids = buffered_token_type_ids.expand(batch_size, seq_length)
478
- else:
479
- token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
480
-
481
- if inputs_embeds is None:
482
- inputs_embeds = self.word_embeddings(input_ids)
483
- token_type_embeddings = self.token_type_embeddings(token_type_ids)
484
- embeddings = inputs_embeds + token_type_embeddings
485
-
486
- position_embeddings = self.position_embeddings(position_ids)
487
- embeddings = embeddings + position_embeddings
488
-
489
- embeddings = self.LayerNorm(embeddings)
490
- embeddings = self.dropout(embeddings)
491
- return embeddings
492
-
493
- @staticmethod
494
- def create_position_ids_from_inputs_embeds(inputs_embeds, padding_idx):
495
- """
496
- We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
497
-
498
- Args:
499
- inputs_embeds: torch.Tensor
500
-
501
- Returns: torch.Tensor
502
- """
503
- input_shape = inputs_embeds.size()[:-1]
504
- sequence_length = input_shape[1]
505
-
506
- position_ids = torch.arange(
507
- padding_idx + 1, sequence_length + padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
508
- )
509
- return position_ids.unsqueeze(0).expand(input_shape)
510
-
511
- @staticmethod
512
- def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
513
- """
514
- Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
515
- are ignored. This is modified from fairseq's `utils.make_positions`.
516
-
517
- Args:
518
- x: torch.Tensor x:
519
-
520
- Returns: torch.Tensor
521
- """
522
- # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
523
- mask = input_ids.ne(padding_idx).int()
524
- incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
525
- return incremental_indices.long() + padding_idx
526
+ elif isinstance(module, CamembertEmbeddings):
527
+ init.copy_(module.position_ids, torch.arange(module.position_ids.shape[-1]).expand((1, -1)))
528
+ init.zeros_(module.token_type_ids)
526
529
 
527
530
 
528
531
  class CamembertEncoder(nn.Module):