transformers 5.0.0__py3-none-any.whl → 5.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1606) hide show
  1. transformers/__init__.py +36 -55
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +33 -32
  4. transformers/cache_utils.py +139 -32
  5. transformers/cli/chat.py +3 -3
  6. transformers/cli/serve.py +19 -49
  7. transformers/cli/transformers.py +1 -2
  8. transformers/configuration_utils.py +155 -129
  9. transformers/conversion_mapping.py +22 -158
  10. transformers/convert_slow_tokenizer.py +17 -227
  11. transformers/core_model_loading.py +185 -528
  12. transformers/data/data_collator.py +4 -12
  13. transformers/data/processors/glue.py +1 -0
  14. transformers/data/processors/utils.py +1 -0
  15. transformers/data/processors/xnli.py +1 -0
  16. transformers/dependency_versions_check.py +1 -0
  17. transformers/dependency_versions_table.py +7 -5
  18. transformers/distributed/configuration_utils.py +2 -1
  19. transformers/dynamic_module_utils.py +25 -24
  20. transformers/feature_extraction_sequence_utils.py +23 -19
  21. transformers/feature_extraction_utils.py +33 -64
  22. transformers/file_utils.py +1 -0
  23. transformers/generation/__init__.py +1 -11
  24. transformers/generation/candidate_generator.py +33 -80
  25. transformers/generation/configuration_utils.py +133 -189
  26. transformers/generation/continuous_batching/__init__.py +1 -4
  27. transformers/generation/continuous_batching/cache.py +25 -83
  28. transformers/generation/continuous_batching/cache_manager.py +45 -155
  29. transformers/generation/continuous_batching/continuous_api.py +147 -270
  30. transformers/generation/continuous_batching/requests.py +3 -51
  31. transformers/generation/continuous_batching/scheduler.py +105 -160
  32. transformers/generation/logits_process.py +128 -0
  33. transformers/generation/stopping_criteria.py +1 -1
  34. transformers/generation/streamers.py +1 -0
  35. transformers/generation/utils.py +123 -122
  36. transformers/generation/watermarking.py +6 -8
  37. transformers/hf_argparser.py +13 -9
  38. transformers/hyperparameter_search.py +2 -1
  39. transformers/image_processing_base.py +23 -12
  40. transformers/image_processing_utils.py +15 -11
  41. transformers/image_processing_utils_fast.py +75 -85
  42. transformers/image_transforms.py +42 -73
  43. transformers/image_utils.py +32 -30
  44. transformers/initialization.py +0 -37
  45. transformers/integrations/__init__.py +2 -16
  46. transformers/integrations/accelerate.py +113 -58
  47. transformers/integrations/aqlm.py +66 -36
  48. transformers/integrations/awq.py +516 -45
  49. transformers/integrations/bitnet.py +105 -47
  50. transformers/integrations/bitsandbytes.py +202 -91
  51. transformers/integrations/deepspeed.py +4 -161
  52. transformers/integrations/eetq.py +82 -84
  53. transformers/integrations/executorch.py +1 -1
  54. transformers/integrations/fbgemm_fp8.py +145 -190
  55. transformers/integrations/finegrained_fp8.py +215 -249
  56. transformers/integrations/flash_attention.py +3 -3
  57. transformers/integrations/flex_attention.py +1 -1
  58. transformers/integrations/fp_quant.py +0 -90
  59. transformers/integrations/ggml.py +2 -11
  60. transformers/integrations/higgs.py +62 -37
  61. transformers/integrations/hub_kernels.py +8 -65
  62. transformers/integrations/integration_utils.py +3 -47
  63. transformers/integrations/mistral.py +0 -12
  64. transformers/integrations/mxfp4.py +80 -33
  65. transformers/integrations/peft.py +191 -483
  66. transformers/integrations/quanto.py +56 -77
  67. transformers/integrations/spqr.py +90 -42
  68. transformers/integrations/tensor_parallel.py +221 -167
  69. transformers/integrations/torchao.py +43 -35
  70. transformers/integrations/vptq.py +59 -40
  71. transformers/kernels/__init__.py +0 -0
  72. transformers/{models/pe_audio_video/processing_pe_audio_video.py → kernels/falcon_mamba/__init__.py} +3 -12
  73. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +529 -0
  74. transformers/loss/loss_utils.py +0 -2
  75. transformers/masking_utils.py +55 -51
  76. transformers/model_debugging_utils.py +5 -4
  77. transformers/modelcard.py +194 -15
  78. transformers/modeling_attn_mask_utils.py +19 -19
  79. transformers/modeling_flash_attention_utils.py +27 -27
  80. transformers/modeling_gguf_pytorch_utils.py +24 -79
  81. transformers/modeling_layers.py +22 -21
  82. transformers/modeling_outputs.py +253 -242
  83. transformers/modeling_rope_utils.py +117 -138
  84. transformers/modeling_utils.py +739 -850
  85. transformers/models/__init__.py +0 -27
  86. transformers/models/afmoe/configuration_afmoe.py +33 -40
  87. transformers/models/afmoe/modeling_afmoe.py +54 -42
  88. transformers/models/afmoe/modular_afmoe.py +33 -23
  89. transformers/models/aimv2/configuration_aimv2.py +10 -2
  90. transformers/models/aimv2/modeling_aimv2.py +42 -47
  91. transformers/models/aimv2/modular_aimv2.py +19 -17
  92. transformers/models/albert/configuration_albert.py +2 -8
  93. transformers/models/albert/modeling_albert.py +69 -70
  94. transformers/models/albert/tokenization_albert.py +14 -5
  95. transformers/models/align/configuration_align.py +6 -8
  96. transformers/models/align/modeling_align.py +89 -94
  97. transformers/models/align/processing_align.py +30 -2
  98. transformers/models/altclip/configuration_altclip.py +7 -4
  99. transformers/models/altclip/modeling_altclip.py +103 -114
  100. transformers/models/altclip/processing_altclip.py +15 -2
  101. transformers/models/apertus/__init__.py +1 -0
  102. transformers/models/apertus/configuration_apertus.py +28 -23
  103. transformers/models/apertus/modeling_apertus.py +40 -39
  104. transformers/models/apertus/modular_apertus.py +38 -37
  105. transformers/models/arcee/configuration_arcee.py +30 -25
  106. transformers/models/arcee/modeling_arcee.py +39 -36
  107. transformers/models/arcee/modular_arcee.py +23 -20
  108. transformers/models/aria/configuration_aria.py +44 -31
  109. transformers/models/aria/image_processing_aria.py +27 -25
  110. transformers/models/aria/modeling_aria.py +106 -110
  111. transformers/models/aria/modular_aria.py +127 -118
  112. transformers/models/aria/processing_aria.py +35 -28
  113. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +1 -0
  114. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +6 -3
  115. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +8 -6
  116. transformers/models/audioflamingo3/__init__.py +1 -0
  117. transformers/models/audioflamingo3/configuration_audioflamingo3.py +1 -0
  118. transformers/models/audioflamingo3/modeling_audioflamingo3.py +49 -58
  119. transformers/models/audioflamingo3/modular_audioflamingo3.py +43 -53
  120. transformers/models/audioflamingo3/processing_audioflamingo3.py +30 -33
  121. transformers/models/auto/auto_factory.py +7 -6
  122. transformers/models/auto/configuration_auto.py +5 -66
  123. transformers/models/auto/feature_extraction_auto.py +10 -14
  124. transformers/models/auto/image_processing_auto.py +41 -32
  125. transformers/models/auto/modeling_auto.py +188 -46
  126. transformers/models/auto/processing_auto.py +11 -24
  127. transformers/models/auto/tokenization_auto.py +588 -171
  128. transformers/models/auto/video_processing_auto.py +10 -12
  129. transformers/models/autoformer/configuration_autoformer.py +7 -4
  130. transformers/models/autoformer/modeling_autoformer.py +101 -104
  131. transformers/models/aya_vision/configuration_aya_vision.py +1 -4
  132. transformers/models/aya_vision/modeling_aya_vision.py +102 -71
  133. transformers/models/aya_vision/modular_aya_vision.py +74 -46
  134. transformers/models/aya_vision/processing_aya_vision.py +53 -25
  135. transformers/models/bamba/configuration_bamba.py +39 -34
  136. transformers/models/bamba/modeling_bamba.py +86 -82
  137. transformers/models/bamba/modular_bamba.py +72 -70
  138. transformers/models/bark/configuration_bark.py +8 -6
  139. transformers/models/bark/generation_configuration_bark.py +5 -3
  140. transformers/models/bark/modeling_bark.py +57 -54
  141. transformers/models/bark/processing_bark.py +41 -19
  142. transformers/models/bart/configuration_bart.py +6 -9
  143. transformers/models/bart/modeling_bart.py +126 -135
  144. transformers/models/barthez/tokenization_barthez.py +11 -3
  145. transformers/models/bartpho/tokenization_bartpho.py +7 -6
  146. transformers/models/beit/configuration_beit.py +11 -0
  147. transformers/models/beit/image_processing_beit.py +56 -53
  148. transformers/models/beit/image_processing_beit_fast.py +12 -10
  149. transformers/models/beit/modeling_beit.py +60 -69
  150. transformers/models/bert/configuration_bert.py +2 -12
  151. transformers/models/bert/modeling_bert.py +122 -114
  152. transformers/models/bert/tokenization_bert.py +23 -8
  153. transformers/models/bert/tokenization_bert_legacy.py +5 -3
  154. transformers/models/bert_generation/configuration_bert_generation.py +2 -17
  155. transformers/models/bert_generation/modeling_bert_generation.py +49 -49
  156. transformers/models/bert_generation/tokenization_bert_generation.py +3 -2
  157. transformers/models/bert_japanese/tokenization_bert_japanese.py +6 -5
  158. transformers/models/bertweet/tokenization_bertweet.py +3 -1
  159. transformers/models/big_bird/configuration_big_bird.py +9 -12
  160. transformers/models/big_bird/modeling_big_bird.py +109 -116
  161. transformers/models/big_bird/tokenization_big_bird.py +43 -16
  162. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  163. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +117 -130
  164. transformers/models/biogpt/configuration_biogpt.py +2 -8
  165. transformers/models/biogpt/modeling_biogpt.py +76 -72
  166. transformers/models/biogpt/modular_biogpt.py +66 -62
  167. transformers/models/biogpt/tokenization_biogpt.py +5 -3
  168. transformers/models/bit/configuration_bit.py +1 -0
  169. transformers/models/bit/image_processing_bit.py +24 -21
  170. transformers/models/bit/image_processing_bit_fast.py +1 -0
  171. transformers/models/bit/modeling_bit.py +12 -25
  172. transformers/models/bitnet/configuration_bitnet.py +28 -23
  173. transformers/models/bitnet/modeling_bitnet.py +39 -36
  174. transformers/models/bitnet/modular_bitnet.py +6 -4
  175. transformers/models/blenderbot/configuration_blenderbot.py +5 -8
  176. transformers/models/blenderbot/modeling_blenderbot.py +96 -77
  177. transformers/models/blenderbot/tokenization_blenderbot.py +24 -18
  178. transformers/models/blenderbot_small/configuration_blenderbot_small.py +5 -8
  179. transformers/models/blenderbot_small/modeling_blenderbot_small.py +69 -79
  180. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +3 -1
  181. transformers/models/blip/configuration_blip.py +10 -9
  182. transformers/models/blip/image_processing_blip.py +20 -17
  183. transformers/models/blip/image_processing_blip_fast.py +1 -0
  184. transformers/models/blip/modeling_blip.py +108 -117
  185. transformers/models/blip/modeling_blip_text.py +65 -73
  186. transformers/models/blip/processing_blip.py +36 -5
  187. transformers/models/blip_2/configuration_blip_2.py +2 -2
  188. transformers/models/blip_2/modeling_blip_2.py +118 -146
  189. transformers/models/blip_2/processing_blip_2.py +38 -8
  190. transformers/models/bloom/configuration_bloom.py +2 -5
  191. transformers/models/bloom/modeling_bloom.py +104 -77
  192. transformers/models/blt/configuration_blt.py +86 -94
  193. transformers/models/blt/modeling_blt.py +81 -238
  194. transformers/models/blt/modular_blt.py +65 -228
  195. transformers/models/bridgetower/configuration_bridgetower.py +2 -7
  196. transformers/models/bridgetower/image_processing_bridgetower.py +35 -34
  197. transformers/models/bridgetower/image_processing_bridgetower_fast.py +16 -13
  198. transformers/models/bridgetower/modeling_bridgetower.py +119 -141
  199. transformers/models/bridgetower/processing_bridgetower.py +16 -2
  200. transformers/models/bros/configuration_bros.py +18 -24
  201. transformers/models/bros/modeling_bros.py +80 -90
  202. transformers/models/bros/processing_bros.py +12 -2
  203. transformers/models/byt5/tokenization_byt5.py +6 -4
  204. transformers/models/camembert/configuration_camembert.py +2 -8
  205. transformers/models/camembert/modeling_camembert.py +195 -196
  206. transformers/models/camembert/modular_camembert.py +54 -51
  207. transformers/models/camembert/tokenization_camembert.py +13 -6
  208. transformers/models/canine/configuration_canine.py +2 -4
  209. transformers/models/canine/modeling_canine.py +75 -84
  210. transformers/models/canine/tokenization_canine.py +1 -2
  211. transformers/models/chameleon/configuration_chameleon.py +34 -29
  212. transformers/models/chameleon/image_processing_chameleon.py +24 -21
  213. transformers/models/chameleon/image_processing_chameleon_fast.py +6 -5
  214. transformers/models/chameleon/modeling_chameleon.py +93 -142
  215. transformers/models/chameleon/processing_chameleon.py +41 -16
  216. transformers/models/chinese_clip/configuration_chinese_clip.py +8 -10
  217. transformers/models/chinese_clip/image_processing_chinese_clip.py +24 -21
  218. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +1 -0
  219. transformers/models/chinese_clip/modeling_chinese_clip.py +92 -96
  220. transformers/models/chinese_clip/processing_chinese_clip.py +15 -2
  221. transformers/models/clap/configuration_clap.py +9 -4
  222. transformers/models/clap/feature_extraction_clap.py +12 -11
  223. transformers/models/clap/modeling_clap.py +123 -136
  224. transformers/models/clap/processing_clap.py +15 -2
  225. transformers/models/clip/configuration_clip.py +2 -4
  226. transformers/models/clip/image_processing_clip.py +24 -21
  227. transformers/models/clip/image_processing_clip_fast.py +1 -9
  228. transformers/models/clip/modeling_clip.py +65 -65
  229. transformers/models/clip/processing_clip.py +14 -2
  230. transformers/models/clip/tokenization_clip.py +46 -21
  231. transformers/models/clipseg/configuration_clipseg.py +2 -4
  232. transformers/models/clipseg/modeling_clipseg.py +109 -119
  233. transformers/models/clipseg/processing_clipseg.py +42 -19
  234. transformers/models/clvp/configuration_clvp.py +5 -15
  235. transformers/models/clvp/feature_extraction_clvp.py +10 -7
  236. transformers/models/clvp/modeling_clvp.py +146 -155
  237. transformers/models/clvp/number_normalizer.py +2 -1
  238. transformers/models/clvp/processing_clvp.py +20 -3
  239. transformers/models/clvp/tokenization_clvp.py +64 -1
  240. transformers/models/code_llama/tokenization_code_llama.py +44 -18
  241. transformers/models/codegen/configuration_codegen.py +4 -4
  242. transformers/models/codegen/modeling_codegen.py +53 -63
  243. transformers/models/codegen/tokenization_codegen.py +47 -17
  244. transformers/models/cohere/configuration_cohere.py +30 -25
  245. transformers/models/cohere/modeling_cohere.py +42 -40
  246. transformers/models/cohere/modular_cohere.py +29 -26
  247. transformers/models/cohere/tokenization_cohere.py +46 -15
  248. transformers/models/cohere2/configuration_cohere2.py +32 -31
  249. transformers/models/cohere2/modeling_cohere2.py +44 -42
  250. transformers/models/cohere2/modular_cohere2.py +54 -54
  251. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +14 -13
  252. transformers/models/cohere2_vision/modeling_cohere2_vision.py +58 -59
  253. transformers/models/cohere2_vision/modular_cohere2_vision.py +46 -45
  254. transformers/models/cohere2_vision/processing_cohere2_vision.py +36 -6
  255. transformers/models/colpali/configuration_colpali.py +1 -0
  256. transformers/models/colpali/modeling_colpali.py +16 -14
  257. transformers/models/colpali/modular_colpali.py +51 -11
  258. transformers/models/colpali/processing_colpali.py +52 -14
  259. transformers/models/colqwen2/modeling_colqwen2.py +28 -28
  260. transformers/models/colqwen2/modular_colqwen2.py +74 -37
  261. transformers/models/colqwen2/processing_colqwen2.py +52 -16
  262. transformers/models/conditional_detr/configuration_conditional_detr.py +2 -1
  263. transformers/models/conditional_detr/image_processing_conditional_detr.py +70 -67
  264. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +36 -36
  265. transformers/models/conditional_detr/modeling_conditional_detr.py +87 -99
  266. transformers/models/conditional_detr/modular_conditional_detr.py +3 -49
  267. transformers/models/convbert/configuration_convbert.py +8 -11
  268. transformers/models/convbert/modeling_convbert.py +87 -94
  269. transformers/models/convbert/tokenization_convbert.py +1 -0
  270. transformers/models/convnext/configuration_convnext.py +1 -0
  271. transformers/models/convnext/image_processing_convnext.py +23 -20
  272. transformers/models/convnext/image_processing_convnext_fast.py +21 -16
  273. transformers/models/convnext/modeling_convnext.py +12 -9
  274. transformers/models/convnextv2/configuration_convnextv2.py +1 -0
  275. transformers/models/convnextv2/modeling_convnextv2.py +12 -9
  276. transformers/models/cpm/tokenization_cpm.py +7 -6
  277. transformers/models/cpm/tokenization_cpm_fast.py +5 -3
  278. transformers/models/cpmant/configuration_cpmant.py +1 -4
  279. transformers/models/cpmant/modeling_cpmant.py +40 -38
  280. transformers/models/cpmant/tokenization_cpmant.py +3 -1
  281. transformers/models/csm/configuration_csm.py +66 -58
  282. transformers/models/csm/generation_csm.py +35 -31
  283. transformers/models/csm/modeling_csm.py +85 -85
  284. transformers/models/csm/modular_csm.py +58 -58
  285. transformers/models/csm/processing_csm.py +68 -25
  286. transformers/models/ctrl/configuration_ctrl.py +1 -16
  287. transformers/models/ctrl/modeling_ctrl.py +44 -54
  288. transformers/models/ctrl/tokenization_ctrl.py +1 -0
  289. transformers/models/cvt/configuration_cvt.py +1 -0
  290. transformers/models/cvt/modeling_cvt.py +16 -20
  291. transformers/models/cwm/__init__.py +1 -0
  292. transformers/models/cwm/configuration_cwm.py +12 -8
  293. transformers/models/cwm/modeling_cwm.py +39 -37
  294. transformers/models/cwm/modular_cwm.py +12 -10
  295. transformers/models/d_fine/configuration_d_fine.py +5 -7
  296. transformers/models/d_fine/modeling_d_fine.py +128 -138
  297. transformers/models/d_fine/modular_d_fine.py +18 -33
  298. transformers/models/dab_detr/configuration_dab_detr.py +3 -6
  299. transformers/models/dab_detr/modeling_dab_detr.py +75 -81
  300. transformers/models/dac/configuration_dac.py +1 -0
  301. transformers/models/dac/feature_extraction_dac.py +9 -6
  302. transformers/models/dac/modeling_dac.py +26 -24
  303. transformers/models/data2vec/configuration_data2vec_audio.py +2 -4
  304. transformers/models/data2vec/configuration_data2vec_text.py +3 -11
  305. transformers/models/data2vec/configuration_data2vec_vision.py +1 -0
  306. transformers/models/data2vec/modeling_data2vec_audio.py +56 -57
  307. transformers/models/data2vec/modeling_data2vec_text.py +93 -98
  308. transformers/models/data2vec/modeling_data2vec_vision.py +45 -49
  309. transformers/models/data2vec/modular_data2vec_audio.py +1 -6
  310. transformers/models/data2vec/modular_data2vec_text.py +54 -58
  311. transformers/models/dbrx/configuration_dbrx.py +22 -36
  312. transformers/models/dbrx/modeling_dbrx.py +45 -42
  313. transformers/models/dbrx/modular_dbrx.py +33 -31
  314. transformers/models/deberta/configuration_deberta.py +1 -6
  315. transformers/models/deberta/modeling_deberta.py +60 -64
  316. transformers/models/deberta/tokenization_deberta.py +21 -9
  317. transformers/models/deberta_v2/configuration_deberta_v2.py +1 -6
  318. transformers/models/deberta_v2/modeling_deberta_v2.py +65 -71
  319. transformers/models/deberta_v2/tokenization_deberta_v2.py +29 -11
  320. transformers/models/decision_transformer/configuration_decision_transformer.py +2 -3
  321. transformers/models/decision_transformer/modeling_decision_transformer.py +56 -60
  322. transformers/models/deepseek_v2/configuration_deepseek_v2.py +44 -39
  323. transformers/models/deepseek_v2/modeling_deepseek_v2.py +43 -43
  324. transformers/models/deepseek_v2/modular_deepseek_v2.py +49 -48
  325. transformers/models/deepseek_v3/configuration_deepseek_v3.py +45 -40
  326. transformers/models/deepseek_v3/modeling_deepseek_v3.py +42 -45
  327. transformers/models/deepseek_v3/modular_deepseek_v3.py +9 -14
  328. transformers/models/deepseek_vl/configuration_deepseek_vl.py +3 -2
  329. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +26 -25
  330. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +10 -10
  331. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -57
  332. transformers/models/deepseek_vl/modular_deepseek_vl.py +43 -14
  333. transformers/models/deepseek_vl/processing_deepseek_vl.py +41 -10
  334. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +5 -3
  335. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +35 -35
  336. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +24 -20
  337. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +61 -109
  338. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +118 -146
  339. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +44 -12
  340. transformers/models/deformable_detr/configuration_deformable_detr.py +3 -2
  341. transformers/models/deformable_detr/image_processing_deformable_detr.py +61 -59
  342. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +28 -28
  343. transformers/models/deformable_detr/modeling_deformable_detr.py +82 -88
  344. transformers/models/deformable_detr/modular_deformable_detr.py +3 -1
  345. transformers/models/deit/configuration_deit.py +1 -0
  346. transformers/models/deit/image_processing_deit.py +21 -18
  347. transformers/models/deit/image_processing_deit_fast.py +1 -0
  348. transformers/models/deit/modeling_deit.py +22 -24
  349. transformers/models/depth_anything/configuration_depth_anything.py +4 -2
  350. transformers/models/depth_anything/modeling_depth_anything.py +10 -10
  351. transformers/models/depth_pro/configuration_depth_pro.py +1 -0
  352. transformers/models/depth_pro/image_processing_depth_pro.py +23 -22
  353. transformers/models/depth_pro/image_processing_depth_pro_fast.py +10 -8
  354. transformers/models/depth_pro/modeling_depth_pro.py +27 -31
  355. transformers/models/detr/configuration_detr.py +2 -1
  356. transformers/models/detr/image_processing_detr.py +66 -64
  357. transformers/models/detr/image_processing_detr_fast.py +34 -33
  358. transformers/models/detr/modeling_detr.py +79 -95
  359. transformers/models/dia/configuration_dia.py +15 -9
  360. transformers/models/dia/feature_extraction_dia.py +9 -6
  361. transformers/models/dia/generation_dia.py +50 -48
  362. transformers/models/dia/modeling_dia.py +69 -78
  363. transformers/models/dia/modular_dia.py +56 -64
  364. transformers/models/dia/processing_dia.py +29 -39
  365. transformers/models/dia/tokenization_dia.py +6 -3
  366. transformers/models/diffllama/configuration_diffllama.py +30 -25
  367. transformers/models/diffllama/modeling_diffllama.py +49 -46
  368. transformers/models/diffllama/modular_diffllama.py +19 -17
  369. transformers/models/dinat/configuration_dinat.py +1 -0
  370. transformers/models/dinat/modeling_dinat.py +44 -47
  371. transformers/models/dinov2/configuration_dinov2.py +1 -0
  372. transformers/models/dinov2/modeling_dinov2.py +15 -15
  373. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +1 -1
  374. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +15 -16
  375. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +9 -9
  376. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +7 -4
  377. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +6 -3
  378. transformers/models/dinov3_vit/configuration_dinov3_vit.py +8 -5
  379. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +9 -7
  380. transformers/models/dinov3_vit/modeling_dinov3_vit.py +18 -19
  381. transformers/models/dinov3_vit/modular_dinov3_vit.py +15 -16
  382. transformers/models/distilbert/configuration_distilbert.py +2 -8
  383. transformers/models/distilbert/modeling_distilbert.py +55 -55
  384. transformers/models/distilbert/tokenization_distilbert.py +1 -13
  385. transformers/models/doge/__init__.py +1 -0
  386. transformers/models/doge/configuration_doge.py +32 -39
  387. transformers/models/doge/modeling_doge.py +49 -45
  388. transformers/models/doge/modular_doge.py +63 -71
  389. transformers/models/donut/configuration_donut_swin.py +1 -0
  390. transformers/models/donut/image_processing_donut.py +29 -26
  391. transformers/models/donut/image_processing_donut_fast.py +15 -9
  392. transformers/models/donut/modeling_donut_swin.py +58 -62
  393. transformers/models/donut/processing_donut.py +26 -5
  394. transformers/models/dots1/configuration_dots1.py +33 -41
  395. transformers/models/dots1/modeling_dots1.py +45 -54
  396. transformers/models/dots1/modular_dots1.py +4 -5
  397. transformers/models/dpr/configuration_dpr.py +2 -19
  398. transformers/models/dpr/modeling_dpr.py +39 -42
  399. transformers/models/dpr/tokenization_dpr.py +9 -19
  400. transformers/models/dpr/tokenization_dpr_fast.py +9 -7
  401. transformers/models/dpt/configuration_dpt.py +2 -1
  402. transformers/models/dpt/image_processing_dpt.py +66 -65
  403. transformers/models/dpt/image_processing_dpt_fast.py +20 -18
  404. transformers/models/dpt/modeling_dpt.py +30 -32
  405. transformers/models/dpt/modular_dpt.py +17 -15
  406. transformers/models/edgetam/configuration_edgetam.py +3 -2
  407. transformers/models/edgetam/modeling_edgetam.py +86 -86
  408. transformers/models/edgetam/modular_edgetam.py +26 -21
  409. transformers/models/edgetam_video/__init__.py +1 -0
  410. transformers/models/edgetam_video/configuration_edgetam_video.py +1 -0
  411. transformers/models/edgetam_video/modeling_edgetam_video.py +158 -169
  412. transformers/models/edgetam_video/modular_edgetam_video.py +37 -30
  413. transformers/models/efficientloftr/configuration_efficientloftr.py +5 -4
  414. transformers/models/efficientloftr/image_processing_efficientloftr.py +16 -14
  415. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +9 -9
  416. transformers/models/efficientloftr/modeling_efficientloftr.py +38 -59
  417. transformers/models/efficientloftr/modular_efficientloftr.py +3 -1
  418. transformers/models/efficientnet/configuration_efficientnet.py +1 -0
  419. transformers/models/efficientnet/image_processing_efficientnet.py +32 -28
  420. transformers/models/efficientnet/image_processing_efficientnet_fast.py +19 -17
  421. transformers/models/efficientnet/modeling_efficientnet.py +15 -19
  422. transformers/models/electra/configuration_electra.py +3 -13
  423. transformers/models/electra/modeling_electra.py +103 -108
  424. transformers/models/emu3/configuration_emu3.py +17 -13
  425. transformers/models/emu3/image_processing_emu3.py +39 -44
  426. transformers/models/emu3/modeling_emu3.py +108 -148
  427. transformers/models/emu3/modular_emu3.py +73 -115
  428. transformers/models/emu3/processing_emu3.py +43 -18
  429. transformers/models/encodec/configuration_encodec.py +4 -2
  430. transformers/models/encodec/feature_extraction_encodec.py +13 -10
  431. transformers/models/encodec/modeling_encodec.py +29 -39
  432. transformers/models/encoder_decoder/configuration_encoder_decoder.py +2 -12
  433. transformers/models/encoder_decoder/modeling_encoder_decoder.py +43 -37
  434. transformers/models/eomt/configuration_eomt.py +1 -0
  435. transformers/models/eomt/image_processing_eomt.py +56 -66
  436. transformers/models/eomt/image_processing_eomt_fast.py +33 -76
  437. transformers/models/eomt/modeling_eomt.py +18 -23
  438. transformers/models/eomt/modular_eomt.py +13 -18
  439. transformers/models/ernie/configuration_ernie.py +3 -24
  440. transformers/models/ernie/modeling_ernie.py +132 -127
  441. transformers/models/ernie/modular_ernie.py +103 -97
  442. transformers/models/ernie4_5/configuration_ernie4_5.py +27 -23
  443. transformers/models/ernie4_5/modeling_ernie4_5.py +38 -36
  444. transformers/models/ernie4_5/modular_ernie4_5.py +4 -3
  445. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +36 -32
  446. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +55 -56
  447. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +46 -18
  448. transformers/models/esm/configuration_esm.py +15 -11
  449. transformers/models/esm/modeling_esm.py +34 -38
  450. transformers/models/esm/modeling_esmfold.py +49 -53
  451. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  452. transformers/models/esm/openfold_utils/loss.py +2 -1
  453. transformers/models/esm/openfold_utils/protein.py +16 -15
  454. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  455. transformers/models/esm/tokenization_esm.py +4 -2
  456. transformers/models/evolla/configuration_evolla.py +40 -50
  457. transformers/models/evolla/modeling_evolla.py +66 -71
  458. transformers/models/evolla/modular_evolla.py +47 -53
  459. transformers/models/evolla/processing_evolla.py +35 -23
  460. transformers/models/exaone4/configuration_exaone4.py +25 -23
  461. transformers/models/exaone4/modeling_exaone4.py +38 -35
  462. transformers/models/exaone4/modular_exaone4.py +46 -44
  463. transformers/models/falcon/configuration_falcon.py +26 -31
  464. transformers/models/falcon/modeling_falcon.py +80 -82
  465. transformers/models/falcon_h1/configuration_falcon_h1.py +51 -45
  466. transformers/models/falcon_h1/modeling_falcon_h1.py +82 -85
  467. transformers/models/falcon_h1/modular_falcon_h1.py +51 -56
  468. transformers/models/falcon_mamba/configuration_falcon_mamba.py +2 -1
  469. transformers/models/falcon_mamba/modeling_falcon_mamba.py +82 -75
  470. transformers/models/falcon_mamba/modular_falcon_mamba.py +45 -28
  471. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +6 -2
  472. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +60 -76
  473. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +3 -2
  474. transformers/models/flaubert/configuration_flaubert.py +5 -10
  475. transformers/models/flaubert/modeling_flaubert.py +143 -145
  476. transformers/models/flaubert/tokenization_flaubert.py +5 -3
  477. transformers/models/flava/configuration_flava.py +6 -5
  478. transformers/models/flava/image_processing_flava.py +67 -66
  479. transformers/models/flava/image_processing_flava_fast.py +49 -46
  480. transformers/models/flava/modeling_flava.py +136 -153
  481. transformers/models/flava/processing_flava.py +12 -2
  482. transformers/models/flex_olmo/__init__.py +1 -0
  483. transformers/models/flex_olmo/configuration_flex_olmo.py +32 -28
  484. transformers/models/flex_olmo/modeling_flex_olmo.py +47 -47
  485. transformers/models/flex_olmo/modular_flex_olmo.py +44 -40
  486. transformers/models/florence2/configuration_florence2.py +1 -0
  487. transformers/models/florence2/modeling_florence2.py +69 -111
  488. transformers/models/florence2/modular_florence2.py +101 -104
  489. transformers/models/florence2/processing_florence2.py +47 -18
  490. transformers/models/fnet/configuration_fnet.py +2 -6
  491. transformers/models/fnet/modeling_fnet.py +80 -83
  492. transformers/models/fnet/tokenization_fnet.py +1 -0
  493. transformers/models/focalnet/configuration_focalnet.py +1 -0
  494. transformers/models/focalnet/modeling_focalnet.py +45 -51
  495. transformers/models/fsmt/configuration_fsmt.py +17 -12
  496. transformers/models/fsmt/modeling_fsmt.py +48 -49
  497. transformers/models/fsmt/tokenization_fsmt.py +5 -3
  498. transformers/models/funnel/configuration_funnel.py +1 -8
  499. transformers/models/funnel/modeling_funnel.py +93 -99
  500. transformers/models/funnel/tokenization_funnel.py +27 -17
  501. transformers/models/fuyu/configuration_fuyu.py +34 -28
  502. transformers/models/fuyu/image_processing_fuyu.py +31 -29
  503. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  504. transformers/models/fuyu/modeling_fuyu.py +53 -53
  505. transformers/models/fuyu/processing_fuyu.py +34 -23
  506. transformers/models/gemma/configuration_gemma.py +30 -25
  507. transformers/models/gemma/modeling_gemma.py +50 -46
  508. transformers/models/gemma/modular_gemma.py +47 -42
  509. transformers/models/gemma/tokenization_gemma.py +30 -10
  510. transformers/models/gemma2/configuration_gemma2.py +35 -30
  511. transformers/models/gemma2/modeling_gemma2.py +42 -39
  512. transformers/models/gemma2/modular_gemma2.py +66 -63
  513. transformers/models/gemma3/configuration_gemma3.py +44 -44
  514. transformers/models/gemma3/image_processing_gemma3.py +31 -29
  515. transformers/models/gemma3/image_processing_gemma3_fast.py +13 -11
  516. transformers/models/gemma3/modeling_gemma3.py +207 -159
  517. transformers/models/gemma3/modular_gemma3.py +204 -153
  518. transformers/models/gemma3/processing_gemma3.py +5 -5
  519. transformers/models/gemma3n/configuration_gemma3n.py +26 -36
  520. transformers/models/gemma3n/feature_extraction_gemma3n.py +11 -9
  521. transformers/models/gemma3n/modeling_gemma3n.py +356 -222
  522. transformers/models/gemma3n/modular_gemma3n.py +207 -230
  523. transformers/models/gemma3n/processing_gemma3n.py +26 -12
  524. transformers/models/git/configuration_git.py +8 -5
  525. transformers/models/git/modeling_git.py +204 -266
  526. transformers/models/git/processing_git.py +14 -2
  527. transformers/models/glm/configuration_glm.py +28 -24
  528. transformers/models/glm/modeling_glm.py +40 -37
  529. transformers/models/glm/modular_glm.py +7 -4
  530. transformers/models/glm4/configuration_glm4.py +28 -24
  531. transformers/models/glm4/modeling_glm4.py +42 -40
  532. transformers/models/glm4/modular_glm4.py +10 -8
  533. transformers/models/glm46v/configuration_glm46v.py +1 -0
  534. transformers/models/glm46v/image_processing_glm46v.py +40 -35
  535. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  536. transformers/models/glm46v/modeling_glm46v.py +90 -137
  537. transformers/models/glm46v/modular_glm46v.py +3 -4
  538. transformers/models/glm46v/processing_glm46v.py +41 -7
  539. transformers/models/glm46v/video_processing_glm46v.py +11 -9
  540. transformers/models/glm4_moe/configuration_glm4_moe.py +32 -40
  541. transformers/models/glm4_moe/modeling_glm4_moe.py +42 -45
  542. transformers/models/glm4_moe/modular_glm4_moe.py +34 -42
  543. transformers/models/glm4v/configuration_glm4v.py +20 -18
  544. transformers/models/glm4v/image_processing_glm4v.py +40 -34
  545. transformers/models/glm4v/image_processing_glm4v_fast.py +9 -8
  546. transformers/models/glm4v/modeling_glm4v.py +205 -254
  547. transformers/models/glm4v/modular_glm4v.py +224 -210
  548. transformers/models/glm4v/processing_glm4v.py +41 -7
  549. transformers/models/glm4v/video_processing_glm4v.py +11 -9
  550. transformers/models/glm4v_moe/configuration_glm4v_moe.py +125 -136
  551. transformers/models/glm4v_moe/modeling_glm4v_moe.py +368 -377
  552. transformers/models/glm4v_moe/modular_glm4v_moe.py +169 -83
  553. transformers/models/glpn/configuration_glpn.py +1 -0
  554. transformers/models/glpn/image_processing_glpn.py +12 -11
  555. transformers/models/glpn/image_processing_glpn_fast.py +13 -11
  556. transformers/models/glpn/modeling_glpn.py +14 -16
  557. transformers/models/got_ocr2/configuration_got_ocr2.py +12 -4
  558. transformers/models/got_ocr2/image_processing_got_ocr2.py +24 -22
  559. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +11 -9
  560. transformers/models/got_ocr2/modeling_got_ocr2.py +80 -77
  561. transformers/models/got_ocr2/modular_got_ocr2.py +51 -54
  562. transformers/models/got_ocr2/processing_got_ocr2.py +63 -42
  563. transformers/models/gpt2/configuration_gpt2.py +2 -13
  564. transformers/models/gpt2/modeling_gpt2.py +115 -120
  565. transformers/models/gpt2/tokenization_gpt2.py +46 -15
  566. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +2 -5
  567. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +89 -79
  568. transformers/models/gpt_neo/configuration_gpt_neo.py +2 -9
  569. transformers/models/gpt_neo/modeling_gpt_neo.py +67 -83
  570. transformers/models/gpt_neox/configuration_gpt_neox.py +25 -25
  571. transformers/models/gpt_neox/modeling_gpt_neox.py +75 -76
  572. transformers/models/gpt_neox/modular_gpt_neox.py +66 -67
  573. transformers/models/gpt_neox/tokenization_gpt_neox.py +51 -9
  574. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +19 -24
  575. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +47 -46
  576. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +3 -1
  577. transformers/models/gpt_oss/configuration_gpt_oss.py +28 -46
  578. transformers/models/gpt_oss/modeling_gpt_oss.py +121 -83
  579. transformers/models/gpt_oss/modular_gpt_oss.py +103 -64
  580. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  581. transformers/models/gptj/configuration_gptj.py +4 -4
  582. transformers/models/gptj/modeling_gptj.py +87 -101
  583. transformers/models/granite/configuration_granite.py +33 -28
  584. transformers/models/granite/modeling_granite.py +46 -44
  585. transformers/models/granite/modular_granite.py +31 -29
  586. transformers/models/granite_speech/configuration_granite_speech.py +1 -0
  587. transformers/models/granite_speech/feature_extraction_granite_speech.py +3 -1
  588. transformers/models/granite_speech/modeling_granite_speech.py +52 -82
  589. transformers/models/granite_speech/processing_granite_speech.py +4 -11
  590. transformers/models/granitemoe/configuration_granitemoe.py +36 -31
  591. transformers/models/granitemoe/modeling_granitemoe.py +46 -41
  592. transformers/models/granitemoe/modular_granitemoe.py +27 -22
  593. transformers/models/granitemoehybrid/__init__.py +1 -0
  594. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +47 -46
  595. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +93 -97
  596. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +21 -54
  597. transformers/models/granitemoeshared/configuration_granitemoeshared.py +37 -33
  598. transformers/models/granitemoeshared/modeling_granitemoeshared.py +61 -54
  599. transformers/models/granitemoeshared/modular_granitemoeshared.py +21 -19
  600. transformers/models/grounding_dino/configuration_grounding_dino.py +4 -6
  601. transformers/models/grounding_dino/image_processing_grounding_dino.py +62 -60
  602. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +29 -28
  603. transformers/models/grounding_dino/modeling_grounding_dino.py +140 -155
  604. transformers/models/grounding_dino/modular_grounding_dino.py +3 -2
  605. transformers/models/grounding_dino/processing_grounding_dino.py +38 -10
  606. transformers/models/groupvit/configuration_groupvit.py +2 -4
  607. transformers/models/groupvit/modeling_groupvit.py +93 -107
  608. transformers/models/helium/configuration_helium.py +29 -25
  609. transformers/models/helium/modeling_helium.py +40 -38
  610. transformers/models/helium/modular_helium.py +7 -3
  611. transformers/models/herbert/tokenization_herbert.py +28 -10
  612. transformers/models/hgnet_v2/configuration_hgnet_v2.py +1 -0
  613. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -24
  614. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -24
  615. transformers/models/hiera/configuration_hiera.py +1 -0
  616. transformers/models/hiera/modeling_hiera.py +66 -72
  617. transformers/models/hubert/configuration_hubert.py +2 -4
  618. transformers/models/hubert/modeling_hubert.py +37 -42
  619. transformers/models/hubert/modular_hubert.py +11 -13
  620. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +31 -26
  621. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +38 -35
  622. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +6 -4
  623. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  624. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +36 -31
  625. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +42 -47
  626. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +9 -9
  627. transformers/models/ibert/configuration_ibert.py +2 -4
  628. transformers/models/ibert/modeling_ibert.py +62 -82
  629. transformers/models/ibert/quant_modules.py +1 -0
  630. transformers/models/idefics/configuration_idefics.py +8 -5
  631. transformers/models/idefics/image_processing_idefics.py +15 -13
  632. transformers/models/idefics/modeling_idefics.py +82 -75
  633. transformers/models/idefics/perceiver.py +3 -1
  634. transformers/models/idefics/processing_idefics.py +48 -32
  635. transformers/models/idefics/vision.py +25 -24
  636. transformers/models/idefics2/configuration_idefics2.py +3 -1
  637. transformers/models/idefics2/image_processing_idefics2.py +32 -31
  638. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  639. transformers/models/idefics2/modeling_idefics2.py +101 -127
  640. transformers/models/idefics2/processing_idefics2.py +68 -10
  641. transformers/models/idefics3/configuration_idefics3.py +4 -1
  642. transformers/models/idefics3/image_processing_idefics3.py +43 -42
  643. transformers/models/idefics3/image_processing_idefics3_fast.py +15 -40
  644. transformers/models/idefics3/modeling_idefics3.py +90 -115
  645. transformers/models/idefics3/processing_idefics3.py +69 -15
  646. transformers/models/ijepa/configuration_ijepa.py +1 -0
  647. transformers/models/ijepa/modeling_ijepa.py +11 -10
  648. transformers/models/ijepa/modular_ijepa.py +7 -5
  649. transformers/models/imagegpt/configuration_imagegpt.py +2 -9
  650. transformers/models/imagegpt/image_processing_imagegpt.py +18 -17
  651. transformers/models/imagegpt/image_processing_imagegpt_fast.py +16 -11
  652. transformers/models/imagegpt/modeling_imagegpt.py +65 -76
  653. transformers/models/informer/configuration_informer.py +9 -6
  654. transformers/models/informer/modeling_informer.py +86 -88
  655. transformers/models/informer/modular_informer.py +16 -14
  656. transformers/models/instructblip/configuration_instructblip.py +2 -2
  657. transformers/models/instructblip/modeling_instructblip.py +63 -103
  658. transformers/models/instructblip/processing_instructblip.py +36 -10
  659. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  660. transformers/models/instructblipvideo/modeling_instructblipvideo.py +139 -157
  661. transformers/models/instructblipvideo/modular_instructblipvideo.py +64 -73
  662. transformers/models/instructblipvideo/processing_instructblipvideo.py +33 -14
  663. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +8 -6
  664. transformers/models/internvl/configuration_internvl.py +1 -0
  665. transformers/models/internvl/modeling_internvl.py +106 -85
  666. transformers/models/internvl/modular_internvl.py +67 -47
  667. transformers/models/internvl/processing_internvl.py +45 -12
  668. transformers/models/internvl/video_processing_internvl.py +12 -10
  669. transformers/models/jamba/configuration_jamba.py +8 -5
  670. transformers/models/jamba/modeling_jamba.py +66 -68
  671. transformers/models/jamba/modular_jamba.py +55 -54
  672. transformers/models/janus/configuration_janus.py +1 -0
  673. transformers/models/janus/image_processing_janus.py +37 -35
  674. transformers/models/janus/image_processing_janus_fast.py +20 -18
  675. transformers/models/janus/modeling_janus.py +191 -115
  676. transformers/models/janus/modular_janus.py +84 -133
  677. transformers/models/janus/processing_janus.py +43 -17
  678. transformers/models/jetmoe/configuration_jetmoe.py +26 -24
  679. transformers/models/jetmoe/modeling_jetmoe.py +46 -43
  680. transformers/models/jetmoe/modular_jetmoe.py +33 -31
  681. transformers/models/kosmos2/configuration_kosmos2.py +9 -10
  682. transformers/models/kosmos2/modeling_kosmos2.py +173 -208
  683. transformers/models/kosmos2/processing_kosmos2.py +55 -40
  684. transformers/models/kosmos2_5/__init__.py +1 -0
  685. transformers/models/kosmos2_5/configuration_kosmos2_5.py +9 -8
  686. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +12 -10
  687. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +13 -4
  688. transformers/models/kosmos2_5/modeling_kosmos2_5.py +118 -132
  689. transformers/models/kosmos2_5/processing_kosmos2_5.py +29 -8
  690. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +28 -31
  691. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +14 -12
  692. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +100 -110
  693. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +22 -28
  694. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +8 -2
  695. transformers/models/layoutlm/configuration_layoutlm.py +2 -14
  696. transformers/models/layoutlm/modeling_layoutlm.py +72 -77
  697. transformers/models/layoutlmv2/configuration_layoutlmv2.py +17 -14
  698. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +21 -18
  699. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +9 -7
  700. transformers/models/layoutlmv2/modeling_layoutlmv2.py +50 -64
  701. transformers/models/layoutlmv2/processing_layoutlmv2.py +44 -14
  702. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +126 -73
  703. transformers/models/layoutlmv3/configuration_layoutlmv3.py +19 -16
  704. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +26 -24
  705. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +11 -9
  706. transformers/models/layoutlmv3/modeling_layoutlmv3.py +56 -82
  707. transformers/models/layoutlmv3/processing_layoutlmv3.py +46 -14
  708. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +134 -74
  709. transformers/models/layoutxlm/configuration_layoutxlm.py +17 -14
  710. transformers/models/layoutxlm/modular_layoutxlm.py +1 -0
  711. transformers/models/layoutxlm/processing_layoutxlm.py +44 -14
  712. transformers/models/layoutxlm/tokenization_layoutxlm.py +113 -77
  713. transformers/models/led/configuration_led.py +12 -8
  714. transformers/models/led/modeling_led.py +266 -124
  715. transformers/models/levit/configuration_levit.py +1 -0
  716. transformers/models/levit/image_processing_levit.py +21 -19
  717. transformers/models/levit/image_processing_levit_fast.py +5 -4
  718. transformers/models/levit/modeling_levit.py +19 -38
  719. transformers/models/lfm2/configuration_lfm2.py +30 -27
  720. transformers/models/lfm2/modeling_lfm2.py +50 -47
  721. transformers/models/lfm2/modular_lfm2.py +30 -29
  722. transformers/models/lfm2_moe/__init__.py +1 -0
  723. transformers/models/lfm2_moe/configuration_lfm2_moe.py +9 -6
  724. transformers/models/lfm2_moe/modeling_lfm2_moe.py +53 -61
  725. transformers/models/lfm2_moe/modular_lfm2_moe.py +37 -13
  726. transformers/models/lfm2_vl/configuration_lfm2_vl.py +1 -4
  727. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +12 -41
  728. transformers/models/lfm2_vl/modeling_lfm2_vl.py +66 -84
  729. transformers/models/lfm2_vl/modular_lfm2_vl.py +56 -70
  730. transformers/models/lfm2_vl/processing_lfm2_vl.py +76 -96
  731. transformers/models/lightglue/image_processing_lightglue.py +15 -16
  732. transformers/models/lightglue/image_processing_lightglue_fast.py +9 -9
  733. transformers/models/lightglue/modeling_lightglue.py +31 -31
  734. transformers/models/lightglue/modular_lightglue.py +28 -29
  735. transformers/models/lilt/configuration_lilt.py +2 -6
  736. transformers/models/lilt/modeling_lilt.py +70 -76
  737. transformers/models/llama/configuration_llama.py +31 -26
  738. transformers/models/llama/modeling_llama.py +39 -36
  739. transformers/models/llama/tokenization_llama.py +44 -14
  740. transformers/models/llama4/configuration_llama4.py +30 -27
  741. transformers/models/llama4/image_processing_llama4_fast.py +14 -12
  742. transformers/models/llama4/modeling_llama4.py +113 -120
  743. transformers/models/llama4/processing_llama4.py +57 -33
  744. transformers/models/llava/configuration_llava.py +1 -10
  745. transformers/models/llava/image_processing_llava.py +28 -25
  746. transformers/models/llava/image_processing_llava_fast.py +11 -9
  747. transformers/models/llava/modeling_llava.py +109 -85
  748. transformers/models/llava/processing_llava.py +51 -18
  749. transformers/models/llava_next/configuration_llava_next.py +2 -2
  750. transformers/models/llava_next/image_processing_llava_next.py +45 -43
  751. transformers/models/llava_next/image_processing_llava_next_fast.py +13 -11
  752. transformers/models/llava_next/modeling_llava_next.py +107 -110
  753. transformers/models/llava_next/processing_llava_next.py +47 -18
  754. transformers/models/llava_next_video/configuration_llava_next_video.py +7 -4
  755. transformers/models/llava_next_video/modeling_llava_next_video.py +158 -175
  756. transformers/models/llava_next_video/modular_llava_next_video.py +150 -155
  757. transformers/models/llava_next_video/processing_llava_next_video.py +63 -21
  758. transformers/models/llava_next_video/video_processing_llava_next_video.py +1 -0
  759. transformers/models/llava_onevision/configuration_llava_onevision.py +7 -4
  760. transformers/models/llava_onevision/image_processing_llava_onevision.py +42 -40
  761. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +15 -14
  762. transformers/models/llava_onevision/modeling_llava_onevision.py +169 -177
  763. transformers/models/llava_onevision/modular_llava_onevision.py +156 -163
  764. transformers/models/llava_onevision/processing_llava_onevision.py +53 -21
  765. transformers/models/llava_onevision/video_processing_llava_onevision.py +1 -0
  766. transformers/models/longcat_flash/__init__.py +1 -0
  767. transformers/models/longcat_flash/configuration_longcat_flash.py +42 -37
  768. transformers/models/longcat_flash/modeling_longcat_flash.py +36 -36
  769. transformers/models/longcat_flash/modular_longcat_flash.py +21 -21
  770. transformers/models/longformer/configuration_longformer.py +5 -5
  771. transformers/models/longformer/modeling_longformer.py +101 -105
  772. transformers/models/longt5/configuration_longt5.py +7 -9
  773. transformers/models/longt5/modeling_longt5.py +49 -49
  774. transformers/models/luke/configuration_luke.py +2 -8
  775. transformers/models/luke/modeling_luke.py +181 -188
  776. transformers/models/luke/tokenization_luke.py +140 -107
  777. transformers/models/lxmert/configuration_lxmert.py +1 -16
  778. transformers/models/lxmert/modeling_lxmert.py +74 -65
  779. transformers/models/m2m_100/configuration_m2m_100.py +9 -7
  780. transformers/models/m2m_100/modeling_m2m_100.py +71 -83
  781. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  782. transformers/models/mamba/configuration_mamba.py +2 -1
  783. transformers/models/mamba/modeling_mamba.py +66 -58
  784. transformers/models/mamba2/configuration_mamba2.py +8 -5
  785. transformers/models/mamba2/modeling_mamba2.py +69 -68
  786. transformers/models/marian/configuration_marian.py +5 -10
  787. transformers/models/marian/modeling_marian.py +87 -93
  788. transformers/models/marian/tokenization_marian.py +6 -6
  789. transformers/models/markuplm/configuration_markuplm.py +7 -4
  790. transformers/models/markuplm/feature_extraction_markuplm.py +2 -1
  791. transformers/models/markuplm/modeling_markuplm.py +70 -69
  792. transformers/models/markuplm/processing_markuplm.py +38 -31
  793. transformers/models/markuplm/tokenization_markuplm.py +136 -93
  794. transformers/models/mask2former/configuration_mask2former.py +8 -5
  795. transformers/models/mask2former/image_processing_mask2former.py +85 -84
  796. transformers/models/mask2former/image_processing_mask2former_fast.py +40 -37
  797. transformers/models/mask2former/modeling_mask2former.py +103 -118
  798. transformers/models/mask2former/modular_mask2former.py +8 -6
  799. transformers/models/maskformer/configuration_maskformer.py +9 -6
  800. transformers/models/maskformer/configuration_maskformer_swin.py +1 -0
  801. transformers/models/maskformer/image_processing_maskformer.py +85 -84
  802. transformers/models/maskformer/image_processing_maskformer_fast.py +40 -36
  803. transformers/models/maskformer/modeling_maskformer.py +65 -79
  804. transformers/models/maskformer/modeling_maskformer_swin.py +32 -36
  805. transformers/models/mbart/configuration_mbart.py +4 -9
  806. transformers/models/mbart/modeling_mbart.py +116 -131
  807. transformers/models/mbart/tokenization_mbart.py +54 -11
  808. transformers/models/mbart50/tokenization_mbart50.py +13 -8
  809. transformers/models/megatron_bert/configuration_megatron_bert.py +3 -13
  810. transformers/models/megatron_bert/modeling_megatron_bert.py +150 -148
  811. transformers/models/metaclip_2/configuration_metaclip_2.py +1 -4
  812. transformers/models/metaclip_2/modeling_metaclip_2.py +84 -91
  813. transformers/models/metaclip_2/modular_metaclip_2.py +45 -61
  814. transformers/models/mgp_str/configuration_mgp_str.py +1 -0
  815. transformers/models/mgp_str/modeling_mgp_str.py +18 -20
  816. transformers/models/mgp_str/processing_mgp_str.py +20 -3
  817. transformers/models/mgp_str/tokenization_mgp_str.py +3 -1
  818. transformers/models/mimi/configuration_mimi.py +40 -42
  819. transformers/models/mimi/modeling_mimi.py +113 -142
  820. transformers/models/minimax/__init__.py +1 -0
  821. transformers/models/minimax/configuration_minimax.py +43 -37
  822. transformers/models/minimax/modeling_minimax.py +51 -61
  823. transformers/models/minimax/modular_minimax.py +62 -68
  824. transformers/models/ministral/configuration_ministral.py +29 -25
  825. transformers/models/ministral/modeling_ministral.py +38 -36
  826. transformers/models/ministral/modular_ministral.py +37 -32
  827. transformers/models/ministral3/configuration_ministral3.py +27 -24
  828. transformers/models/ministral3/modeling_ministral3.py +37 -36
  829. transformers/models/ministral3/modular_ministral3.py +5 -4
  830. transformers/models/mistral/configuration_mistral.py +29 -24
  831. transformers/models/mistral/modeling_mistral.py +37 -36
  832. transformers/models/mistral/modular_mistral.py +12 -11
  833. transformers/models/mistral3/configuration_mistral3.py +1 -4
  834. transformers/models/mistral3/modeling_mistral3.py +86 -89
  835. transformers/models/mistral3/modular_mistral3.py +68 -69
  836. transformers/models/mixtral/configuration_mixtral.py +34 -29
  837. transformers/models/mixtral/modeling_mixtral.py +45 -50
  838. transformers/models/mixtral/modular_mixtral.py +31 -32
  839. transformers/models/mlcd/configuration_mlcd.py +1 -0
  840. transformers/models/mlcd/modeling_mlcd.py +14 -20
  841. transformers/models/mlcd/modular_mlcd.py +13 -17
  842. transformers/models/mllama/configuration_mllama.py +15 -10
  843. transformers/models/mllama/image_processing_mllama.py +25 -23
  844. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  845. transformers/models/mllama/modeling_mllama.py +94 -105
  846. transformers/models/mllama/processing_mllama.py +55 -6
  847. transformers/models/mluke/tokenization_mluke.py +107 -101
  848. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +3 -5
  849. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +140 -155
  850. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +3 -5
  851. transformers/models/mobilebert/configuration_mobilebert.py +2 -4
  852. transformers/models/mobilebert/modeling_mobilebert.py +85 -77
  853. transformers/models/mobilebert/tokenization_mobilebert.py +1 -0
  854. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +1 -0
  855. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +23 -20
  856. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +1 -0
  857. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +16 -15
  858. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +1 -0
  859. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +51 -48
  860. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +15 -13
  861. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +22 -24
  862. transformers/models/mobilevit/configuration_mobilevit.py +1 -0
  863. transformers/models/mobilevit/image_processing_mobilevit.py +49 -46
  864. transformers/models/mobilevit/image_processing_mobilevit_fast.py +14 -12
  865. transformers/models/mobilevit/modeling_mobilevit.py +21 -28
  866. transformers/models/mobilevitv2/configuration_mobilevitv2.py +1 -0
  867. transformers/models/mobilevitv2/modeling_mobilevitv2.py +22 -28
  868. transformers/models/modernbert/configuration_modernbert.py +42 -44
  869. transformers/models/modernbert/modeling_modernbert.py +133 -145
  870. transformers/models/modernbert/modular_modernbert.py +170 -186
  871. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +40 -40
  872. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +57 -62
  873. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +86 -94
  874. transformers/models/moonshine/configuration_moonshine.py +31 -34
  875. transformers/models/moonshine/modeling_moonshine.py +71 -71
  876. transformers/models/moonshine/modular_moonshine.py +83 -88
  877. transformers/models/moshi/configuration_moshi.py +23 -46
  878. transformers/models/moshi/modeling_moshi.py +187 -157
  879. transformers/models/mpnet/configuration_mpnet.py +2 -6
  880. transformers/models/mpnet/modeling_mpnet.py +57 -62
  881. transformers/models/mpnet/tokenization_mpnet.py +15 -4
  882. transformers/models/mpt/configuration_mpt.py +9 -5
  883. transformers/models/mpt/modeling_mpt.py +60 -60
  884. transformers/models/mra/configuration_mra.py +2 -8
  885. transformers/models/mra/modeling_mra.py +57 -64
  886. transformers/models/mt5/configuration_mt5.py +8 -10
  887. transformers/models/mt5/modeling_mt5.py +95 -87
  888. transformers/models/musicgen/configuration_musicgen.py +8 -12
  889. transformers/models/musicgen/modeling_musicgen.py +122 -118
  890. transformers/models/musicgen/processing_musicgen.py +21 -3
  891. transformers/models/musicgen_melody/configuration_musicgen_melody.py +8 -15
  892. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +9 -8
  893. transformers/models/musicgen_melody/modeling_musicgen_melody.py +123 -117
  894. transformers/models/musicgen_melody/processing_musicgen_melody.py +22 -3
  895. transformers/models/mvp/configuration_mvp.py +5 -8
  896. transformers/models/mvp/modeling_mvp.py +123 -135
  897. transformers/models/myt5/tokenization_myt5.py +10 -8
  898. transformers/models/nanochat/configuration_nanochat.py +8 -5
  899. transformers/models/nanochat/modeling_nanochat.py +40 -37
  900. transformers/models/nanochat/modular_nanochat.py +14 -12
  901. transformers/models/nemotron/configuration_nemotron.py +30 -25
  902. transformers/models/nemotron/modeling_nemotron.py +57 -56
  903. transformers/models/nllb/tokenization_nllb.py +28 -12
  904. transformers/models/nllb_moe/configuration_nllb_moe.py +9 -7
  905. transformers/models/nllb_moe/modeling_nllb_moe.py +69 -77
  906. transformers/models/nougat/image_processing_nougat.py +32 -29
  907. transformers/models/nougat/image_processing_nougat_fast.py +14 -12
  908. transformers/models/nougat/processing_nougat.py +39 -37
  909. transformers/models/nougat/tokenization_nougat.py +73 -18
  910. transformers/models/nystromformer/configuration_nystromformer.py +2 -8
  911. transformers/models/nystromformer/modeling_nystromformer.py +63 -74
  912. transformers/models/olmo/configuration_olmo.py +28 -23
  913. transformers/models/olmo/modeling_olmo.py +39 -36
  914. transformers/models/olmo/modular_olmo.py +11 -7
  915. transformers/models/olmo2/configuration_olmo2.py +28 -23
  916. transformers/models/olmo2/modeling_olmo2.py +41 -37
  917. transformers/models/olmo2/modular_olmo2.py +32 -29
  918. transformers/models/olmo3/__init__.py +1 -0
  919. transformers/models/olmo3/configuration_olmo3.py +30 -26
  920. transformers/models/olmo3/modeling_olmo3.py +39 -36
  921. transformers/models/olmo3/modular_olmo3.py +40 -37
  922. transformers/models/olmoe/configuration_olmoe.py +33 -29
  923. transformers/models/olmoe/modeling_olmoe.py +46 -52
  924. transformers/models/olmoe/modular_olmoe.py +15 -16
  925. transformers/models/omdet_turbo/configuration_omdet_turbo.py +4 -2
  926. transformers/models/omdet_turbo/modeling_omdet_turbo.py +47 -53
  927. transformers/models/omdet_turbo/processing_omdet_turbo.py +67 -19
  928. transformers/models/oneformer/configuration_oneformer.py +8 -5
  929. transformers/models/oneformer/image_processing_oneformer.py +84 -83
  930. transformers/models/oneformer/image_processing_oneformer_fast.py +42 -41
  931. transformers/models/oneformer/modeling_oneformer.py +171 -147
  932. transformers/models/oneformer/processing_oneformer.py +43 -28
  933. transformers/models/openai/configuration_openai.py +1 -16
  934. transformers/models/openai/modeling_openai.py +51 -65
  935. transformers/models/openai/tokenization_openai.py +47 -8
  936. transformers/models/opt/configuration_opt.py +7 -6
  937. transformers/models/opt/modeling_opt.py +76 -78
  938. transformers/models/ovis2/__init__.py +1 -0
  939. transformers/models/ovis2/configuration_ovis2.py +1 -0
  940. transformers/models/ovis2/image_processing_ovis2.py +24 -22
  941. transformers/models/ovis2/image_processing_ovis2_fast.py +11 -9
  942. transformers/models/ovis2/modeling_ovis2.py +142 -111
  943. transformers/models/ovis2/modular_ovis2.py +45 -90
  944. transformers/models/ovis2/processing_ovis2.py +40 -12
  945. transformers/models/owlv2/configuration_owlv2.py +2 -4
  946. transformers/models/owlv2/image_processing_owlv2.py +21 -20
  947. transformers/models/owlv2/image_processing_owlv2_fast.py +15 -12
  948. transformers/models/owlv2/modeling_owlv2.py +117 -133
  949. transformers/models/owlv2/modular_owlv2.py +14 -11
  950. transformers/models/owlv2/processing_owlv2.py +49 -20
  951. transformers/models/owlvit/configuration_owlvit.py +2 -4
  952. transformers/models/owlvit/image_processing_owlvit.py +22 -21
  953. transformers/models/owlvit/image_processing_owlvit_fast.py +3 -2
  954. transformers/models/owlvit/modeling_owlvit.py +116 -132
  955. transformers/models/owlvit/processing_owlvit.py +48 -20
  956. transformers/models/paligemma/configuration_paligemma.py +1 -4
  957. transformers/models/paligemma/modeling_paligemma.py +93 -103
  958. transformers/models/paligemma/processing_paligemma.py +66 -13
  959. transformers/models/parakeet/configuration_parakeet.py +14 -7
  960. transformers/models/parakeet/feature_extraction_parakeet.py +12 -10
  961. transformers/models/parakeet/modeling_parakeet.py +28 -32
  962. transformers/models/parakeet/modular_parakeet.py +20 -23
  963. transformers/models/parakeet/processing_parakeet.py +5 -13
  964. transformers/models/parakeet/{tokenization_parakeet.py → tokenization_parakeet_fast.py} +7 -5
  965. transformers/models/patchtsmixer/configuration_patchtsmixer.py +8 -5
  966. transformers/models/patchtsmixer/modeling_patchtsmixer.py +62 -70
  967. transformers/models/patchtst/configuration_patchtst.py +9 -6
  968. transformers/models/patchtst/modeling_patchtst.py +80 -97
  969. transformers/models/pegasus/configuration_pegasus.py +5 -8
  970. transformers/models/pegasus/modeling_pegasus.py +66 -72
  971. transformers/models/pegasus/tokenization_pegasus.py +45 -15
  972. transformers/models/pegasus_x/configuration_pegasus_x.py +4 -5
  973. transformers/models/pegasus_x/modeling_pegasus_x.py +52 -55
  974. transformers/models/perceiver/configuration_perceiver.py +1 -0
  975. transformers/models/perceiver/image_processing_perceiver.py +25 -22
  976. transformers/models/perceiver/image_processing_perceiver_fast.py +9 -7
  977. transformers/models/perceiver/modeling_perceiver.py +146 -165
  978. transformers/models/perceiver/tokenization_perceiver.py +6 -3
  979. transformers/models/perception_lm/configuration_perception_lm.py +1 -0
  980. transformers/models/perception_lm/image_processing_perception_lm_fast.py +10 -8
  981. transformers/models/perception_lm/modeling_perception_lm.py +70 -71
  982. transformers/models/perception_lm/modular_perception_lm.py +61 -65
  983. transformers/models/perception_lm/processing_perception_lm.py +47 -13
  984. transformers/models/perception_lm/video_processing_perception_lm.py +1 -0
  985. transformers/models/persimmon/configuration_persimmon.py +28 -23
  986. transformers/models/persimmon/modeling_persimmon.py +45 -43
  987. transformers/models/phi/configuration_phi.py +28 -23
  988. transformers/models/phi/modeling_phi.py +43 -40
  989. transformers/models/phi/modular_phi.py +24 -23
  990. transformers/models/phi3/configuration_phi3.py +33 -28
  991. transformers/models/phi3/modeling_phi3.py +38 -36
  992. transformers/models/phi3/modular_phi3.py +17 -13
  993. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +33 -30
  994. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +9 -7
  995. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  996. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +78 -95
  997. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +80 -98
  998. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +44 -7
  999. transformers/models/phimoe/configuration_phimoe.py +36 -31
  1000. transformers/models/phimoe/modeling_phimoe.py +45 -50
  1001. transformers/models/phimoe/modular_phimoe.py +4 -3
  1002. transformers/models/phobert/tokenization_phobert.py +6 -4
  1003. transformers/models/pix2struct/configuration_pix2struct.py +10 -12
  1004. transformers/models/pix2struct/image_processing_pix2struct.py +19 -15
  1005. transformers/models/pix2struct/image_processing_pix2struct_fast.py +15 -12
  1006. transformers/models/pix2struct/modeling_pix2struct.py +52 -58
  1007. transformers/models/pix2struct/processing_pix2struct.py +30 -5
  1008. transformers/models/pixtral/configuration_pixtral.py +14 -11
  1009. transformers/models/pixtral/image_processing_pixtral.py +28 -26
  1010. transformers/models/pixtral/image_processing_pixtral_fast.py +11 -10
  1011. transformers/models/pixtral/modeling_pixtral.py +34 -28
  1012. transformers/models/pixtral/processing_pixtral.py +53 -21
  1013. transformers/models/plbart/configuration_plbart.py +5 -8
  1014. transformers/models/plbart/modeling_plbart.py +106 -119
  1015. transformers/models/plbart/modular_plbart.py +33 -39
  1016. transformers/models/plbart/tokenization_plbart.py +7 -4
  1017. transformers/models/poolformer/configuration_poolformer.py +1 -0
  1018. transformers/models/poolformer/image_processing_poolformer.py +24 -21
  1019. transformers/models/poolformer/image_processing_poolformer_fast.py +15 -13
  1020. transformers/models/poolformer/modeling_poolformer.py +13 -23
  1021. transformers/models/pop2piano/configuration_pop2piano.py +8 -7
  1022. transformers/models/pop2piano/feature_extraction_pop2piano.py +9 -6
  1023. transformers/models/pop2piano/modeling_pop2piano.py +24 -26
  1024. transformers/models/pop2piano/processing_pop2piano.py +33 -25
  1025. transformers/models/pop2piano/tokenization_pop2piano.py +23 -15
  1026. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +3 -3
  1027. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1028. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +21 -20
  1029. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +13 -16
  1030. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +13 -16
  1031. transformers/models/prophetnet/configuration_prophetnet.py +38 -37
  1032. transformers/models/prophetnet/modeling_prophetnet.py +131 -114
  1033. transformers/models/prophetnet/tokenization_prophetnet.py +16 -14
  1034. transformers/models/pvt/configuration_pvt.py +1 -0
  1035. transformers/models/pvt/image_processing_pvt.py +27 -24
  1036. transformers/models/pvt/image_processing_pvt_fast.py +2 -1
  1037. transformers/models/pvt/modeling_pvt.py +21 -21
  1038. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -2
  1039. transformers/models/pvt_v2/modeling_pvt_v2.py +25 -28
  1040. transformers/models/qwen2/configuration_qwen2.py +25 -32
  1041. transformers/models/qwen2/modeling_qwen2.py +38 -36
  1042. transformers/models/qwen2/modular_qwen2.py +12 -11
  1043. transformers/models/qwen2/tokenization_qwen2.py +23 -12
  1044. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +26 -32
  1045. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +277 -340
  1046. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +211 -278
  1047. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +49 -41
  1048. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +35 -29
  1049. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +148 -203
  1050. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +118 -93
  1051. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +43 -7
  1052. transformers/models/qwen2_audio/configuration_qwen2_audio.py +1 -0
  1053. transformers/models/qwen2_audio/modeling_qwen2_audio.py +40 -40
  1054. transformers/models/qwen2_audio/processing_qwen2_audio.py +42 -13
  1055. transformers/models/qwen2_moe/configuration_qwen2_moe.py +35 -42
  1056. transformers/models/qwen2_moe/modeling_qwen2_moe.py +46 -51
  1057. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -7
  1058. transformers/models/qwen2_vl/configuration_qwen2_vl.py +34 -29
  1059. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +42 -41
  1060. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +15 -12
  1061. transformers/models/qwen2_vl/modeling_qwen2_vl.py +153 -199
  1062. transformers/models/qwen2_vl/processing_qwen2_vl.py +44 -7
  1063. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +18 -38
  1064. transformers/models/qwen3/configuration_qwen3.py +27 -34
  1065. transformers/models/qwen3/modeling_qwen3.py +39 -36
  1066. transformers/models/qwen3/modular_qwen3.py +6 -4
  1067. transformers/models/qwen3_moe/configuration_qwen3_moe.py +32 -39
  1068. transformers/models/qwen3_moe/modeling_qwen3_moe.py +46 -51
  1069. transformers/models/qwen3_moe/modular_qwen3_moe.py +13 -10
  1070. transformers/models/qwen3_next/configuration_qwen3_next.py +35 -45
  1071. transformers/models/qwen3_next/modeling_qwen3_next.py +51 -47
  1072. transformers/models/qwen3_next/modular_qwen3_next.py +35 -34
  1073. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +101 -135
  1074. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +252 -355
  1075. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +196 -250
  1076. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +48 -40
  1077. transformers/models/qwen3_vl/configuration_qwen3_vl.py +29 -27
  1078. transformers/models/qwen3_vl/modeling_qwen3_vl.py +155 -233
  1079. transformers/models/qwen3_vl/modular_qwen3_vl.py +179 -206
  1080. transformers/models/qwen3_vl/processing_qwen3_vl.py +42 -6
  1081. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +12 -10
  1082. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +30 -23
  1083. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +303 -358
  1084. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +124 -87
  1085. transformers/models/rag/configuration_rag.py +15 -6
  1086. transformers/models/rag/modeling_rag.py +130 -127
  1087. transformers/models/rag/retrieval_rag.py +5 -3
  1088. transformers/models/rag/tokenization_rag.py +50 -0
  1089. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +30 -29
  1090. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +42 -53
  1091. transformers/models/reformer/configuration_reformer.py +8 -7
  1092. transformers/models/reformer/modeling_reformer.py +69 -80
  1093. transformers/models/reformer/tokenization_reformer.py +31 -11
  1094. transformers/models/regnet/configuration_regnet.py +1 -0
  1095. transformers/models/regnet/modeling_regnet.py +8 -15
  1096. transformers/models/rembert/configuration_rembert.py +2 -8
  1097. transformers/models/rembert/modeling_rembert.py +111 -121
  1098. transformers/models/rembert/tokenization_rembert.py +12 -2
  1099. transformers/models/resnet/configuration_resnet.py +1 -0
  1100. transformers/models/resnet/modeling_resnet.py +13 -27
  1101. transformers/models/roberta/configuration_roberta.py +3 -11
  1102. transformers/models/roberta/modeling_roberta.py +93 -94
  1103. transformers/models/roberta/modular_roberta.py +58 -58
  1104. transformers/models/roberta/tokenization_roberta.py +29 -17
  1105. transformers/models/roberta/tokenization_roberta_old.py +4 -2
  1106. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +3 -11
  1107. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +93 -94
  1108. transformers/models/roc_bert/configuration_roc_bert.py +2 -8
  1109. transformers/models/roc_bert/modeling_roc_bert.py +121 -122
  1110. transformers/models/roc_bert/tokenization_roc_bert.py +94 -88
  1111. transformers/models/roformer/configuration_roformer.py +3 -13
  1112. transformers/models/roformer/modeling_roformer.py +81 -85
  1113. transformers/models/roformer/tokenization_roformer.py +412 -74
  1114. transformers/models/roformer/tokenization_roformer_fast.py +160 -0
  1115. transformers/models/roformer/tokenization_utils.py +1 -0
  1116. transformers/models/rt_detr/configuration_rt_detr.py +2 -1
  1117. transformers/models/rt_detr/configuration_rt_detr_resnet.py +1 -0
  1118. transformers/models/rt_detr/image_processing_rt_detr.py +55 -54
  1119. transformers/models/rt_detr/image_processing_rt_detr_fast.py +26 -26
  1120. transformers/models/rt_detr/modeling_rt_detr.py +90 -99
  1121. transformers/models/rt_detr/modeling_rt_detr_resnet.py +6 -13
  1122. transformers/models/rt_detr/modular_rt_detr.py +16 -16
  1123. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +4 -6
  1124. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +90 -101
  1125. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +12 -19
  1126. transformers/models/rwkv/configuration_rwkv.py +4 -2
  1127. transformers/models/rwkv/modeling_rwkv.py +32 -31
  1128. transformers/models/sam/configuration_sam.py +1 -3
  1129. transformers/models/sam/image_processing_sam.py +60 -59
  1130. transformers/models/sam/image_processing_sam_fast.py +27 -25
  1131. transformers/models/sam/modeling_sam.py +41 -47
  1132. transformers/models/sam/processing_sam.py +27 -39
  1133. transformers/models/sam2/configuration_sam2.py +3 -2
  1134. transformers/models/sam2/image_processing_sam2_fast.py +15 -14
  1135. transformers/models/sam2/modeling_sam2.py +90 -96
  1136. transformers/models/sam2/modular_sam2.py +91 -86
  1137. transformers/models/sam2/processing_sam2.py +47 -31
  1138. transformers/models/sam2_video/configuration_sam2_video.py +1 -0
  1139. transformers/models/sam2_video/modeling_sam2_video.py +144 -151
  1140. transformers/models/sam2_video/modular_sam2_video.py +104 -101
  1141. transformers/models/sam2_video/processing_sam2_video.py +66 -49
  1142. transformers/models/sam2_video/video_processing_sam2_video.py +4 -1
  1143. transformers/models/sam3/configuration_sam3.py +2 -21
  1144. transformers/models/sam3/image_processing_sam3_fast.py +20 -17
  1145. transformers/models/sam3/modeling_sam3.py +170 -184
  1146. transformers/models/sam3/modular_sam3.py +8 -3
  1147. transformers/models/sam3/processing_sam3.py +52 -37
  1148. transformers/models/sam3_tracker/__init__.py +1 -0
  1149. transformers/models/sam3_tracker/configuration_sam3_tracker.py +3 -1
  1150. transformers/models/sam3_tracker/modeling_sam3_tracker.py +77 -82
  1151. transformers/models/sam3_tracker/modular_sam3_tracker.py +3 -8
  1152. transformers/models/sam3_tracker/processing_sam3_tracker.py +48 -31
  1153. transformers/models/sam3_tracker_video/__init__.py +1 -0
  1154. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +1 -25
  1155. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +122 -135
  1156. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +26 -35
  1157. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +66 -50
  1158. transformers/models/sam3_video/configuration_sam3_video.py +1 -14
  1159. transformers/models/sam3_video/modeling_sam3_video.py +34 -33
  1160. transformers/models/sam3_video/processing_sam3_video.py +46 -26
  1161. transformers/models/sam_hq/__init__.py +1 -1
  1162. transformers/models/sam_hq/configuration_sam_hq.py +1 -3
  1163. transformers/models/sam_hq/modeling_sam_hq.py +69 -74
  1164. transformers/models/sam_hq/modular_sam_hq.py +25 -23
  1165. transformers/models/sam_hq/{processing_sam_hq.py → processing_samhq.py} +29 -41
  1166. transformers/models/seamless_m4t/configuration_seamless_m4t.py +10 -8
  1167. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +11 -8
  1168. transformers/models/seamless_m4t/modeling_seamless_m4t.py +194 -212
  1169. transformers/models/seamless_m4t/processing_seamless_m4t.py +39 -18
  1170. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +77 -40
  1171. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +10 -8
  1172. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +196 -204
  1173. transformers/models/seed_oss/configuration_seed_oss.py +32 -28
  1174. transformers/models/seed_oss/modeling_seed_oss.py +35 -33
  1175. transformers/models/seed_oss/modular_seed_oss.py +4 -3
  1176. transformers/models/segformer/configuration_segformer.py +10 -0
  1177. transformers/models/segformer/image_processing_segformer.py +42 -39
  1178. transformers/models/segformer/image_processing_segformer_fast.py +12 -10
  1179. transformers/models/segformer/modeling_segformer.py +31 -34
  1180. transformers/models/segformer/modular_segformer.py +10 -8
  1181. transformers/models/seggpt/configuration_seggpt.py +1 -0
  1182. transformers/models/seggpt/image_processing_seggpt.py +41 -38
  1183. transformers/models/seggpt/modeling_seggpt.py +38 -50
  1184. transformers/models/sew/configuration_sew.py +2 -4
  1185. transformers/models/sew/modeling_sew.py +36 -38
  1186. transformers/models/sew/modular_sew.py +13 -13
  1187. transformers/models/sew_d/configuration_sew_d.py +2 -4
  1188. transformers/models/sew_d/modeling_sew_d.py +30 -31
  1189. transformers/models/shieldgemma2/configuration_shieldgemma2.py +1 -0
  1190. transformers/models/shieldgemma2/modeling_shieldgemma2.py +17 -16
  1191. transformers/models/shieldgemma2/processing_shieldgemma2.py +5 -3
  1192. transformers/models/siglip/configuration_siglip.py +2 -4
  1193. transformers/models/siglip/image_processing_siglip.py +20 -17
  1194. transformers/models/siglip/image_processing_siglip_fast.py +1 -0
  1195. transformers/models/siglip/modeling_siglip.py +75 -84
  1196. transformers/models/siglip/processing_siglip.py +14 -2
  1197. transformers/models/siglip/tokenization_siglip.py +7 -6
  1198. transformers/models/siglip2/configuration_siglip2.py +2 -5
  1199. transformers/models/siglip2/image_processing_siglip2.py +16 -15
  1200. transformers/models/siglip2/image_processing_siglip2_fast.py +7 -6
  1201. transformers/models/siglip2/modeling_siglip2.py +129 -143
  1202. transformers/models/siglip2/modular_siglip2.py +46 -47
  1203. transformers/models/siglip2/processing_siglip2.py +14 -2
  1204. transformers/models/smollm3/configuration_smollm3.py +32 -29
  1205. transformers/models/smollm3/modeling_smollm3.py +39 -36
  1206. transformers/models/smollm3/modular_smollm3.py +35 -33
  1207. transformers/models/smolvlm/configuration_smolvlm.py +4 -2
  1208. transformers/models/smolvlm/image_processing_smolvlm.py +43 -42
  1209. transformers/models/smolvlm/image_processing_smolvlm_fast.py +15 -41
  1210. transformers/models/smolvlm/modeling_smolvlm.py +94 -126
  1211. transformers/models/smolvlm/modular_smolvlm.py +39 -50
  1212. transformers/models/smolvlm/processing_smolvlm.py +83 -15
  1213. transformers/models/smolvlm/video_processing_smolvlm.py +18 -16
  1214. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +1 -0
  1215. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +27 -26
  1216. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1217. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +13 -10
  1218. transformers/models/speech_to_text/modeling_speech_to_text.py +54 -66
  1219. transformers/models/speech_to_text/processing_speech_to_text.py +30 -4
  1220. transformers/models/speech_to_text/tokenization_speech_to_text.py +6 -5
  1221. transformers/models/speecht5/configuration_speecht5.py +9 -7
  1222. transformers/models/speecht5/feature_extraction_speecht5.py +37 -16
  1223. transformers/models/speecht5/modeling_speecht5.py +175 -213
  1224. transformers/models/speecht5/number_normalizer.py +1 -0
  1225. transformers/models/speecht5/processing_speecht5.py +37 -3
  1226. transformers/models/speecht5/tokenization_speecht5.py +5 -4
  1227. transformers/models/splinter/configuration_splinter.py +7 -6
  1228. transformers/models/splinter/modeling_splinter.py +59 -71
  1229. transformers/models/splinter/tokenization_splinter.py +30 -9
  1230. transformers/models/squeezebert/configuration_squeezebert.py +2 -14
  1231. transformers/models/squeezebert/modeling_squeezebert.py +62 -68
  1232. transformers/models/squeezebert/tokenization_squeezebert.py +1 -0
  1233. transformers/models/stablelm/configuration_stablelm.py +29 -24
  1234. transformers/models/stablelm/modeling_stablelm.py +45 -44
  1235. transformers/models/starcoder2/configuration_starcoder2.py +27 -30
  1236. transformers/models/starcoder2/modeling_starcoder2.py +41 -39
  1237. transformers/models/starcoder2/modular_starcoder2.py +16 -14
  1238. transformers/models/superglue/configuration_superglue.py +3 -7
  1239. transformers/models/superglue/image_processing_superglue.py +15 -15
  1240. transformers/models/superglue/image_processing_superglue_fast.py +10 -9
  1241. transformers/models/superglue/modeling_superglue.py +37 -42
  1242. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1243. transformers/models/superpoint/image_processing_superpoint_fast.py +11 -8
  1244. transformers/models/superpoint/modeling_superpoint.py +16 -18
  1245. transformers/models/swiftformer/configuration_swiftformer.py +1 -0
  1246. transformers/models/swiftformer/modeling_swiftformer.py +14 -18
  1247. transformers/models/swin/configuration_swin.py +1 -0
  1248. transformers/models/swin/modeling_swin.py +86 -86
  1249. transformers/models/swin2sr/configuration_swin2sr.py +1 -0
  1250. transformers/models/swin2sr/image_processing_swin2sr.py +13 -10
  1251. transformers/models/swin2sr/image_processing_swin2sr_fast.py +8 -4
  1252. transformers/models/swin2sr/modeling_swin2sr.py +63 -81
  1253. transformers/models/swinv2/configuration_swinv2.py +1 -0
  1254. transformers/models/swinv2/modeling_swinv2.py +104 -108
  1255. transformers/models/switch_transformers/configuration_switch_transformers.py +7 -11
  1256. transformers/models/switch_transformers/modeling_switch_transformers.py +44 -37
  1257. transformers/models/switch_transformers/modular_switch_transformers.py +41 -34
  1258. transformers/models/t5/configuration_t5.py +8 -14
  1259. transformers/models/t5/modeling_t5.py +92 -88
  1260. transformers/models/t5/tokenization_t5.py +9 -3
  1261. transformers/models/t5gemma/configuration_t5gemma.py +41 -43
  1262. transformers/models/t5gemma/modeling_t5gemma.py +107 -104
  1263. transformers/models/t5gemma/modular_t5gemma.py +120 -124
  1264. transformers/models/t5gemma2/configuration_t5gemma2.py +120 -80
  1265. transformers/models/t5gemma2/modeling_t5gemma2.py +125 -141
  1266. transformers/models/t5gemma2/modular_t5gemma2.py +104 -393
  1267. transformers/models/table_transformer/configuration_table_transformer.py +2 -1
  1268. transformers/models/table_transformer/modeling_table_transformer.py +49 -51
  1269. transformers/models/tapas/configuration_tapas.py +2 -12
  1270. transformers/models/tapas/modeling_tapas.py +67 -68
  1271. transformers/models/tapas/tokenization_tapas.py +153 -115
  1272. transformers/models/textnet/configuration_textnet.py +1 -0
  1273. transformers/models/textnet/image_processing_textnet.py +25 -22
  1274. transformers/models/textnet/image_processing_textnet_fast.py +10 -8
  1275. transformers/models/textnet/modeling_textnet.py +16 -28
  1276. transformers/models/time_series_transformer/configuration_time_series_transformer.py +8 -5
  1277. transformers/models/time_series_transformer/modeling_time_series_transformer.py +81 -83
  1278. transformers/models/timesfm/configuration_timesfm.py +1 -0
  1279. transformers/models/timesfm/modeling_timesfm.py +22 -33
  1280. transformers/models/timesfm/modular_timesfm.py +21 -32
  1281. transformers/models/timesformer/configuration_timesformer.py +1 -0
  1282. transformers/models/timesformer/modeling_timesformer.py +16 -15
  1283. transformers/models/timm_backbone/configuration_timm_backbone.py +1 -0
  1284. transformers/models/timm_backbone/modeling_timm_backbone.py +15 -17
  1285. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -5
  1286. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +5 -4
  1287. transformers/models/timm_wrapper/modeling_timm_wrapper.py +29 -34
  1288. transformers/models/trocr/configuration_trocr.py +8 -11
  1289. transformers/models/trocr/modeling_trocr.py +44 -45
  1290. transformers/models/trocr/processing_trocr.py +25 -5
  1291. transformers/models/tvp/configuration_tvp.py +2 -5
  1292. transformers/models/tvp/image_processing_tvp.py +52 -50
  1293. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1294. transformers/models/tvp/modeling_tvp.py +27 -27
  1295. transformers/models/tvp/processing_tvp.py +14 -2
  1296. transformers/models/udop/configuration_udop.py +7 -16
  1297. transformers/models/udop/modeling_udop.py +73 -71
  1298. transformers/models/udop/processing_udop.py +26 -7
  1299. transformers/models/udop/tokenization_udop.py +105 -84
  1300. transformers/models/umt5/configuration_umt5.py +7 -8
  1301. transformers/models/umt5/modeling_umt5.py +90 -94
  1302. transformers/models/unispeech/configuration_unispeech.py +2 -4
  1303. transformers/models/unispeech/modeling_unispeech.py +49 -51
  1304. transformers/models/unispeech/modular_unispeech.py +22 -22
  1305. transformers/models/unispeech_sat/configuration_unispeech_sat.py +2 -4
  1306. transformers/models/unispeech_sat/modeling_unispeech_sat.py +65 -69
  1307. transformers/models/unispeech_sat/modular_unispeech_sat.py +23 -23
  1308. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1309. transformers/models/univnet/modeling_univnet.py +8 -8
  1310. transformers/models/upernet/configuration_upernet.py +1 -0
  1311. transformers/models/upernet/modeling_upernet.py +13 -11
  1312. transformers/models/vaultgemma/__init__.py +1 -0
  1313. transformers/models/vaultgemma/configuration_vaultgemma.py +33 -29
  1314. transformers/models/vaultgemma/modeling_vaultgemma.py +41 -39
  1315. transformers/models/vaultgemma/modular_vaultgemma.py +31 -29
  1316. transformers/models/video_llama_3/configuration_video_llama_3.py +0 -4
  1317. transformers/models/video_llama_3/image_processing_video_llama_3.py +42 -43
  1318. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +14 -12
  1319. transformers/models/video_llama_3/modeling_video_llama_3.py +109 -157
  1320. transformers/models/video_llama_3/modular_video_llama_3.py +146 -155
  1321. transformers/models/video_llama_3/processing_video_llama_3.py +39 -5
  1322. transformers/models/video_llama_3/video_processing_video_llama_3.py +23 -42
  1323. transformers/models/video_llava/configuration_video_llava.py +1 -4
  1324. transformers/models/video_llava/image_processing_video_llava.py +38 -35
  1325. transformers/models/video_llava/modeling_video_llava.py +146 -146
  1326. transformers/models/video_llava/processing_video_llava.py +78 -38
  1327. transformers/models/video_llava/video_processing_video_llava.py +1 -0
  1328. transformers/models/videomae/configuration_videomae.py +1 -0
  1329. transformers/models/videomae/image_processing_videomae.py +34 -31
  1330. transformers/models/videomae/modeling_videomae.py +17 -14
  1331. transformers/models/videomae/video_processing_videomae.py +1 -0
  1332. transformers/models/vilt/configuration_vilt.py +4 -6
  1333. transformers/models/vilt/image_processing_vilt.py +30 -29
  1334. transformers/models/vilt/image_processing_vilt_fast.py +16 -15
  1335. transformers/models/vilt/modeling_vilt.py +90 -116
  1336. transformers/models/vilt/processing_vilt.py +14 -2
  1337. transformers/models/vipllava/configuration_vipllava.py +1 -4
  1338. transformers/models/vipllava/modeling_vipllava.py +70 -99
  1339. transformers/models/vipllava/modular_vipllava.py +54 -78
  1340. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +1 -0
  1341. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +27 -28
  1342. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +1 -0
  1343. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +41 -46
  1344. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +16 -2
  1345. transformers/models/visual_bert/configuration_visual_bert.py +2 -6
  1346. transformers/models/visual_bert/modeling_visual_bert.py +92 -98
  1347. transformers/models/vit/configuration_vit.py +1 -0
  1348. transformers/models/vit/image_processing_vit.py +22 -19
  1349. transformers/models/vit/image_processing_vit_fast.py +1 -0
  1350. transformers/models/vit/modeling_vit.py +17 -17
  1351. transformers/models/vit_mae/configuration_vit_mae.py +1 -0
  1352. transformers/models/vit_mae/modeling_vit_mae.py +27 -29
  1353. transformers/models/vit_msn/configuration_vit_msn.py +1 -0
  1354. transformers/models/vit_msn/modeling_vit_msn.py +16 -18
  1355. transformers/models/vitdet/configuration_vitdet.py +1 -0
  1356. transformers/models/vitdet/modeling_vitdet.py +14 -14
  1357. transformers/models/vitmatte/configuration_vitmatte.py +5 -2
  1358. transformers/models/vitmatte/image_processing_vitmatte.py +18 -15
  1359. transformers/models/vitmatte/image_processing_vitmatte_fast.py +18 -16
  1360. transformers/models/vitmatte/modeling_vitmatte.py +11 -14
  1361. transformers/models/vitpose/configuration_vitpose.py +7 -4
  1362. transformers/models/vitpose/image_processing_vitpose.py +25 -24
  1363. transformers/models/vitpose/image_processing_vitpose_fast.py +11 -9
  1364. transformers/models/vitpose/modeling_vitpose.py +14 -14
  1365. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +1 -0
  1366. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +10 -8
  1367. transformers/models/vits/configuration_vits.py +1 -4
  1368. transformers/models/vits/modeling_vits.py +42 -44
  1369. transformers/models/vits/tokenization_vits.py +4 -3
  1370. transformers/models/vivit/configuration_vivit.py +1 -0
  1371. transformers/models/vivit/image_processing_vivit.py +39 -36
  1372. transformers/models/vivit/modeling_vivit.py +8 -6
  1373. transformers/models/vjepa2/__init__.py +1 -0
  1374. transformers/models/vjepa2/configuration_vjepa2.py +1 -0
  1375. transformers/models/vjepa2/modeling_vjepa2.py +32 -31
  1376. transformers/models/vjepa2/video_processing_vjepa2.py +1 -0
  1377. transformers/models/voxtral/__init__.py +1 -0
  1378. transformers/models/voxtral/configuration_voxtral.py +2 -0
  1379. transformers/models/voxtral/modeling_voxtral.py +47 -40
  1380. transformers/models/voxtral/modular_voxtral.py +40 -37
  1381. transformers/models/voxtral/processing_voxtral.py +48 -25
  1382. transformers/models/wav2vec2/configuration_wav2vec2.py +2 -4
  1383. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +10 -7
  1384. transformers/models/wav2vec2/modeling_wav2vec2.py +121 -73
  1385. transformers/models/wav2vec2/processing_wav2vec2.py +35 -6
  1386. transformers/models/wav2vec2/tokenization_wav2vec2.py +332 -20
  1387. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +2 -4
  1388. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +62 -70
  1389. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +48 -57
  1390. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +35 -6
  1391. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +2 -4
  1392. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +77 -90
  1393. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +30 -37
  1394. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +17 -16
  1395. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +55 -36
  1396. transformers/models/wavlm/configuration_wavlm.py +2 -4
  1397. transformers/models/wavlm/modeling_wavlm.py +48 -50
  1398. transformers/models/wavlm/modular_wavlm.py +5 -4
  1399. transformers/models/whisper/configuration_whisper.py +5 -6
  1400. transformers/models/whisper/english_normalizer.py +4 -3
  1401. transformers/models/whisper/feature_extraction_whisper.py +24 -9
  1402. transformers/models/whisper/generation_whisper.py +48 -26
  1403. transformers/models/whisper/modeling_whisper.py +73 -79
  1404. transformers/models/whisper/processing_whisper.py +20 -3
  1405. transformers/models/whisper/tokenization_whisper.py +43 -11
  1406. transformers/models/x_clip/configuration_x_clip.py +2 -4
  1407. transformers/models/x_clip/modeling_x_clip.py +93 -96
  1408. transformers/models/x_clip/processing_x_clip.py +14 -2
  1409. transformers/models/xcodec/configuration_xcodec.py +6 -4
  1410. transformers/models/xcodec/modeling_xcodec.py +17 -20
  1411. transformers/models/xglm/configuration_xglm.py +8 -9
  1412. transformers/models/xglm/modeling_xglm.py +55 -60
  1413. transformers/models/xglm/tokenization_xglm.py +11 -3
  1414. transformers/models/xlm/configuration_xlm.py +8 -10
  1415. transformers/models/xlm/modeling_xlm.py +144 -144
  1416. transformers/models/xlm/tokenization_xlm.py +5 -3
  1417. transformers/models/xlm_roberta/configuration_xlm_roberta.py +3 -11
  1418. transformers/models/xlm_roberta/modeling_xlm_roberta.py +194 -195
  1419. transformers/models/xlm_roberta/modular_xlm_roberta.py +53 -50
  1420. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +18 -8
  1421. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +2 -10
  1422. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +93 -94
  1423. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +70 -67
  1424. transformers/models/xlnet/configuration_xlnet.py +12 -3
  1425. transformers/models/xlnet/modeling_xlnet.py +163 -152
  1426. transformers/models/xlnet/tokenization_xlnet.py +9 -2
  1427. transformers/models/xlstm/configuration_xlstm.py +12 -8
  1428. transformers/models/xlstm/modeling_xlstm.py +65 -62
  1429. transformers/models/xmod/configuration_xmod.py +3 -11
  1430. transformers/models/xmod/modeling_xmod.py +110 -108
  1431. transformers/models/yolos/configuration_yolos.py +1 -0
  1432. transformers/models/yolos/image_processing_yolos.py +62 -60
  1433. transformers/models/yolos/image_processing_yolos_fast.py +45 -42
  1434. transformers/models/yolos/modeling_yolos.py +16 -16
  1435. transformers/models/yolos/modular_yolos.py +19 -17
  1436. transformers/models/yoso/configuration_yoso.py +2 -8
  1437. transformers/models/yoso/modeling_yoso.py +63 -70
  1438. transformers/models/zamba/configuration_zamba.py +8 -5
  1439. transformers/models/zamba/modeling_zamba.py +78 -81
  1440. transformers/models/zamba2/configuration_zamba2.py +50 -44
  1441. transformers/models/zamba2/modeling_zamba2.py +97 -97
  1442. transformers/models/zamba2/modular_zamba2.py +48 -46
  1443. transformers/models/zoedepth/configuration_zoedepth.py +2 -1
  1444. transformers/models/zoedepth/image_processing_zoedepth.py +29 -28
  1445. transformers/models/zoedepth/image_processing_zoedepth_fast.py +24 -21
  1446. transformers/models/zoedepth/modeling_zoedepth.py +18 -26
  1447. transformers/pipelines/__init__.py +114 -57
  1448. transformers/pipelines/any_to_any.py +22 -14
  1449. transformers/pipelines/audio_utils.py +2 -1
  1450. transformers/pipelines/automatic_speech_recognition.py +12 -20
  1451. transformers/pipelines/base.py +27 -15
  1452. transformers/{models/pe_audio/processing_pe_audio.py → pipelines/deprecated/__init__.py} +3 -10
  1453. transformers/pipelines/deprecated/text2text_generation.py +408 -0
  1454. transformers/pipelines/document_question_answering.py +2 -4
  1455. transformers/pipelines/image_text_to_text.py +1 -0
  1456. transformers/pipelines/image_to_text.py +229 -0
  1457. transformers/pipelines/question_answering.py +44 -5
  1458. transformers/pipelines/text_classification.py +14 -1
  1459. transformers/pipelines/text_generation.py +1 -1
  1460. transformers/pipelines/text_to_audio.py +2 -2
  1461. transformers/pipelines/token_classification.py +22 -1
  1462. transformers/pipelines/video_classification.py +9 -1
  1463. transformers/pipelines/zero_shot_audio_classification.py +1 -0
  1464. transformers/pipelines/zero_shot_classification.py +6 -0
  1465. transformers/pipelines/zero_shot_image_classification.py +7 -0
  1466. transformers/processing_utils.py +145 -230
  1467. transformers/quantizers/auto.py +4 -2
  1468. transformers/quantizers/base.py +173 -53
  1469. transformers/quantizers/quantizer_aqlm.py +23 -2
  1470. transformers/quantizers/quantizer_auto_round.py +12 -2
  1471. transformers/quantizers/quantizer_awq.py +89 -20
  1472. transformers/quantizers/quantizer_bitnet.py +14 -4
  1473. transformers/quantizers/quantizer_bnb_4bit.py +155 -18
  1474. transformers/quantizers/quantizer_bnb_8bit.py +110 -24
  1475. transformers/quantizers/quantizer_compressed_tensors.py +9 -2
  1476. transformers/quantizers/quantizer_eetq.py +74 -16
  1477. transformers/quantizers/quantizer_fbgemm_fp8.py +138 -38
  1478. transformers/quantizers/quantizer_finegrained_fp8.py +113 -26
  1479. transformers/quantizers/quantizer_fp_quant.py +82 -52
  1480. transformers/quantizers/quantizer_gptq.py +28 -8
  1481. transformers/quantizers/quantizer_higgs.py +60 -42
  1482. transformers/quantizers/quantizer_hqq.py +153 -144
  1483. transformers/quantizers/quantizer_mxfp4.py +194 -14
  1484. transformers/quantizers/quantizer_quanto.py +79 -35
  1485. transformers/quantizers/quantizer_quark.py +18 -36
  1486. transformers/quantizers/quantizer_spqr.py +12 -4
  1487. transformers/quantizers/quantizer_torchao.py +325 -50
  1488. transformers/quantizers/quantizer_vptq.py +27 -4
  1489. transformers/quantizers/quantizers_utils.py +0 -20
  1490. transformers/safetensors_conversion.py +3 -9
  1491. transformers/testing_utils.py +82 -326
  1492. transformers/tokenization_mistral_common.py +903 -568
  1493. transformers/tokenization_utils_base.py +340 -220
  1494. transformers/tokenization_utils_sentencepiece.py +6 -5
  1495. transformers/tokenization_utils_tokenizers.py +113 -226
  1496. transformers/trainer.py +53 -60
  1497. transformers/trainer_callback.py +0 -8
  1498. transformers/trainer_seq2seq.py +1 -5
  1499. transformers/trainer_utils.py +1 -1
  1500. transformers/training_args.py +41 -77
  1501. transformers/utils/__init__.py +4 -8
  1502. transformers/utils/attention_visualizer.py +5 -5
  1503. transformers/utils/auto_docstring.py +37 -599
  1504. transformers/utils/doc.py +36 -4
  1505. transformers/utils/dummy_pt_objects.py +42 -0
  1506. transformers/utils/generic.py +28 -111
  1507. transformers/utils/hub.py +15 -5
  1508. transformers/utils/import_utils.py +32 -165
  1509. transformers/utils/kernel_config.py +19 -74
  1510. transformers/utils/loading_report.py +15 -25
  1511. transformers/utils/quantization_config.py +241 -72
  1512. transformers/video_processing_utils.py +39 -41
  1513. transformers/video_utils.py +22 -18
  1514. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/METADATA +236 -284
  1515. transformers-5.0.0rc0.dist-info/RECORD +1987 -0
  1516. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/WHEEL +1 -1
  1517. transformers/integrations/moe.py +0 -360
  1518. transformers/integrations/quark.py +0 -53
  1519. transformers/loss/loss_lw_detr.py +0 -356
  1520. transformers/models/ernie4_5_vl_moe/__init__.py +0 -31
  1521. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +0 -340
  1522. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +0 -455
  1523. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +0 -231
  1524. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +0 -1936
  1525. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +0 -1925
  1526. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +0 -249
  1527. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +0 -593
  1528. transformers/models/fast_vlm/__init__.py +0 -27
  1529. transformers/models/fast_vlm/configuration_fast_vlm.py +0 -137
  1530. transformers/models/fast_vlm/modeling_fast_vlm.py +0 -432
  1531. transformers/models/fast_vlm/modular_fast_vlm.py +0 -373
  1532. transformers/models/glm4_moe_lite/__init__.py +0 -28
  1533. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +0 -233
  1534. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +0 -740
  1535. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +0 -302
  1536. transformers/models/glm_image/__init__.py +0 -31
  1537. transformers/models/glm_image/configuration_glm_image.py +0 -351
  1538. transformers/models/glm_image/image_processing_glm_image.py +0 -503
  1539. transformers/models/glm_image/image_processing_glm_image_fast.py +0 -294
  1540. transformers/models/glm_image/modeling_glm_image.py +0 -1642
  1541. transformers/models/glm_image/modular_glm_image.py +0 -1531
  1542. transformers/models/glm_image/processing_glm_image.py +0 -217
  1543. transformers/models/glmasr/__init__.py +0 -29
  1544. transformers/models/glmasr/configuration_glmasr.py +0 -196
  1545. transformers/models/glmasr/modeling_glmasr.py +0 -517
  1546. transformers/models/glmasr/modular_glmasr.py +0 -443
  1547. transformers/models/glmasr/processing_glmasr.py +0 -331
  1548. transformers/models/jais2/__init__.py +0 -27
  1549. transformers/models/jais2/configuration_jais2.py +0 -148
  1550. transformers/models/jais2/modeling_jais2.py +0 -484
  1551. transformers/models/jais2/modular_jais2.py +0 -194
  1552. transformers/models/lasr/__init__.py +0 -29
  1553. transformers/models/lasr/configuration_lasr.py +0 -244
  1554. transformers/models/lasr/feature_extraction_lasr.py +0 -275
  1555. transformers/models/lasr/modeling_lasr.py +0 -727
  1556. transformers/models/lasr/modular_lasr.py +0 -574
  1557. transformers/models/lasr/processing_lasr.py +0 -100
  1558. transformers/models/lasr/tokenization_lasr.py +0 -184
  1559. transformers/models/lighton_ocr/__init__.py +0 -28
  1560. transformers/models/lighton_ocr/configuration_lighton_ocr.py +0 -128
  1561. transformers/models/lighton_ocr/modeling_lighton_ocr.py +0 -463
  1562. transformers/models/lighton_ocr/modular_lighton_ocr.py +0 -404
  1563. transformers/models/lighton_ocr/processing_lighton_ocr.py +0 -229
  1564. transformers/models/lw_detr/__init__.py +0 -27
  1565. transformers/models/lw_detr/configuration_lw_detr.py +0 -374
  1566. transformers/models/lw_detr/modeling_lw_detr.py +0 -1702
  1567. transformers/models/lw_detr/modular_lw_detr.py +0 -1615
  1568. transformers/models/minimax_m2/__init__.py +0 -28
  1569. transformers/models/minimax_m2/configuration_minimax_m2.py +0 -188
  1570. transformers/models/minimax_m2/modeling_minimax_m2.py +0 -704
  1571. transformers/models/minimax_m2/modular_minimax_m2.py +0 -346
  1572. transformers/models/paddleocr_vl/__init__.py +0 -31
  1573. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +0 -335
  1574. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +0 -503
  1575. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +0 -209
  1576. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +0 -1683
  1577. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +0 -1380
  1578. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +0 -133
  1579. transformers/models/pe_audio/__init__.py +0 -29
  1580. transformers/models/pe_audio/configuration_pe_audio.py +0 -204
  1581. transformers/models/pe_audio/feature_extraction_pe_audio.py +0 -160
  1582. transformers/models/pe_audio/modeling_pe_audio.py +0 -819
  1583. transformers/models/pe_audio/modular_pe_audio.py +0 -298
  1584. transformers/models/pe_audio_video/__init__.py +0 -28
  1585. transformers/models/pe_audio_video/configuration_pe_audio_video.py +0 -223
  1586. transformers/models/pe_audio_video/modeling_pe_audio_video.py +0 -971
  1587. transformers/models/pe_audio_video/modular_pe_audio_video.py +0 -763
  1588. transformers/models/pe_video/__init__.py +0 -29
  1589. transformers/models/pe_video/configuration_pe_video.py +0 -209
  1590. transformers/models/pe_video/modeling_pe_video.py +0 -647
  1591. transformers/models/pe_video/modular_pe_video.py +0 -231
  1592. transformers/models/pe_video/processing_pe_video.py +0 -10
  1593. transformers/models/pe_video/video_processing_pe_video.py +0 -64
  1594. transformers/models/pixio/__init__.py +0 -29
  1595. transformers/models/pixio/configuration_pixio.py +0 -150
  1596. transformers/models/pixio/modeling_pixio.py +0 -507
  1597. transformers/models/pixio/modular_pixio.py +0 -403
  1598. transformers/models/solar_open/__init__.py +0 -27
  1599. transformers/models/solar_open/configuration_solar_open.py +0 -184
  1600. transformers/models/solar_open/modeling_solar_open.py +0 -642
  1601. transformers/models/solar_open/modular_solar_open.py +0 -224
  1602. transformers/trainer_jit_checkpoint.py +0 -125
  1603. transformers-5.0.0.dist-info/RECORD +0 -2068
  1604. {transformers-5.0.0.dist-info/licenses → transformers-5.0.0rc0.dist-info}/LICENSE +0 -0
  1605. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/entry_points.txt +0 -0
  1606. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/top_level.txt +0 -0
