transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1376 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/eomt_dinov3/modular_eomt_dinov3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_eomt_dinov3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 the HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import math
22
+ from collections.abc import Callable
23
+ from dataclasses import dataclass
24
+ from typing import Optional
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from torch import Tensor, nn
30
+
31
+ from ... import initialization as init
32
+ from ...activations import ACT2FN
33
+ from ...file_utils import ModelOutput, is_scipy_available, requires_backends
34
+ from ...modeling_layers import GradientCheckpointingLayer
35
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
36
+ from ...processing_utils import Unpack
37
+ from ...pytorch_utils import compile_compatible_method_lru_cache
38
+ from ...utils import TransformersKwargs, auto_docstring, is_accelerate_available
39
+ from ...utils.generic import check_model_inputs, maybe_autocast
40
+ from .configuration_eomt_dinov3 import EomtDinov3Config
41
+
42
+
43
+ if is_scipy_available():
44
+ from scipy.optimize import linear_sum_assignment
45
+
46
+ if is_accelerate_available():
47
+ from accelerate import PartialState
48
+ from accelerate.utils import reduce
49
+
50
+
51
+ def rotate_half(x):
52
+ """Rotates half the hidden dims of the input."""
53
+ x1 = x[..., : x.shape[-1] // 2]
54
+ x2 = x[..., x.shape[-1] // 2 :]
55
+ return torch.cat((-x2, x1), dim=-1)
56
+
57
+
58
+ def eager_attention_forward(
59
+ module: nn.Module,
60
+ query: torch.Tensor,
61
+ key: torch.Tensor,
62
+ value: torch.Tensor,
63
+ attention_mask: torch.Tensor | None,
64
+ scaling: float | None = None,
65
+ dropout: float = 0.0,
66
+ **kwargs: Unpack[TransformersKwargs],
67
+ ):
68
+ if scaling is None:
69
+ scaling = query.size(-1) ** -0.5
70
+
71
+ # Take the dot product between "query" and "key" to get the raw attention scores.
72
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
73
+
74
+ if attention_mask is not None:
75
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
76
+ attn_weights = attn_weights + attention_mask
77
+
78
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
79
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
80
+
81
+ attn_output = torch.matmul(attn_weights, value)
82
+ attn_output = attn_output.transpose(1, 2).contiguous()
83
+
84
+ return attn_output, attn_weights
85
+
86
+
87
+ def apply_rotary_pos_emb(
88
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, **kwargs
89
+ ) -> tuple[torch.Tensor, torch.Tensor]:
90
+ """Applies Rotary Position Embedding to the query and key tensors, but only to the patch tokens,
91
+ ignoring the prefix tokens (cls token and register tokens).
92
+
93
+ Args:
94
+ q (`torch.Tensor`): The query tensor.
95
+ k (`torch.Tensor`): The key tensor.
96
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
97
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
98
+
99
+ Returns:
100
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
101
+ """
102
+
103
+ num_tokens = q.shape[-2]
104
+ num_patches = sin.shape[-2]
105
+ num_prefix_tokens = num_tokens - num_patches # cls token + register tokens
106
+
107
+ q_prefix_tokens, q_patches = q.split((num_prefix_tokens, num_patches), dim=-2)
108
+ k_prefix_tokens, k_patches = k.split((num_prefix_tokens, num_patches), dim=-2)
109
+
110
+ # apply rope only to patch tokens
111
+ q_patches = (q_patches * cos) + (rotate_half(q_patches) * sin)
112
+ k_patches = (k_patches * cos) + (rotate_half(k_patches) * sin)
113
+
114
+ q = torch.cat((q_prefix_tokens, q_patches), dim=-2)
115
+ k = torch.cat((k_prefix_tokens, k_patches), dim=-2)
116
+
117
+ return q, k
118
+
119
+
120
+ class EomtDinov3Attention(nn.Module):
121
+ """
122
+ Multi-headed attention compatible with ALL_ATTENTION_FUNCTIONS.
123
+ """
124
+
125
+ def __init__(self, config: EomtDinov3Config):
126
+ super().__init__()
127
+ self.config = config
128
+ self.embed_dim = config.hidden_size
129
+ self.num_heads = config.num_attention_heads
130
+ self.head_dim = self.embed_dim // self.num_heads
131
+ self.is_causal = False
132
+
133
+ self.scaling = self.head_dim**-0.5
134
+ self.is_causal = False
135
+
136
+ self.dropout = config.attention_dropout
137
+ self.k_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.key_bias)
138
+ self.v_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.value_bias)
139
+
140
+ self.q_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.query_bias)
141
+ self.o_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=config.proj_bias)
142
+
143
+ def forward(
144
+ self,
145
+ hidden_states: torch.Tensor,
146
+ attention_mask: torch.Tensor | None = None,
147
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
148
+ **kwargs: Unpack[TransformersKwargs],
149
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
150
+ """Input shape: Batch x Time x Channel"""
151
+
152
+ batch_size, patches, _ = hidden_states.size()
153
+
154
+ query_states = self.q_proj(hidden_states)
155
+ key_states = self.k_proj(hidden_states)
156
+ value_states = self.v_proj(hidden_states)
157
+
158
+ query_states = query_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
159
+ key_states = key_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
160
+ value_states = value_states.view(batch_size, patches, self.num_heads, self.head_dim).transpose(1, 2)
161
+
162
+ cos, sin = position_embeddings
163
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
164
+
165
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
166
+ self.config._attn_implementation, eager_attention_forward
167
+ )
168
+
169
+ attn_output, attn_weights = attention_interface(
170
+ self,
171
+ query_states,
172
+ key_states,
173
+ value_states,
174
+ attention_mask,
175
+ dropout=0.0 if not self.training else self.dropout,
176
+ scaling=self.scaling,
177
+ **kwargs,
178
+ )
179
+
180
+ attn_output = attn_output.reshape(batch_size, patches, -1).contiguous()
181
+ attn_output = self.o_proj(attn_output)
182
+
183
+ return attn_output, attn_weights
184
+
185
+
186
+ class EomtDinov3Embeddings(nn.Module):
187
+ """
188
+ Construct the CLS token, mask token, position and patch embeddings.
189
+ """
190
+
191
+ def __init__(self, config: EomtDinov3Config):
192
+ super().__init__()
193
+ self.config = config
194
+ self.cls_token = nn.Parameter(torch.randn(1, 1, config.hidden_size))
195
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, config.hidden_size))
196
+ self.register_tokens = nn.Parameter(torch.empty(1, config.num_register_tokens, config.hidden_size))
197
+ self.patch_embeddings = nn.Conv2d(
198
+ config.num_channels, config.hidden_size, kernel_size=config.patch_size, stride=config.patch_size
199
+ )
200
+ self.num_prefix_tokens = 1 + config.num_register_tokens
201
+
202
+ def forward(self, pixel_values: torch.Tensor, bool_masked_pos: torch.Tensor | None = None) -> torch.Tensor:
203
+ batch_size = pixel_values.shape[0]
204
+ target_dtype = self.patch_embeddings.weight.dtype
205
+
206
+ # (batch_size, num_channels, height, width) -> (batch_size, num_patches, hidden_size)
207
+ patch_embeddings = self.patch_embeddings(pixel_values.to(dtype=target_dtype))
208
+ patch_embeddings = patch_embeddings.flatten(2).transpose(1, 2)
209
+
210
+ if bool_masked_pos is not None:
211
+ mask_token = self.mask_token.to(patch_embeddings.dtype)
212
+ patch_embeddings = torch.where(bool_masked_pos.unsqueeze(-1), mask_token, patch_embeddings)
213
+
214
+ # Add CLS and register tokens
215
+ cls_token = self.cls_token.expand(batch_size, -1, -1)
216
+ register_tokens = self.register_tokens.expand(batch_size, -1, -1)
217
+ embeddings = torch.cat([cls_token, register_tokens, patch_embeddings], dim=1)
218
+
219
+ return embeddings
220
+
221
+
222
+ def drop_path(input: torch.Tensor, drop_prob: float = 0.0, training: bool = False) -> torch.Tensor:
223
+ """
224
+ Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
225
+
226
+ """
227
+ if drop_prob == 0.0 or not training:
228
+ return input
229
+ keep_prob = 1 - drop_prob
230
+ shape = (input.shape[0],) + (1,) * (input.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
231
+ random_tensor = keep_prob + torch.rand(shape, dtype=input.dtype, device=input.device)
232
+ random_tensor.floor_() # binarize
233
+ output = input.div(keep_prob) * random_tensor
234
+ return output
235
+
236
+
237
+ class EomtDinov3DropPath(nn.Module):
238
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
239
+
240
+ def __init__(self, drop_prob: float | None = None) -> None:
241
+ super().__init__()
242
+ self.drop_prob = drop_prob
243
+
244
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
245
+ return drop_path(hidden_states, self.drop_prob, self.training)
246
+
247
+ def extra_repr(self) -> str:
248
+ return f"p={self.drop_prob}"
249
+
250
+
251
+ class EomtDinov3MLP(nn.Module):
252
+ def __init__(self, config):
253
+ super().__init__()
254
+ self.config = config
255
+ self.hidden_size = config.hidden_size
256
+ self.intermediate_size = config.intermediate_size
257
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
258
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
259
+ self.act_fn = ACT2FN[config.hidden_act]
260
+
261
+ def forward(self, x):
262
+ return self.down_proj(self.act_fn(self.up_proj(x)))
263
+
264
+
265
+ class EomtDinov3GatedMLP(nn.Module):
266
+ def __init__(self, config):
267
+ super().__init__()
268
+ self.config = config
269
+ self.hidden_size = config.hidden_size
270
+ self.intermediate_size = config.intermediate_size
271
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
272
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
273
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
274
+ self.act_fn = ACT2FN[config.hidden_act]
275
+
276
+ def forward(self, x):
277
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
278
+ return down_proj
279
+
280
+
281
+ class EomtDinov3Layer(GradientCheckpointingLayer):
282
+ """This corresponds to the Block class in the original implementation."""
283
+
284
+ def __init__(self, config: EomtDinov3Config):
285
+ super().__init__()
286
+
287
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
288
+ self.attention = EomtDinov3Attention(config)
289
+ self.layer_scale1 = EomtDinov3LayerScale(config)
290
+ self.drop_path = EomtDinov3DropPath(config.drop_path_rate) if config.drop_path_rate > 0.0 else nn.Identity()
291
+
292
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
293
+
294
+ if config.use_gated_mlp:
295
+ self.mlp = EomtDinov3GatedMLP(config)
296
+ else:
297
+ self.mlp = EomtDinov3MLP(config)
298
+ self.layer_scale2 = EomtDinov3LayerScale(config)
299
+
300
+ def forward(
301
+ self,
302
+ hidden_states: torch.Tensor,
303
+ attention_mask: torch.Tensor | None = None,
304
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
305
+ ) -> torch.Tensor:
306
+ # Attention with residual connection
307
+ residual = hidden_states
308
+ hidden_states = self.norm1(hidden_states)
309
+ hidden_states, _ = self.attention(
310
+ hidden_states,
311
+ attention_mask=attention_mask,
312
+ position_embeddings=position_embeddings,
313
+ )
314
+ hidden_states = self.layer_scale1(hidden_states)
315
+ hidden_states = self.drop_path(hidden_states) + residual
316
+
317
+ # MLP with residual connection
318
+ residual = hidden_states
319
+ hidden_states = self.norm2(hidden_states)
320
+ hidden_states = self.mlp(hidden_states)
321
+ hidden_states = self.layer_scale2(hidden_states)
322
+ hidden_states = self.drop_path(hidden_states) + residual
323
+
324
+ return hidden_states
325
+
326
+
327
+ class EomtDinov3LayerScale(nn.Module):
328
+ def __init__(self, config) -> None:
329
+ super().__init__()
330
+ self.lambda1 = nn.Parameter(config.layerscale_value * torch.ones(config.hidden_size))
331
+
332
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
333
+ return hidden_state * self.lambda1
334
+
335
+
336
+ @compile_compatible_method_lru_cache(maxsize=32)
337
+ def get_patches_center_coordinates(
338
+ num_patches_h: int, num_patches_w: int, dtype: torch.dtype, device: torch.device
339
+ ) -> torch.Tensor:
340
+ """
341
+ Computes the 2D coordinates of the centers of image patches, normalized to the range [-1, +1].
342
+ The center of each patch is exactly halfway between its top-left and bottom-right corners.
343
+
344
+ Args:
345
+ num_patches_h (int): Number of patches along the vertical (height) axis.
346
+ num_patches_w (int): Number of patches along the horizontal (width) axis.
347
+ dtype (torch.dtype): The desired data type of the returned tensor.
348
+
349
+ Returns:
350
+ torch.Tensor: A tensor of shape (height * width, 2), where each row contains the (y, x)
351
+ coordinates of a patch center, normalized to [-1, +1].
352
+ """
353
+ coords_h = torch.arange(0.5, num_patches_h, dtype=dtype, device=device)
354
+ coords_w = torch.arange(0.5, num_patches_w, dtype=dtype, device=device)
355
+ coords_h = coords_h / num_patches_h
356
+ coords_w = coords_w / num_patches_w
357
+ # (height, width, 2) -> (height * width, 2)
358
+ coords = torch.stack(torch.meshgrid(coords_h, coords_w, indexing="ij"), dim=-1)
359
+ coords = coords.flatten(0, 1)
360
+ # Shift range [0, 1] to [-1, +1]
361
+ coords = 2.0 * coords - 1.0
362
+ return coords
363
+
364
+
365
+ def augment_patches_center_coordinates(
366
+ coords: torch.Tensor,
367
+ shift: float | None = None,
368
+ jitter: float | None = None,
369
+ rescale: float | None = None,
370
+ ) -> torch.Tensor:
371
+ # Shift coords by adding a uniform value in [-shift, shift]
372
+ if shift is not None:
373
+ shift_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
374
+ shift_hw = shift_hw.uniform_(-shift, shift)
375
+ coords = coords + shift_hw
376
+
377
+ # Jitter coords by multiplying the range [-1, 1] by a log-uniform value in [1/jitter, jitter]
378
+ if jitter is not None:
379
+ jitter_range = np.log(jitter)
380
+ jitter_hw = torch.empty((1, 2), device=coords.device, dtype=coords.dtype)
381
+ jitter_hw = jitter_hw.uniform_(-jitter_range, jitter_range).exp()
382
+ coords = coords * jitter_hw
383
+
384
+ # Rescale coords by multiplying the range [-1, 1] by a log-uniform value in [1/rescale, rescale]
385
+ if rescale is not None:
386
+ rescale_range = np.log(rescale)
387
+ rescale_hw = torch.empty(1, device=coords.device, dtype=coords.dtype)
388
+ rescale_hw = rescale_hw.uniform_(-rescale_range, rescale_range).exp()
389
+ coords = coords * rescale_hw
390
+
391
+ return coords
392
+
393
+
394
+ class EomtDinov3RotaryEmbedding(nn.Module):
395
+ inv_freq: Tensor
396
+
397
+ def __init__(self, config: EomtDinov3Config, device=None):
398
+ super().__init__()
399
+ self.config = config
400
+
401
+ self.rope_type = self.config.rope_parameters["rope_type"]
402
+ rope_init_fn: Callable = self.compute_default_rope_parameters
403
+ if self.rope_type != "default":
404
+ raise ValueError("`EomtDinov3` only supports `default` RoPE! Please check your `rope_type`")
405
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
406
+
407
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
408
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
409
+
410
+ def forward(self, pixel_values: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
411
+ _, _, height, width = pixel_values.shape
412
+ num_patches_h = height // self.config.patch_size
413
+ num_patches_w = width // self.config.patch_size
414
+
415
+ device = pixel_values.device
416
+ device_type = device.type if isinstance(device.type, str) and device.type != "mps" else "cpu"
417
+
418
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
419
+ # Although we could precompute static patch_coords from image_size and patch_size in the config,
420
+ # the model was trained with random_scale, so it can process images of varying sizes.
421
+ # Therefore, it's better to compute patch_coords dynamically (with lru_cache).
422
+ patch_coords = get_patches_center_coordinates(
423
+ num_patches_h, num_patches_w, dtype=torch.float32, device=device
424
+ )
425
+ if self.training:
426
+ patch_coords = augment_patches_center_coordinates(
427
+ patch_coords,
428
+ shift=self.config.pos_embed_shift,
429
+ jitter=self.config.pos_embed_jitter,
430
+ rescale=self.config.pos_embed_rescale,
431
+ )
432
+
433
+ # (height * width, 2, head_dim / 4) -> (height * width, head_dim / 2) -> (height * width, head_dim)
434
+ angles = 2 * math.pi * patch_coords[:, :, None] * self.inv_freq[None, None, :]
435
+ angles = angles.flatten(1, 2)
436
+ angles = angles.tile(2)
437
+
438
+ cos = torch.cos(angles)
439
+ sin = torch.sin(angles)
440
+
441
+ dtype = pixel_values.dtype
442
+ return cos.to(dtype=dtype), sin.to(dtype=dtype)
443
+
444
+ @staticmethod
445
+ def compute_default_rope_parameters(
446
+ config: EomtDinov3Config | None = None,
447
+ device: Optional["torch.device"] = None,
448
+ seq_len: int | None = None,
449
+ ) -> torch.Tensor:
450
+ """
451
+ Computes the inverse frequencies according to the original RoPE implementation
452
+ Args:
453
+ config ([`~transformers.PreTrainedConfig`]):
454
+ The model configuration.
455
+ device (`torch.device`):
456
+ The device to use for initialization of the inverse frequencies.
457
+ seq_len (`int`, *optional*):
458
+ The current sequence length. Unused for this type of RoPE.
459
+ Returns:
460
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
461
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
462
+ """
463
+ base = config.rope_parameters["rope_theta"]
464
+ head_dim = config.hidden_size // config.num_attention_heads
465
+
466
+ attention_factor = 1.0 # Unused in this type of RoPE
467
+
468
+ # Compute the inverse frequencies
469
+ inv_freq = 1 / base ** torch.arange(0, 1, 4 / head_dim, dtype=torch.float32, device=device)
470
+ return inv_freq, attention_factor
471
+
472
+
473
+ # Adapted from https://github.com/facebookresearch/detectron2/blob/main/projects/PointRend/point_rend/point_features.py
474
+ def sample_point(
475
+ input_features: torch.Tensor, point_coordinates: torch.Tensor, add_dim=False, **kwargs
476
+ ) -> torch.Tensor:
477
+ """
478
+ A wrapper around `torch.nn.functional.grid_sample` to support 3D point_coordinates tensors.
479
+
480
+ Args:
481
+ input_features (`torch.Tensor` of shape (batch_size, channels, height, width)):
482
+ A tensor that contains features map on a height * width grid
483
+ point_coordinates (`torch.Tensor` of shape (batch_size, num_points, 2) or (batch_size, grid_height, grid_width,:
484
+ 2)):
485
+ A tensor that contains [0, 1] * [0, 1] normalized point coordinates
486
+ add_dim (`bool`):
487
+ boolean value to keep track of added dimension
488
+
489
+ Returns:
490
+ point_features (`torch.Tensor` of shape (batch_size, channels, num_points) or (batch_size, channels,
491
+ height_grid, width_grid):
492
+ A tensor that contains features for points in `point_coordinates`.
493
+ """
494
+ if point_coordinates.dim() == 3:
495
+ add_dim = True
496
+ point_coordinates = point_coordinates.unsqueeze(2)
497
+
498
+ # use nn.function.grid_sample to get features for points in `point_coordinates` via bilinear interpolation
499
+ point_features = torch.nn.functional.grid_sample(input_features, 2.0 * point_coordinates - 1.0, **kwargs)
500
+ if add_dim:
501
+ point_features = point_features.squeeze(3)
502
+
503
+ return point_features
504
+
505
+
506
+ def pair_wise_dice_loss(inputs: Tensor, labels: Tensor) -> Tensor:
507
+ """
508
+ A pair wise version of the dice loss, see `dice_loss` for usage.
509
+
510
+ Args:
511
+ inputs (`torch.Tensor`):
512
+ A tensor representing a mask
513
+ labels (`torch.Tensor`):
514
+ A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
515
+ (0 for the negative class and 1 for the positive class).
516
+
517
+ Returns:
518
+ `torch.Tensor`: The computed loss between each pairs.
519
+ """
520
+ inputs = inputs.sigmoid().flatten(1)
521
+ numerator = 2 * torch.matmul(inputs, labels.T)
522
+ # using broadcasting to get a [num_queries, NUM_CLASSES] matrix
523
+ denominator = inputs.sum(-1)[:, None] + labels.sum(-1)[None, :]
524
+ loss = 1 - (numerator + 1) / (denominator + 1)
525
+ return loss
526
+
527
+
528
+ def pair_wise_sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
529
+ r"""
530
+ A pair wise version of the cross entropy loss, see `sigmoid_cross_entropy_loss` for usage.
531
+
532
+ Args:
533
+ inputs (`torch.Tensor`):
534
+ A tensor representing a mask.
535
+ labels (`torch.Tensor`):
536
+ A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
537
+ (0 for the negative class and 1 for the positive class).
538
+
539
+ Returns:
540
+ loss (`torch.Tensor`): The computed loss between each pairs.
541
+ """
542
+
543
+ height_and_width = inputs.shape[1]
544
+
545
+ criterion = nn.BCEWithLogitsLoss(reduction="none")
546
+ cross_entropy_loss_pos = criterion(inputs, torch.ones_like(inputs))
547
+ cross_entropy_loss_neg = criterion(inputs, torch.zeros_like(inputs))
548
+
549
+ loss_pos = torch.matmul(cross_entropy_loss_pos / height_and_width, labels.T)
550
+ loss_neg = torch.matmul(cross_entropy_loss_neg / height_and_width, (1 - labels).T)
551
+ loss = loss_pos + loss_neg
552
+ return loss
553
+
554
+
555
+ # Adapted from https://github.com/facebookresearch/EomtDinov3/blob/main/eomt_dinov3/modeling/matcher.py
556
+ class EomtDinov3HungarianMatcher(nn.Module):
557
+ """This class computes an assignment between the labels and the predictions of the network.
558
+
559
+ For efficiency reasons, the labels don't include the no_object. Because of this, in general, there are more
560
+ predictions than labels. In this case, we do a 1-to-1 matching of the best predictions, while the others are
561
+ un-matched (and thus treated as non-objects).
562
+ """
563
+
564
+ def __init__(
565
+ self, cost_class: float = 1.0, cost_mask: float = 1.0, cost_dice: float = 1.0, num_points: int = 12544
566
+ ):
567
+ """Creates the matcher
568
+
569
+ Params:
570
+ cost_class (`float`, *optional*, defaults to 1.0):
571
+ Relative weight of the classification error in the matching cost.
572
+ cost_mask (`float`, *optional*, defaults to 1.0):
573
+ This is the relative weight of the focal loss of the binary mask in the matching cost.
574
+ cost_dice (`float`, *optional*, defaults to 1.0):
575
+ This is the relative weight of the dice loss of the binary mask in the matching cost.
576
+ num_points (`int`, *optional*, defaults to 12544):
577
+ No. of points to sample on which the mask loss will be calculated. The same set of K points are
578
+ uniformly sampled for all prediction and ground truth masks to construct the cost matrix for bipartite
579
+ matching.
580
+ """
581
+ super().__init__()
582
+ if cost_class == 0 and cost_mask == 0 and cost_dice == 0:
583
+ raise ValueError("All costs can't be 0")
584
+
585
+ self.num_points = num_points
586
+ self.cost_class = cost_class
587
+ self.cost_mask = cost_mask
588
+ self.cost_dice = cost_dice
589
+
590
+ @torch.no_grad()
591
+ def forward(
592
+ self,
593
+ masks_queries_logits: torch.Tensor,
594
+ class_queries_logits: torch.Tensor,
595
+ mask_labels: torch.Tensor,
596
+ class_labels: torch.Tensor,
597
+ ) -> list[tuple[Tensor]]:
598
+ """
599
+ Params:
600
+ masks_queries_logits (`torch.Tensor`):
601
+ A tensor of dim `batch_size, num_queries, num_labels` with the classification logits.
602
+ class_queries_logits (`torch.Tensor`):
603
+ A tensor of dim `batch_size, num_queries, height, width` with the predicted masks.
604
+ class_labels (`torch.Tensor`):
605
+ A tensor of dim `num_target_boxes` (where num_target_boxes is the number of ground-truth objects in the
606
+ target) containing the class labels.
607
+ mask_labels (`torch.Tensor`):
608
+ A tensor of dim `num_target_boxes, height, width` containing the target masks.
609
+
610
+ Returns:
611
+ matched_indices (`list[tuple[Tensor]]`): A list of size batch_size, containing tuples of (index_i, index_j)
612
+ where:
613
+ - index_i is the indices of the selected predictions (in order)
614
+ - index_j is the indices of the corresponding selected labels (in order)
615
+ For each batch element, it holds:
616
+ len(index_i) = len(index_j) = min(num_queries, num_target_boxes).
617
+ """
618
+ indices: list[tuple[np.array]] = []
619
+
620
+ # iterate through batch size
621
+ batch_size = masks_queries_logits.shape[0]
622
+ for i in range(batch_size):
623
+ pred_probs = class_queries_logits[i].softmax(-1)
624
+ pred_mask = masks_queries_logits[i]
625
+
626
+ # Compute the classification cost. Contrary to the loss, we don't use the NLL, but approximate it in 1 - proba[target class]. The 1 is a constant that doesn't change the matching, it can be omitted.
627
+ cost_class = -pred_probs[:, class_labels[i]]
628
+ target_mask = mask_labels[i].to(pred_mask)
629
+ target_mask = target_mask[:, None]
630
+ pred_mask = pred_mask[:, None]
631
+
632
+ # Sample ground truth and predicted masks
633
+ point_coordinates = torch.rand(1, self.num_points, 2, device=pred_mask.device)
634
+
635
+ target_coordinates = point_coordinates.repeat(target_mask.shape[0], 1, 1)
636
+ target_mask = sample_point(target_mask, target_coordinates, align_corners=False).squeeze(1)
637
+
638
+ pred_coordinates = point_coordinates.repeat(pred_mask.shape[0], 1, 1)
639
+ pred_mask = sample_point(pred_mask, pred_coordinates, align_corners=False).squeeze(1)
640
+
641
+ # compute the cross entropy loss between each mask pairs -> shape (num_queries, num_labels)
642
+ cost_mask = pair_wise_sigmoid_cross_entropy_loss(pred_mask, target_mask)
643
+ # Compute the dice loss between each mask pairs -> shape (num_queries, num_labels)
644
+ cost_dice = pair_wise_dice_loss(pred_mask, target_mask)
645
+ # final cost matrix
646
+ cost_matrix = self.cost_mask * cost_mask + self.cost_class * cost_class + self.cost_dice * cost_dice
647
+ # eliminate infinite values in cost_matrix to avoid the error ``ValueError: cost matrix is infeasible``
648
+ cost_matrix = torch.minimum(cost_matrix, torch.tensor(1e10))
649
+ cost_matrix = torch.maximum(cost_matrix, torch.tensor(-1e10))
650
+ cost_matrix = torch.nan_to_num(cost_matrix, 0)
651
+ # do the assignment using the hungarian algorithm in scipy
652
+ assigned_indices: tuple[np.array] = linear_sum_assignment(cost_matrix.cpu())
653
+ indices.append(assigned_indices)
654
+
655
+ # It could be stacked in one tensor
656
+ matched_indices = [
657
+ (torch.as_tensor(i, dtype=torch.int64), torch.as_tensor(j, dtype=torch.int64)) for i, j in indices
658
+ ]
659
+ return matched_indices
660
+
661
+
662
+ def dice_loss(inputs: Tensor, labels: Tensor, num_masks: int) -> Tensor:
663
+ r"""
664
+ Compute the DICE loss, similar to generalized IOU for masks as follows:
665
+
666
+ $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x \cap y }{x \cup y + 1}} $$
667
+
668
+ In practice, since `labels` is a binary mask, (only 0s and 1s), dice can be computed as follow
669
+
670
+ $$ \mathcal{L}_{\text{dice}(x, y) = 1 - \frac{2 * x * y }{x + y + 1}} $$
671
+
672
+ Args:
673
+ inputs (`torch.Tensor`):
674
+ A tensor representing a mask.
675
+ labels (`torch.Tensor`):
676
+ A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
677
+ (0 for the negative class and 1 for the positive class).
678
+ num_masks (`int`):
679
+ The number of masks present in the current batch, used for normalization.
680
+
681
+ Returns:
682
+ `torch.Tensor`: The computed loss.
683
+ """
684
+ probs = inputs.sigmoid().flatten(1)
685
+ numerator = 2 * (probs * labels).sum(-1)
686
+ denominator = probs.sum(-1) + labels.sum(-1)
687
+ loss = 1 - (numerator + 1) / (denominator + 1)
688
+ loss = loss.sum() / num_masks
689
+ return loss
690
+
691
+
692
+ def sigmoid_cross_entropy_loss(inputs: torch.Tensor, labels: torch.Tensor, num_masks: int) -> torch.Tensor:
693
+ r"""
694
+ Args:
695
+ inputs (`torch.Tensor`):
696
+ A float tensor of arbitrary shape.
697
+ labels (`torch.Tensor`):
698
+ A tensor with the same shape as inputs. Stores the binary classification labels for each element in inputs
699
+ (0 for the negative class and 1 for the positive class).
700
+
701
+ Returns:
702
+ loss (`torch.Tensor`): The computed loss.
703
+ """
704
+ criterion = nn.BCEWithLogitsLoss(reduction="none")
705
+ cross_entropy_loss = criterion(inputs, labels)
706
+
707
+ loss = cross_entropy_loss.mean(1).sum() / num_masks
708
+ return loss
709
+
710
+
711
+ # Adapted from https://github.com/facebookresearch/EomtDinov3/blob/main/eomt_dinov3/modeling/criterion.py
712
+ class EomtDinov3Loss(nn.Module):
713
+ def __init__(self, config: EomtDinov3Config, weight_dict: dict[str, float]):
714
+ """
715
+ The EomtDinov3 Loss. The loss is computed very similar to DETR. The process happens in two steps: 1) we
716
+ compute hungarian assignment between ground truth masks and the outputs of the model 2) we supervise each pair
717
+ of matched ground-truth / prediction (supervise class and mask)
718
+
719
+ Args:
720
+ config (`EomtDinov3Config`):
721
+ The configuration for EomtDinov3 model also containing loss calculation specific parameters.
722
+ weight_dict (`dict[str, float]`):
723
+ A dictionary of weights to be applied to the different losses.
724
+ """
725
+ super().__init__()
726
+ requires_backends(self, ["scipy"])
727
+ self.num_labels = config.num_labels
728
+ self.weight_dict = weight_dict
729
+
730
+ # Weight to apply to the null class
731
+ self.eos_coef = config.no_object_weight
732
+ empty_weight = torch.ones(self.num_labels + 1)
733
+ empty_weight[-1] = self.eos_coef
734
+ self.register_buffer("empty_weight", empty_weight)
735
+
736
+ # pointwise mask loss parameters
737
+ self.num_points = config.train_num_points
738
+ self.oversample_ratio = config.oversample_ratio
739
+ self.importance_sample_ratio = config.importance_sample_ratio
740
+
741
+ self.matcher = EomtDinov3HungarianMatcher(
742
+ cost_class=config.class_weight,
743
+ cost_dice=config.dice_weight,
744
+ cost_mask=config.mask_weight,
745
+ num_points=self.num_points,
746
+ )
747
+
748
+ def _max_by_axis(self, sizes: list[list[int]]) -> list[int]:
749
+ maxes = sizes[0]
750
+ for sublist in sizes[1:]:
751
+ for index, item in enumerate(sublist):
752
+ maxes[index] = max(maxes[index], item)
753
+ return maxes
754
+
755
+ # Adapted from nested_tensor_from_tensor_list() in original implementation
756
+ def _pad_images_to_max_in_batch(self, tensors: list[Tensor]) -> tuple[Tensor, Tensor]:
757
+ # get the maximum size in the batch
758
+ max_size = self._max_by_axis([list(tensor.shape) for tensor in tensors])
759
+ # compute final size
760
+ batch_shape = [len(tensors)] + max_size
761
+ batch_size, _, height, width = batch_shape
762
+ dtype = tensors[0].dtype
763
+ device = tensors[0].device
764
+ padded_tensors = torch.zeros(batch_shape, dtype=dtype, device=device)
765
+ padding_masks = torch.ones((batch_size, height, width), dtype=torch.bool, device=device)
766
+ # pad the tensors to the size of the biggest one
767
+ for tensor, padded_tensor, padding_mask in zip(tensors, padded_tensors, padding_masks):
768
+ padded_tensor[: tensor.shape[0], : tensor.shape[1], : tensor.shape[2]].copy_(tensor)
769
+ padding_mask[: tensor.shape[1], : tensor.shape[2]] = False
770
+
771
+ return padded_tensors, padding_masks
772
+
773
+ def loss_labels(
774
+ self, class_queries_logits: Tensor, class_labels: list[Tensor], indices: tuple[np.array]
775
+ ) -> dict[str, Tensor]:
776
+ """Compute the losses related to the labels using cross entropy.
777
+
778
+ Args:
779
+ class_queries_logits (`torch.Tensor`):
780
+ A tensor of shape `batch_size, num_queries, num_labels`
781
+ class_labels (`list[torch.Tensor]`):
782
+ List of class labels of shape `(labels)`.
783
+ indices (`tuple[np.array])`:
784
+ The indices computed by the Hungarian matcher.
785
+
786
+ Returns:
787
+ `dict[str, Tensor]`: A dict of `torch.Tensor` containing the following key:
788
+ - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
789
+ """
790
+ pred_logits = class_queries_logits
791
+ batch_size, num_queries, _ = pred_logits.shape
792
+ criterion = nn.CrossEntropyLoss(weight=self.empty_weight)
793
+ idx = self._get_predictions_permutation_indices(indices) # shape of (batch_size, num_queries)
794
+ target_classes_o = torch.cat(
795
+ [target[j] for target, (_, j) in zip(class_labels, indices)]
796
+ ) # shape of (batch_size, num_queries)
797
+ target_classes = torch.full(
798
+ (batch_size, num_queries), fill_value=self.num_labels, dtype=torch.int64, device=pred_logits.device
799
+ )
800
+ target_classes[idx] = target_classes_o
801
+ # Permute target_classes (batch_size, num_queries, num_labels) -> (batch_size, num_labels, num_queries)
802
+ pred_logits_transposed = pred_logits.transpose(1, 2)
803
+ loss_ce = criterion(pred_logits_transposed, target_classes)
804
+ losses = {"loss_cross_entropy": loss_ce}
805
+ return losses
806
+
807
+ def loss_masks(
808
+ self,
809
+ masks_queries_logits: torch.Tensor,
810
+ mask_labels: list[torch.Tensor],
811
+ indices: tuple[np.array],
812
+ num_masks: int,
813
+ ) -> dict[str, torch.Tensor]:
814
+ """Compute the losses related to the masks using sigmoid_cross_entropy_loss and dice loss.
815
+
816
+ Args:
817
+ masks_queries_logits (`torch.Tensor`):
818
+ A tensor of shape `(batch_size, num_queries, height, width)`.
819
+ mask_labels (`torch.Tensor`):
820
+ List of mask labels of shape `(labels, height, width)`.
821
+ indices (`tuple[np.array])`:
822
+ The indices computed by the Hungarian matcher.
823
+ num_masks (`int)`:
824
+ The number of masks, used for normalization.
825
+
826
+ Returns:
827
+ losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing two keys:
828
+ - **loss_mask** -- The loss computed using sigmoid cross entropy loss on the predicted and ground truth.
829
+ masks.
830
+ - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth,
831
+ masks.
832
+ """
833
+ src_idx = self._get_predictions_permutation_indices(indices)
834
+ tgt_idx = self._get_targets_permutation_indices(indices)
835
+ # shape (batch_size * num_queries, height, width)
836
+ pred_masks = masks_queries_logits[src_idx]
837
+ # shape (batch_size, num_queries, height, width)
838
+ # pad all and stack the targets to the num_labels dimension
839
+ target_masks, _ = self._pad_images_to_max_in_batch(mask_labels)
840
+ target_masks = target_masks[tgt_idx]
841
+
842
+ # No need to upsample predictions as we are using normalized coordinates
843
+ pred_masks = pred_masks[:, None]
844
+ target_masks = target_masks[:, None]
845
+
846
+ # Sample point coordinates
847
+ with torch.no_grad():
848
+ point_coordinates = self.sample_points_using_uncertainty(
849
+ pred_masks,
850
+ lambda logits: self.calculate_uncertainty(logits),
851
+ self.num_points,
852
+ self.oversample_ratio,
853
+ self.importance_sample_ratio,
854
+ )
855
+
856
+ point_labels = sample_point(target_masks, point_coordinates, align_corners=False).squeeze(1)
857
+
858
+ point_logits = sample_point(pred_masks, point_coordinates, align_corners=False).squeeze(1)
859
+
860
+ losses = {
861
+ "loss_mask": sigmoid_cross_entropy_loss(point_logits, point_labels, num_masks),
862
+ "loss_dice": dice_loss(point_logits, point_labels, num_masks),
863
+ }
864
+
865
+ del pred_masks
866
+ del target_masks
867
+ return losses
868
+
869
+ def _get_predictions_permutation_indices(self, indices):
870
+ # Permute predictions following indices
871
+ batch_indices = torch.cat([torch.full_like(src, i) for i, (src, _) in enumerate(indices)])
872
+ predictions_indices = torch.cat([src for (src, _) in indices])
873
+ return batch_indices, predictions_indices
874
+
875
+ def _get_targets_permutation_indices(self, indices):
876
+ # Permute labels following indices
877
+ batch_indices = torch.cat([torch.full_like(tgt, i) for i, (_, tgt) in enumerate(indices)])
878
+ target_indices = torch.cat([tgt for (_, tgt) in indices])
879
+ return batch_indices, target_indices
880
+
881
+ def calculate_uncertainty(self, logits: torch.Tensor) -> torch.Tensor:
882
+ """
883
+ In EomtDinov3 paper, uncertainty is estimated as L1 distance between 0.0 and the logit prediction in 'logits'
884
+ for the foreground class in `classes`.
885
+
886
+ Args:
887
+ logits (`torch.Tensor`):
888
+ A tensor of shape (R, 1, ...) for class-specific or class-agnostic, where R is the total number of predicted masks in all images and C is:
889
+ the number of foreground classes. The values are logits.
890
+
891
+ Returns:
892
+ scores (`torch.Tensor`): A tensor of shape (R, 1, ...) that contains uncertainty scores with the most
893
+ uncertain locations having the highest uncertainty score.
894
+ """
895
+ uncertainty_scores = -(torch.abs(logits))
896
+ return uncertainty_scores
897
+
898
+ def sample_points_using_uncertainty(
899
+ self,
900
+ logits: torch.Tensor,
901
+ uncertainty_function,
902
+ num_points: int,
903
+ oversample_ratio: int,
904
+ importance_sample_ratio: float,
905
+ ) -> torch.Tensor:
906
+ """
907
+ This function is meant for sampling points in [0, 1] * [0, 1] coordinate space based on their uncertainty. The
908
+ uncertainty is calculated for each point using the passed `uncertainty function` that takes points logit
909
+ prediction as input.
910
+
911
+ Args:
912
+ logits (`float`):
913
+ Logit predictions for P points.
914
+ uncertainty_function:
915
+ A function that takes logit predictions for P points and returns their uncertainties.
916
+ num_points (`int`):
917
+ The number of points P to sample.
918
+ oversample_ratio (`int`):
919
+ Oversampling parameter.
920
+ importance_sample_ratio (`float`):
921
+ Ratio of points that are sampled via importance sampling.
922
+
923
+ Returns:
924
+ point_coordinates (`torch.Tensor`):
925
+ Coordinates for P sampled points.
926
+ """
927
+
928
+ num_boxes = logits.shape[0]
929
+ num_points_sampled = int(num_points * oversample_ratio)
930
+
931
+ # Get random point coordinates
932
+ point_coordinates = torch.rand(num_boxes, num_points_sampled, 2, device=logits.device)
933
+ # Get sampled prediction value for the point coordinates
934
+ point_logits = sample_point(logits, point_coordinates, align_corners=False)
935
+ # Calculate the uncertainties based on the sampled prediction values of the points
936
+ point_uncertainties = uncertainty_function(point_logits)
937
+
938
+ num_uncertain_points = int(importance_sample_ratio * num_points)
939
+ num_random_points = num_points - num_uncertain_points
940
+
941
+ idx = torch.topk(point_uncertainties[:, 0, :], k=num_uncertain_points, dim=1)[1]
942
+ shift = num_points_sampled * torch.arange(num_boxes, dtype=torch.long, device=logits.device)
943
+ idx += shift[:, None]
944
+ point_coordinates = point_coordinates.view(-1, 2)[idx.view(-1), :].view(num_boxes, num_uncertain_points, 2)
945
+
946
+ if num_random_points > 0:
947
+ point_coordinates = torch.cat(
948
+ [point_coordinates, torch.rand(num_boxes, num_random_points, 2, device=logits.device)],
949
+ dim=1,
950
+ )
951
+ return point_coordinates
952
+
953
+ def forward(
954
+ self,
955
+ masks_queries_logits: torch.Tensor,
956
+ class_queries_logits: torch.Tensor,
957
+ mask_labels: list[torch.Tensor],
958
+ class_labels: list[torch.Tensor],
959
+ auxiliary_predictions: dict[str, torch.Tensor] | None = None,
960
+ ) -> dict[str, torch.Tensor]:
961
+ """
962
+ This performs the loss computation.
963
+
964
+ Args:
965
+ masks_queries_logits (`torch.Tensor`):
966
+ A tensor of shape `(batch_size, num_queries, height, width)`.
967
+ class_queries_logits (`torch.Tensor`):
968
+ A tensor of shape `(batch_size, num_queries, num_labels)`.
969
+ mask_labels (`torch.Tensor`):
970
+ List of mask labels of shape `(labels, height, width)`.
971
+ class_labels (`list[torch.Tensor]`):
972
+ List of class labels of shape `(labels)`.
973
+ auxiliary_predictions (`dict[str, torch.Tensor]`, *optional*):
974
+ if `use_auxiliary_loss` was set to `true` in [`EomtDinov3Config`], then it contains the logits from
975
+ the inner layers of the EomtDinov3MaskedAttentionDecoder.
976
+
977
+ Returns:
978
+ losses (`dict[str, Tensor]`): A dict of `torch.Tensor` containing three keys:
979
+ - **loss_cross_entropy** -- The loss computed using cross entropy on the predicted and ground truth labels.
980
+ - **loss_mask** -- The loss computed using sigmoid cross_entropy loss on the predicted and ground truth
981
+ masks.
982
+ - **loss_dice** -- The loss computed using dice loss on the predicted on the predicted and ground truth
983
+ masks.
984
+ if `use_auxiliary_loss` was set to `true` in [`EomtDinov3Config`], the dictionary contains additional
985
+ losses for each auxiliary predictions.
986
+ """
987
+
988
+ # retrieve the matching between the outputs of the last layer and the labels
989
+ indices = self.matcher(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
990
+ # compute the average number of target masks for normalization purposes
991
+ num_masks = self.get_num_masks(class_labels, device=class_labels[0].device)
992
+ # get all the losses
993
+ losses: dict[str, Tensor] = {
994
+ **self.loss_masks(masks_queries_logits, mask_labels, indices, num_masks),
995
+ **self.loss_labels(class_queries_logits, class_labels, indices),
996
+ }
997
+ # in case of auxiliary losses, we repeat this process with the output of each intermediate layer.
998
+ if auxiliary_predictions is not None:
999
+ for idx, aux_outputs in enumerate(auxiliary_predictions):
1000
+ masks_queries_logits = aux_outputs["masks_queries_logits"]
1001
+ class_queries_logits = aux_outputs["class_queries_logits"]
1002
+ loss_dict = self.forward(masks_queries_logits, class_queries_logits, mask_labels, class_labels)
1003
+ loss_dict = {f"{key}_{idx}": value for key, value in loss_dict.items()}
1004
+ losses.update(loss_dict)
1005
+
1006
+ return losses
1007
+
1008
+ def get_num_masks(self, class_labels: torch.Tensor, device: torch.device) -> torch.Tensor:
1009
+ """
1010
+ Computes the average number of target masks across the batch, for normalization purposes.
1011
+ """
1012
+ num_masks = sum(len(classes) for classes in class_labels)
1013
+ num_masks = torch.as_tensor(num_masks, dtype=torch.float, device=device)
1014
+ world_size = 1
1015
+ if is_accelerate_available():
1016
+ if PartialState._shared_state != {}:
1017
+ num_masks = reduce(num_masks)
1018
+ world_size = PartialState().num_processes
1019
+
1020
+ num_masks = torch.clamp(num_masks / world_size, min=1)
1021
+ return num_masks
1022
+
1023
+
1024
+ @dataclass
1025
+ @auto_docstring(
1026
+ custom_intro="""
1027
+ Class for outputs of [`EomtDinov3ForUniversalSegmentationOutput`].
1028
+
1029
+ This output can be directly passed to [`~EomtDinov3ImageProcessor.post_process_semantic_segmentation`] or
1030
+ [`~EomtDinov3ImageProcessor.post_process_instance_segmentation`] or
1031
+ [`~EomtDinov3ImageProcessor.post_process_panoptic_segmentation`] to compute final segmentation maps. Please, see
1032
+ [`~EomtDinov3ImageProcessor] for details regarding usage.
1033
+ """
1034
+ )
1035
+ class EomtDinov3ForUniversalSegmentationOutput(ModelOutput):
1036
+ r"""
1037
+ loss (`torch.Tensor`, *optional*):
1038
+ The computed loss, returned when labels are present.
1039
+ class_queries_logits (`torch.FloatTensor`):
1040
+ A tensor of shape `(batch_size, num_queries, num_labels + 1)` representing the proposed classes for each
1041
+ query. Note the `+ 1` is needed because we incorporate the null class.
1042
+ masks_queries_logits (`torch.FloatTensor`):
1043
+ A tensor of shape `(batch_size, num_queries, height, width)` representing the proposed masks for each
1044
+ query.
1045
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
1046
+ Last hidden states (final feature map) of the last layer.
1047
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
1048
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each stage) of
1049
+ shape `(batch_size, sequence_length, hidden_size)`. Hidden-states all layers of the model.
1050
+ attentions (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
1051
+ Tuple of `tuple(torch.FloatTensor)` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1052
+ sequence_length)`. Self and Cross Attentions weights from transformer decoder.
1053
+ patch_offsets (`list[torch.Tensor]`, *optional*):
1054
+ list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
1055
+ """
1056
+
1057
+ loss: torch.FloatTensor | None = None
1058
+ class_queries_logits: torch.FloatTensor | None = None
1059
+ masks_queries_logits: torch.FloatTensor | None = None
1060
+ last_hidden_state: torch.FloatTensor | None = None
1061
+ hidden_states: tuple[torch.FloatTensor] | None = None
1062
+ attentions: tuple[torch.FloatTensor] | None = None
1063
+ patch_offsets: list[torch.Tensor] | None = None
1064
+
1065
+
1066
+ @auto_docstring
1067
+ class EomtDinov3PreTrainedModel(PreTrainedModel):
1068
+ """
1069
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
1070
+ models.
1071
+ """
1072
+
1073
+ config: EomtDinov3Config
1074
+ base_model_prefix = "eomt_dinov3"
1075
+ main_input_name = "pixel_values"
1076
+ input_modalities = ("image",)
1077
+ supports_gradient_checkpointing = False
1078
+ _no_split_modules = ["EomtDinov3Layer"]
1079
+ _supports_sdpa = True
1080
+ _can_record_outputs = {
1081
+ "hidden_states": EomtDinov3Layer,
1082
+ "attentions": EomtDinov3Attention,
1083
+ }
1084
+ config_class = EomtDinov3Config
1085
+
1086
+ @torch.no_grad()
1087
+ def _init_weights(self, module: nn.Module) -> None:
1088
+ super()._init_weights(module)
1089
+ std = self.config.initializer_range
1090
+ if isinstance(module, EomtDinov3LayerScale):
1091
+ if hasattr(module, "lambda1"):
1092
+ init.constant_(module.lambda1, self.config.layerscale_value)
1093
+ elif isinstance(module, EomtDinov3Embeddings):
1094
+ init.trunc_normal_(module.cls_token, mean=0.0, std=std)
1095
+ init.zeros_(module.register_tokens)
1096
+ elif isinstance(module, EomtDinov3Loss):
1097
+ empty_weight = torch.ones(module.num_labels + 1)
1098
+ empty_weight[-1] = module.eos_coef
1099
+ init.copy_(module.empty_weight, empty_weight)
1100
+ elif isinstance(module, EomtDinov3ForUniversalSegmentation):
1101
+ init.ones_(module.attn_mask_probs)
1102
+
1103
+
1104
+ class EomtDinov3LayerNorm2d(nn.LayerNorm):
1105
+ def __init__(self, num_channels, eps=1e-6, affine=True):
1106
+ super().__init__(num_channels, eps=eps, elementwise_affine=affine)
1107
+
1108
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
1109
+ hidden_state = hidden_state.permute(0, 2, 3, 1)
1110
+ hidden_state = F.layer_norm(hidden_state, self.normalized_shape, self.weight, self.bias, self.eps)
1111
+ hidden_state = hidden_state.permute(0, 3, 1, 2)
1112
+ return hidden_state
1113
+
1114
+
1115
+ class EomtDinov3ScaleLayer(nn.Module):
1116
+ def __init__(self, config: EomtDinov3Config):
1117
+ super().__init__()
1118
+ hidden_size = config.hidden_size
1119
+ self.conv1 = nn.ConvTranspose2d(hidden_size, hidden_size, kernel_size=2, stride=2)
1120
+ self.activation = ACT2FN[config.hidden_act]
1121
+ self.conv2 = nn.Conv2d(
1122
+ hidden_size,
1123
+ hidden_size,
1124
+ kernel_size=3,
1125
+ padding=1,
1126
+ groups=hidden_size,
1127
+ bias=False,
1128
+ )
1129
+
1130
+ self.layernorm2d = EomtDinov3LayerNorm2d(hidden_size)
1131
+
1132
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1133
+ hidden_states = self.conv1(hidden_states)
1134
+ hidden_states = self.activation(hidden_states)
1135
+ hidden_states = self.conv2(hidden_states)
1136
+ hidden_states = self.layernorm2d(hidden_states)
1137
+ return hidden_states
1138
+
1139
+
1140
+ class EomtDinov3ScaleBlock(nn.Module):
1141
+ def __init__(self, config: EomtDinov3Config):
1142
+ super().__init__()
1143
+ self.num_blocks = config.num_upscale_blocks
1144
+ self.block = nn.ModuleList([EomtDinov3ScaleLayer(config) for _ in range(self.num_blocks)])
1145
+
1146
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1147
+ for block in self.block:
1148
+ hidden_states = block(hidden_states)
1149
+ return hidden_states
1150
+
1151
+
1152
+ class EomtDinov3MaskHead(nn.Module):
1153
+ def __init__(self, config: EomtDinov3Config):
1154
+ super().__init__()
1155
+
1156
+ hidden_size = config.hidden_size
1157
+ self.fc1 = nn.Linear(hidden_size, hidden_size)
1158
+ self.fc2 = nn.Linear(hidden_size, hidden_size)
1159
+ self.fc3 = nn.Linear(hidden_size, hidden_size)
1160
+ self.activation = ACT2FN[config.hidden_act]
1161
+
1162
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
1163
+ hidden_states = self.activation(self.fc1(hidden_states))
1164
+ hidden_states = self.activation(self.fc2(hidden_states))
1165
+ hidden_states = self.fc3(hidden_states)
1166
+ return hidden_states
1167
+
1168
+
1169
+ @auto_docstring(
1170
+ custom_intro="""
1171
+ The EoMT-DINOv3 model with head on top for instance/semantic/panoptic segmentation.
1172
+ """,
1173
+ )
1174
+ class EomtDinov3ForUniversalSegmentation(EomtDinov3PreTrainedModel):
1175
+ main_input_name = "pixel_values"
1176
+
1177
+ def __init__(self, config: EomtDinov3Config):
1178
+ super().__init__(config)
1179
+ self.config = config
1180
+ self.num_hidden_layers = config.num_hidden_layers
1181
+ self.embeddings = EomtDinov3Embeddings(config)
1182
+ self.layernorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
1183
+
1184
+ self.query = nn.Embedding(config.num_queries, config.hidden_size)
1185
+ self.layers = nn.ModuleList([EomtDinov3Layer(config) for _ in range(config.num_hidden_layers)])
1186
+
1187
+ self.upscale_block = EomtDinov3ScaleBlock(config)
1188
+ self.mask_head = EomtDinov3MaskHead(config)
1189
+
1190
+ self.class_predictor = nn.Linear(config.hidden_size, config.num_labels + 1)
1191
+
1192
+ self.grid_size = (config.image_size // config.patch_size, config.image_size // config.patch_size)
1193
+ self.weight_dict: dict[str, float] = {
1194
+ "loss_cross_entropy": config.class_weight,
1195
+ "loss_mask": config.mask_weight,
1196
+ "loss_dice": config.dice_weight,
1197
+ }
1198
+
1199
+ self.criterion = EomtDinov3Loss(config=config, weight_dict=self.weight_dict)
1200
+
1201
+ self.register_buffer("attn_mask_probs", torch.ones(config.num_blocks))
1202
+
1203
+ self.num_prefix_tokens = 1 + config.num_register_tokens
1204
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
1205
+ self.embeddings.register_parameter("mask_token", None)
1206
+
1207
+ self.rope_embeddings = EomtDinov3RotaryEmbedding(config)
1208
+
1209
+ self.post_init()
1210
+
1211
+ def get_loss_dict(
1212
+ self,
1213
+ masks_queries_logits: Tensor,
1214
+ class_queries_logits: Tensor,
1215
+ mask_labels: Tensor,
1216
+ class_labels: Tensor,
1217
+ auxiliary_predictions: dict[str, Tensor],
1218
+ ) -> dict[str, Tensor]:
1219
+ loss_dict: dict[str, Tensor] = self.criterion(
1220
+ masks_queries_logits=masks_queries_logits,
1221
+ class_queries_logits=class_queries_logits,
1222
+ mask_labels=mask_labels,
1223
+ class_labels=class_labels,
1224
+ auxiliary_predictions=auxiliary_predictions,
1225
+ )
1226
+
1227
+ # weight each loss by `self.weight_dict[<LOSS_NAME>]` including auxiliary losses
1228
+ for key, weight in self.weight_dict.items():
1229
+ for loss_key, loss in loss_dict.items():
1230
+ if key in loss_key:
1231
+ loss *= weight
1232
+
1233
+ return loss_dict
1234
+
1235
+ def get_loss(self, loss_dict: dict[str, Tensor]) -> Tensor:
1236
+ return sum(loss_dict.values())
1237
+
1238
+ @check_model_inputs
1239
+ @auto_docstring
1240
+ def forward(
1241
+ self,
1242
+ pixel_values: Tensor,
1243
+ mask_labels: list[Tensor] | None = None,
1244
+ class_labels: list[Tensor] | None = None,
1245
+ patch_offsets: list[Tensor] | None = None,
1246
+ **kwargs: Unpack[TransformersKwargs],
1247
+ ) -> EomtDinov3ForUniversalSegmentationOutput:
1248
+ r"""
1249
+ mask_labels (`list[torch.Tensor]`, *optional*):
1250
+ list of mask labels of shape `(num_labels, height, width)` to be fed to a model
1251
+ class_labels (`list[torch.LongTensor]`, *optional*):
1252
+ list of target class labels of shape `(num_labels, height, width)` to be fed to a model. They identify the
1253
+ labels of `mask_labels`, e.g. the label of `mask_labels[i][j]` if `class_labels[i][j]`.
1254
+ patch_offsets (`list[torch.Tensor]`, *optional*):
1255
+ list of tuples indicating the image index and start and end positions of patches for semantic segmentation.
1256
+ """
1257
+ masks_queries_logits_per_layer, class_queries_logits_per_layer = (), ()
1258
+
1259
+ hidden_states = self.dropout(self.embeddings(pixel_values))
1260
+ position_embeddings = self.rope_embeddings(pixel_values.to(hidden_states.dtype))
1261
+ attention_mask = None
1262
+
1263
+ for idx, layer_module in enumerate(self.layers):
1264
+ if idx == self.num_hidden_layers - self.config.num_blocks:
1265
+ query = self.query.weight[None, :, :].expand(hidden_states.shape[0], -1, -1).to(hidden_states.device)
1266
+ hidden_states = torch.cat((query, hidden_states), dim=1)
1267
+
1268
+ if idx >= self.num_hidden_layers - self.config.num_blocks and (
1269
+ self.training or self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks] > 0
1270
+ ):
1271
+ norm_hidden_states = self.layernorm(hidden_states)
1272
+ masks_queries_logits, class_queries_logits = self.predict(norm_hidden_states)
1273
+
1274
+ masks_queries_logits_per_layer += (masks_queries_logits,)
1275
+ class_queries_logits_per_layer += (class_queries_logits,)
1276
+
1277
+ attention_mask = torch.ones(
1278
+ hidden_states.shape[0],
1279
+ hidden_states.shape[1],
1280
+ hidden_states.shape[1],
1281
+ device=hidden_states.device,
1282
+ dtype=torch.bool,
1283
+ )
1284
+
1285
+ interpolated_logits = F.interpolate(masks_queries_logits, size=self.grid_size, mode="bilinear")
1286
+ interpolated_logits = interpolated_logits.view(
1287
+ interpolated_logits.size(0), interpolated_logits.size(1), -1
1288
+ )
1289
+
1290
+ num_query_tokens = self.config.num_queries
1291
+ encoder_start_tokens = num_query_tokens + self.num_prefix_tokens
1292
+
1293
+ # Set attention mask for queries to focus on encoder tokens based on interpolated logits
1294
+ attention_mask[:, :num_query_tokens, encoder_start_tokens:] = interpolated_logits > 0
1295
+
1296
+ # Disable attention mask for random query tokens.
1297
+ attention_mask = self._disable_attention_mask(
1298
+ attention_mask,
1299
+ prob=self.attn_mask_probs[idx - self.num_hidden_layers + self.config.num_blocks],
1300
+ num_query_tokens=num_query_tokens,
1301
+ encoder_start_tokens=encoder_start_tokens,
1302
+ device=attention_mask.device,
1303
+ )
1304
+
1305
+ # Expand attention mask to 4d mask.
1306
+ attention_mask = attention_mask[:, None, ...].expand(-1, self.config.num_attention_heads, -1, -1)
1307
+ dtype_min = torch.finfo(hidden_states.dtype).min
1308
+ attention_mask = attention_mask.to(hidden_states.dtype).masked_fill(~attention_mask, dtype_min)
1309
+
1310
+ hidden_states = layer_module(
1311
+ hidden_states,
1312
+ attention_mask=attention_mask,
1313
+ position_embeddings=position_embeddings,
1314
+ )
1315
+
1316
+ sequence_output = self.layernorm(hidden_states)
1317
+
1318
+ masks_queries_logits, class_queries_logits = self.predict(sequence_output)
1319
+ masks_queries_logits_per_layer += (masks_queries_logits,)
1320
+ class_queries_logits_per_layer += (class_queries_logits,)
1321
+
1322
+ loss = None
1323
+ if mask_labels is not None and class_labels is not None:
1324
+ loss = 0.0
1325
+ for masks_queries_logits, class_queries_logits in zip(
1326
+ masks_queries_logits_per_layer, class_queries_logits_per_layer
1327
+ ):
1328
+ loss_dict = self.get_loss_dict(
1329
+ masks_queries_logits=masks_queries_logits,
1330
+ class_queries_logits=class_queries_logits,
1331
+ mask_labels=mask_labels,
1332
+ class_labels=class_labels,
1333
+ auxiliary_predictions=None,
1334
+ )
1335
+ loss += self.get_loss(loss_dict)
1336
+
1337
+ return EomtDinov3ForUniversalSegmentationOutput(
1338
+ loss=loss,
1339
+ masks_queries_logits=masks_queries_logits,
1340
+ class_queries_logits=class_queries_logits,
1341
+ last_hidden_state=sequence_output,
1342
+ patch_offsets=patch_offsets,
1343
+ )
1344
+
1345
+ def get_input_embeddings(self):
1346
+ return self.embeddings.patch_embeddings
1347
+
1348
+ def predict(self, logits: torch.Tensor):
1349
+ query_tokens = logits[:, : self.config.num_queries, :]
1350
+ class_logits = self.class_predictor(query_tokens)
1351
+
1352
+ prefix_tokens = logits[:, self.config.num_queries + self.embeddings.num_prefix_tokens :, :]
1353
+ prefix_tokens = prefix_tokens.transpose(1, 2)
1354
+
1355
+ prefix_tokens = prefix_tokens.reshape(prefix_tokens.shape[0], -1, *self.grid_size)
1356
+
1357
+ query_tokens = self.mask_head(query_tokens)
1358
+ prefix_tokens = self.upscale_block(prefix_tokens)
1359
+
1360
+ mask_logits = torch.einsum("bqc, bchw -> bqhw", query_tokens, prefix_tokens)
1361
+
1362
+ return mask_logits, class_logits
1363
+
1364
+ @staticmethod
1365
+ def _disable_attention_mask(attn_mask, prob, num_query_tokens, encoder_start_tokens, device):
1366
+ if prob < 1:
1367
+ # Generate random queries to disable based on the probs
1368
+ random_queries = torch.rand(attn_mask.shape[0], num_query_tokens, device=device) > prob
1369
+
1370
+ # Disable attention to the query tokens, considering the prefix tokens
1371
+ attn_mask[:, :num_query_tokens, encoder_start_tokens:][random_queries] = 1
1372
+
1373
+ return attn_mask
1374
+
1375
+
1376
+ __all__ = ["EomtDinov3PreTrainedModel", "EomtDinov3ForUniversalSegmentation"]