transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,2083 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_pp_doclayout_v3.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 The PaddlePaddle Team and The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import math
22
+ import warnings
23
+ from collections.abc import Callable
24
+ from dataclasses import dataclass
25
+
26
+ import numpy as np
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from torch import Tensor, nn
30
+
31
+ from ... import initialization as init
32
+ from ...activations import ACT2CLS, ACT2FN
33
+ from ...backbone_utils import load_backbone
34
+ from ...image_transforms import center_to_corners_format, corners_to_center_format
35
+ from ...integrations import use_kernel_forward_from_hub
36
+ from ...modeling_outputs import BaseModelOutput
37
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from ...processing_utils import Unpack
39
+ from ...pytorch_utils import compile_compatible_method_lru_cache
40
+ from ...utils import (
41
+ ModelOutput,
42
+ TransformersKwargs,
43
+ auto_docstring,
44
+ torch_compilable_check,
45
+ torch_int,
46
+ )
47
+ from ...utils.generic import can_return_tuple, check_model_inputs
48
+ from .configuration_pp_doclayout_v3 import PPDocLayoutV3Config
49
+
50
+
51
+ class PPDocLayoutV3GlobalPointer(nn.Module):
52
+ def __init__(self, config):
53
+ super().__init__()
54
+ self.head_size = config.global_pointer_head_size
55
+ self.dense = nn.Linear(config.d_model, self.head_size * 2)
56
+ self.dropout = nn.Dropout(config.gp_dropout_value)
57
+
58
+ def forward(self, inputs):
59
+ batch_size, sequence_length, _ = inputs.shape
60
+ query_key_projection = self.dense(inputs).reshape(batch_size, sequence_length, 2, self.head_size)
61
+ query_key_projection = self.dropout(query_key_projection)
62
+ queries, keys = torch.unbind(query_key_projection, dim=2)
63
+
64
+ logits = (queries @ keys.transpose(-2, -1)) / (self.head_size**0.5)
65
+ mask = torch.tril(torch.ones(sequence_length, sequence_length, device=logits.device)).bool()
66
+ logits = logits.masked_fill(mask.unsqueeze(0), -1e4)
67
+
68
+ return logits
69
+
70
+
71
+ @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
72
+ class MultiScaleDeformableAttention(nn.Module):
73
+ def forward(
74
+ self,
75
+ value: Tensor,
76
+ value_spatial_shapes: Tensor,
77
+ value_spatial_shapes_list: list[tuple],
78
+ level_start_index: Tensor,
79
+ sampling_locations: Tensor,
80
+ attention_weights: Tensor,
81
+ im2col_step: int,
82
+ ):
83
+ batch_size, _, num_heads, hidden_dim = value.shape
84
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
85
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
86
+ sampling_grids = 2 * sampling_locations - 1
87
+ sampling_value_list = []
88
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
89
+ # batch_size, height*width, num_heads, hidden_dim
90
+ # -> batch_size, height*width, num_heads*hidden_dim
91
+ # -> batch_size, num_heads*hidden_dim, height*width
92
+ # -> batch_size*num_heads, hidden_dim, height, width
93
+ value_l_ = (
94
+ value_list[level_id]
95
+ .flatten(2)
96
+ .transpose(1, 2)
97
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
98
+ )
99
+ # batch_size, num_queries, num_heads, num_points, 2
100
+ # -> batch_size, num_heads, num_queries, num_points, 2
101
+ # -> batch_size*num_heads, num_queries, num_points, 2
102
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
103
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
104
+ sampling_value_l_ = nn.functional.grid_sample(
105
+ value_l_,
106
+ sampling_grid_l_,
107
+ mode="bilinear",
108
+ padding_mode="zeros",
109
+ align_corners=False,
110
+ )
111
+ sampling_value_list.append(sampling_value_l_)
112
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
113
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
114
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
115
+ attention_weights = attention_weights.transpose(1, 2).reshape(
116
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
117
+ )
118
+ output = (
119
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
120
+ .sum(-1)
121
+ .view(batch_size, num_heads * hidden_dim, num_queries)
122
+ )
123
+ return output.transpose(1, 2).contiguous()
124
+
125
+
126
+ class PPDocLayoutV3MultiscaleDeformableAttention(nn.Module):
127
+ """
128
+ Multiscale deformable attention as proposed in Deformable DETR.
129
+ """
130
+
131
+ def __init__(self, config: PPDocLayoutV3Config, num_heads: int, n_points: int):
132
+ super().__init__()
133
+
134
+ self.attn = MultiScaleDeformableAttention()
135
+
136
+ if config.d_model % num_heads != 0:
137
+ raise ValueError(
138
+ f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
139
+ )
140
+ dim_per_head = config.d_model // num_heads
141
+ # check if dim_per_head is power of 2
142
+ if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
143
+ warnings.warn(
144
+ "You'd better set embed_dim (d_model) in PPDocLayoutV3MultiscaleDeformableAttention to make the"
145
+ " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
146
+ " implementation."
147
+ )
148
+
149
+ self.im2col_step = 64
150
+
151
+ self.d_model = config.d_model
152
+ self.n_levels = config.num_feature_levels
153
+ self.n_heads = num_heads
154
+ self.n_points = n_points
155
+
156
+ self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
157
+ self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
158
+ self.value_proj = nn.Linear(config.d_model, config.d_model)
159
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
160
+
161
+ self.disable_custom_kernels = config.disable_custom_kernels
162
+
163
+ def forward(
164
+ self,
165
+ hidden_states: torch.Tensor,
166
+ attention_mask: torch.Tensor | None = None,
167
+ encoder_hidden_states=None,
168
+ encoder_attention_mask=None,
169
+ position_embeddings: torch.Tensor | None = None,
170
+ reference_points=None,
171
+ spatial_shapes=None,
172
+ spatial_shapes_list=None,
173
+ level_start_index=None,
174
+ **kwargs: Unpack[TransformersKwargs],
175
+ ) -> tuple[torch.Tensor, torch.Tensor]:
176
+ # add position embeddings to the hidden states before projecting to queries and keys
177
+ if position_embeddings is not None:
178
+ hidden_states = hidden_states + position_embeddings
179
+
180
+ batch_size, num_queries, _ = hidden_states.shape
181
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
182
+ total_elements = sum(height * width for height, width in spatial_shapes_list)
183
+ torch_compilable_check(
184
+ total_elements == sequence_length,
185
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
186
+ )
187
+
188
+ value = self.value_proj(encoder_hidden_states)
189
+ if attention_mask is not None:
190
+ # we invert the attention_mask
191
+ value = value.masked_fill(~attention_mask[..., None], float(0))
192
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
193
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
194
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
195
+ )
196
+ attention_weights = self.attention_weights(hidden_states).view(
197
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
198
+ )
199
+ attention_weights = F.softmax(attention_weights, -1).view(
200
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
201
+ )
202
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
203
+ num_coordinates = reference_points.shape[-1]
204
+ if num_coordinates == 2:
205
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
206
+ sampling_locations = (
207
+ reference_points[:, :, None, :, None, :]
208
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
209
+ )
210
+ elif num_coordinates == 4:
211
+ sampling_locations = (
212
+ reference_points[:, :, None, :, None, :2]
213
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
214
+ )
215
+ else:
216
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
217
+
218
+ output = self.attn(
219
+ value,
220
+ spatial_shapes,
221
+ spatial_shapes_list,
222
+ level_start_index,
223
+ sampling_locations,
224
+ attention_weights,
225
+ self.im2col_step,
226
+ )
227
+
228
+ output = self.output_proj(output)
229
+
230
+ return output, attention_weights
231
+
232
+
233
+ @auto_docstring
234
+ class PPDocLayoutV3PreTrainedModel(PreTrainedModel):
235
+ config: PPDocLayoutV3Config
236
+ base_model_prefix = "pp_doclayout_v3"
237
+ main_input_name = "pixel_values"
238
+ input_modalities = ("image",)
239
+ _no_split_modules = [r"PPDocLayoutV3HybridEncoder", r"PPDocLayoutV3DecoderLayer"]
240
+ _supports_sdpa = True
241
+ _supports_flash_attn = True
242
+ _supports_attention_backend = True
243
+ _supports_flex_attn = True
244
+
245
+ @torch.no_grad()
246
+ def _init_weights(self, module):
247
+ """Initialize the weights"""
248
+ if isinstance(module, PPDocLayoutV3MultiscaleDeformableAttention):
249
+ init.constant_(module.sampling_offsets.weight, 0.0)
250
+ default_dtype = torch.get_default_dtype()
251
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).to(default_dtype) * (
252
+ 2.0 * math.pi / module.n_heads
253
+ )
254
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
255
+ grid_init = (
256
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
257
+ .view(module.n_heads, 1, 1, 2)
258
+ .repeat(1, module.n_levels, module.n_points, 1)
259
+ )
260
+ for i in range(module.n_points):
261
+ grid_init[:, :, i, :] *= i + 1
262
+
263
+ init.copy_(module.sampling_offsets.bias, grid_init.view(-1))
264
+ init.constant_(module.attention_weights.weight, 0.0)
265
+ init.constant_(module.attention_weights.bias, 0.0)
266
+ init.xavier_uniform_(module.value_proj.weight)
267
+ init.constant_(module.value_proj.bias, 0.0)
268
+ init.xavier_uniform_(module.output_proj.weight)
269
+ init.constant_(module.output_proj.bias, 0.0)
270
+
271
+ elif isinstance(module, PPDocLayoutV3Model):
272
+ prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
273
+ bias = float(-math.log((1 - prior_prob) / prior_prob))
274
+ init.xavier_uniform_(module.enc_score_head.weight)
275
+ init.constant_(module.enc_score_head.bias, bias)
276
+ init.xavier_uniform_(module.decoder.class_embed.weight)
277
+ init.constant_(module.decoder.class_embed.bias, bias)
278
+
279
+ elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
280
+ init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
281
+ if module.bias is not None:
282
+ init.zeros_(module.bias)
283
+ if getattr(module, "running_mean", None) is not None:
284
+ init.zeros_(module.running_mean)
285
+ init.ones_(module.running_var)
286
+ init.zeros_(module.num_batches_tracked)
287
+
288
+ elif isinstance(module, nn.LayerNorm):
289
+ init.ones_(module.weight)
290
+ init.zeros_(module.bias)
291
+
292
+ if isinstance(module, nn.Embedding):
293
+ init.normal_(module.weight, mean=0.0, std=self.config.initializer_range)
294
+ if module.padding_idx is not None:
295
+ init.zeros_(module.weight.data[module.padding_idx])
296
+
297
+
298
+ @dataclass
299
+ class PPDocLayoutV3DecoderOutput(ModelOutput):
300
+ r"""
301
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
302
+ Stacked intermediate hidden states (output of each layer of the decoder).
303
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
304
+ Stacked intermediate logits (logits of each layer of the decoder).
305
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
306
+ Stacked intermediate reference points (reference points of each layer of the decoder).
307
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
308
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
309
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
310
+ Stacked initial reference points (initial reference points of each layer of the decoder).
311
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
312
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
313
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
314
+ used to compute the weighted average in the cross-attention heads.
315
+ decoder_out_order_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.num_queries, config.num_queries)`):
316
+ Stacked order logits (order logits of each layer of the decoder).
317
+ decoder_out_masks (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.num_queries, 200, 200)`):
318
+ Stacked masks (masks of each layer of the decoder).
319
+ """
320
+
321
+ last_hidden_state: torch.FloatTensor | None = None
322
+ intermediate_hidden_states: torch.FloatTensor | None = None
323
+ intermediate_logits: torch.FloatTensor | None = None
324
+ intermediate_reference_points: torch.FloatTensor | None = None
325
+ intermediate_predicted_corners: torch.FloatTensor | None = None
326
+ initial_reference_points: torch.FloatTensor | None = None
327
+ hidden_states: tuple[torch.FloatTensor] | None = None
328
+ attentions: tuple[torch.FloatTensor] | None = None
329
+ cross_attentions: tuple[torch.FloatTensor] | None = None
330
+
331
+ decoder_out_order_logits: torch.FloatTensor | None = None
332
+ decoder_out_masks: torch.FloatTensor | None = None
333
+
334
+
335
+ @dataclass
336
+ @auto_docstring(
337
+ custom_intro="""
338
+ Base class for outputs of the PP-DocLayoutV3 model.
339
+ """
340
+ )
341
+ class PPDocLayoutV3ModelOutput(ModelOutput):
342
+ r"""
343
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
344
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
345
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
346
+ Stacked intermediate hidden states (output of each layer of the decoder).
347
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, config.num_labels)`):
348
+ Stacked intermediate logits (logits of each layer of the decoder).
349
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
350
+ Stacked intermediate reference points (reference points of each layer of the decoder).
351
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
352
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
353
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
354
+ Initial reference points used for the first decoder layer.
355
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
356
+ Initial reference points sent through the Transformer decoder.
357
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`):
358
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
359
+ picked as region proposals in the encoder stage. Output of bounding box binary classification (i.e.
360
+ foreground and background).
361
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`):
362
+ Logits of predicted bounding boxes coordinates in the encoder stage.
363
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
364
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
365
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
366
+ foreground and background).
367
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
368
+ Logits of predicted bounding boxes coordinates in the first stage.
369
+ denoising_meta_values (`dict`):
370
+ Extra dictionary for the denoising related values.
371
+ out_order_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.num_queries, config.num_queries)`):
372
+ Stacked order logits (order logits of each layer of the decoder).
373
+ out_masks (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, config.num_queries, 200, 200)`):
374
+ Stacked masks (masks of each layer of the decoder).
375
+ """
376
+
377
+ last_hidden_state: torch.FloatTensor | None = None
378
+ intermediate_hidden_states: torch.FloatTensor | None = None
379
+ intermediate_logits: torch.FloatTensor | None = None
380
+ intermediate_reference_points: torch.FloatTensor | None = None
381
+ intermediate_predicted_corners: torch.FloatTensor | None = None
382
+ initial_reference_points: torch.FloatTensor | None = None
383
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
384
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
385
+ cross_attentions: tuple[torch.FloatTensor] | None = None
386
+ encoder_last_hidden_state: torch.FloatTensor | None = None
387
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
388
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
389
+ init_reference_points: torch.FloatTensor | None = None
390
+ enc_topk_logits: torch.FloatTensor | None = None
391
+ enc_topk_bboxes: torch.FloatTensor | None = None
392
+ enc_outputs_class: torch.FloatTensor | None = None
393
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
394
+ denoising_meta_values: dict | None = None
395
+
396
+ out_order_logits: torch.FloatTensor | None = None
397
+ out_masks: torch.FloatTensor | None = None
398
+
399
+
400
+ class PPDocLayoutV3MLPPredictionHead(nn.Module):
401
+ """
402
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
403
+ height and width of a bounding box w.r.t. an image.
404
+
405
+ """
406
+
407
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
408
+ super().__init__()
409
+ self.num_layers = num_layers
410
+ h = [hidden_dim] * (num_layers - 1)
411
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
412
+
413
+ def forward(self, x):
414
+ for i, layer in enumerate(self.layers):
415
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
416
+ return x
417
+
418
+
419
+ class PPDocLayoutV3ConvLayer(nn.Module):
420
+ def __init__(
421
+ self, in_channels: int, out_channels: int, kernel_size: int = 3, stride: int = 1, activation: str = "relu"
422
+ ):
423
+ super().__init__()
424
+ self.convolution = nn.Conv2d(
425
+ in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=kernel_size // 2, bias=False
426
+ )
427
+ self.normalization = nn.BatchNorm2d(out_channels)
428
+ self.activation = ACT2FN[activation] if activation is not None else nn.Identity()
429
+
430
+ def forward(self, input: Tensor) -> Tensor:
431
+ hidden_state = self.convolution(input)
432
+ hidden_state = self.normalization(hidden_state)
433
+ hidden_state = self.activation(hidden_state)
434
+ return hidden_state
435
+
436
+
437
+ class PPDocLayoutV3ScaleHead(nn.Module):
438
+ def __init__(self, in_channels, feature_channels, fpn_stride, base_stride, align_corners=False):
439
+ super().__init__()
440
+ head_length = max(1, int(np.log2(fpn_stride) - np.log2(base_stride)))
441
+ self.layers = nn.ModuleList()
442
+ for k in range(head_length):
443
+ in_c = in_channels if k == 0 else feature_channels
444
+ self.layers.append(PPDocLayoutV3ConvLayer(in_c, feature_channels, 3, 1, "silu"))
445
+ if fpn_stride != base_stride:
446
+ self.layers.append(nn.Upsample(scale_factor=2, mode="bilinear", align_corners=align_corners))
447
+
448
+ def forward(self, x):
449
+ for layer in self.layers:
450
+ x = layer(x)
451
+ return x
452
+
453
+
454
+ class PPDocLayoutV3MaskFeatFPN(nn.Module):
455
+ def __init__(
456
+ self,
457
+ in_channels=[256, 256, 256],
458
+ fpn_strides=[32, 16, 8],
459
+ feature_channels=256,
460
+ dropout_ratio=0.0,
461
+ out_channels=256,
462
+ align_corners=False,
463
+ ):
464
+ super().__init__()
465
+
466
+ reorder_index = np.argsort(fpn_strides, axis=0).tolist()
467
+ in_channels = [in_channels[i] for i in reorder_index]
468
+ fpn_strides = [fpn_strides[i] for i in reorder_index]
469
+
470
+ self.reorder_index = reorder_index
471
+ self.fpn_strides = fpn_strides
472
+ self.dropout_ratio = dropout_ratio
473
+ self.align_corners = align_corners
474
+ if self.dropout_ratio > 0:
475
+ self.dropout = nn.Dropout2d(dropout_ratio)
476
+
477
+ self.scale_heads = nn.ModuleList()
478
+ for i in range(len(fpn_strides)):
479
+ self.scale_heads.append(
480
+ PPDocLayoutV3ScaleHead(
481
+ in_channels=in_channels[i],
482
+ feature_channels=feature_channels,
483
+ fpn_stride=fpn_strides[i],
484
+ base_stride=fpn_strides[0],
485
+ align_corners=align_corners,
486
+ )
487
+ )
488
+ self.output_conv = PPDocLayoutV3ConvLayer(feature_channels, out_channels, 3, 1, "silu")
489
+
490
+ def forward(self, inputs):
491
+ x = [inputs[i] for i in self.reorder_index]
492
+
493
+ output = self.scale_heads[0](x[0])
494
+ for i in range(1, len(self.fpn_strides)):
495
+ output = output + F.interpolate(
496
+ self.scale_heads[i](x[i]), size=output.shape[2:], mode="bilinear", align_corners=self.align_corners
497
+ )
498
+
499
+ if self.dropout_ratio > 0:
500
+ output = self.dropout(output)
501
+ output = self.output_conv(output)
502
+ return output
503
+
504
+
505
+ class PPDocLayoutV3EncoderMaskOutput(nn.Module):
506
+ def __init__(self, in_channels, num_prototypes):
507
+ super().__init__()
508
+ self.base_conv = PPDocLayoutV3ConvLayer(in_channels, in_channels, 3, 1, "silu")
509
+ self.conv = nn.Conv2d(in_channels, num_prototypes, kernel_size=1)
510
+
511
+ def forward(self, x):
512
+ x = self.base_conv(x)
513
+ x = self.conv(x)
514
+ return x
515
+
516
+
517
+ class PPDocLayoutV3MLP(nn.Module):
518
+ def __init__(
519
+ self, config: PPDocLayoutV3Config, hidden_size: int, intermediate_size: int, activation_function: str
520
+ ):
521
+ super().__init__()
522
+ self.fc1 = nn.Linear(hidden_size, intermediate_size)
523
+ self.fc2 = nn.Linear(intermediate_size, hidden_size)
524
+ self.activation_fn = ACT2FN[activation_function]
525
+ self.activation_dropout = config.activation_dropout
526
+ self.dropout = config.dropout
527
+
528
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
529
+ hidden_states = self.activation_fn(self.fc1(hidden_states))
530
+ hidden_states = nn.functional.dropout(hidden_states, p=self.activation_dropout, training=self.training)
531
+ hidden_states = self.fc2(hidden_states)
532
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
533
+ return hidden_states
534
+
535
+
536
+ def eager_attention_forward(
537
+ module: nn.Module,
538
+ query: torch.Tensor,
539
+ key: torch.Tensor,
540
+ value: torch.Tensor,
541
+ attention_mask: torch.Tensor | None,
542
+ scaling: float | None = None,
543
+ dropout: float = 0.0,
544
+ **kwargs: Unpack[TransformersKwargs],
545
+ ):
546
+ if scaling is None:
547
+ scaling = query.size(-1) ** -0.5
548
+
549
+ # Take the dot product between "query" and "key" to get the raw attention scores.
550
+ attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
551
+
552
+ if attention_mask is not None:
553
+ attention_mask = attention_mask[:, :, :, : key.shape[-2]]
554
+ attn_weights = attn_weights + attention_mask
555
+
556
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1)
557
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
558
+
559
+ attn_output = torch.matmul(attn_weights, value)
560
+ attn_output = attn_output.transpose(1, 2).contiguous()
561
+
562
+ return attn_output, attn_weights
563
+
564
+
565
+ class PPDocLayoutV3SelfAttention(nn.Module):
566
+ """
567
+ Multi-headed self-attention from 'Attention Is All You Need' paper.
568
+
569
+ In PP_DOCLAYOUT_V3, position embeddings are added to both queries and keys (but not values) in self-attention.
570
+ """
571
+
572
+ def __init__(
573
+ self,
574
+ config: PPDocLayoutV3Config,
575
+ hidden_size: int,
576
+ num_attention_heads: int,
577
+ dropout: float = 0.0,
578
+ bias: bool = True,
579
+ ):
580
+ super().__init__()
581
+ self.config = config
582
+ self.head_dim = hidden_size // num_attention_heads
583
+ self.scaling = self.head_dim**-0.5
584
+ self.attention_dropout = dropout
585
+ self.is_causal = False
586
+
587
+ self.k_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
588
+ self.v_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
589
+ self.q_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
590
+ self.o_proj = nn.Linear(hidden_size, hidden_size, bias=bias)
591
+
592
+ def forward(
593
+ self,
594
+ hidden_states: torch.Tensor,
595
+ attention_mask: torch.Tensor | None = None,
596
+ position_embeddings: torch.Tensor | None = None,
597
+ **kwargs: Unpack[TransformersKwargs],
598
+ ) -> tuple[torch.Tensor, torch.Tensor]:
599
+ """
600
+ Position embeddings are added to both queries and keys (but not values).
601
+ """
602
+ input_shape = hidden_states.shape[:-1]
603
+ hidden_shape = (*input_shape, -1, self.head_dim)
604
+
605
+ query_key_input = hidden_states + position_embeddings if position_embeddings is not None else hidden_states
606
+
607
+ query_states = self.q_proj(query_key_input).view(hidden_shape).transpose(1, 2)
608
+ key_states = self.k_proj(query_key_input).view(hidden_shape).transpose(1, 2)
609
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
610
+
611
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
612
+ self.config._attn_implementation, eager_attention_forward
613
+ )
614
+
615
+ attn_output, attn_weights = attention_interface(
616
+ self,
617
+ query_states,
618
+ key_states,
619
+ value_states,
620
+ attention_mask,
621
+ dropout=0.0 if not self.training else self.attention_dropout,
622
+ scaling=self.scaling,
623
+ **kwargs,
624
+ )
625
+
626
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
627
+ attn_output = self.o_proj(attn_output)
628
+ return attn_output, attn_weights
629
+
630
+
631
+ class PPDocLayoutV3ConvNormLayer(nn.Module):
632
+ def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
633
+ super().__init__()
634
+ self.conv = nn.Conv2d(
635
+ in_channels,
636
+ out_channels,
637
+ kernel_size,
638
+ stride,
639
+ padding=(kernel_size - 1) // 2 if padding is None else padding,
640
+ bias=False,
641
+ )
642
+ self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
643
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
644
+
645
+ def forward(self, hidden_state):
646
+ hidden_state = self.conv(hidden_state)
647
+ hidden_state = self.norm(hidden_state)
648
+ hidden_state = self.activation(hidden_state)
649
+ return hidden_state
650
+
651
+
652
+ class PPDocLayoutV3EncoderLayer(nn.Module):
653
+ def __init__(self, config: PPDocLayoutV3Config):
654
+ super().__init__()
655
+ self.normalize_before = config.normalize_before
656
+ self.hidden_size = config.encoder_hidden_dim
657
+
658
+ # self-attention
659
+ self.self_attn = PPDocLayoutV3SelfAttention(
660
+ config=config,
661
+ hidden_size=self.hidden_size,
662
+ num_attention_heads=config.num_attention_heads,
663
+ dropout=config.dropout,
664
+ )
665
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
666
+ self.dropout = config.dropout
667
+ self.mlp = PPDocLayoutV3MLP(
668
+ config, self.hidden_size, config.encoder_ffn_dim, config.encoder_activation_function
669
+ )
670
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
671
+
672
+ def forward(
673
+ self,
674
+ hidden_states: torch.Tensor,
675
+ attention_mask: torch.Tensor,
676
+ spatial_position_embeddings: torch.Tensor | None = None,
677
+ **kwargs: Unpack[TransformersKwargs],
678
+ ) -> torch.Tensor:
679
+ """
680
+ Args:
681
+ hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, hidden_size)`
682
+ attention_mask (`torch.FloatTensor`): attention mask of size
683
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
684
+ values.
685
+ spatial_position_embeddings (`torch.FloatTensor`, *optional*):
686
+ Spatial position embeddings (2D positional encodings of image locations), to be added to both
687
+ the queries and keys in self-attention (but not to values).
688
+ """
689
+ residual = hidden_states
690
+ if self.normalize_before:
691
+ hidden_states = self.self_attn_layer_norm(hidden_states)
692
+
693
+ hidden_states, _ = self.self_attn(
694
+ hidden_states=hidden_states,
695
+ attention_mask=attention_mask,
696
+ position_embeddings=spatial_position_embeddings,
697
+ **kwargs,
698
+ )
699
+
700
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
701
+ hidden_states = residual + hidden_states
702
+ if not self.normalize_before:
703
+ hidden_states = self.self_attn_layer_norm(hidden_states)
704
+
705
+ if self.normalize_before:
706
+ hidden_states = self.final_layer_norm(hidden_states)
707
+ residual = hidden_states
708
+
709
+ hidden_states = self.mlp(hidden_states)
710
+
711
+ hidden_states = residual + hidden_states
712
+ if not self.normalize_before:
713
+ hidden_states = self.final_layer_norm(hidden_states)
714
+
715
+ if self.training:
716
+ if torch.isinf(hidden_states).any() or torch.isnan(hidden_states).any():
717
+ clamp_value = torch.finfo(hidden_states.dtype).max - 1000
718
+ hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
719
+
720
+ return hidden_states
721
+
722
+
723
+ class PPDocLayoutV3RepVggBlock(nn.Module):
724
+ """
725
+ RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
726
+ """
727
+
728
+ def __init__(self, config: PPDocLayoutV3Config):
729
+ super().__init__()
730
+
731
+ activation = config.activation_function
732
+ hidden_channels = int(config.encoder_hidden_dim * config.hidden_expansion)
733
+ self.conv1 = PPDocLayoutV3ConvNormLayer(config, hidden_channels, hidden_channels, 3, 1, padding=1)
734
+ self.conv2 = PPDocLayoutV3ConvNormLayer(config, hidden_channels, hidden_channels, 1, 1, padding=0)
735
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
736
+
737
+ def forward(self, x):
738
+ y = self.conv1(x) + self.conv2(x)
739
+ return self.activation(y)
740
+
741
+
742
+ class PPDocLayoutV3CSPRepLayer(nn.Module):
743
+ """
744
+ Cross Stage Partial (CSP) network layer with RepVGG blocks.
745
+ """
746
+
747
+ def __init__(self, config: PPDocLayoutV3Config):
748
+ super().__init__()
749
+
750
+ in_channels = config.encoder_hidden_dim * 2
751
+ out_channels = config.encoder_hidden_dim
752
+ num_blocks = 3
753
+ activation = config.activation_function
754
+
755
+ hidden_channels = int(out_channels * config.hidden_expansion)
756
+ self.conv1 = PPDocLayoutV3ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
757
+ self.conv2 = PPDocLayoutV3ConvNormLayer(config, in_channels, hidden_channels, 1, 1, activation=activation)
758
+ self.bottlenecks = nn.Sequential(*[PPDocLayoutV3RepVggBlock(config) for _ in range(num_blocks)])
759
+ if hidden_channels != out_channels:
760
+ self.conv3 = PPDocLayoutV3ConvNormLayer(config, hidden_channels, out_channels, 1, 1, activation=activation)
761
+ else:
762
+ self.conv3 = nn.Identity()
763
+
764
+ def forward(self, hidden_state):
765
+ hidden_state_1 = self.conv1(hidden_state)
766
+ hidden_state_1 = self.bottlenecks(hidden_state_1)
767
+ hidden_state_2 = self.conv2(hidden_state)
768
+ return self.conv3(hidden_state_1 + hidden_state_2)
769
+
770
+
771
+ class PPDocLayoutV3SinePositionEmbedding(nn.Module):
772
+ """
773
+ 2D sinusoidal position embedding used in RT-DETR hybrid encoder.
774
+ """
775
+
776
+ def __init__(self, embed_dim: int = 256, temperature: int = 10000):
777
+ super().__init__()
778
+ self.embed_dim = embed_dim
779
+ self.temperature = temperature
780
+
781
+ @compile_compatible_method_lru_cache(maxsize=32)
782
+ def forward(
783
+ self,
784
+ width: int,
785
+ height: int,
786
+ device: torch.device | str,
787
+ dtype: torch.dtype,
788
+ ) -> torch.Tensor:
789
+ """
790
+ Generate 2D sinusoidal position embeddings.
791
+
792
+ Returns:
793
+ Position embeddings of shape (1, height*width, embed_dim)
794
+ """
795
+ grid_w = torch.arange(torch_int(width), device=device).to(dtype)
796
+ grid_h = torch.arange(torch_int(height), device=device).to(dtype)
797
+ grid_w, grid_h = torch.meshgrid(grid_w, grid_h, indexing="xy")
798
+ if self.embed_dim % 4 != 0:
799
+ raise ValueError("Embed dimension must be divisible by 4 for 2D sin-cos position embedding")
800
+ pos_dim = self.embed_dim // 4
801
+ omega = torch.arange(pos_dim, device=device).to(dtype) / pos_dim
802
+ omega = 1.0 / (self.temperature**omega)
803
+
804
+ out_w = grid_w.flatten()[..., None] @ omega[None]
805
+ out_h = grid_h.flatten()[..., None] @ omega[None]
806
+
807
+ return torch.concat([out_h.sin(), out_h.cos(), out_w.sin(), out_w.cos()], dim=1)[None, :, :]
808
+
809
+
810
+ class PPDocLayoutV3AIFILayer(nn.Module):
811
+ """
812
+ AIFI (Attention-based Intra-scale Feature Interaction) layer used in RT-DETR hybrid encoder.
813
+ """
814
+
815
+ def __init__(self, config: PPDocLayoutV3Config):
816
+ super().__init__()
817
+ self.config = config
818
+ self.encoder_hidden_dim = config.encoder_hidden_dim
819
+ self.eval_size = config.eval_size
820
+
821
+ self.position_embedding = PPDocLayoutV3SinePositionEmbedding(
822
+ embed_dim=self.encoder_hidden_dim,
823
+ temperature=config.positional_encoding_temperature,
824
+ )
825
+ self.layers = nn.ModuleList([PPDocLayoutV3EncoderLayer(config) for _ in range(config.encoder_layers)])
826
+
827
+ def forward(
828
+ self,
829
+ hidden_states: torch.Tensor,
830
+ **kwargs: Unpack[TransformersKwargs],
831
+ ) -> torch.Tensor:
832
+ """
833
+ Args:
834
+ hidden_states (`torch.FloatTensor` of shape `(batch_size, channels, height, width)`):
835
+ Feature map to process.
836
+ """
837
+ batch_size = hidden_states.shape[0]
838
+ height, width = hidden_states.shape[2:]
839
+
840
+ hidden_states = hidden_states.flatten(2).permute(0, 2, 1)
841
+
842
+ if self.training or self.eval_size is None:
843
+ pos_embed = self.position_embedding(
844
+ width=width,
845
+ height=height,
846
+ device=hidden_states.device,
847
+ dtype=hidden_states.dtype,
848
+ )
849
+ else:
850
+ pos_embed = None
851
+
852
+ for layer in self.layers:
853
+ hidden_states = layer(
854
+ hidden_states,
855
+ attention_mask=None,
856
+ spatial_position_embeddings=pos_embed,
857
+ **kwargs,
858
+ )
859
+
860
+ hidden_states = (
861
+ hidden_states.permute(0, 2, 1).reshape(batch_size, self.encoder_hidden_dim, height, width).contiguous()
862
+ )
863
+
864
+ return hidden_states
865
+
866
+
867
+ class PPDocLayoutV3HybridEncoder(PPDocLayoutV3PreTrainedModel):
868
+ """
869
+ Main difference to `RTDetrHybridEncoder`:
870
+ 1. Mask Feature Head: Added `PPDocLayoutV3MaskFeatFPN` module (`self.mask_feature_head`) for document - specific mask feature generation.
871
+ 2. Extra Conv Layers: Introduced `self.encoder_mask_lateral` and `self.encoder_mask_output` for mask feature processing and output.
872
+ """
873
+
874
+ _can_record_outputs = {
875
+ "hidden_states": PPDocLayoutV3AIFILayer,
876
+ "attentions": PPDocLayoutV3SelfAttention,
877
+ }
878
+
879
+ def __init__(self, config: PPDocLayoutV3Config):
880
+ super().__init__(config)
881
+ self.config = config
882
+ self.in_channels = config.encoder_in_channels
883
+ self.feat_strides = config.feat_strides
884
+ self.encoder_hidden_dim = config.encoder_hidden_dim
885
+ self.encode_proj_layers = config.encode_proj_layers
886
+ self.positional_encoding_temperature = config.positional_encoding_temperature
887
+ self.eval_size = config.eval_size
888
+ self.out_channels = [self.encoder_hidden_dim for _ in self.in_channels]
889
+ self.out_strides = self.feat_strides
890
+ self.num_fpn_stages = len(self.in_channels) - 1
891
+ self.num_pan_stages = len(self.in_channels) - 1
892
+
893
+ # AIFI (Attention-based Intra-scale Feature Interaction) layers
894
+ self.aifi = nn.ModuleList([PPDocLayoutV3AIFILayer(config) for _ in range(len(self.encode_proj_layers))])
895
+
896
+ # top-down FPN
897
+ self.lateral_convs = nn.ModuleList()
898
+ self.fpn_blocks = nn.ModuleList()
899
+ for _ in range(self.num_fpn_stages):
900
+ lateral_conv = PPDocLayoutV3ConvNormLayer(
901
+ config,
902
+ in_channels=self.encoder_hidden_dim,
903
+ out_channels=self.encoder_hidden_dim,
904
+ kernel_size=1,
905
+ stride=1,
906
+ activation=config.activation_function,
907
+ )
908
+ fpn_block = PPDocLayoutV3CSPRepLayer(config)
909
+ self.lateral_convs.append(lateral_conv)
910
+ self.fpn_blocks.append(fpn_block)
911
+
912
+ # bottom-up PAN
913
+ self.downsample_convs = nn.ModuleList()
914
+ self.pan_blocks = nn.ModuleList()
915
+ for _ in range(self.num_pan_stages):
916
+ downsample_conv = PPDocLayoutV3ConvNormLayer(
917
+ config,
918
+ in_channels=self.encoder_hidden_dim,
919
+ out_channels=self.encoder_hidden_dim,
920
+ kernel_size=3,
921
+ stride=2,
922
+ activation=config.activation_function,
923
+ )
924
+ pan_block = PPDocLayoutV3CSPRepLayer(config)
925
+ self.downsample_convs.append(downsample_conv)
926
+ self.pan_blocks.append(pan_block)
927
+
928
+ feat_strides = config.feat_strides
929
+ mask_feature_channels = config.mask_feature_channels
930
+ self.mask_feature_head = PPDocLayoutV3MaskFeatFPN(
931
+ [self.encoder_hidden_dim] * len(feat_strides),
932
+ feat_strides,
933
+ feature_channels=mask_feature_channels[0],
934
+ out_channels=mask_feature_channels[1],
935
+ )
936
+ self.encoder_mask_lateral = PPDocLayoutV3ConvLayer(config.x4_feat_dim, mask_feature_channels[1], 3, 1, "silu")
937
+ self.encoder_mask_output = PPDocLayoutV3EncoderMaskOutput(
938
+ in_channels=mask_feature_channels[1], num_prototypes=config.num_prototypes
939
+ )
940
+
941
+ self.post_init()
942
+
943
+ @check_model_inputs(tie_last_hidden_states=False)
944
+ def forward(
945
+ self,
946
+ inputs_embeds=None,
947
+ x4_feat=None,
948
+ **kwargs: Unpack[TransformersKwargs],
949
+ ) -> BaseModelOutput:
950
+ r"""
951
+ Args:
952
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`):
953
+ Flattened feature map (output of the backbone + projection layer) that is passed to the encoder.
954
+ """
955
+ feature_maps = inputs_embeds
956
+
957
+ # AIFI: Apply transformer encoder to specified feature levels
958
+ if self.config.encoder_layers > 0:
959
+ for i, enc_ind in enumerate(self.encode_proj_layers):
960
+ feature_maps[enc_ind] = self.aifi[i](feature_maps[enc_ind], **kwargs)
961
+
962
+ # top-down FPN
963
+ fpn_feature_maps = [feature_maps[-1]]
964
+ for idx, (lateral_conv, fpn_block) in enumerate(zip(self.lateral_convs, self.fpn_blocks)):
965
+ backbone_feature_map = feature_maps[self.num_fpn_stages - idx - 1]
966
+ top_fpn_feature_map = fpn_feature_maps[-1]
967
+ # apply lateral block
968
+ top_fpn_feature_map = lateral_conv(top_fpn_feature_map)
969
+ fpn_feature_maps[-1] = top_fpn_feature_map
970
+ # apply fpn block
971
+ top_fpn_feature_map = F.interpolate(top_fpn_feature_map, scale_factor=2.0, mode="nearest")
972
+ fused_feature_map = torch.concat([top_fpn_feature_map, backbone_feature_map], dim=1)
973
+ new_fpn_feature_map = fpn_block(fused_feature_map)
974
+ fpn_feature_maps.append(new_fpn_feature_map)
975
+
976
+ fpn_feature_maps.reverse()
977
+
978
+ # bottom-up PAN
979
+ pan_feature_maps = [fpn_feature_maps[0]]
980
+ for idx, (downsample_conv, pan_block) in enumerate(zip(self.downsample_convs, self.pan_blocks)):
981
+ top_pan_feature_map = pan_feature_maps[-1]
982
+ fpn_feature_map = fpn_feature_maps[idx + 1]
983
+ downsampled_feature_map = downsample_conv(top_pan_feature_map)
984
+ fused_feature_map = torch.concat([downsampled_feature_map, fpn_feature_map], dim=1)
985
+ new_pan_feature_map = pan_block(fused_feature_map)
986
+ pan_feature_maps.append(new_pan_feature_map)
987
+
988
+ mask_feat = self.mask_feature_head(pan_feature_maps)
989
+ mask_feat = F.interpolate(mask_feat, scale_factor=2, mode="bilinear", align_corners=False)
990
+ mask_feat += self.encoder_mask_lateral(x4_feat[0])
991
+ mask_feat = self.encoder_mask_output(mask_feat)
992
+
993
+ return PPDocLayoutV3HybridEncoderOutput(
994
+ last_hidden_state=pan_feature_maps,
995
+ mask_feat=mask_feat,
996
+ )
997
+
998
+
999
+ class PPDocLayoutV3DecoderLayer(nn.Module):
1000
+ def __init__(self, config: PPDocLayoutV3Config):
1001
+ super().__init__()
1002
+ self.hidden_size = config.d_model
1003
+
1004
+ # self-attention
1005
+ self.self_attn = PPDocLayoutV3SelfAttention(
1006
+ config=config,
1007
+ hidden_size=self.hidden_size,
1008
+ num_attention_heads=config.decoder_attention_heads,
1009
+ dropout=config.attention_dropout,
1010
+ )
1011
+ self.dropout = config.dropout
1012
+
1013
+ self.self_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
1014
+ # cross-attention
1015
+ self.encoder_attn = PPDocLayoutV3MultiscaleDeformableAttention(
1016
+ config,
1017
+ num_heads=config.decoder_attention_heads,
1018
+ n_points=config.decoder_n_points,
1019
+ )
1020
+ self.encoder_attn_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
1021
+ # feedforward neural networks
1022
+ self.mlp = PPDocLayoutV3MLP(
1023
+ config, self.hidden_size, config.decoder_ffn_dim, config.decoder_activation_function
1024
+ )
1025
+ self.final_layer_norm = nn.LayerNorm(self.hidden_size, eps=config.layer_norm_eps)
1026
+
1027
+ def forward(
1028
+ self,
1029
+ hidden_states: torch.Tensor,
1030
+ object_queries_position_embeddings: torch.Tensor | None = None,
1031
+ reference_points=None,
1032
+ spatial_shapes=None,
1033
+ spatial_shapes_list=None,
1034
+ level_start_index=None,
1035
+ encoder_hidden_states: torch.Tensor | None = None,
1036
+ encoder_attention_mask: torch.Tensor | None = None,
1037
+ **kwargs: Unpack[TransformersKwargs],
1038
+ ) -> torch.Tensor:
1039
+ """
1040
+ Args:
1041
+ hidden_states (`torch.FloatTensor`):
1042
+ Input to the layer of shape `(batch, seq_len, hidden_size)`.
1043
+ object_queries_position_embeddings (`torch.FloatTensor`, *optional*):
1044
+ Position embeddings for the object query slots. These are added to both queries and keys
1045
+ in the self-attention layer (not values).
1046
+ reference_points (`torch.FloatTensor`, *optional*):
1047
+ Reference points.
1048
+ spatial_shapes (`torch.LongTensor`, *optional*):
1049
+ Spatial shapes.
1050
+ level_start_index (`torch.LongTensor`, *optional*):
1051
+ Level start index.
1052
+ encoder_hidden_states (`torch.FloatTensor`):
1053
+ cross attention input to the layer of shape `(batch, seq_len, hidden_size)`
1054
+ encoder_attention_mask (`torch.FloatTensor`): encoder attention mask of size
1055
+ `(batch, 1, target_len, source_len)` where padding elements are indicated by very large negative
1056
+ values.
1057
+ """
1058
+ residual = hidden_states
1059
+
1060
+ # Self Attention
1061
+ hidden_states, _ = self.self_attn(
1062
+ hidden_states=hidden_states,
1063
+ attention_mask=encoder_attention_mask,
1064
+ position_embeddings=object_queries_position_embeddings,
1065
+ **kwargs,
1066
+ )
1067
+
1068
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1069
+ hidden_states = residual + hidden_states
1070
+ hidden_states = self.self_attn_layer_norm(hidden_states)
1071
+
1072
+ residual = hidden_states
1073
+
1074
+ # Cross-Attention
1075
+ hidden_states, _ = self.encoder_attn(
1076
+ hidden_states=hidden_states,
1077
+ encoder_hidden_states=encoder_hidden_states,
1078
+ position_embeddings=object_queries_position_embeddings,
1079
+ reference_points=reference_points,
1080
+ spatial_shapes=spatial_shapes,
1081
+ spatial_shapes_list=spatial_shapes_list,
1082
+ level_start_index=level_start_index,
1083
+ **kwargs,
1084
+ )
1085
+
1086
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
1087
+ hidden_states = residual + hidden_states
1088
+
1089
+ hidden_states = self.encoder_attn_layer_norm(hidden_states)
1090
+
1091
+ # Fully Connected
1092
+ residual = hidden_states
1093
+ hidden_states = self.mlp(hidden_states)
1094
+ hidden_states = residual + hidden_states
1095
+ hidden_states = self.final_layer_norm(hidden_states)
1096
+
1097
+ return hidden_states
1098
+
1099
+
1100
+ def inverse_sigmoid(x, eps=1e-5):
1101
+ x = x.clamp(min=0, max=1)
1102
+ x1 = x.clamp(min=eps)
1103
+ x2 = (1 - x).clamp(min=eps)
1104
+ return torch.log(x1 / x2)
1105
+
1106
+
1107
+ class PPDocLayoutV3Decoder(PPDocLayoutV3PreTrainedModel):
1108
+ """
1109
+ Main difference to `RTDetrDecoder`:
1110
+ A new mask generation process is introduced at each decoder layer.
1111
+ """
1112
+
1113
+ _can_record_outputs = {
1114
+ "hidden_states": PPDocLayoutV3DecoderLayer,
1115
+ "attentions": PPDocLayoutV3SelfAttention,
1116
+ "cross_attentions": PPDocLayoutV3MultiscaleDeformableAttention,
1117
+ }
1118
+
1119
+ def __init__(self, config: PPDocLayoutV3Config):
1120
+ super().__init__(config)
1121
+
1122
+ self.dropout = config.dropout
1123
+ self.layers = nn.ModuleList([PPDocLayoutV3DecoderLayer(config) for _ in range(config.decoder_layers)])
1124
+ self.query_pos_head = PPDocLayoutV3MLPPredictionHead(4, 2 * config.d_model, config.d_model, num_layers=2)
1125
+
1126
+ # hack implementation for iterative bounding box refinement and two-stage Deformable DETR
1127
+ self.bbox_embed = None
1128
+ self.class_embed = None
1129
+
1130
+ self.num_queries = config.num_queries
1131
+
1132
+ # Initialize weights and apply final processing
1133
+ self.post_init()
1134
+
1135
+ @check_model_inputs()
1136
+ def forward(
1137
+ self,
1138
+ inputs_embeds=None,
1139
+ encoder_hidden_states=None,
1140
+ encoder_attention_mask=None,
1141
+ reference_points=None,
1142
+ spatial_shapes=None,
1143
+ spatial_shapes_list=None,
1144
+ level_start_index=None,
1145
+ order_head=None,
1146
+ global_pointer=None,
1147
+ mask_query_head=None,
1148
+ norm=None,
1149
+ mask_feat=None,
1150
+ **kwargs: Unpack[TransformersKwargs],
1151
+ ):
1152
+ r"""
1153
+ Args:
1154
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
1155
+ The query embeddings that are passed into the decoder.
1156
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1157
+ Sequence of hidden-states at the output of the last layer of the encoder. Used in the cross-attention
1158
+ of the decoder.
1159
+ encoder_attention_mask (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1160
+ Mask to avoid performing cross-attention on padding pixel_values of the encoder. Mask values selected
1161
+ in `[0, 1]`:
1162
+ - 1 for pixels that are real (i.e. **not masked**),
1163
+ - 0 for pixels that are padding (i.e. **masked**).
1164
+ reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)` is `as_two_stage` else `(batch_size, num_queries, 2)` or , *optional*):
1165
+ Reference point in range `[0, 1]`, top-left (0,0), bottom-right (1, 1), including padding area.
1166
+ spatial_shapes (`torch.FloatTensor` of shape `(num_feature_levels, 2)`):
1167
+ Spatial shapes of the feature maps.
1168
+ level_start_index (`torch.LongTensor` of shape `(num_feature_levels)`, *optional*):
1169
+ Indexes for the start of each feature level. In range `[0, sequence_length]`.
1170
+ """
1171
+ if inputs_embeds is not None:
1172
+ hidden_states = inputs_embeds
1173
+
1174
+ # decoder layers
1175
+ intermediate = ()
1176
+ intermediate_reference_points = ()
1177
+ intermediate_logits = ()
1178
+ decoder_out_order_logits = ()
1179
+ decoder_out_masks = ()
1180
+
1181
+ reference_points = F.sigmoid(reference_points)
1182
+
1183
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L252
1184
+ for idx, decoder_layer in enumerate(self.layers):
1185
+ reference_points_input = reference_points.unsqueeze(2)
1186
+ object_queries_position_embeddings = self.query_pos_head(reference_points)
1187
+
1188
+ hidden_states = decoder_layer(
1189
+ hidden_states,
1190
+ object_queries_position_embeddings=object_queries_position_embeddings,
1191
+ encoder_hidden_states=encoder_hidden_states,
1192
+ reference_points=reference_points_input,
1193
+ spatial_shapes=spatial_shapes,
1194
+ spatial_shapes_list=spatial_shapes_list,
1195
+ level_start_index=level_start_index,
1196
+ encoder_attention_mask=encoder_attention_mask,
1197
+ **kwargs,
1198
+ )
1199
+
1200
+ # hack implementation for iterative bounding box refinement
1201
+ if self.bbox_embed is not None:
1202
+ predicted_corners = self.bbox_embed(hidden_states)
1203
+ new_reference_points = F.sigmoid(predicted_corners + inverse_sigmoid(reference_points))
1204
+ reference_points = new_reference_points.detach()
1205
+
1206
+ intermediate += (hidden_states,)
1207
+ intermediate_reference_points += (
1208
+ (new_reference_points,) if self.bbox_embed is not None else (reference_points,)
1209
+ )
1210
+
1211
+ # get_pred_class_order_and_mask
1212
+ out_query = norm(hidden_states)
1213
+ mask_query_embed = mask_query_head(out_query)
1214
+ batch_size, mask_dim, _ = mask_query_embed.shape
1215
+ _, _, mask_h, mask_w = mask_feat.shape
1216
+ out_mask = torch.bmm(mask_query_embed, mask_feat.flatten(start_dim=2)).reshape(
1217
+ batch_size, mask_dim, mask_h, mask_w
1218
+ )
1219
+ decoder_out_masks += (out_mask,)
1220
+
1221
+ if self.class_embed is not None:
1222
+ logits = self.class_embed(out_query)
1223
+ intermediate_logits += (logits,)
1224
+
1225
+ if order_head is not None and global_pointer is not None:
1226
+ valid_query = out_query[:, -self.num_queries :] if self.num_queries is not None else out_query
1227
+ order_logits = global_pointer(order_head[idx](valid_query))
1228
+ decoder_out_order_logits += (order_logits,)
1229
+
1230
+ # Keep batch_size as first dimension
1231
+ intermediate = torch.stack(intermediate, dim=1)
1232
+ intermediate_reference_points = torch.stack(intermediate_reference_points, dim=1)
1233
+ if self.class_embed is not None:
1234
+ intermediate_logits = torch.stack(intermediate_logits, dim=1)
1235
+ if order_head is not None and global_pointer is not None:
1236
+ decoder_out_order_logits = torch.stack(decoder_out_order_logits, dim=1)
1237
+ decoder_out_masks = torch.stack(decoder_out_masks, dim=1)
1238
+
1239
+ return PPDocLayoutV3DecoderOutput(
1240
+ last_hidden_state=hidden_states,
1241
+ intermediate_hidden_states=intermediate,
1242
+ intermediate_logits=intermediate_logits,
1243
+ intermediate_reference_points=intermediate_reference_points,
1244
+ decoder_out_order_logits=decoder_out_order_logits,
1245
+ decoder_out_masks=decoder_out_masks,
1246
+ )
1247
+
1248
+
1249
+ class PPDocLayoutV3FrozenBatchNorm2d(nn.Module):
1250
+ """
1251
+ BatchNorm2d where the batch statistics and the affine parameters are fixed.
1252
+
1253
+ Copy-paste from torchvision.misc.ops with added eps before rqsrt, without which any other models than
1254
+ torchvision.models.resnet[18,34,50,101] produce nans.
1255
+ """
1256
+
1257
+ def __init__(self, n):
1258
+ super().__init__()
1259
+ self.register_buffer("weight", torch.ones(n))
1260
+ self.register_buffer("bias", torch.zeros(n))
1261
+ self.register_buffer("running_mean", torch.zeros(n))
1262
+ self.register_buffer("running_var", torch.ones(n))
1263
+
1264
+ def _load_from_state_dict(
1265
+ self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
1266
+ ):
1267
+ num_batches_tracked_key = prefix + "num_batches_tracked"
1268
+ if num_batches_tracked_key in state_dict:
1269
+ del state_dict[num_batches_tracked_key]
1270
+
1271
+ super()._load_from_state_dict(
1272
+ state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs
1273
+ )
1274
+
1275
+ def forward(self, x):
1276
+ # move reshapes to the beginning
1277
+ # to make it user-friendly
1278
+ weight = self.weight.reshape(1, -1, 1, 1)
1279
+ bias = self.bias.reshape(1, -1, 1, 1)
1280
+ running_var = self.running_var.reshape(1, -1, 1, 1)
1281
+ running_mean = self.running_mean.reshape(1, -1, 1, 1)
1282
+ epsilon = 1e-5
1283
+ scale = weight * (running_var + epsilon).rsqrt()
1284
+ bias = bias - running_mean * scale
1285
+ return x * scale + bias
1286
+
1287
+
1288
+ def replace_batch_norm(model):
1289
+ r"""
1290
+ Recursively replace all `torch.nn.BatchNorm2d` with `PPDocLayoutV3FrozenBatchNorm2d`.
1291
+
1292
+ Args:
1293
+ model (torch.nn.Module):
1294
+ input model
1295
+ """
1296
+ for name, module in model.named_children():
1297
+ if isinstance(module, nn.BatchNorm2d):
1298
+ new_module = PPDocLayoutV3FrozenBatchNorm2d(module.num_features)
1299
+
1300
+ if module.weight.device != torch.device("meta"):
1301
+ new_module.weight.copy_(module.weight)
1302
+ new_module.bias.copy_(module.bias)
1303
+ new_module.running_mean.copy_(module.running_mean)
1304
+ new_module.running_var.copy_(module.running_var)
1305
+
1306
+ model._modules[name] = new_module
1307
+
1308
+ if len(list(module.children())) > 0:
1309
+ replace_batch_norm(module)
1310
+
1311
+
1312
+ class PPDocLayoutV3ConvEncoder(nn.Module):
1313
+ """
1314
+ Convolutional backbone using the modeling_pp_doclayout_v3_resnet.py.
1315
+
1316
+ nn.BatchNorm2d layers are replaced by PPDocLayoutV3FrozenBatchNorm2d as defined above.
1317
+ https://github.com/lyuwenyu/RT-DETR/blob/main/PPDocLayoutV3_pytorch/src/nn/backbone/presnet.py#L142
1318
+ """
1319
+
1320
+ def __init__(self, config):
1321
+ super().__init__()
1322
+
1323
+ backbone = load_backbone(config)
1324
+
1325
+ if config.freeze_backbone_batch_norms:
1326
+ # replace batch norm by frozen batch norm
1327
+ with torch.no_grad():
1328
+ replace_batch_norm(backbone)
1329
+ self.model = backbone
1330
+ self.intermediate_channel_sizes = self.model.channels
1331
+
1332
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
1333
+ # send pixel_values through the model to get list of feature maps
1334
+ features = self.model(pixel_values).feature_maps
1335
+
1336
+ out = []
1337
+ for feature_map in features:
1338
+ # downsample pixel_mask to match shape of corresponding feature_map
1339
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
1340
+ out.append((feature_map, mask))
1341
+ return out
1342
+
1343
+
1344
+ def get_contrastive_denoising_training_group(
1345
+ targets,
1346
+ num_classes,
1347
+ num_queries,
1348
+ class_embed,
1349
+ num_denoising_queries=100,
1350
+ label_noise_ratio=0.5,
1351
+ box_noise_scale=1.0,
1352
+ ):
1353
+ """
1354
+ Creates a contrastive denoising training group using ground-truth samples. It adds noise to labels and boxes.
1355
+
1356
+ Args:
1357
+ targets (`list[dict]`):
1358
+ The target objects, each containing 'class_labels' and 'boxes' for objects in an image.
1359
+ num_classes (`int`):
1360
+ Total number of classes in the dataset.
1361
+ num_queries (`int`):
1362
+ Number of query slots in the transformer.
1363
+ class_embed (`callable`):
1364
+ A function or a model layer to embed class labels.
1365
+ num_denoising_queries (`int`, *optional*, defaults to 100):
1366
+ Number of denoising queries.
1367
+ label_noise_ratio (`float`, *optional*, defaults to 0.5):
1368
+ Ratio of noise applied to labels.
1369
+ box_noise_scale (`float`, *optional*, defaults to 1.0):
1370
+ Scale of noise applied to bounding boxes.
1371
+ Returns:
1372
+ `tuple` comprising various elements:
1373
+ - **input_query_class** (`torch.FloatTensor`) --
1374
+ Class queries with applied label noise.
1375
+ - **input_query_bbox** (`torch.FloatTensor`) --
1376
+ Bounding box queries with applied box noise.
1377
+ - **attn_mask** (`torch.FloatTensor`) --
1378
+ Attention mask for separating denoising and reconstruction queries.
1379
+ - **denoising_meta_values** (`dict`) --
1380
+ Metadata including denoising positive indices, number of groups, and split sizes.
1381
+ """
1382
+
1383
+ if num_denoising_queries <= 0:
1384
+ return None, None, None, None
1385
+
1386
+ num_ground_truths = [len(t["class_labels"]) for t in targets]
1387
+ device = targets[0]["class_labels"].device
1388
+
1389
+ max_gt_num = max(num_ground_truths)
1390
+ if max_gt_num == 0:
1391
+ return None, None, None, None
1392
+
1393
+ num_groups_denoising_queries = num_denoising_queries // max_gt_num
1394
+ num_groups_denoising_queries = 1 if num_groups_denoising_queries == 0 else num_groups_denoising_queries
1395
+ # pad gt to max_num of a batch
1396
+ batch_size = len(num_ground_truths)
1397
+
1398
+ input_query_class = torch.full([batch_size, max_gt_num], num_classes, dtype=torch.int32, device=device)
1399
+ input_query_bbox = torch.zeros([batch_size, max_gt_num, 4], device=device)
1400
+ pad_gt_mask = torch.zeros([batch_size, max_gt_num], dtype=torch.bool, device=device)
1401
+
1402
+ for i in range(batch_size):
1403
+ num_gt = num_ground_truths[i]
1404
+ if num_gt > 0:
1405
+ input_query_class[i, :num_gt] = targets[i]["class_labels"]
1406
+ input_query_bbox[i, :num_gt] = targets[i]["boxes"]
1407
+ pad_gt_mask[i, :num_gt] = 1
1408
+ # each group has positive and negative queries.
1409
+ input_query_class = input_query_class.tile([1, 2 * num_groups_denoising_queries])
1410
+ input_query_bbox = input_query_bbox.tile([1, 2 * num_groups_denoising_queries, 1])
1411
+ pad_gt_mask = pad_gt_mask.tile([1, 2 * num_groups_denoising_queries])
1412
+ # positive and negative mask
1413
+ negative_gt_mask = torch.zeros([batch_size, max_gt_num * 2, 1], device=device)
1414
+ negative_gt_mask[:, max_gt_num:] = 1
1415
+ negative_gt_mask = negative_gt_mask.tile([1, num_groups_denoising_queries, 1])
1416
+ positive_gt_mask = 1 - negative_gt_mask
1417
+ # contrastive denoising training positive index
1418
+ positive_gt_mask = positive_gt_mask.squeeze(-1) * pad_gt_mask
1419
+ denoise_positive_idx = torch.nonzero(positive_gt_mask)[:, 1]
1420
+ denoise_positive_idx = torch.split(
1421
+ denoise_positive_idx, [n * num_groups_denoising_queries for n in num_ground_truths]
1422
+ )
1423
+ # total denoising queries
1424
+ num_denoising_queries = torch_int(max_gt_num * 2 * num_groups_denoising_queries)
1425
+
1426
+ if label_noise_ratio > 0:
1427
+ mask = torch.rand_like(input_query_class, dtype=torch.float) < (label_noise_ratio * 0.5)
1428
+ # randomly put a new one here
1429
+ new_label = torch.randint_like(mask, 0, num_classes, dtype=input_query_class.dtype)
1430
+ input_query_class = torch.where(mask & pad_gt_mask, new_label, input_query_class)
1431
+
1432
+ if box_noise_scale > 0:
1433
+ known_bbox = center_to_corners_format(input_query_bbox)
1434
+ diff = torch.tile(input_query_bbox[..., 2:] * 0.5, [1, 1, 2]) * box_noise_scale
1435
+ rand_sign = torch.randint_like(input_query_bbox, 0, 2) * 2.0 - 1.0
1436
+ rand_part = torch.rand_like(input_query_bbox)
1437
+ rand_part = (rand_part + 1.0) * negative_gt_mask + rand_part * (1 - negative_gt_mask)
1438
+ rand_part *= rand_sign
1439
+ known_bbox += rand_part * diff
1440
+ known_bbox.clip_(min=0.0, max=1.0)
1441
+ input_query_bbox = corners_to_center_format(known_bbox)
1442
+ input_query_bbox = inverse_sigmoid(input_query_bbox)
1443
+
1444
+ input_query_class = class_embed(input_query_class)
1445
+
1446
+ target_size = num_denoising_queries + num_queries
1447
+ attn_mask = torch.full([target_size, target_size], 0, dtype=torch.float, device=device)
1448
+ # match query cannot see the reconstruction
1449
+ attn_mask[num_denoising_queries:, :num_denoising_queries] = -torch.inf
1450
+
1451
+ # reconstructions cannot see each other
1452
+ for i in range(num_groups_denoising_queries):
1453
+ idx_block_start = max_gt_num * 2 * i
1454
+ idx_block_end = max_gt_num * 2 * (i + 1)
1455
+ attn_mask[idx_block_start:idx_block_end, :idx_block_start] = -torch.inf
1456
+ attn_mask[idx_block_start:idx_block_end, idx_block_end:num_denoising_queries] = -torch.inf
1457
+
1458
+ denoising_meta_values = {
1459
+ "dn_positive_idx": denoise_positive_idx,
1460
+ "dn_num_group": num_groups_denoising_queries,
1461
+ "dn_num_split": [num_denoising_queries, num_queries],
1462
+ }
1463
+
1464
+ return input_query_class, input_query_bbox, attn_mask, denoising_meta_values
1465
+
1466
+
1467
+ def mask_to_box_coordinate(mask, dtype):
1468
+ mask = mask.bool()
1469
+
1470
+ height, width = mask.shape[-2:]
1471
+
1472
+ y_coords, x_coords = torch.meshgrid(
1473
+ torch.arange(height, device=mask.device), torch.arange(width, device=mask.device), indexing="ij"
1474
+ )
1475
+ x_coords = x_coords.to(dtype)
1476
+ y_coords = y_coords.to(dtype)
1477
+
1478
+ x_coords_masked = x_coords * mask
1479
+ x_max = x_coords_masked.flatten(start_dim=-2).max(dim=-1).values + 1
1480
+ x_min = (
1481
+ torch.where(mask, x_coords_masked, torch.tensor(torch.finfo(dtype).max))
1482
+ .flatten(start_dim=-2)
1483
+ .min(dim=-1)
1484
+ .values
1485
+ )
1486
+
1487
+ y_coords_masked = y_coords * mask
1488
+ y_max = y_coords_masked.flatten(start_dim=-2).max(dim=-1).values + 1
1489
+ y_min = (
1490
+ torch.where(mask, y_coords_masked, torch.tensor(torch.finfo(dtype).max))
1491
+ .flatten(start_dim=-2)
1492
+ .min(dim=-1)
1493
+ .values
1494
+ )
1495
+
1496
+ unnormalized_bbox = torch.stack([x_min, y_min, x_max, y_max], dim=-1)
1497
+
1498
+ is_mask_non_empty = torch.any(mask, dim=(-2, -1)).unsqueeze(-1)
1499
+ unnormalized_bbox = unnormalized_bbox * is_mask_non_empty
1500
+
1501
+ norm_tensor = torch.tensor([width, height, width, height], device=mask.device, dtype=dtype)
1502
+ normalized_bbox_xyxy = unnormalized_bbox / norm_tensor
1503
+
1504
+ x_min_norm, y_min_norm, x_max_norm, y_max_norm = normalized_bbox_xyxy.unbind(dim=-1)
1505
+
1506
+ center_x = (x_min_norm + x_max_norm) / 2
1507
+ center_y = (y_min_norm + y_max_norm) / 2
1508
+ box_width = x_max_norm - x_min_norm
1509
+ box_height = y_max_norm - y_min_norm
1510
+
1511
+ return torch.stack([center_x, center_y, box_width, box_height], dim=-1)
1512
+
1513
+
1514
+ @auto_docstring(
1515
+ custom_intro="""
1516
+ PP-DocLayoutV3 Model (consisting of a backbone and encoder-decoder) outputting raw hidden states without any head on top.
1517
+ """
1518
+ )
1519
+ class PPDocLayoutV3Model(PPDocLayoutV3PreTrainedModel):
1520
+ _tied_weights_keys = {
1521
+ "decoder.class_embed": "enc_score_head",
1522
+ "decoder.bbox_embed": "enc_bbox_head",
1523
+ }
1524
+
1525
+ def __init__(self, config: PPDocLayoutV3Config):
1526
+ super().__init__(config)
1527
+
1528
+ # Create backbone
1529
+ self.backbone = PPDocLayoutV3ConvEncoder(config)
1530
+ intermediate_channel_sizes = self.backbone.intermediate_channel_sizes
1531
+
1532
+ # Create encoder input projection layers
1533
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/PPDocLayoutV3_pytorch/src/zoo/PPDocLayoutV3/hybrid_encoder.py#L212
1534
+ num_backbone_outs = len(intermediate_channel_sizes)
1535
+
1536
+ encoder_input_proj_list = []
1537
+ for i in range(num_backbone_outs):
1538
+ in_channels = intermediate_channel_sizes[i]
1539
+ encoder_input_proj_list.append(
1540
+ nn.Sequential(
1541
+ nn.Conv2d(in_channels, config.encoder_hidden_dim, kernel_size=1, bias=False),
1542
+ nn.BatchNorm2d(config.encoder_hidden_dim),
1543
+ )
1544
+ )
1545
+ self.encoder_input_proj = nn.ModuleList(encoder_input_proj_list[1:])
1546
+
1547
+ # Create encoder
1548
+ self.encoder = PPDocLayoutV3HybridEncoder(config)
1549
+
1550
+ # denoising part
1551
+ if config.num_denoising > 0:
1552
+ self.denoising_class_embed = nn.Embedding(
1553
+ config.num_labels + 1, config.d_model, padding_idx=config.num_labels
1554
+ )
1555
+
1556
+ # decoder embedding
1557
+ if config.learn_initial_query:
1558
+ self.weight_embedding = nn.Embedding(config.num_queries, config.d_model)
1559
+
1560
+ # encoder head
1561
+ self.enc_output = nn.Sequential(
1562
+ nn.Linear(config.d_model, config.d_model),
1563
+ nn.LayerNorm(config.d_model, eps=config.layer_norm_eps),
1564
+ )
1565
+ self.enc_score_head = nn.Linear(config.d_model, config.num_labels)
1566
+ self.enc_bbox_head = PPDocLayoutV3MLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
1567
+
1568
+ # init encoder output anchors and valid_mask
1569
+ if config.anchor_image_size:
1570
+ self.anchors, self.valid_mask = self.generate_anchors(dtype=self.dtype)
1571
+
1572
+ # Create decoder input projection layers
1573
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/PPDocLayoutV3_pytorch/src/zoo/PPDocLayoutV3/PPDocLayoutV3_decoder.py#L412
1574
+ num_backbone_outs = len(config.decoder_in_channels)
1575
+ decoder_input_proj_list = []
1576
+ for i in range(num_backbone_outs):
1577
+ in_channels = config.decoder_in_channels[i]
1578
+ decoder_input_proj_list.append(
1579
+ nn.Sequential(
1580
+ nn.Conv2d(in_channels, config.d_model, kernel_size=1, bias=False),
1581
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
1582
+ )
1583
+ )
1584
+ for _ in range(config.num_feature_levels - num_backbone_outs):
1585
+ decoder_input_proj_list.append(
1586
+ nn.Sequential(
1587
+ nn.Conv2d(in_channels, config.d_model, kernel_size=3, stride=2, padding=1, bias=False),
1588
+ nn.BatchNorm2d(config.d_model, config.batch_norm_eps),
1589
+ )
1590
+ )
1591
+ in_channels = config.d_model
1592
+ self.decoder_input_proj = nn.ModuleList(decoder_input_proj_list)
1593
+ self.decoder = PPDocLayoutV3Decoder(config)
1594
+
1595
+ self.decoder_order_head = nn.ModuleList(
1596
+ [nn.Linear(config.d_model, config.d_model) for _ in range(config.decoder_layers)]
1597
+ )
1598
+ self.decoder_global_pointer = PPDocLayoutV3GlobalPointer(config)
1599
+ self.decoder_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
1600
+ self.decoder.class_embed = nn.Linear(config.d_model, config.num_labels)
1601
+ self.decoder.bbox_embed = PPDocLayoutV3MLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
1602
+
1603
+ self.mask_enhanced = config.mask_enhanced
1604
+ self.mask_query_head = PPDocLayoutV3MLPPredictionHead(
1605
+ config.d_model, config.d_model, config.num_prototypes, num_layers=3
1606
+ )
1607
+
1608
+ self.post_init()
1609
+
1610
+ def freeze_backbone(self):
1611
+ for param in self.backbone.parameters():
1612
+ param.requires_grad_(False)
1613
+
1614
+ def unfreeze_backbone(self):
1615
+ for param in self.backbone.parameters():
1616
+ param.requires_grad_(True)
1617
+
1618
+ @compile_compatible_method_lru_cache(maxsize=32)
1619
+ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, device="cpu", dtype=torch.float32):
1620
+ if spatial_shapes is None:
1621
+ spatial_shapes = [
1622
+ [int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
1623
+ for s in self.config.feat_strides
1624
+ ]
1625
+ anchors = []
1626
+ for level, (height, width) in enumerate(spatial_shapes):
1627
+ grid_y, grid_x = torch.meshgrid(
1628
+ torch.arange(end=height, device=device).to(dtype),
1629
+ torch.arange(end=width, device=device).to(dtype),
1630
+ indexing="ij",
1631
+ )
1632
+ grid_xy = torch.stack([grid_x, grid_y], -1)
1633
+ grid_xy = grid_xy.unsqueeze(0) + 0.5
1634
+ grid_xy[..., 0] /= width
1635
+ grid_xy[..., 1] /= height
1636
+ wh = torch.ones_like(grid_xy) * grid_size * (2.0**level)
1637
+ anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
1638
+ # define the valid range for anchor coordinates
1639
+ eps = 1e-2
1640
+ anchors = torch.concat(anchors, 1)
1641
+ valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
1642
+ anchors = torch.log(anchors / (1 - anchors))
1643
+ anchors = torch.where(valid_mask, anchors, torch.tensor(torch.finfo(dtype).max, dtype=dtype, device=device))
1644
+
1645
+ return anchors, valid_mask
1646
+
1647
+ @auto_docstring
1648
+ @can_return_tuple
1649
+ def forward(
1650
+ self,
1651
+ pixel_values: torch.FloatTensor,
1652
+ pixel_mask: torch.LongTensor | None = None,
1653
+ encoder_outputs: torch.FloatTensor | None = None,
1654
+ labels: list[dict] | None = None,
1655
+ **kwargs: Unpack[TransformersKwargs],
1656
+ ) -> tuple[torch.FloatTensor] | PPDocLayoutV3ModelOutput:
1657
+ r"""
1658
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1659
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1660
+ can choose to directly pass a flattened representation of an image.
1661
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1662
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1663
+ embedded representation.
1664
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1665
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1666
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1667
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1668
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1669
+
1670
+ Examples:
1671
+
1672
+ ```python
1673
+ >>> from transformers import AutoImageProcessor, PPDocLayoutV2Model
1674
+ >>> from PIL import Image
1675
+ >>> import requests
1676
+
1677
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1678
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1679
+
1680
+ >>> image_processor = AutoImageProcessor.from_pretrained("PekingU/PPDocLayoutV2_r50vd")
1681
+ >>> model = PPDocLayoutV2Model.from_pretrained("PekingU/PPDocLayoutV2_r50vd")
1682
+
1683
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1684
+
1685
+ >>> outputs = model(**inputs)
1686
+
1687
+ >>> last_hidden_states = outputs.last_hidden_state
1688
+ >>> list(last_hidden_states.shape)
1689
+ [1, 300, 256]
1690
+ ```"""
1691
+ batch_size, num_channels, height, width = pixel_values.shape
1692
+ device = pixel_values.device
1693
+
1694
+ if pixel_mask is None:
1695
+ pixel_mask = torch.ones(((batch_size, height, width)), device=device)
1696
+
1697
+ features = self.backbone(pixel_values, pixel_mask)
1698
+ x4_feat = features.pop(0)
1699
+ proj_feats = [self.encoder_input_proj[level](source) for level, (source, mask) in enumerate(features)]
1700
+
1701
+ if encoder_outputs is None:
1702
+ encoder_outputs = self.encoder(
1703
+ proj_feats,
1704
+ x4_feat,
1705
+ **kwargs,
1706
+ )
1707
+ # If the user passed a tuple for encoder_outputs, we wrap it in a PPDocLayoutV3HybridEncoderOutput when return_dict=True
1708
+ elif not isinstance(encoder_outputs, PPDocLayoutV3HybridEncoderOutput):
1709
+ encoder_outputs = PPDocLayoutV3HybridEncoderOutput(
1710
+ last_hidden_state=encoder_outputs[0],
1711
+ hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
1712
+ attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
1713
+ mask_feat=encoder_outputs[-1],
1714
+ )
1715
+
1716
+ # Equivalent to def _get_encoder_input
1717
+ # https://github.com/lyuwenyu/RT-DETR/blob/94f5e16708329d2f2716426868ec89aa774af016/rtdetr_pytorch/src/zoo/rtdetr/rtdetr_decoder.py#L412
1718
+ sources = []
1719
+ for level, source in enumerate(encoder_outputs.last_hidden_state):
1720
+ sources.append(self.decoder_input_proj[level](source))
1721
+
1722
+ # Lowest resolution feature maps are obtained via 3x3 stride 2 convolutions on the final stage
1723
+ if self.config.num_feature_levels > len(sources):
1724
+ _len_sources = len(sources)
1725
+ sources.append(self.decoder_input_proj[_len_sources](encoder_outputs.last_hidden_state[-1]))
1726
+ for i in range(_len_sources + 1, self.config.num_feature_levels):
1727
+ sources.append(self.decoder_input_proj[i](encoder_outputs.last_hidden_state[-1]))
1728
+
1729
+ # Prepare encoder inputs (by flattening)
1730
+ source_flatten = []
1731
+ spatial_shapes_list = []
1732
+ spatial_shapes = torch.empty((len(sources), 2), device=device, dtype=torch.long)
1733
+ for level, source in enumerate(sources):
1734
+ height, width = source.shape[-2:]
1735
+ spatial_shapes[level, 0] = height
1736
+ spatial_shapes[level, 1] = width
1737
+ spatial_shapes_list.append((height, width))
1738
+ source = source.flatten(2).transpose(1, 2)
1739
+ source_flatten.append(source)
1740
+ source_flatten = torch.cat(source_flatten, 1)
1741
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
1742
+
1743
+ # prepare denoising training
1744
+ if self.training and self.config.num_denoising > 0 and labels is not None:
1745
+ (
1746
+ denoising_class,
1747
+ denoising_bbox_unact,
1748
+ attention_mask,
1749
+ denoising_meta_values,
1750
+ ) = get_contrastive_denoising_training_group(
1751
+ targets=labels,
1752
+ num_classes=self.config.num_labels,
1753
+ num_queries=self.config.num_queries,
1754
+ class_embed=self.denoising_class_embed,
1755
+ num_denoising_queries=self.config.num_denoising,
1756
+ label_noise_ratio=self.config.label_noise_ratio,
1757
+ box_noise_scale=self.config.box_noise_scale,
1758
+ )
1759
+ else:
1760
+ denoising_class, denoising_bbox_unact, attention_mask, denoising_meta_values = None, None, None, None
1761
+
1762
+ batch_size = len(source_flatten)
1763
+ device = source_flatten.device
1764
+ dtype = source_flatten.dtype
1765
+
1766
+ # prepare input for decoder
1767
+ if self.training or self.config.anchor_image_size is None:
1768
+ # Pass spatial_shapes as tuple to make it hashable and make sure
1769
+ # lru_cache is working for generate_anchors()
1770
+ spatial_shapes_tuple = tuple(spatial_shapes_list)
1771
+ anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple, device=device, dtype=dtype)
1772
+ else:
1773
+ anchors, valid_mask = self.anchors, self.valid_mask
1774
+ anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)
1775
+
1776
+ # use the valid_mask to selectively retain values in the feature map where the mask is `True`
1777
+ memory = valid_mask.to(source_flatten.dtype) * source_flatten
1778
+
1779
+ output_memory = self.enc_output(memory)
1780
+
1781
+ enc_outputs_class = self.enc_score_head(output_memory)
1782
+ enc_outputs_coord_logits = self.enc_bbox_head(output_memory) + anchors
1783
+
1784
+ _, topk_ind = torch.topk(enc_outputs_class.max(-1).values, self.config.num_queries, dim=1)
1785
+
1786
+ reference_points_unact = enc_outputs_coord_logits.gather(
1787
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_coord_logits.shape[-1])
1788
+ )
1789
+
1790
+ # _get_pred_class_and_mask
1791
+ batch_ind = torch.arange(memory.shape[0], device=output_memory.device).unsqueeze(1)
1792
+ target = output_memory[batch_ind, topk_ind]
1793
+ out_query = self.decoder_norm(target)
1794
+ mask_query_embed = self.mask_query_head(out_query)
1795
+ batch_size, mask_dim, _ = mask_query_embed.shape
1796
+
1797
+ enc_topk_bboxes = F.sigmoid(reference_points_unact)
1798
+
1799
+ enc_topk_logits = enc_outputs_class.gather(
1800
+ dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, enc_outputs_class.shape[-1])
1801
+ )
1802
+
1803
+ # extract region features
1804
+ if self.config.learn_initial_query:
1805
+ target = self.weight_embedding.tile([batch_size, 1, 1])
1806
+ else:
1807
+ target = output_memory.gather(dim=1, index=topk_ind.unsqueeze(-1).repeat(1, 1, output_memory.shape[-1]))
1808
+ target = target.detach()
1809
+
1810
+ if denoising_class is not None:
1811
+ target = torch.concat([denoising_class, target], 1)
1812
+
1813
+ if self.mask_enhanced:
1814
+ _, _, mask_h, mask_w = encoder_outputs.mask_feat.shape
1815
+ enc_out_masks = torch.bmm(mask_query_embed, encoder_outputs.mask_feat.flatten(start_dim=2)).reshape(
1816
+ batch_size, mask_dim, mask_h, mask_w
1817
+ )
1818
+ reference_points = mask_to_box_coordinate(enc_out_masks > 0, dtype=reference_points_unact.dtype)
1819
+ reference_points_unact = inverse_sigmoid(reference_points)
1820
+
1821
+ if denoising_bbox_unact is not None:
1822
+ reference_points_unact = torch.concat([denoising_bbox_unact, reference_points_unact], 1)
1823
+
1824
+ init_reference_points = reference_points_unact.detach()
1825
+
1826
+ # decoder
1827
+ decoder_outputs = self.decoder(
1828
+ inputs_embeds=target,
1829
+ encoder_hidden_states=source_flatten,
1830
+ encoder_attention_mask=attention_mask,
1831
+ reference_points=init_reference_points,
1832
+ spatial_shapes=spatial_shapes,
1833
+ spatial_shapes_list=spatial_shapes_list,
1834
+ level_start_index=level_start_index,
1835
+ order_head=self.decoder_order_head,
1836
+ global_pointer=self.decoder_global_pointer,
1837
+ mask_query_head=self.mask_query_head,
1838
+ norm=self.decoder_norm,
1839
+ mask_feat=encoder_outputs.mask_feat,
1840
+ **kwargs,
1841
+ )
1842
+
1843
+ return PPDocLayoutV3ModelOutput(
1844
+ last_hidden_state=decoder_outputs.last_hidden_state,
1845
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
1846
+ intermediate_logits=decoder_outputs.intermediate_logits,
1847
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
1848
+ intermediate_predicted_corners=decoder_outputs.intermediate_predicted_corners,
1849
+ initial_reference_points=decoder_outputs.initial_reference_points,
1850
+ decoder_hidden_states=decoder_outputs.hidden_states,
1851
+ decoder_attentions=decoder_outputs.attentions,
1852
+ cross_attentions=decoder_outputs.cross_attentions,
1853
+ out_order_logits=decoder_outputs.decoder_out_order_logits,
1854
+ out_masks=decoder_outputs.decoder_out_masks,
1855
+ encoder_last_hidden_state=encoder_outputs.last_hidden_state,
1856
+ encoder_hidden_states=encoder_outputs.hidden_states,
1857
+ encoder_attentions=encoder_outputs.attentions,
1858
+ init_reference_points=init_reference_points,
1859
+ enc_topk_logits=enc_topk_logits,
1860
+ enc_topk_bboxes=enc_topk_bboxes,
1861
+ enc_outputs_class=enc_outputs_class,
1862
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
1863
+ denoising_meta_values=denoising_meta_values,
1864
+ )
1865
+
1866
+
1867
+ @dataclass
1868
+ @auto_docstring
1869
+ class PPDocLayoutV3HybridEncoderOutput(BaseModelOutput):
1870
+ r"""
1871
+ mask_feat (`torch.FloatTensor` of shape `(batch_size, config.num_queries, 200, 200)`):
1872
+ Mask features for each query in the batch.
1873
+ """
1874
+
1875
+ mask_feat: torch.FloatTensor = None
1876
+
1877
+
1878
+ @dataclass
1879
+ @auto_docstring
1880
+ class PPDocLayoutV3ForObjectDetectionOutput(ModelOutput):
1881
+ r"""
1882
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
1883
+ Classification logits (including no-object) for all queries.
1884
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1885
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
1886
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
1887
+ possible padding). You can use [`~PPDocLayoutV3ImageProcessorFast.post_process_object_detection`] to retrieve the
1888
+ unnormalized (absolute) bounding boxes.
1889
+ order_logits (`tuple` of `torch.FloatTensor` of shape `(batch_size, num_queries, num_queries)`):
1890
+ Order logits of the final layer of the decoder.
1891
+ out_masks (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, height, width)`):
1892
+ Masks of the final layer of the decoder.
1893
+ last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`):
1894
+ Sequence of hidden-states at the output of the last layer of the decoder of the model.
1895
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
1896
+ Stacked intermediate hidden states (output of each layer of the decoder).
1897
+ intermediate_logits (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, config.num_labels)`):
1898
+ Stacked intermediate logits (logits of each layer of the decoder).
1899
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1900
+ Stacked intermediate reference points (reference points of each layer of the decoder).
1901
+ intermediate_predicted_corners (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1902
+ Stacked intermediate predicted corners (predicted corners of each layer of the decoder).
1903
+ initial_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1904
+ Stacked initial reference points (initial reference points of each layer of the decoder).
1905
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1906
+ Initial reference points sent through the Transformer decoder.
1907
+ enc_topk_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1908
+ Logits of predicted bounding boxes coordinates in the encoder.
1909
+ enc_topk_bboxes (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1910
+ Logits of predicted bounding boxes coordinates in the encoder.
1911
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1912
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
1913
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
1914
+ foreground and background).
1915
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1916
+ Logits of predicted bounding boxes coordinates in the first stage.
1917
+ denoising_meta_values (`dict`):
1918
+ Extra dictionary for the denoising related values
1919
+ """
1920
+
1921
+ logits: torch.FloatTensor | None = None
1922
+ pred_boxes: torch.FloatTensor | None = None
1923
+ order_logits: torch.FloatTensor | None = None
1924
+ out_masks: torch.FloatTensor | None = None
1925
+ last_hidden_state: torch.FloatTensor | None = None
1926
+ intermediate_hidden_states: torch.FloatTensor | None = None
1927
+ intermediate_logits: torch.FloatTensor | None = None
1928
+ intermediate_reference_points: torch.FloatTensor | None = None
1929
+ intermediate_predicted_corners: torch.FloatTensor | None = None
1930
+ initial_reference_points: torch.FloatTensor | None = None
1931
+ decoder_hidden_states: tuple[torch.FloatTensor] | None = None
1932
+ decoder_attentions: tuple[torch.FloatTensor] | None = None
1933
+ cross_attentions: tuple[torch.FloatTensor] | None = None
1934
+ encoder_last_hidden_state: torch.FloatTensor | None = None
1935
+ encoder_hidden_states: tuple[torch.FloatTensor] | None = None
1936
+ encoder_attentions: tuple[torch.FloatTensor] | None = None
1937
+ init_reference_points: tuple[torch.FloatTensor] | None = None
1938
+ enc_topk_logits: torch.FloatTensor | None = None
1939
+ enc_topk_bboxes: torch.FloatTensor | None = None
1940
+ enc_outputs_class: torch.FloatTensor | None = None
1941
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
1942
+ denoising_meta_values: dict | None = None
1943
+
1944
+
1945
+ @auto_docstring(
1946
+ custom_intro="""
1947
+ PP-DocLayoutV3 Model (consisting of a backbone and encoder-decoder) outputs bounding boxes and logits sorted according to reading order,
1948
+ which are further decoded into scores and classes.
1949
+ """
1950
+ )
1951
+ class PPDocLayoutV3ForObjectDetection(PPDocLayoutV3PreTrainedModel):
1952
+ # When using clones, all layers > 0 will be clones, but layer 0 *is* required
1953
+ # We can't initialize the model on meta device as some weights are modified during the initialization
1954
+ _no_split_modules = None
1955
+ _keys_to_ignore_on_load_missing = ["num_batches_tracked", "rel_pos_y_bias", "rel_pos_x_bias"]
1956
+
1957
+ def __init__(self, config: PPDocLayoutV3Config):
1958
+ super().__init__(config)
1959
+ self.model = PPDocLayoutV3Model(config)
1960
+
1961
+ self.model.denoising_class_embed = nn.Embedding(config.num_labels, config.d_model)
1962
+ self.num_queries = config.num_queries
1963
+ # if two-stage, the last class_embed and bbox_embed is for region proposal generation
1964
+ self.post_init()
1965
+
1966
+ def _set_aux_loss(self, outputs_class, outputs_coord):
1967
+ return [{"logits": a, "pred_boxes": b} for a, b in zip(outputs_class, outputs_coord)]
1968
+
1969
+ @auto_docstring
1970
+ @can_return_tuple
1971
+ def forward(
1972
+ self,
1973
+ pixel_values: torch.FloatTensor,
1974
+ pixel_mask: torch.LongTensor | None = None,
1975
+ encoder_outputs: torch.FloatTensor | None = None,
1976
+ labels: list[dict] | None = None,
1977
+ **kwargs: Unpack[TransformersKwargs],
1978
+ ) -> tuple[torch.FloatTensor] | PPDocLayoutV3ForObjectDetectionOutput:
1979
+ r"""
1980
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1981
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1982
+ can choose to directly pass a flattened representation of an image.
1983
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1984
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1985
+ embedded representation.
1986
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1987
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1988
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1989
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1990
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1991
+
1992
+ Examples:
1993
+
1994
+ ```python
1995
+ >>> from transformers import AutoModelForObjectDetection, AutoImageProcessor
1996
+ >>> from PIL import Image
1997
+ >>> import requests
1998
+ >>> import torch
1999
+
2000
+ >>> url = "https://paddle-model-ecology.bj.bcebos.com/paddlex/imgs/demo_image/layout_demo.jpg"
2001
+ >>> image = Image.open(requests.get(url, stream=True).raw)
2002
+
2003
+ >>> model_path = "PaddlePaddle/PP-DocLayoutV3_safetensors"
2004
+ >>> image_processor = AutoImageProcessor.from_pretrained(model_path)
2005
+ >>> model = AutoModelForObjectDetection.from_pretrained(model_path)
2006
+
2007
+ >>> # prepare image for the model
2008
+ >>> inputs = image_processor(images=[image], return_tensors="pt")
2009
+
2010
+ >>> # forward pass
2011
+ >>> outputs = model(**inputs)
2012
+
2013
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
2014
+ >>> results = image_processor.post_process_object_detection(outputs, target_sizes=torch.tensor([image.size[::-1]]))
2015
+
2016
+ >>> # print outputs
2017
+ >>> for result in results:
2018
+ ... for idx, (score, label_id, box) in enumerate(zip(result["scores"], result["labels"], result["boxes"])):
2019
+ ... score, label = score.item(), label_id.item()
2020
+ ... box = [round(i, 2) for i in box.tolist()]
2021
+ ... print(f"Order {idx + 1}: {model.config.id2label[label]}: {score:.2f} {box}")
2022
+ Order 1: text: 0.99 [334.95, 184.78, 897.25, 654.83]
2023
+ Order 2: paragraph_title: 0.97 [337.28, 683.92, 869.16, 798.35]
2024
+ Order 3: text: 0.99 [335.75, 842.82, 892.13, 1454.32]
2025
+ Order 4: text: 0.99 [920.18, 185.28, 1476.38, 464.49]
2026
+ Order 5: text: 0.98 [920.47, 483.68, 1480.63, 765.72]
2027
+ Order 6: text: 0.98 [920.62, 846.8, 1482.09, 1220.67]
2028
+ Order 7: text: 0.97 [920.92, 1239.41, 1469.55, 1378.02]
2029
+ Order 8: footnote: 0.86 [335.03, 1614.68, 1483.33, 1731.73]
2030
+ Order 9: footnote: 0.83 [334.64, 1756.74, 1471.78, 1845.69]
2031
+ Order 10: text: 0.81 [336.8, 1910.52, 661.64, 1939.92]
2032
+ Order 11: footnote: 0.96 [336.24, 2114.42, 1450.14, 2172.12]
2033
+ Order 12: number: 0.88 [106.0, 2257.5, 135.84, 2282.18]
2034
+ Order 13: footer: 0.93 [338.4, 2255.52, 986.15, 2284.37]
2035
+ ```"""
2036
+ outputs = self.model(
2037
+ pixel_values,
2038
+ pixel_mask=pixel_mask,
2039
+ encoder_outputs=encoder_outputs,
2040
+ labels=labels,
2041
+ **kwargs,
2042
+ )
2043
+
2044
+ intermediate_logits = outputs.intermediate_logits
2045
+ intermediate_reference_points = outputs.intermediate_reference_points
2046
+ order_logits = outputs.out_order_logits
2047
+ out_masks = outputs.out_masks
2048
+
2049
+ pred_boxes = intermediate_reference_points[:, -1]
2050
+ logits = intermediate_logits[:, -1]
2051
+ order_logits = order_logits[:, -1]
2052
+ out_masks = out_masks[:, -1]
2053
+
2054
+ if labels is not None:
2055
+ raise ValueError("PPDocLayoutV3ForObjectDetection does not support training")
2056
+
2057
+ return PPDocLayoutV3ForObjectDetectionOutput(
2058
+ logits=logits,
2059
+ pred_boxes=pred_boxes,
2060
+ order_logits=order_logits,
2061
+ out_masks=out_masks,
2062
+ last_hidden_state=outputs.last_hidden_state,
2063
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
2064
+ intermediate_logits=outputs.intermediate_logits,
2065
+ intermediate_reference_points=outputs.intermediate_reference_points,
2066
+ intermediate_predicted_corners=outputs.intermediate_predicted_corners,
2067
+ initial_reference_points=outputs.initial_reference_points,
2068
+ decoder_hidden_states=outputs.decoder_hidden_states,
2069
+ decoder_attentions=outputs.decoder_attentions,
2070
+ cross_attentions=outputs.cross_attentions,
2071
+ encoder_last_hidden_state=outputs.encoder_last_hidden_state,
2072
+ encoder_hidden_states=outputs.encoder_hidden_states,
2073
+ encoder_attentions=outputs.encoder_attentions,
2074
+ init_reference_points=outputs.init_reference_points,
2075
+ enc_topk_logits=outputs.enc_topk_logits,
2076
+ enc_topk_bboxes=outputs.enc_topk_bboxes,
2077
+ enc_outputs_class=outputs.enc_outputs_class,
2078
+ enc_outputs_coord_logits=outputs.enc_outputs_coord_logits,
2079
+ denoising_meta_values=outputs.denoising_meta_values,
2080
+ )
2081
+
2082
+
2083
+ __all__ = ["PPDocLayoutV3ForObjectDetection", "PPDocLayoutV3Model", "PPDocLayoutV3PreTrainedModel"]