@@ -4,6 +4,7 @@
4
4
  # the file from the modular. If any change should be done, please apply the change to the
5
5
  # modular_qwen3_omni_moe.py file directly. One of our CI enforces this.
6
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # coding=utf-8
7
8
  # Copyright 2025 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
8
9
  #
9
10
  #
@@ -22,7 +23,7 @@
22
23
  import math
23
24
  from collections.abc import Callable
24
25
  from dataclasses import dataclass
25
- from typing import Optional
26
+ from typing import Optional, Union
26
27
 
27
28
  import numpy as np
28
29
  import torch
@@ -34,18 +35,13 @@ from ... import initialization as init
34
35
  from ...activations import ACT2FN
35
36
  from ...cache_utils import Cache, DynamicCache
36
37
  from ...generation import GenerationMixin
37
- from ...integrations import (
38
- use_experts_implementation,
39
- use_kernel_forward_from_hub,
40
- use_kernel_func_from_hub,
41
- use_kernelized_func,
42
- )
38
+ from ...integrations import use_kernel_forward_from_hub, use_kernel_func_from_hub
43
39
  from ...masking_utils import create_causal_mask, create_sliding_window_causal_mask
44
40
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
45
41
  from ...modeling_layers import GradientCheckpointingLayer
46
42
  from ...modeling_outputs import (
43
+ BaseModelOutput,
47
44
  BaseModelOutputWithPast,
48
- BaseModelOutputWithPooling,
49
45
  CausalLMOutputWithPast,
50
46
  MoeCausalLMOutputWithPast,
51
47
  MoeModelOutputWithPast,
@@ -53,14 +49,8 @@ from ...modeling_outputs import (
53
49
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
54
50
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
55
51
  from ...processing_utils import Unpack
56
- from ...utils import auto_docstring, can_return_tuple, is_grouped_mm_available, torch_compilable_check
57
- from ...utils.generic import (
58
- OutputRecorder,
59
- TransformersKwargs,
60
- check_model_inputs,
61
- is_flash_attention_requested,
62
- maybe_autocast,
63
- )
52
+ from ...utils import auto_docstring, can_return_tuple
53
+ from ...utils.generic import OutputRecorder, TransformersKwargs, check_model_inputs
64
54
  from .configuration_qwen3_omni_moe import (
65
55
  Qwen3OmniMoeAudioEncoderConfig,
66
56
  Qwen3OmniMoeCode2WavConfig,
@@ -74,38 +64,6 @@ from .configuration_qwen3_omni_moe import (
74
64
  )
75
65
 
76
66
 
77
- @dataclass
78
- @auto_docstring
79
- class BaseModelOutputWithDeepstackFeatures(BaseModelOutputWithPooling):
80
- r"""
81
- deepstack_features (`List[torch.FloatTensor]`, *optional*):
82
- List of hidden-states (feature maps) from deepstack layers.
83
- """
84
-
85
- deepstack_features: list[torch.FloatTensor] | None = None
86
-
87
-
88
- class SinusoidsPositionEmbedding(nn.Module):
89
- def __init__(self, length, channels, max_timescale=10000):
90
- super().__init__()
91
- self.length = length
92
- self.channels = channels
93
- self.max_timescale = max_timescale
94
- if channels % 2 != 0:
95
- raise ValueError("SinusoidsPositionEmbedding needs even channels input")
96
- log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
97
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
98
- scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
99
- self.register_buffer(
100
- "positional_embedding",
101
- torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
102
- persistent=False,
103
- )
104
-
105
- def forward(self, seqlen: int):
106
- return self.positional_embedding[:seqlen, :]
107
-
108
-
109
67
  @auto_docstring
110
68
  class Qwen3OmniMoePreTrainedModel(PreTrainedModel):
111
69
  config: Qwen3OmniMoeConfig
@@ -127,19 +85,6 @@ class Qwen3OmniMoePreTrainedModel(PreTrainedModel):
127
85
  init.normal_(module.experts.gate_up_proj, mean=0.0, std=std)
128
86
  init.normal_(module.experts.down_proj, mean=0.0, std=std)
129
87
  init.normal_(module.gate.weight, mean=0.0, std=std)
130
- elif isinstance(module, Qwen3OmniMoeCode2Wav):
131
- init.copy_(
132
- module.code_offset,
133
- torch.arange(module.config.num_quantizers).view(1, -1, 1) * module.config.codebook_size,
134
- )
135
- elif isinstance(module, SinusoidsPositionEmbedding):
136
- log_timescale_increment = np.log(module.max_timescale) / (module.channels // 2 - 1)
137
- inv_timescales = torch.exp(-log_timescale_increment * torch.arange(module.channels // 2).float())
138
- scaled_time = torch.arange(module.length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
139
- init.copy_(module.positional_embedding, torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1))
140
- elif isinstance(module, Qwen3OmniMoeVisionRotaryEmbedding):
141
- inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
142
- init.copy_(module.inv_freq, inv_freq)
143
88
 
144
89
 
145
90
  def _get_feat_extract_output_lengths(input_lengths):
@@ -270,13 +215,13 @@ class Qwen3OmniMoePreTrainedModelForConditionalGeneration(Qwen3OmniMoePreTrained
270
215
 
271
216
  def get_rope_index(
272
217
  self,
273
- input_ids: torch.LongTensor | None = None,
274
- image_grid_thw: torch.LongTensor | None = None,
275
- video_grid_thw: torch.LongTensor | None = None,
276
- attention_mask: torch.Tensor | None = None,
218
+ input_ids: Optional[torch.LongTensor] = None,
219
+ image_grid_thw: Optional[torch.LongTensor] = None,
220
+ video_grid_thw: Optional[torch.LongTensor] = None,
221
+ attention_mask: Optional[torch.Tensor] = None,
277
222
  use_audio_in_video: bool = False,
278
- audio_seqlens: torch.LongTensor | None = None,
279
- second_per_grids: torch.Tensor | None = None,
223
+ audio_seqlens: Optional[torch.LongTensor] = None,
224
+ second_per_grids: Optional[torch.Tensor] = None,
280
225
  ) -> tuple[torch.Tensor, torch.Tensor]:
281
226
  """
282
227
  Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
@@ -528,7 +473,7 @@ def eager_attention_forward(
528
473
  query: torch.Tensor,
529
474
  key: torch.Tensor,
530
475
  value: torch.Tensor,
531
- attention_mask: torch.Tensor | None,
476
+ attention_mask: Optional[torch.Tensor],
532
477
  scaling: float,
533
478
  dropout: float = 0.0,
534
479
  **kwargs,
@@ -578,10 +523,10 @@ class Qwen3OmniMoeAudioAttention(nn.Module):
578
523
  def forward(
579
524
  self,
580
525
  hidden_states: torch.Tensor,
581
- cu_seqlens: torch.Tensor | None = None,
582
- attention_mask: torch.Tensor | None = None,
526
+ cu_seqlens: Optional[torch.Tensor] = None,
527
+ attention_mask: Optional[torch.Tensor] = None,
583
528
  **kwargs,
584
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
529
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
585
530
  """Input shape: Batch x Time x Channel"""
586
531
 
587
532
  seq_length, _ = hidden_states.size()
@@ -638,7 +583,7 @@ class Qwen3OmniMoeAudioEncoderLayer(GradientCheckpointingLayer):
638
583
  self,
639
584
  hidden_states: torch.Tensor,
640
585
  cu_seqlens: torch.Tensor,
641
- attention_mask: torch.Tensor | None = None,
586
+ attention_mask: Optional[torch.Tensor] = None,
642
587
  **kwargs,
643
588
  ) -> torch.Tensor:
644
589
  """
@@ -675,6 +620,24 @@ class Qwen3OmniMoeAudioEncoderLayer(GradientCheckpointingLayer):
675
620
  return outputs
676
621
 
677
622
 
623
+ class SinusoidsPositionEmbedding(nn.Module):
624
+ def __init__(self, length, channels, max_timescale=10000):
625
+ super().__init__()
626
+ if channels % 2 != 0:
627
+ raise ValueError("SinusoidsPositionEmbedding needs even channels input")
628
+ log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
629
+ inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2).float())
630
+ scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
631
+ self.register_buffer(
632
+ "positional_embedding",
633
+ torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1),
634
+ persistent=False,
635
+ )
636
+
637
+ def forward(self, seqlen: int):
638
+ return self.positional_embedding[:seqlen, :]
639
+
640
+
678
641
  @auto_docstring(
679
642
  custom_intro="""
680
643
  Transformer encoder consisting of *config.encoder_layers* self attention layers. Each layer is a
@@ -687,10 +650,6 @@ class Qwen3OmniMoeAudioEncoder(Qwen3OmniMoePreTrainedModel):
687
650
  input_modalities = "audio"
688
651
  _no_split_modules = ["Qwen3OmniMoeAudioEncoderLayer"]
689
652
  _supports_sdpa = True
690
- _can_record_outputs = {
691
- "hidden_states": Qwen3OmniMoeAudioEncoderLayer,
692
- "attentions": Qwen3OmniMoeAudioAttention,
693
- }
694
653
 
695
654
  def __init__(self, config: Qwen3OmniMoeAudioEncoderConfig):
696
655
  super().__init__(config)
@@ -737,7 +696,7 @@ class Qwen3OmniMoeAudioEncoder(Qwen3OmniMoePreTrainedModel):
737
696
  # NOTE: the created attention masl only approximates the ragged FA2 attention by
738
697
  # allowing bidirectional attention within `cu_seqlens` blocks, and not attending between
739
698
  # blocks. Though it will not be a 100% match for FA2's `varlen` path
740
- if is_flash_attention_requested(self.config):
699
+ if self.config._attn_implementation == "flash_attention_2":
741
700
  return None
742
701
 
743
702
  seq_length = inputs_tensor.shape[0]
@@ -751,14 +710,12 @@ class Qwen3OmniMoeAudioEncoder(Qwen3OmniMoePreTrainedModel):
751
710
  attention_mask[..., cu_seqlens[i - 1] : cu_seqlens[i], cu_seqlens[i - 1] : cu_seqlens[i]] = 0
752
711
  return attention_mask
753
712
 
754
- @check_model_inputs(tie_last_hidden_states=False)
755
713
  @auto_docstring
756
714
  def forward(
757
715
  self,
758
716
  input_features,
759
717
  feature_lens=None,
760
718
  aftercnn_lens=None,
761
- **kwargs,
762
719
  ):
763
720
  r"""
764
721
  feature_lens (`torch.LongTensor` of shape `(batch_size,)`):
@@ -769,7 +726,11 @@ class Qwen3OmniMoeAudioEncoder(Qwen3OmniMoePreTrainedModel):
769
726
  aftercnn_lens = _get_feat_extract_output_lengths(feature_lens)
770
727
  chunk_num = torch.ceil(feature_lens / (self.n_window * 2)).long()
771
728
 
772
- chunk_lengths = torch.full((chunk_num.sum(),), self.n_window * 2, dtype=torch.long, device=feature_lens.device)
729
+ chunk_lengths = torch.tensor(
730
+ [self.n_window * 2] * chunk_num.sum(),
731
+ dtype=torch.long,
732
+ device=feature_lens.device,
733
+ )
773
734
  tail_chunk_index = F.pad(chunk_num, (1, 0), value=-1).cumsum(0)[1:]
774
735
  chunk_lengths[tail_chunk_index] = feature_lens % (self.n_window * 2)
775
736
  chunk_lengths[chunk_lengths == 0] = self.n_window * 2
@@ -821,7 +782,7 @@ class Qwen3OmniMoeAudioEncoder(Qwen3OmniMoePreTrainedModel):
821
782
  hidden_states = self.proj1(hidden_states)
822
783
  hidden_states = self.act(hidden_states)
823
784
  hidden_states = self.proj2(hidden_states)
824
- return BaseModelOutputWithPooling(last_hidden_state=hidden_states)
785
+ return BaseModelOutput(last_hidden_state=hidden_states)
825
786
 
826
787
  def padded_and_mask_function(self, tensor_list, tensor_len, padding_value=0, padding_side="right"):
827
788
  """
@@ -910,8 +871,8 @@ class Qwen3OmniMoeVisionAttention(nn.Module):
910
871
  self,
911
872
  hidden_states: torch.Tensor,
912
873
  cu_seqlens: torch.Tensor,
913
- rotary_pos_emb: torch.Tensor | None = None,
914
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
874
+ rotary_pos_emb: Optional[torch.Tensor] = None,
875
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
915
876
  **kwargs,
916
877
  ) -> torch.Tensor:
917
878
  seq_length = hidden_states.shape[0]
@@ -929,8 +890,8 @@ class Qwen3OmniMoeVisionAttention(nn.Module):
929
890
  if self.config._attn_implementation != "eager":
930
891
  attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
931
892
 
932
- if is_flash_attention_requested(self.config):
933
- # Flash Attention: Use cu_seqlens for variable length attention
893
+ if self.config._attn_implementation == "flash_attention_2":
894
+ # Flash Attention 2: Use cu_seqlens for variable length attention
934
895
  max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
935
896
  attn_output, _ = attention_interface(
936
897
  self,
@@ -998,13 +959,44 @@ class Qwen3OmniMoeVisionPatchMerger(nn.Module):
998
959
  return hidden
999
960
 
1000
961
 
962
+ class Qwen3OmniMoeVisionMLP(nn.Module):
963
+ def __init__(self, config):
964
+ super().__init__()
965
+ self.hidden_size = config.hidden_size
966
+ self.intermediate_size = config.intermediate_size
967
+ self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
968
+ self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
969
+ self.act_fn = ACT2FN[config.hidden_act]
970
+
971
+ def forward(self, hidden_state):
972
+ return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
973
+
974
+
975
+ class Qwen3OmniMoeVisionPatchEmbed(nn.Module):
976
+ def __init__(self, config) -> None:
977
+ super().__init__()
978
+ self.patch_size = config.patch_size
979
+ self.temporal_patch_size = config.temporal_patch_size
980
+ self.in_channels = config.in_channels
981
+ self.embed_dim = config.hidden_size
982
+
983
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
984
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
985
+
986
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
987
+ target_dtype = self.proj.weight.dtype
988
+ hidden_states = hidden_states.view(
989
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
990
+ )
991
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
992
+ return hidden_states
993
+
994
+
1001
995
  class Qwen3OmniMoeVisionRotaryEmbedding(nn.Module):
1002
996
  inv_freq: torch.Tensor # fix linting for `register_buffer`
1003
997
 
1004
998
  def __init__(self, dim: int, theta: float = 10000.0) -> None:
1005
999
  super().__init__()
1006
- self.dim = dim
1007
- self.theta = theta
1008
1000
  inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
1009
1001
  self.register_buffer("inv_freq", inv_freq, persistent=False)
1010
1002
 
@@ -1014,38 +1006,6 @@ class Qwen3OmniMoeVisionRotaryEmbedding(nn.Module):
1014
1006
  return freqs
1015
1007
 
1016
1008
 
1017
- class Qwen3OmniMoeTextTopKRouter(nn.Module):
1018
- def __init__(self, config):
1019
- super().__init__()
1020
- self.top_k = config.num_experts_per_tok
1021
- self.num_experts = config.num_experts
1022
- self.hidden_dim = config.hidden_size
1023
- self.weight = nn.Parameter(torch.zeros(self.num_experts, self.hidden_dim))
1024
-
1025
- def forward(self, hidden_states):
1026
- hidden_states = hidden_states.reshape(-1, self.hidden_dim)
1027
- router_logits = F.linear(hidden_states, self.weight) # (seq_len, num_experts)
1028
- router_logits = torch.nn.functional.softmax(router_logits, dtype=torch.float, dim=-1)
1029
- router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
1030
- router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
1031
- router_top_value = router_top_value.to(router_logits.dtype)
1032
- router_scores = router_top_value
1033
- return router_logits, router_scores, router_indices
1034
-
1035
-
1036
- class Qwen3OmniMoeVisionMLP(nn.Module):
1037
- def __init__(self, config):
1038
- super().__init__()
1039
- self.hidden_size = config.hidden_size
1040
- self.intermediate_size = config.intermediate_size
1041
- self.linear_fc1 = nn.Linear(self.hidden_size, self.intermediate_size, bias=True)
1042
- self.linear_fc2 = nn.Linear(self.intermediate_size, self.hidden_size, bias=True)
1043
- self.act_fn = ACT2FN[config.hidden_act]
1044
-
1045
- def forward(self, hidden_state):
1046
- return self.linear_fc2(self.act_fn(self.linear_fc1(hidden_state)))
1047
-
1048
-
1049
1009
  class Qwen3OmniMoeVisionBlock(GradientCheckpointingLayer):
1050
1010
  def __init__(self, config, attn_implementation: str = "sdpa") -> None:
1051
1011
  super().__init__()
@@ -1058,8 +1018,8 @@ class Qwen3OmniMoeVisionBlock(GradientCheckpointingLayer):
1058
1018
  self,
1059
1019
  hidden_states: torch.Tensor,
1060
1020
  cu_seqlens: torch.Tensor,
1061
- rotary_pos_emb: torch.Tensor | None = None,
1062
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
1021
+ rotary_pos_emb: Optional[torch.Tensor] = None,
1022
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
1063
1023
  **kwargs,
1064
1024
  ) -> torch.Tensor:
1065
1025
  hidden_states = hidden_states + self.attn(
@@ -1073,34 +1033,9 @@ class Qwen3OmniMoeVisionBlock(GradientCheckpointingLayer):
1073
1033
  return hidden_states
1074
1034
 
1075
1035
 
1076
- class Qwen3OmniMoeVisionPatchEmbed(nn.Module):
1077
- def __init__(self, config) -> None:
1078
- super().__init__()
1079
- self.patch_size = config.patch_size
1080
- self.temporal_patch_size = config.temporal_patch_size
1081
- self.in_channels = config.in_channels
1082
- self.embed_dim = config.hidden_size
1083
-
1084
- kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
1085
- self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size, bias=True)
1086
-
1087
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1088
- target_dtype = self.proj.weight.dtype
1089
- hidden_states = hidden_states.view(
1090
- -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
1091
- )
1092
- hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
1093
- return hidden_states
1094
-
1095
-
1096
1036
  class Qwen3OmniMoeVisionEncoder(Qwen3OmniMoePreTrainedModel):
1097
1037
  config: Qwen3OmniMoeVisionEncoderConfig
1098
1038
  _no_split_modules = ["Qwen3OmniMoeVisionBlock"]
1099
- _can_record_outputs = {
1100
- "router_logits": OutputRecorder(Qwen3OmniMoeTextTopKRouter, layer_name="mlp.gate", index=0),
1101
- "hidden_states": Qwen3OmniMoeVisionBlock,
1102
- "attentions": Qwen3OmniMoeVisionAttention,
1103
- }
1104
1039
 
1105
1040
  def __init__(self, config, *inputs, **kwargs) -> None:
1106
1041
  super().__init__(config, *inputs, **kwargs)
@@ -1137,8 +1072,6 @@ class Qwen3OmniMoeVisionEncoder(Qwen3OmniMoePreTrainedModel):
1137
1072
 
1138
1073
  self.gradient_checkpointing = False
1139
1074
 
1140
- self.post_init()
1141
-
1142
1075
  def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor:
1143
1076
  merge_size = self.spatial_merge_size
1144
1077
 
@@ -1238,10 +1171,7 @@ class Qwen3OmniMoeVisionEncoder(Qwen3OmniMoePreTrainedModel):
1238
1171
  patch_pos_embeds = torch.cat(patch_pos_embeds_permute)
1239
1172
  return patch_pos_embeds
1240
1173
 
1241
- @check_model_inputs
1242
- def forward(
1243
- self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
1244
- ) -> tuple | BaseModelOutputWithDeepstackFeatures:
1174
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
1245
1175
  """
1246
1176
  Args:
1247
1177
  hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
@@ -1289,13 +1219,9 @@ class Qwen3OmniMoeVisionEncoder(Qwen3OmniMoePreTrainedModel):
1289
1219
  )
1290
1220
  deepstack_feature_lists.append(deepstack_feature)
1291
1221
 
1292
- merged_hidden_states = self.merger(hidden_states)
1222
+ hidden_states = self.merger(hidden_states)
1293
1223
 
1294
- return BaseModelOutputWithDeepstackFeatures(
1295
- last_hidden_state=hidden_states,
1296
- pooler_output=merged_hidden_states,
1297
- deepstack_features=deepstack_feature_lists,
1298
- )
1224
+ return hidden_states, deepstack_feature_lists
1299
1225
 
1300
1226
  @property
1301
1227
  def deepstack_merger_list(self):
@@ -1319,15 +1245,15 @@ class Qwen3OmniMoeThinkerTextRotaryEmbedding(nn.Module):
1319
1245
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
1320
1246
 
1321
1247
  self.register_buffer("inv_freq", inv_freq, persistent=False)
1322
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
1248
+ self.original_inv_freq = inv_freq
1323
1249
 
1324
1250
  self.mrope_section = config.rope_parameters.get("mrope_section", [24, 20, 20])
1325
1251
 
1326
1252
  @staticmethod
1327
1253
  def compute_default_rope_parameters(
1328
- config: Qwen3OmniMoeTextConfig | None = None,
1254
+ config: Optional[Qwen3OmniMoeTextConfig] = None,
1329
1255
  device: Optional["torch.device"] = None,
1330
- seq_len: int | None = None,
1256
+ seq_len: Optional[int] = None,
1331
1257
  ) -> tuple["torch.Tensor", float]:
1332
1258
  """
1333
1259
  Computes the inverse frequencies according to the original RoPE implementation
@@ -1364,7 +1290,7 @@ class Qwen3OmniMoeThinkerTextRotaryEmbedding(nn.Module):
1364
1290
  position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
1365
1291
 
1366
1292
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
1367
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
1293
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
1368
1294
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
1369
1295
  freqs = self.apply_interleaved_mrope(freqs, self.mrope_section)
1370
1296
  emb = torch.cat((freqs, freqs), dim=-1)
@@ -1391,7 +1317,6 @@ class Qwen3OmniMoeThinkerTextRotaryEmbedding(nn.Module):
1391
1317
  return freqs_t
1392
1318
 
1393
1319
 
1394
- @use_experts_implementation
1395
1320
  class Qwen3OmniMoeThinkerTextExperts(nn.Module):
1396
1321
  """
1397
1322
  ModuleList of experts.
@@ -1490,7 +1415,7 @@ class Qwen3OmniMoeThinkerTextRMSNorm(nn.Module):
1490
1415
 
1491
1416
 
1492
1417
  @use_kernel_func_from_hub("rotary_pos_emb")
1493
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
1418
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
1494
1419
  """Applies Rotary Position Embedding to the query and key tensors.
1495
1420
 
1496
1421
  Args:
@@ -1498,6 +1423,8 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
1498
1423
  k (`torch.Tensor`): The key tensor.
1499
1424
  cos (`torch.Tensor`): The cosine part of the rotary embedding.
1500
1425
  sin (`torch.Tensor`): The sine part of the rotary embedding.
1426
+ position_ids (`torch.Tensor`, *optional*):
1427
+ Deprecated and unused.
1501
1428
  unsqueeze_dim (`int`, *optional*, defaults to 1):
1502
1429
  The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
1503
1430
  sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
@@ -1515,7 +1442,6 @@ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
1515
1442
  return q_embed, k_embed
1516
1443
 
1517
1444
 
1518
- @use_kernelized_func(apply_rotary_pos_emb)
1519
1445
  class Qwen3OmniMoeThinkerTextAttention(nn.Module):
1520
1446
  """Multi-headed attention from 'Attention Is All You Need' paper"""
1521
1447
 
@@ -1541,6 +1467,7 @@ class Qwen3OmniMoeThinkerTextAttention(nn.Module):
1541
1467
  self.o_proj = nn.Linear(
1542
1468
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
1543
1469
  )
1470
+ self.rotary_fn = apply_rotary_pos_emb
1544
1471
  self.q_norm = Qwen3OmniMoeThinkerTextRMSNorm(
1545
1472
  self.head_dim, eps=config.rms_norm_eps
1546
1473
  ) # unlike olmo, only on the head dim!
@@ -1553,11 +1480,11 @@ class Qwen3OmniMoeThinkerTextAttention(nn.Module):
1553
1480
  self,
1554
1481
  hidden_states: torch.Tensor,
1555
1482
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
1556
- attention_mask: torch.Tensor | None,
1557
- past_key_values: Cache | None = None,
1558
- cache_position: torch.LongTensor | None = None,
1483
+ attention_mask: Optional[torch.Tensor],
1484
+ past_key_values: Optional[Cache] = None,
1485
+ cache_position: Optional[torch.LongTensor] = None,
1559
1486
  **kwargs: Unpack[FlashAttentionKwargs],
1560
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
1487
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
1561
1488
  input_shape = hidden_states.shape[:-1]
1562
1489
  hidden_shape = (*input_shape, -1, self.head_dim)
1563
1490
 
@@ -1627,12 +1554,12 @@ class Qwen3OmniMoeThinkerTextDecoderLayer(GradientCheckpointingLayer):
1627
1554
  def forward(
1628
1555
  self,
1629
1556
  hidden_states: torch.Tensor,
1630
- attention_mask: torch.Tensor | None = None,
1631
- position_ids: torch.LongTensor | None = None,
1632
- past_key_values: Cache | None = None,
1633
- use_cache: bool | None = False,
1634
- cache_position: torch.LongTensor | None = None,
1635
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
1557
+ attention_mask: Optional[torch.Tensor] = None,
1558
+ position_ids: Optional[torch.LongTensor] = None,
1559
+ past_key_values: Optional[Cache] = None,
1560
+ use_cache: Optional[bool] = False,
1561
+ cache_position: Optional[torch.LongTensor] = None,
1562
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
1636
1563
  **kwargs: Unpack[TransformersKwargs],
1637
1564
  ) -> torch.Tensor:
1638
1565
  residual = hidden_states
@@ -1668,9 +1595,7 @@ class Qwen3OmniMoeThinkerTextPreTrainedModel(PreTrainedModel):
1668
1595
  _supports_flash_attn = True
1669
1596
  _supports_sdpa = True
1670
1597
  _supports_flex_attn = True
1671
- _can_compile_fullgraph = (
1672
- is_grouped_mm_available()
1673
- ) # https://huggingface.co/docs/transformers/experts_interface#torchcompile
1598
+ _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported)
1674
1599
  _supports_attention_backend = True
1675
1600
  _can_record_outputs = {
1676
1601
  "router_logits": OutputRecorder(Qwen3OmniMoeThinkerTextTopKRouter, layer_name="mlp.gate", index=0),
@@ -1747,18 +1672,18 @@ class Qwen3OmniMoeThinkerTextModel(Qwen3OmniMoePreTrainedModel):
1747
1672
  @auto_docstring
1748
1673
  def forward(
1749
1674
  self,
1750
- input_ids: torch.LongTensor | None = None,
1751
- attention_mask: torch.Tensor | None = None,
1752
- position_ids: torch.LongTensor | None = None,
1753
- past_key_values: Cache | None = None,
1754
- inputs_embeds: torch.FloatTensor | None = None,
1755
- use_cache: bool | None = None,
1756
- cache_position: torch.LongTensor | None = None,
1675
+ input_ids: Optional[torch.LongTensor] = None,
1676
+ attention_mask: Optional[torch.Tensor] = None,
1677
+ position_ids: Optional[torch.LongTensor] = None,
1678
+ past_key_values: Optional[Cache] = None,
1679
+ inputs_embeds: Optional[torch.FloatTensor] = None,
1680
+ use_cache: Optional[bool] = None,
1681
+ cache_position: Optional[torch.LongTensor] = None,
1757
1682
  # args for deepstack
1758
- visual_pos_masks: torch.Tensor | None = None,
1759
- deepstack_visual_embeds: list[torch.Tensor] | None = None,
1683
+ visual_pos_masks: Optional[torch.Tensor] = None,
1684
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
1760
1685
  **kwargs: Unpack[FlashAttentionKwargs],
1761
- ) -> tuple | BaseModelOutputWithPast:
1686
+ ) -> Union[tuple, BaseModelOutputWithPast]:
1762
1687
  r"""
1763
1688
  visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
1764
1689
  The mask of the visual positions.
@@ -1856,15 +1781,15 @@ class Qwen3OmniMoeThinkerCausalLMOutputWithPast(MoeCausalLMOutputWithPast):
1856
1781
  The rope index difference between sequence length and multimodal rope.
1857
1782
  """
1858
1783
 
1859
- rope_deltas: torch.LongTensor | None = None
1784
+ rope_deltas: Optional[torch.LongTensor] = None
1860
1785
 
1861
1786
 
1862
1787
  def load_balancing_loss_func(
1863
- gate_logits: torch.Tensor | tuple[torch.Tensor] | None,
1864
- num_experts: int | None = None,
1788
+ gate_logits: Union[torch.Tensor, tuple[torch.Tensor], None],
1789
+ num_experts: Optional[int] = None,
1865
1790
  top_k=2,
1866
- attention_mask: torch.Tensor | None = None,
1867
- ) -> torch.Tensor | int:
1791
+ attention_mask: Optional[torch.Tensor] = None,
1792
+ ) -> Union[torch.Tensor, int]:
1868
1793
  r"""
1869
1794
  Computes auxiliary load balancing loss as in Switch Transformer - implemented in Pytorch.
1870
1795
 
@@ -1969,6 +1894,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
1969
1894
  self.vocab_size = config.text_config.vocab_size
1970
1895
  self.model = Qwen3OmniMoeThinkerTextModel._from_config(config.text_config)
1971
1896
  self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1897
+ self.pad_token_id = self.config.pad_token_id if self.config.pad_token_id is not None else -1
1972
1898
  self.spatial_merge_size = config.vision_config.spatial_merge_size
1973
1899
  self.rope_deltas = None
1974
1900
  self.num_experts = config.text_config.num_experts
@@ -1982,56 +1908,52 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
1982
1908
  def set_input_embeddings(self, value):
1983
1909
  self.model.set_input_embeddings(value)
1984
1910
 
1985
- @can_return_tuple
1986
- @auto_docstring
1987
1911
  def get_video_features(
1988
- self,
1989
- pixel_values_videos: torch.FloatTensor,
1990
- video_grid_thw: torch.LongTensor | None = None,
1991
- **kwargs: Unpack[TransformersKwargs],
1992
- ) -> tuple | BaseModelOutputWithPooling:
1993
- r"""
1994
- pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1995
- The tensors corresponding to the input videos.
1996
- video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1997
- The temporal, height and width of feature shape of each video in LLM.
1912
+ self, pixel_values_videos: torch.FloatTensor, video_grid_thw: Optional[torch.LongTensor] = None
1913
+ ):
1914
+ """
1915
+ Encodes videos into continuous embeddings that can be forwarded to the language model.
1916
+
1917
+ Args:
1918
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1919
+ The tensors corresponding to the input videos.
1920
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1921
+ The temporal, height and width of feature shape of each video in LLM.
1998
1922
  """
1999
1923
  pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
2000
- return self.visual(pixel_values_videos, grid_thw=video_grid_thw, **kwargs)
1924
+ video_embeds = self.visual(pixel_values_videos, grid_thw=video_grid_thw)
1925
+ return video_embeds
2001
1926
 
2002
- @can_return_tuple
2003
- @auto_docstring
2004
- def get_image_features(
2005
- self,
2006
- pixel_values: torch.FloatTensor,
2007
- image_grid_thw: torch.LongTensor | None = None,
2008
- **kwargs: Unpack[TransformersKwargs],
2009
- ) -> tuple | BaseModelOutputWithPooling:
2010
- r"""
2011
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
2012
- The tensors corresponding to the input images.
2013
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
2014
- The temporal, height and width of feature shape of each image in LLM.
1927
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: Optional[torch.LongTensor] = None):
1928
+ """
1929
+ Encodes images into continuous embeddings that can be forwarded to the language model.
1930
+
1931
+ Args:
1932
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1933
+ The tensors corresponding to the input images.
1934
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1935
+ The temporal, height and width of feature shape of each image in LLM.
2015
1936
  """
2016
1937
  pixel_values = pixel_values.type(self.visual.dtype)
2017
- return self.visual(pixel_values, grid_thw=image_grid_thw, **kwargs)
1938
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
1939
+ return image_embeds
2018
1940
 
2019
- @can_return_tuple
2020
- @auto_docstring
2021
1941
  def get_audio_features(
2022
1942
  self,
2023
1943
  input_features: torch.FloatTensor,
2024
- feature_attention_mask: torch.LongTensor | None = None,
2025
- audio_feature_lengths: torch.LongTensor | None = None,
2026
- **kwargs: Unpack[TransformersKwargs],
2027
- ) -> tuple | BaseModelOutputWithPooling:
2028
- r"""
2029
- input_features (`torch.FloatTensor`):
2030
- The tensors corresponding to the input audios.
2031
- feature_attention_mask (`torch.LongTensor`, *optional*):
2032
- Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
2033
- audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
2034
- The length of feature shape of each audio in LLM.
1944
+ feature_attention_mask: Optional[torch.LongTensor] = None,
1945
+ audio_feature_lengths: Optional[torch.LongTensor] = None,
1946
+ ):
1947
+ """
1948
+ Encodes audios into continuous embeddings that can be forwarded to the language model.
1949
+
1950
+ Args:
1951
+ input_features (`torch.FloatTensor`):
1952
+ The tensors corresponding to the input audios.
1953
+ feature_attention_mask (`torch.LongTensor`, *optional*):
1954
+ Mask to avoid performing attention on padding feature indices. Mask values selected in `[0, 1]`:
1955
+ audio_feature_lengths (`torch.LongTensor` of shape `(num_audios)`, *optional*):
1956
+ The length of feature shape of each audio in LLM.
2035
1957
  """
2036
1958
  if feature_attention_mask is not None:
2037
1959
  audio_feature_lengths = torch.sum(feature_attention_mask, dim=1)
@@ -2043,18 +1965,17 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2043
1965
  audio_outputs = self.audio_tower(
2044
1966
  input_features,
2045
1967
  feature_lens=feature_lens,
2046
- return_dict=True,
2047
- **kwargs,
2048
1968
  )
1969
+ audio_features = audio_outputs.last_hidden_state
2049
1970
 
2050
- return audio_outputs
1971
+ return audio_features
2051
1972
 
2052
1973
  def get_placeholder_mask(
2053
1974
  self,
2054
1975
  input_ids: torch.LongTensor,
2055
1976
  inputs_embeds: torch.FloatTensor,
2056
- image_features: torch.FloatTensor | None = None,
2057
- video_features: torch.FloatTensor | None = None,
1977
+ image_features: Optional[torch.FloatTensor] = None,
1978
+ video_features: Optional[torch.FloatTensor] = None,
2058
1979
  ):
2059
1980
  """
2060
1981
  Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
@@ -2082,18 +2003,16 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2082
2003
 
2083
2004
  n_image_tokens = special_image_mask.sum()
2084
2005
  special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
2085
- if image_features is not None:
2086
- torch_compilable_check(
2087
- inputs_embeds[special_image_mask].numel() == image_features.numel(),
2088
- f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
2006
+ if image_features is not None and inputs_embeds[special_image_mask].numel() != image_features.numel():
2007
+ raise ValueError(
2008
+ f"Image features and image tokens do not match: tokens: {n_image_tokens}, features {image_features.shape[0]}"
2089
2009
  )
2090
2010
 
2091
2011
  n_video_tokens = special_video_mask.sum()
2092
2012
  special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
2093
- if video_features is not None:
2094
- torch_compilable_check(
2095
- inputs_embeds[special_video_mask].numel() == video_features.numel(),
2096
- f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}",
2013
+ if video_features is not None and inputs_embeds[special_video_mask].numel() != video_features.numel():
2014
+ raise ValueError(
2015
+ f"Videos features and video tokens do not match: tokens: {n_video_tokens}, features {video_features.shape[0]}"
2097
2016
  )
2098
2017
 
2099
2018
  special_audio_mask = special_audio_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
@@ -2118,12 +2037,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2118
2037
  rope_deltas=None,
2119
2038
  labels=None,
2120
2039
  use_cache=None,
2121
- output_router_logits: bool | None = None,
2040
+ output_router_logits: Optional[bool] = None,
2122
2041
  use_audio_in_video=None,
2123
2042
  cache_position=None,
2124
2043
  video_second_per_grid=None,
2125
2044
  **kwargs,
2126
- ) -> tuple | Qwen3OmniMoeThinkerCausalLMOutputWithPast:
2045
+ ) -> Union[tuple, Qwen3OmniMoeThinkerCausalLMOutputWithPast]:
2127
2046
  r"""
2128
2047
  image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
2129
2048
  The temporal, height and width of feature shape of each image in LLM.
@@ -2196,18 +2115,13 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2196
2115
  input_features,
2197
2116
  feature_attention_mask=feature_attention_mask,
2198
2117
  audio_feature_lengths=audio_feature_lengths,
2199
- return_dict=True,
2200
- ).last_hidden_state
2118
+ )
2201
2119
  audio_features = audio_features.to(inputs_embeds.device, inputs_embeds.dtype)
2202
2120
  _, _, audio_mask = self.get_placeholder_mask(input_ids, inputs_embeds=inputs_embeds)
2203
2121
  inputs_embeds = inputs_embeds.masked_scatter(audio_mask, audio_features)
2204
2122
 
2205
2123
  if pixel_values is not None:
2206
- image_outputs: BaseModelOutputWithDeepstackFeatures = self.get_image_features(
2207
- pixel_values, image_grid_thw, return_dict=True
2208
- )
2209
- image_embeds = image_outputs.pooler_output
2210
- image_embeds_multiscale = image_outputs.deepstack_features
2124
+ image_embeds, image_embeds_multiscale = self.get_image_features(pixel_values, image_grid_thw)
2211
2125
  image_embeds = image_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
2212
2126
  image_mask, _, _ = self.get_placeholder_mask(
2213
2127
  input_ids, inputs_embeds=inputs_embeds, image_features=image_embeds
@@ -2215,9 +2129,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2215
2129
  inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
2216
2130
 
2217
2131
  if pixel_values_videos is not None:
2218
- video_embeds, video_embeds_multiscale = self.get_video_features(
2219
- pixel_values_videos, video_grid_thw, return_dict=True
2220
- ).pooler_output
2132
+ video_embeds, video_embeds_multiscale = self.get_video_features(pixel_values_videos, video_grid_thw)
2221
2133
 
2222
2134
  video_embeds = video_embeds.to(inputs_embeds.device, inputs_embeds.dtype)
2223
2135
  _, video_mask, _ = self.get_placeholder_mask(
@@ -2253,8 +2165,11 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2253
2165
  audio_feature_lengths = None
2254
2166
 
2255
2167
  if attention_mask is not None and position_ids is None:
2256
- past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
2257
- if past_key_values_length == 0 or self.rope_deltas is None:
2168
+ if (
2169
+ cache_position is None
2170
+ or (cache_position is not None and cache_position[0] == 0)
2171
+ or self.rope_deltas is None
2172
+ ):
2258
2173
  delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
2259
2174
  position_ids, rope_deltas = self.get_rope_index(
2260
2175
  input_ids,
@@ -2269,7 +2184,7 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2269
2184
  self.rope_deltas = rope_deltas
2270
2185
  else:
2271
2186
  batch_size, seq_length = input_ids.shape
2272
- delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
2187
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
2273
2188
  position_ids = torch.arange(seq_length, device=input_ids.device)
2274
2189
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
2275
2190
  position_ids = position_ids.add(delta)
@@ -2335,7 +2250,6 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2335
2250
  feature_attention_mask=None,
2336
2251
  use_audio_in_video=False,
2337
2252
  video_second_per_grid=None,
2338
- is_first_iteration=False,
2339
2253
  **kwargs,
2340
2254
  ):
2341
2255
  model_inputs = super().prepare_inputs_for_generation(
@@ -2354,13 +2268,12 @@ class Qwen3OmniMoeThinkerForConditionalGeneration(
2354
2268
  feature_attention_mask=feature_attention_mask,
2355
2269
  use_audio_in_video=use_audio_in_video,
2356
2270
  video_second_per_grid=video_second_per_grid,
2357
- is_first_iteration=is_first_iteration,
2358
2271
  **kwargs,
2359
2272
  )
2360
2273
 
2361
2274
  model_inputs["position_ids"] = None
2362
2275
 
2363
- if not is_first_iteration and use_cache:
2276
+ if cache_position[0] != 0:
2364
2277
  model_inputs["pixel_values"] = None
2365
2278
  model_inputs["pixel_values_videos"] = None
2366
2279
  model_inputs["input_features"] = None
@@ -2386,7 +2299,7 @@ class Qwen3OmniMoeTalkerCodePredictorOutputWithPast(CausalLMOutputWithPast):
2386
2299
  Current generation step of code predictor model.
2387
2300
  """
2388
2301
 
2389
- generation_steps: int | None = None
2302
+ generation_steps: Optional[int] = None
2390
2303
 
2391
2304
 
2392
2305
  @use_kernel_forward_from_hub("RMSNorm")
@@ -2410,7 +2323,6 @@ class Qwen3OmniMoeRMSNorm(nn.Module):
2410
2323
  return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
2411
2324
 
2412
2325
 
2413
- @use_kernelized_func(apply_rotary_pos_emb)
2414
2326
  class Qwen3OmniMoeTalkerCodePredictorAttention(nn.Module):
2415
2327
  """Multi-headed attention from 'Attention Is All You Need' paper"""
2416
2328
 
@@ -2437,6 +2349,7 @@ class Qwen3OmniMoeTalkerCodePredictorAttention(nn.Module):
2437
2349
  self.o_proj = nn.Linear(
2438
2350
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
2439
2351
  )
2352
+ self.rotary_fn = apply_rotary_pos_emb
2440
2353
  self.q_norm = Qwen3OmniMoeRMSNorm(self.head_dim, eps=config.rms_norm_eps) # unlike olmo, only on the head dim!
2441
2354
  self.k_norm = Qwen3OmniMoeRMSNorm(
2442
2355
  self.head_dim, eps=config.rms_norm_eps
@@ -2447,11 +2360,11 @@ class Qwen3OmniMoeTalkerCodePredictorAttention(nn.Module):
2447
2360
  self,
2448
2361
  hidden_states: torch.Tensor,
2449
2362
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
2450
- attention_mask: torch.Tensor | None,
2451
- past_key_values: Cache | None = None,
2452
- cache_position: torch.LongTensor | None = None,
2363
+ attention_mask: Optional[torch.Tensor],
2364
+ past_key_values: Optional[Cache] = None,
2365
+ cache_position: Optional[torch.LongTensor] = None,
2453
2366
  **kwargs: Unpack[FlashAttentionKwargs],
2454
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
2367
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
2455
2368
  input_shape = hidden_states.shape[:-1]
2456
2369
  hidden_shape = (*input_shape, -1, self.head_dim)
2457
2370
 
@@ -2518,12 +2431,12 @@ class Qwen3OmniMoeTalkerCodePredictorDecoderLayer(GradientCheckpointingLayer):
2518
2431
  def forward(
2519
2432
  self,
2520
2433
  hidden_states: torch.Tensor,
2521
- attention_mask: torch.Tensor | None = None,
2522
- position_ids: torch.LongTensor | None = None,
2523
- past_key_values: Cache | None = None,
2524
- use_cache: bool | None = False,
2525
- cache_position: torch.LongTensor | None = None,
2526
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
2434
+ attention_mask: Optional[torch.Tensor] = None,
2435
+ position_ids: Optional[torch.LongTensor] = None,
2436
+ past_key_values: Optional[Cache] = None,
2437
+ use_cache: Optional[bool] = False,
2438
+ cache_position: Optional[torch.LongTensor] = None,
2439
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
2527
2440
  **kwargs: Unpack[TransformersKwargs],
2528
2441
  ) -> torch.Tensor:
2529
2442
  residual = hidden_states
@@ -2566,13 +2479,13 @@ class Qwen3OmniMoeRotaryEmbedding(nn.Module):
2566
2479
  inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
2567
2480
 
2568
2481
  self.register_buffer("inv_freq", inv_freq, persistent=False)
2569
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
2482
+ self.original_inv_freq = inv_freq
2570
2483
 
2571
2484
  @staticmethod
2572
2485
  def compute_default_rope_parameters(
2573
- config: Qwen3OmniMoeConfig | None = None,
2486
+ config: Optional[Qwen3OmniMoeConfig] = None,
2574
2487
  device: Optional["torch.device"] = None,
2575
- seq_len: int | None = None,
2488
+ seq_len: Optional[int] = None,
2576
2489
  ) -> tuple["torch.Tensor", float]:
2577
2490
  """
2578
2491
  Computes the inverse frequencies according to the original RoPE implementation
@@ -2605,7 +2518,7 @@ class Qwen3OmniMoeRotaryEmbedding(nn.Module):
2605
2518
  position_ids_expanded = position_ids[:, None, :].float()
2606
2519
 
2607
2520
  device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
2608
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
2521
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
2609
2522
  freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
2610
2523
  emb = torch.cat((freqs, freqs), dim=-1)
2611
2524
  cos = emb.cos() * self.attention_scaling
@@ -2648,13 +2561,13 @@ class Qwen3OmniMoeTalkerCodePredictorModel(Qwen3OmniMoePreTrainedModel):
2648
2561
  @auto_docstring
2649
2562
  def forward(
2650
2563
  self,
2651
- input_ids: torch.LongTensor | None = None,
2652
- attention_mask: torch.Tensor | None = None,
2653
- position_ids: torch.LongTensor | None = None,
2654
- past_key_values: Cache | None = None,
2655
- inputs_embeds: torch.FloatTensor | None = None,
2656
- use_cache: bool | None = None,
2657
- cache_position: torch.LongTensor | None = None,
2564
+ input_ids: Optional[torch.LongTensor] = None,
2565
+ attention_mask: Optional[torch.Tensor] = None,
2566
+ position_ids: Optional[torch.LongTensor] = None,
2567
+ past_key_values: Optional[Cache] = None,
2568
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2569
+ use_cache: Optional[bool] = None,
2570
+ cache_position: Optional[torch.LongTensor] = None,
2658
2571
  **kwargs: Unpack[TransformersKwargs],
2659
2572
  ) -> BaseModelOutputWithPast:
2660
2573
  if input_ids is not None:
@@ -2811,7 +2724,7 @@ class Qwen3OmniMoeTalkerOutputWithPast(MoeCausalLMOutputWithPast):
2811
2724
  Current generation step, used to track which `trailing_text_hidden` should be used.
2812
2725
  """
2813
2726
 
2814
- generation_step: int | None = None
2727
+ generation_step: Optional[int] = None
2815
2728
 
2816
2729
 
2817
2730
  class Qwen3OmniMoeTalkerRotaryEmbedding(Qwen3OmniMoeThinkerTextRotaryEmbedding):
@@ -2834,7 +2747,6 @@ class Qwen3OmniMoeTalkerTextMLP(nn.Module):
2834
2747
  return down_proj
2835
2748
 
2836
2749
 
2837
- @use_experts_implementation
2838
2750
  class Qwen3OmniMoeTalkerTextExperts(nn.Module):
2839
2751
  """Collection of expert weights stored as 3D tensors."""
2840
2752
 
@@ -2937,12 +2849,12 @@ class Qwen3OmniMoeTalkerDecoderLayer(GradientCheckpointingLayer):
2937
2849
  def forward(
2938
2850
  self,
2939
2851
  hidden_states: torch.Tensor,
2940
- attention_mask: torch.Tensor | None = None,
2941
- position_ids: torch.LongTensor | None = None,
2942
- past_key_values: Cache | None = None,
2943
- use_cache: bool | None = False,
2944
- cache_position: torch.LongTensor | None = None,
2945
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
2852
+ attention_mask: Optional[torch.Tensor] = None,
2853
+ position_ids: Optional[torch.LongTensor] = None,
2854
+ past_key_values: Optional[Cache] = None,
2855
+ use_cache: Optional[bool] = False,
2856
+ cache_position: Optional[torch.LongTensor] = None,
2857
+ position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
2946
2858
  **kwargs: Unpack[TransformersKwargs],
2947
2859
  ) -> torch.Tensor:
2948
2860
  residual = hidden_states
@@ -3004,18 +2916,18 @@ class Qwen3OmniMoeTalkerModel(Qwen3OmniMoePreTrainedModel):
3004
2916
  @auto_docstring
3005
2917
  def forward(
3006
2918
  self,
3007
- input_ids: torch.LongTensor | None = None,
3008
- attention_mask: torch.Tensor | None = None,
3009
- position_ids: torch.LongTensor | None = None,
3010
- past_key_values: Cache | None = None,
3011
- inputs_embeds: torch.FloatTensor | None = None,
3012
- use_cache: bool | None = None,
3013
- cache_position: torch.LongTensor | None = None,
2919
+ input_ids: Optional[torch.LongTensor] = None,
2920
+ attention_mask: Optional[torch.Tensor] = None,
2921
+ position_ids: Optional[torch.LongTensor] = None,
2922
+ past_key_values: Optional[Cache] = None,
2923
+ inputs_embeds: Optional[torch.FloatTensor] = None,
2924
+ use_cache: Optional[bool] = None,
2925
+ cache_position: Optional[torch.LongTensor] = None,
3014
2926
  # args for deepstack
3015
- visual_pos_masks: torch.Tensor | None = None,
3016
- deepstack_visual_embeds: list[torch.Tensor] | None = None,
2927
+ visual_pos_masks: Optional[torch.Tensor] = None,
2928
+ deepstack_visual_embeds: Optional[list[torch.Tensor]] = None,
3017
2929
  **kwargs: Unpack[FlashAttentionKwargs],
3018
- ) -> tuple | BaseModelOutputWithPast:
2930
+ ) -> Union[tuple, BaseModelOutputWithPast]:
3019
2931
  r"""
3020
2932
  visual_pos_masks (`torch.Tensor` of shape `(batch_size, seqlen)`, *optional*):
3021
2933
  The mask of the visual positions.
@@ -3110,9 +3022,9 @@ class Qwen3OmniMoeTalkerModel(Qwen3OmniMoePreTrainedModel):
3110
3022
 
3111
3023
  @auto_docstring
3112
3024
  class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrainedModel, GenerationMixin):
3113
- _tied_weights_keys = {"codec_head": "model.codec_embedding.weight"}
3114
- _tp_plan = {"codec_head": "colwise_rep"}
3115
- _pp_plan = {"codec_head": (["hidden_states"], ["logits"])}
3025
+ _tied_weights_keys = {"lm_head.weight": "model.embed_tokens.weight"}
3026
+ _tp_plan = {"lm_head": "colwise_rep"}
3027
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
3116
3028
  config_class = Qwen3OmniMoeTalkerConfig
3117
3029
  base_model_prefix = "talker"
3118
3030
  _no_split_modules = ["Qwen3OmniMoeTalkerCodePredictorModelForConditionalGeneration"]
@@ -3191,9 +3103,12 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
3191
3103
  if inputs_embeds is not None and inputs_embeds.shape[1] > 1:
3192
3104
  generation_step = -1
3193
3105
  residual_codes = None
3194
- if position_ids is None:
3195
- past_key_values_length = 0 if past_key_values is None else past_key_values.get_seq_length()
3196
- if past_key_values_length == 0 or self.rope_deltas is None:
3106
+ if attention_mask is not None:
3107
+ if (
3108
+ cache_position is None
3109
+ or (cache_position is not None and cache_position[0] == 0)
3110
+ or self.rope_deltas is None
3111
+ ):
3197
3112
  delta0 = (1 - attention_mask).sum(dim=-1).unsqueeze(1)
3198
3113
  position_ids, rope_deltas = self.get_rope_index(
3199
3114
  talker_input_ids,
@@ -3208,7 +3123,7 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
3208
3123
  self.rope_deltas = rope_deltas
3209
3124
  else:
3210
3125
  batch_size, seq_length = input_ids.shape
3211
- delta = (past_key_values_length + self.rope_deltas).to(input_ids.device)
3126
+ delta = cache_position[0] + self.rope_deltas if cache_position is not None else 0
3212
3127
  position_ids = torch.arange(seq_length, device=input_ids.device)
3213
3128
  position_ids = position_ids.view(1, -1).expand(batch_size, -1)
3214
3129
  position_ids = position_ids.add(delta)
@@ -3259,13 +3174,13 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
3259
3174
  # Should inherit from PretrainedModel, but cannot inherit multiple classes in modular
3260
3175
  def get_rope_index(
3261
3176
  self,
3262
- input_ids: torch.LongTensor | None = None,
3263
- image_grid_thw: torch.LongTensor | None = None,
3264
- video_grid_thw: torch.LongTensor | None = None,
3265
- attention_mask: torch.Tensor | None = None,
3177
+ input_ids: Optional[torch.LongTensor] = None,
3178
+ image_grid_thw: Optional[torch.LongTensor] = None,
3179
+ video_grid_thw: Optional[torch.LongTensor] = None,
3180
+ attention_mask: Optional[torch.Tensor] = None,
3266
3181
  use_audio_in_video: bool = False,
3267
- audio_seqlens: torch.LongTensor | None = None,
3268
- second_per_grids: torch.Tensor | None = None,
3182
+ audio_seqlens: Optional[torch.LongTensor] = None,
3183
+ second_per_grids: Optional[torch.Tensor] = None,
3269
3184
  ) -> tuple[torch.Tensor, torch.Tensor]:
3270
3185
  return Qwen3OmniMoePreTrainedModelForConditionalGeneration.get_rope_index(
3271
3186
  self,
@@ -3303,31 +3218,15 @@ class Qwen3OmniMoeTalkerForConditionalGeneration(Qwen3OmniMoeThinkerTextPreTrain
3303
3218
  return model_kwargs
3304
3219
 
3305
3220
  def prepare_inputs_for_generation(
3306
- self,
3307
- input_ids,
3308
- past_key_values=None,
3309
- attention_mask=None,
3310
- inputs_embeds=None,
3311
- cache_position=None,
3312
- is_first_iteration=False,
3313
- **kwargs,
3221
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, cache_position=None, **kwargs
3314
3222
  ):
3315
3223
  hidden_states = kwargs.pop("hidden_states", None)
3316
3224
  inputs = super().prepare_inputs_for_generation(
3317
- input_ids,
3318
- past_key_values,
3319
- attention_mask,
3320
- inputs_embeds,
3321
- cache_position,
3322
- is_first_iteration=is_first_iteration,
3323
- **kwargs,
3225
+ input_ids, past_key_values, attention_mask, inputs_embeds, cache_position, **kwargs
3324
3226
  )
3325
-
3326
- # Qwen3-Omni will prepare position ids in forward with deltas
3327
- inputs["position_ids"] = None
3328
-
3227
+ # Decode stage
3329
3228
  # TODO(raushan, gante): Refactor this part to a utility function
3330
- if not is_first_iteration and kwargs.get("use_cache", True):
3229
+ if cache_position[0] != 0:
3331
3230
  input_ids = input_ids[:, -1:]
3332
3231
  generation_step = kwargs.get("generation_step")
3333
3232
  trailing_text_hidden = kwargs.get("trailing_text_hidden")
@@ -3453,7 +3352,6 @@ class Qwen3OmniMoeConvNeXtBlock(nn.Module):
3453
3352
  return hidden_states
3454
3353
 
3455
3354
 
3456
- @use_kernelized_func(apply_rotary_pos_emb)
3457
3355
  class Qwen3OmniMoeCode2WavAttention(nn.Module):
3458
3356
  """Multi-headed attention from 'Attention Is All You Need' paper"""
3459
3357
 
@@ -3480,6 +3378,7 @@ class Qwen3OmniMoeCode2WavAttention(nn.Module):
3480
3378
  self.o_proj = nn.Linear(
3481
3379
  config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
3482
3380
  )
3381
+ self.rotary_fn = apply_rotary_pos_emb
3483
3382
  self.q_norm = nn.Identity()
3484
3383
  self.k_norm = nn.Identity()
3485
3384
  self.sliding_window = config.sliding_window
@@ -3488,11 +3387,11 @@ class Qwen3OmniMoeCode2WavAttention(nn.Module):
3488
3387
  self,
3489
3388
  hidden_states: torch.Tensor,
3490
3389
  position_embeddings: tuple[torch.Tensor, torch.Tensor],
3491
- attention_mask: torch.Tensor | None,
3492
- past_key_values: Cache | None = None,
3493
- cache_position: torch.LongTensor | None = None,
3390
+ attention_mask: Optional[torch.Tensor],
3391
+ past_key_values: Optional[Cache] = None,
3392
+ cache_position: Optional[torch.LongTensor] = None,
3494
3393
  **kwargs: Unpack[FlashAttentionKwargs],
3495
- ) -> tuple[torch.Tensor, torch.Tensor | None]:
3394
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
3496
3395
  input_shape = hidden_states.shape[:-1]
3497
3396
  hidden_shape = (*input_shape, -1, self.head_dim)
3498
3397
 
@@ -3596,13 +3495,13 @@ class Qwen3OmniMoeCode2WavTransformerLayer(GradientCheckpointingLayer):
3596
3495
  def forward(
3597
3496
  self,
3598
3497
  hidden_states: torch.Tensor,
3599
- attention_mask: torch.Tensor | None = None,
3600
- position_ids: torch.LongTensor | None = None,
3601
- past_key_values: Cache | None = None,
3602
- use_cache: bool | None = False,
3603
- cache_position: torch.LongTensor | None = None,
3498
+ attention_mask: Optional[torch.Tensor] = None,
3499
+ position_ids: Optional[torch.LongTensor] = None,
3500
+ past_key_values: Optional[Cache] = None,
3501
+ use_cache: Optional[bool] = False,
3502
+ cache_position: Optional[torch.LongTensor] = None,
3604
3503
  **kwargs,
3605
- ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
3504
+ ) -> tuple[torch.FloatTensor, Optional[tuple[torch.FloatTensor, torch.FloatTensor]]]:
3606
3505
  """
3607
3506
  Args:
3608
3507
  hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
@@ -3819,9 +3718,7 @@ class Qwen3OmniMoeCode2WavDecoderBlock(Qwen3OmniMoePreTrainedModel):
3819
3718
 
3820
3719
  self.block = nn.ModuleList(block)
3821
3720
 
3822
- self.post_init()
3823
-
3824
- def forward(self, hidden, **kwargs):
3721
+ def forward(self, hidden):
3825
3722
  for block in self.block:
3826
3723
  hidden = block(hidden)
3827
3724
  return hidden
@@ -3863,7 +3760,7 @@ class Qwen3OmniMoeCode2Wav(Qwen3OmniMoePreTrainedModel):
3863
3760
 
3864
3761
  self.post_init()
3865
3762
 
3866
- def forward(self, codes, **kwargs):
3763
+ def forward(self, codes):
3867
3764
  if codes.shape[1] != self.config.num_quantizers:
3868
3765
  raise ValueError(f"Expected {self.config.num_quantizers} layer of codes, got {codes.shape[1]}")
3869
3766
  hidden = self.code_embedding(codes + self.code_offset).mean(1)
@@ -3995,10 +3892,10 @@ class Qwen3OmniMoeForConditionalGeneration(Qwen3OmniMoePreTrainedModel, Generati
3995
3892
  @torch.no_grad()
3996
3893
  def generate(
3997
3894
  self,
3998
- input_ids: torch.Tensor | None = None,
3895
+ input_ids: Optional[torch.Tensor] = None,
3999
3896
  speaker: str = "Ethan",
4000
3897
  use_audio_in_video: bool = False,
4001
- return_audio: bool | None = None,
3898
+ return_audio: Optional[bool] = None,
4002
3899
  thinker_max_new_tokens: int = 1024,
4003
3900
  thinker_eos_token_id: int = 151645,
4004
3901
  talker_max_new_tokens: int = 4096,