transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1697 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/lw_detr/modular_lw_detr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_lw_detr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+ import collections.abc
21
+ import math
22
+ import warnings
23
+ from collections.abc import Callable
24
+ from dataclasses import dataclass
25
+ from typing import Any
26
+
27
+ import torch
28
+ import torch.nn.functional as F
29
+ from torch import Tensor, nn
30
+
31
+ from ... import initialization as init
32
+ from ...activations import ACT2CLS, ACT2FN
33
+ from ...backbone_utils import BackboneMixin
34
+ from ...integrations import use_kernel_forward_from_hub
35
+ from ...modeling_layers import GradientCheckpointingLayer
36
+ from ...modeling_outputs import BackboneOutput, BaseModelOutputWithCrossAttentions
37
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from ...processing_utils import Unpack
39
+ from ...pytorch_utils import meshgrid
40
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check
41
+ from ...utils.generic import check_model_inputs
42
+ from .configuration_lw_detr import LwDetrConfig, LwDetrViTConfig
43
+
44
+
45
+ def eager_attention_forward(
46
+ module: nn.Module,
47
+ query: torch.Tensor,
48
+ key: torch.Tensor,
49
+ value: torch.Tensor,
50
+ attention_mask: torch.Tensor | None,
51
+ scaling: float,
52
+ dropout: float = 0.0,
53
+ **kwargs: Unpack[TransformersKwargs],
54
+ ):
55
+ key_states = repeat_kv(key, module.num_key_value_groups)
56
+ value_states = repeat_kv(value, module.num_key_value_groups)
57
+
58
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
59
+ if attention_mask is not None:
60
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
61
+ attn_weights = attn_weights + causal_mask
62
+
63
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
64
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
65
+ attn_output = torch.matmul(attn_weights, value_states)
66
+ attn_output = attn_output.transpose(1, 2).contiguous()
67
+
68
+ return attn_output, attn_weights
69
+
70
+
71
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
72
+ """
73
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
74
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
75
+ """
76
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
77
+ if n_rep == 1:
78
+ return hidden_states
79
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
80
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
81
+
82
+
83
+ class LwDetrViTSelfAttention(nn.Module):
84
+ def __init__(self, config: LwDetrViTConfig):
85
+ super().__init__()
86
+ if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
87
+ raise ValueError(
88
+ f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
89
+ f"heads {config.num_attention_heads}."
90
+ )
91
+
92
+ self.config = config
93
+ self.num_attention_heads = config.num_attention_heads
94
+ self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
95
+ self.all_head_size = self.num_attention_heads * self.attention_head_size
96
+ self.dropout_prob = config.dropout_prob
97
+ self.scaling = self.attention_head_size**-0.5
98
+ self.is_causal = False
99
+
100
+ self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
101
+ self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
102
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
103
+ self.num_key_value_groups = 1
104
+
105
+ def forward(
106
+ self,
107
+ hidden_states: torch.Tensor,
108
+ **kwargs: Unpack[TransformersKwargs],
109
+ ) -> tuple[torch.Tensor, torch.Tensor]:
110
+ batch_size = hidden_states.shape[0]
111
+ new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
112
+
113
+ key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
114
+ value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
115
+ query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
116
+
117
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
118
+ self.config._attn_implementation, eager_attention_forward
119
+ )
120
+
121
+ context_layer, attention_probs = attention_interface(
122
+ self,
123
+ query_layer,
124
+ key_layer,
125
+ value_layer,
126
+ None,
127
+ is_causal=self.is_causal,
128
+ scaling=self.scaling,
129
+ dropout=0.0 if not self.training else self.dropout_prob,
130
+ **kwargs,
131
+ )
132
+
133
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
134
+ context_layer = context_layer.reshape(new_context_layer_shape)
135
+
136
+ return context_layer, attention_probs
137
+
138
+
139
+ class LwDetrViTAttention(nn.Module):
140
+ def __init__(self, config: LwDetrViTConfig):
141
+ """
142
+ Args:
143
+ config (`LwDetrViTConfig`):
144
+ Model configuration.
145
+ """
146
+ super().__init__()
147
+ self.attention = LwDetrViTSelfAttention(config)
148
+ self.output = nn.Linear(config.hidden_size, config.hidden_size)
149
+
150
+ def forward(
151
+ self,
152
+ hidden_states: torch.Tensor,
153
+ **kwargs: Unpack[TransformersKwargs],
154
+ ) -> torch.Tensor:
155
+ self_attn_output, _ = self.attention(hidden_states, **kwargs)
156
+ output = self.output(self_attn_output)
157
+ return output
158
+
159
+
160
+ class LwDetrViTMlp(nn.Module):
161
+ def __init__(self, config, in_features: int, hidden_features: int) -> None:
162
+ super().__init__()
163
+ self.fc1 = nn.Linear(in_features, hidden_features)
164
+ self.act = ACT2FN[config.hidden_act]
165
+ self.fc2 = nn.Linear(hidden_features, in_features)
166
+ self.drop = nn.Dropout(config.dropout_prob)
167
+
168
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
169
+ x = self.fc1(x)
170
+ x = self.act(x)
171
+ x = self.drop(x)
172
+ x = self.fc2(x)
173
+ x = self.drop(x)
174
+
175
+ return x
176
+
177
+
178
+ class LwDetrViTLayer(GradientCheckpointingLayer):
179
+ def __init__(
180
+ self,
181
+ config: LwDetrViTConfig,
182
+ layer_idx,
183
+ ) -> None:
184
+ super().__init__()
185
+
186
+ dim = config.hidden_size
187
+ self.attention = LwDetrViTAttention(config)
188
+ self.intermediate = LwDetrViTMlp(config=config, in_features=dim, hidden_features=int(dim * config.mlp_ratio))
189
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
190
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
191
+
192
+ self.gamma_1 = nn.Parameter(torch.Tensor(dim), requires_grad=True)
193
+ self.gamma_2 = nn.Parameter(torch.Tensor(dim), requires_grad=True)
194
+
195
+ self.window = layer_idx in config.window_block_indices
196
+ self.num_windows = config.num_windows
197
+
198
+ def forward(
199
+ self,
200
+ hidden_states: torch.Tensor,
201
+ **kwargs: Unpack[TransformersKwargs],
202
+ ) -> torch.Tensor:
203
+ batch_size, seq_len, channels = hidden_states.shape
204
+ hidden_states_norm = self.layernorm_before(hidden_states)
205
+
206
+ if not self.window:
207
+ hidden_states_norm = hidden_states_norm.reshape(
208
+ batch_size // self.num_windows, self.num_windows * seq_len, channels
209
+ )
210
+
211
+ attention_output = self.attention(hidden_states_norm, **kwargs)
212
+ attention_output = attention_output * self.gamma_1
213
+
214
+ if not self.window:
215
+ attention_output = attention_output.reshape(batch_size, seq_len, channels)
216
+
217
+ hidden_states = hidden_states + attention_output
218
+
219
+ layer_output = self.layernorm_after(hidden_states)
220
+ layer_output = self.intermediate(layer_output)
221
+ layer_output = layer_output * self.gamma_2
222
+
223
+ hidden_states = hidden_states + layer_output
224
+
225
+ return hidden_states
226
+
227
+
228
+ class LwDetrViTEncoder(nn.Module):
229
+ def __init__(self, config: LwDetrViTConfig) -> None:
230
+ super().__init__()
231
+ self.config = config
232
+ self.layer = nn.ModuleList([LwDetrViTLayer(config, i) for i in range(config.num_hidden_layers)])
233
+ self.gradient_checkpointing = False
234
+
235
+ def forward(
236
+ self,
237
+ hidden_states: torch.Tensor,
238
+ **kwargs: Unpack[TransformersKwargs],
239
+ ) -> list[torch.Tensor]:
240
+ list_hidden_states = [hidden_states]
241
+ for i, layer_module in enumerate(self.layer):
242
+ hidden_states = layer_module(hidden_states, **kwargs)
243
+ list_hidden_states.append(hidden_states)
244
+ return list_hidden_states
245
+
246
+
247
+ class LwDetrViTEmbeddings(nn.Module):
248
+ """
249
+ This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
250
+ `hidden_states` (patch embeddings) to be consumed by a Transformer.
251
+ """
252
+
253
+ def __init__(self, config):
254
+ super().__init__()
255
+ image_size, patch_size = config.pretrain_image_size, config.patch_size
256
+ num_channels, hidden_size = config.num_channels, config.hidden_size
257
+
258
+ image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
259
+ patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
260
+ num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
261
+ self.image_size = image_size
262
+ self.patch_size = patch_size
263
+ self.num_channels = num_channels
264
+ self.num_patches = num_patches
265
+
266
+ if config.use_absolute_position_embeddings:
267
+ # Initialize absolute positional embedding with pretrain image size.
268
+ num_positions = num_patches + 1
269
+ self.position_embeddings = nn.Parameter(torch.zeros(1, num_positions, config.hidden_size))
270
+ else:
271
+ self.position_embeddings = None
272
+
273
+ self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
274
+
275
+ def get_absolute_positions(self, abs_pos_embeddings, has_cls_token, height, width):
276
+ """
277
+ Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
278
+ original embeddings.
279
+
280
+ Args:
281
+ abs_pos_embeddings (`torch.Tensor`):
282
+ Absolute positional embeddings with (1, num_position, num_channels).
283
+ has_cls_token (`bool`):
284
+ If true, has 1 embedding in abs_pos_embeddings for cls token.
285
+ height (`int`):
286
+ Height of input image tokens.
287
+ width (`int`):
288
+ Width of input image tokens.
289
+
290
+ Returns:
291
+ Absolute positional embeddings after processing with shape (1, height, width, num_channels)
292
+ """
293
+ if has_cls_token:
294
+ abs_pos_embeddings = abs_pos_embeddings[:, 1:]
295
+ num_position = abs_pos_embeddings.shape[1]
296
+ size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export.
297
+ if size * size != num_position:
298
+ raise ValueError("Absolute position embeddings must be a square number.")
299
+
300
+ if torch.jit.is_tracing() or (size != height or size != width):
301
+ # nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
302
+ new_abs_pos_embeddings = nn.functional.interpolate(
303
+ abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2),
304
+ size=(height, width),
305
+ mode="bicubic",
306
+ align_corners=False,
307
+ )
308
+
309
+ return new_abs_pos_embeddings.permute(0, 2, 3, 1)
310
+ else:
311
+ return abs_pos_embeddings.reshape(1, height, width, -1)
312
+
313
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
314
+ num_channels = pixel_values.shape[1]
315
+ if num_channels != self.num_channels:
316
+ raise ValueError(
317
+ "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
318
+ f" Expected {self.num_channels} but got {num_channels}."
319
+ )
320
+ embeddings = self.projection(pixel_values)
321
+
322
+ if self.position_embeddings is not None:
323
+ # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
324
+ embeddings = embeddings.permute(0, 2, 3, 1)
325
+ # add position embeddings
326
+ embeddings = embeddings + self.get_absolute_positions(
327
+ self.position_embeddings, True, embeddings.shape[1], embeddings.shape[2]
328
+ )
329
+ # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
330
+ embeddings = embeddings.permute(0, 3, 1, 2)
331
+
332
+ return embeddings
333
+
334
+
335
+ @auto_docstring
336
+ class LwDetrViTPreTrainedModel(PreTrainedModel):
337
+ config: LwDetrViTConfig
338
+ base_model_prefix = "lw_detr_vit"
339
+ main_input_name = "pixel_values"
340
+ input_modalities = ("image",)
341
+ supports_gradient_checkpointing = True
342
+ _no_split_modules = ["LwDetrViTEmbeddings", "LwDetrViTLayer"]
343
+ _supports_sdpa = True
344
+ _supports_flash_attn = True
345
+ _supports_flex_attn = True
346
+ _supports_attention_backend = True
347
+ _can_record_outputs = {
348
+ "hidden_states": LwDetrViTLayer,
349
+ "attentions": LwDetrViTSelfAttention,
350
+ }
351
+
352
+ @torch.no_grad()
353
+ def _init_weights(self, module) -> None:
354
+ """Initialize the weights"""
355
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
356
+ init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
357
+ if module.bias is not None:
358
+ init.zeros_(module.bias)
359
+ elif isinstance(module, nn.LayerNorm):
360
+ init.zeros_(module.bias)
361
+ init.ones_(module.weight)
362
+ elif isinstance(module, LwDetrViTEmbeddings):
363
+ init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
364
+ if isinstance(module, LwDetrViTLayer):
365
+ nn.init.constant_(module.gamma_1, self.config.cae_init_values)
366
+ nn.init.constant_(module.gamma_2, self.config.cae_init_values)
367
+
368
+
369
+ @auto_docstring()
370
+ class LwDetrViTBackbone(BackboneMixin, LwDetrViTPreTrainedModel):
371
+ def __init__(self, config):
372
+ super().__init__(config)
373
+
374
+ self.embeddings = LwDetrViTEmbeddings(config)
375
+ self.encoder = LwDetrViTEncoder(config)
376
+ self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
377
+
378
+ # initialize weights and apply final processing
379
+ self.post_init()
380
+
381
+ def get_input_embeddings(self) -> LwDetrViTEmbeddings:
382
+ return self.embeddings.projection
383
+
384
+ @check_model_inputs
385
+ @auto_docstring
386
+ def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BackboneOutput:
387
+ r"""
388
+ Examples:
389
+
390
+ ```python
391
+ >>> from transformers import LwDetrViTConfig, LwDetrViTBackbone
392
+ >>> import torch
393
+
394
+ >>> config = LwDetrViTConfig()
395
+ >>> model = LwDetrViTBackbone(config)
396
+
397
+ >>> pixel_values = torch.randn(1, 3, 224, 224)
398
+
399
+ >>> with torch.no_grad():
400
+ ... outputs = model(pixel_values)
401
+
402
+ >>> feature_maps = outputs.feature_maps
403
+ >>> list(feature_maps[-1].shape)
404
+ [1, 768, 14, 14]
405
+ ```"""
406
+ embedding_output = self.embeddings(pixel_values)
407
+
408
+ batch_size, channels, height, width = embedding_output.shape
409
+ # (batch_size, channels, height, width) -> (batch_size, height, width, channels)
410
+ hidden_states = embedding_output.permute(0, 2, 3, 1)
411
+
412
+ window_height = height // self.config.num_windows_side
413
+ window_width = width // self.config.num_windows_side
414
+ # (batch_size, height, width, channels) -> (batch_size*num_windows_side**2, window_height*window_width, channels)
415
+ hidden_states = (
416
+ hidden_states.reshape(
417
+ batch_size,
418
+ self.config.num_windows_side,
419
+ window_height,
420
+ self.config.num_windows_side,
421
+ window_width,
422
+ channels,
423
+ )
424
+ .permute(0, 1, 3, 2, 4, 5)
425
+ .reshape(batch_size * self.config.num_windows_side**2, window_height * window_width, channels)
426
+ )
427
+
428
+ hidden_states = self.encoder(hidden_states, **kwargs)
429
+
430
+ feature_maps = ()
431
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
432
+ if stage in self.out_features:
433
+ hidden_state = (
434
+ hidden_state.reshape(
435
+ batch_size,
436
+ self.config.num_windows_side,
437
+ self.config.num_windows_side,
438
+ window_height,
439
+ window_width,
440
+ channels,
441
+ )
442
+ .permute(0, 5, 1, 3, 2, 4)
443
+ .reshape(batch_size, channels, height, width)
444
+ )
445
+ feature_maps += (hidden_state,)
446
+
447
+ return BackboneOutput(feature_maps=feature_maps)
448
+
449
+
450
+ class LwDetrConvNormLayer(nn.Module):
451
+ def __init__(
452
+ self,
453
+ config: LwDetrConfig,
454
+ in_channels: int,
455
+ out_channels: int,
456
+ kernel_size: int,
457
+ stride: int,
458
+ activation: str | None = None,
459
+ ):
460
+ super().__init__()
461
+ self.conv = nn.Conv2d(
462
+ in_channels,
463
+ out_channels,
464
+ kernel_size,
465
+ stride,
466
+ padding=kernel_size // 2,
467
+ bias=False,
468
+ )
469
+ self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
470
+ self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
471
+
472
+ def forward(self, hidden_state):
473
+ hidden_state = self.conv(hidden_state)
474
+ hidden_state = self.norm(hidden_state)
475
+ hidden_state = self.activation(hidden_state)
476
+ return hidden_state
477
+
478
+
479
+ class LwDetrRepVggBlock(nn.Module):
480
+ def __init__(self, config: LwDetrConfig):
481
+ super().__init__()
482
+ hidden_channels = int(config.d_model * config.hidden_expansion)
483
+ self.conv1 = LwDetrConvNormLayer(
484
+ config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
485
+ )
486
+ self.conv2 = LwDetrConvNormLayer(
487
+ config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
488
+ )
489
+
490
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
491
+ y = self.conv1(x)
492
+ y = self.conv2(y)
493
+ return y
494
+
495
+
496
+ class LwDetrC2FLayer(nn.Module):
497
+ # Inspired by RTDetrCSPRepLayer
498
+ def __init__(self, config: LwDetrConfig, in_channels: int):
499
+ super().__init__()
500
+ num_blocks = config.c2f_num_blocks
501
+ activation = config.activation_function
502
+ out_channels = config.d_model
503
+
504
+ self.hidden_channels = int(out_channels * config.hidden_expansion)
505
+
506
+ conv1_out_channels = 2 * self.hidden_channels
507
+ self.conv1 = LwDetrConvNormLayer(config, in_channels, conv1_out_channels, 1, 1, activation=activation)
508
+
509
+ conv2_in_channels = (2 + num_blocks) * self.hidden_channels
510
+ self.conv2 = LwDetrConvNormLayer(config, conv2_in_channels, out_channels, 1, 1, activation=activation)
511
+
512
+ self.bottlenecks = nn.ModuleList(LwDetrRepVggBlock(config) for _ in range(num_blocks))
513
+
514
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
515
+ hidden_states = self.conv1(hidden_states)
516
+ all_hidden_states = list(hidden_states.split(self.hidden_channels, 1))
517
+ hidden_states = all_hidden_states[-1]
518
+
519
+ for bottleneck in self.bottlenecks:
520
+ hidden_states = bottleneck(hidden_states)
521
+ all_hidden_states.append(hidden_states)
522
+
523
+ hidden_states = torch.cat(all_hidden_states, 1)
524
+ hidden_states = self.conv2(hidden_states)
525
+ return hidden_states
526
+
527
+
528
+ class LwDetrLayerNorm(nn.LayerNorm):
529
+ r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
530
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
531
+ width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
532
+ """
533
+
534
+ def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
535
+ super().__init__(normalized_shape, eps=eps, **kwargs)
536
+ if data_format not in ["channels_last", "channels_first"]:
537
+ raise NotImplementedError(f"Unsupported data format: {data_format}")
538
+ self.data_format = data_format
539
+
540
+ def forward(self, features: torch.Tensor) -> torch.Tensor:
541
+ """
542
+ Args:
543
+ features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
544
+ """
545
+ if self.data_format == "channels_first":
546
+ features = features.permute(0, 2, 3, 1)
547
+ features = super().forward(features)
548
+ features = features.permute(0, 3, 1, 2)
549
+ else:
550
+ features = super().forward(features)
551
+ return features
552
+
553
+
554
+ class LwDetrSamplingLayer(nn.Module):
555
+ def __init__(self, config: LwDetrConfig, channel_size: int, scale: float):
556
+ super().__init__()
557
+
558
+ self.scale = scale
559
+ self.channel_size = channel_size
560
+
561
+ layers = []
562
+ if scale == 2.0:
563
+ if channel_size > 512:
564
+ layers.append(LwDetrConvNormLayer(config, channel_size, channel_size // 2, 1, 1, activation="relu"))
565
+ layers.append(nn.ConvTranspose2d(channel_size // 2, channel_size // 4, kernel_size=2, stride=2))
566
+ else:
567
+ layers.append(nn.ConvTranspose2d(channel_size, channel_size // 2, 2, 2))
568
+ elif scale == 0.5:
569
+ layers.append(LwDetrConvNormLayer(config, channel_size, channel_size, 3, 2, activation="relu"))
570
+ self.layers = nn.ModuleList(layers)
571
+
572
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
573
+ for layer in self.layers:
574
+ hidden_states = layer(hidden_states)
575
+ return hidden_states
576
+
577
+
578
+ class LwDetrScaleProjector(nn.Module):
579
+ def __init__(self, config: LwDetrConfig, scale: float):
580
+ super().__init__()
581
+
582
+ intermediate_dims = [config.backbone_config.hidden_size] * len(config.backbone_config.out_indices)
583
+ sampling_layers = []
584
+ for channel_size in intermediate_dims:
585
+ sampling_layers.append(LwDetrSamplingLayer(config, channel_size, scale))
586
+ self.sampling_layers = nn.ModuleList(sampling_layers)
587
+
588
+ intermediate_dim = intermediate_dims[-1]
589
+ if scale == 2.0:
590
+ if intermediate_dim > 512:
591
+ intermediate_dim = intermediate_dim // 4
592
+ else:
593
+ intermediate_dim = intermediate_dim // 2
594
+ projector_input_dim = intermediate_dim * len(intermediate_dims)
595
+
596
+ self.projector_layer = LwDetrC2FLayer(config, projector_input_dim)
597
+ self.layer_norm = LwDetrLayerNorm(config.d_model, data_format="channels_first")
598
+
599
+ def forward(self, hidden_states_tuple: tuple[torch.Tensor]) -> torch.Tensor:
600
+ sampled_hidden_states = []
601
+ for sampling_layer, hidden_states in zip(self.sampling_layers, hidden_states_tuple):
602
+ hidden_states = sampling_layer(hidden_states)
603
+ sampled_hidden_states.append(hidden_states)
604
+ hidden_states = torch.cat(sampled_hidden_states, dim=1)
605
+ hidden_states = self.projector_layer(hidden_states)
606
+ hidden_states = self.layer_norm(hidden_states)
607
+ return hidden_states
608
+
609
+
610
+ class LwDetrMultiScaleProjector(nn.Module):
611
+ def __init__(self, config: LwDetrConfig):
612
+ super().__init__()
613
+
614
+ self.config = config
615
+ scale_factors = config.projector_scale_factors
616
+
617
+ self.scale_layers = nn.ModuleList([LwDetrScaleProjector(config, scale) for scale in scale_factors])
618
+
619
+ def forward(self, hidden_states: tuple[torch.Tensor]) -> list[torch.Tensor]:
620
+ output_hidden_states = []
621
+ for scale_layer in self.scale_layers:
622
+ output_hidden_states.append(scale_layer(hidden_states))
623
+ return output_hidden_states
624
+
625
+
626
+ class LwDetrConvEncoder(nn.Module):
627
+ def __init__(self, config: LwDetrConfig):
628
+ super().__init__()
629
+ self.backbone = LwDetrViTBackbone(config.backbone_config)
630
+ self.projector = LwDetrMultiScaleProjector(config)
631
+
632
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
633
+ # send pixel_values through the model to get list of feature maps
634
+ features = self.backbone(pixel_values).feature_maps
635
+ features = self.projector(features)
636
+ out = []
637
+ for feature_map in features:
638
+ # downsample pixel_mask to match shape of corresponding feature_map
639
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
640
+ out.append((feature_map, mask))
641
+ return out
642
+
643
+
644
+ class LwDetrAttention(nn.Module):
645
+ def __init__(self, config: LwDetrConfig, layer_idx: int):
646
+ super().__init__()
647
+ self.config = config
648
+ self.layer_idx = layer_idx
649
+ self.head_dim = getattr(config, "head_dim", config.d_model // config.decoder_self_attention_heads)
650
+ self.scaling = self.head_dim**-0.5
651
+ self.attention_dropout = config.attention_dropout
652
+ self.is_causal = False
653
+ self.num_key_value_groups = 1
654
+
655
+ self.q_proj = nn.Linear(
656
+ config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
657
+ )
658
+ self.k_proj = nn.Linear(
659
+ config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
660
+ )
661
+ self.v_proj = nn.Linear(
662
+ config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
663
+ )
664
+ self.o_proj = nn.Linear(
665
+ config.decoder_self_attention_heads * self.head_dim, config.d_model, bias=config.attention_bias
666
+ )
667
+
668
+ def forward(
669
+ self,
670
+ hidden_states: torch.Tensor,
671
+ position_embeddings: torch.Tensor | None = None,
672
+ **kwargs: Unpack[TransformersKwargs],
673
+ ) -> tuple[torch.Tensor, torch.Tensor]:
674
+ batch_size, seq_len, _ = hidden_states.shape
675
+ input_shape = hidden_states.shape[:-1]
676
+ hidden_shape = (*input_shape, -1, self.head_dim)
677
+
678
+ hidden_states_original = hidden_states
679
+ if position_embeddings is not None:
680
+ hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
681
+
682
+ if self.training:
683
+ # at training, we use group detr technique to add more supervision by using multiple weight-sharing decoders at once for faster convergence
684
+ # at inference, we only use one decoder
685
+ hidden_states_original = torch.cat(
686
+ hidden_states_original.split(seq_len // self.config.group_detr, dim=1), dim=0
687
+ )
688
+ hidden_states = torch.cat(hidden_states.split(seq_len // self.config.group_detr, dim=1), dim=0)
689
+
690
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
691
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
692
+ value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2)
693
+
694
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
695
+ self.config._attn_implementation, eager_attention_forward
696
+ )
697
+
698
+ attn_output, attn_weights = attention_interface(
699
+ self,
700
+ query_states,
701
+ key_states,
702
+ value_states,
703
+ attention_mask=None,
704
+ dropout=0.0 if not self.training else self.attention_dropout,
705
+ scaling=self.scaling,
706
+ **kwargs,
707
+ )
708
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
709
+ attn_output = self.o_proj(attn_output)
710
+
711
+ if self.training:
712
+ attn_output = torch.cat(torch.split(attn_output, batch_size, dim=0), dim=1)
713
+
714
+ return attn_output, attn_weights
715
+
716
+
717
+ @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
718
+ class MultiScaleDeformableAttention(nn.Module):
719
+ def forward(
720
+ self,
721
+ value: Tensor,
722
+ value_spatial_shapes: Tensor,
723
+ value_spatial_shapes_list: list[tuple],
724
+ level_start_index: Tensor,
725
+ sampling_locations: Tensor,
726
+ attention_weights: Tensor,
727
+ im2col_step: int,
728
+ ):
729
+ batch_size, _, num_heads, hidden_dim = value.shape
730
+ _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
731
+ value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
732
+ sampling_grids = 2 * sampling_locations - 1
733
+ sampling_value_list = []
734
+ for level_id, (height, width) in enumerate(value_spatial_shapes_list):
735
+ # batch_size, height*width, num_heads, hidden_dim
736
+ # -> batch_size, height*width, num_heads*hidden_dim
737
+ # -> batch_size, num_heads*hidden_dim, height*width
738
+ # -> batch_size*num_heads, hidden_dim, height, width
739
+ value_l_ = (
740
+ value_list[level_id]
741
+ .flatten(2)
742
+ .transpose(1, 2)
743
+ .reshape(batch_size * num_heads, hidden_dim, height, width)
744
+ )
745
+ # batch_size, num_queries, num_heads, num_points, 2
746
+ # -> batch_size, num_heads, num_queries, num_points, 2
747
+ # -> batch_size*num_heads, num_queries, num_points, 2
748
+ sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
749
+ # batch_size*num_heads, hidden_dim, num_queries, num_points
750
+ sampling_value_l_ = nn.functional.grid_sample(
751
+ value_l_,
752
+ sampling_grid_l_,
753
+ mode="bilinear",
754
+ padding_mode="zeros",
755
+ align_corners=False,
756
+ )
757
+ sampling_value_list.append(sampling_value_l_)
758
+ # (batch_size, num_queries, num_heads, num_levels, num_points)
759
+ # -> (batch_size, num_heads, num_queries, num_levels, num_points)
760
+ # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
761
+ attention_weights = attention_weights.transpose(1, 2).reshape(
762
+ batch_size * num_heads, 1, num_queries, num_levels * num_points
763
+ )
764
+ output = (
765
+ (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
766
+ .sum(-1)
767
+ .view(batch_size, num_heads * hidden_dim, num_queries)
768
+ )
769
+ return output.transpose(1, 2).contiguous()
770
+
771
+
772
+ class LwDetrMultiscaleDeformableAttention(nn.Module):
773
+ """
774
+ Multiscale deformable attention as proposed in Deformable DETR.
775
+ """
776
+
777
+ def __init__(self, config: LwDetrConfig, num_heads: int, n_points: int):
778
+ super().__init__()
779
+
780
+ self.attn = MultiScaleDeformableAttention()
781
+
782
+ if config.d_model % num_heads != 0:
783
+ raise ValueError(
784
+ f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
785
+ )
786
+ dim_per_head = config.d_model // num_heads
787
+ # check if dim_per_head is power of 2
788
+ if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
789
+ warnings.warn(
790
+ "You'd better set embed_dim (d_model) in LwDetrMultiscaleDeformableAttention to make the"
791
+ " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
792
+ " implementation."
793
+ )
794
+
795
+ self.im2col_step = 64
796
+
797
+ self.d_model = config.d_model
798
+ self.n_levels = config.num_feature_levels
799
+ self.n_heads = num_heads
800
+ self.n_points = n_points
801
+
802
+ self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
803
+ self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
804
+ self.value_proj = nn.Linear(config.d_model, config.d_model)
805
+ self.output_proj = nn.Linear(config.d_model, config.d_model)
806
+
807
+ self.disable_custom_kernels = config.disable_custom_kernels
808
+
809
+ def forward(
810
+ self,
811
+ hidden_states: torch.Tensor,
812
+ attention_mask: torch.Tensor | None = None,
813
+ encoder_hidden_states=None,
814
+ encoder_attention_mask=None,
815
+ position_embeddings: torch.Tensor | None = None,
816
+ reference_points=None,
817
+ spatial_shapes=None,
818
+ spatial_shapes_list=None,
819
+ level_start_index=None,
820
+ **kwargs: Unpack[TransformersKwargs],
821
+ ) -> tuple[torch.Tensor, torch.Tensor]:
822
+ # add position embeddings to the hidden states before projecting to queries and keys
823
+ if position_embeddings is not None:
824
+ hidden_states = hidden_states + position_embeddings
825
+
826
+ batch_size, num_queries, _ = hidden_states.shape
827
+ batch_size, sequence_length, _ = encoder_hidden_states.shape
828
+ total_elements = sum(height * width for height, width in spatial_shapes_list)
829
+ torch_compilable_check(
830
+ total_elements == sequence_length,
831
+ "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
832
+ )
833
+
834
+ value = self.value_proj(encoder_hidden_states)
835
+ if attention_mask is not None:
836
+ # we invert the attention_mask
837
+ value = value.masked_fill(~attention_mask[..., None], float(0))
838
+ value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
839
+ sampling_offsets = self.sampling_offsets(hidden_states).view(
840
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
841
+ )
842
+ attention_weights = self.attention_weights(hidden_states).view(
843
+ batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
844
+ )
845
+ attention_weights = F.softmax(attention_weights, -1).view(
846
+ batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
847
+ )
848
+ # batch_size, num_queries, n_heads, n_levels, n_points, 2
849
+ num_coordinates = reference_points.shape[-1]
850
+ if num_coordinates == 2:
851
+ offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
852
+ sampling_locations = (
853
+ reference_points[:, :, None, :, None, :]
854
+ + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
855
+ )
856
+ elif num_coordinates == 4:
857
+ sampling_locations = (
858
+ reference_points[:, :, None, :, None, :2]
859
+ + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
860
+ )
861
+ else:
862
+ raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
863
+
864
+ output = self.attn(
865
+ value,
866
+ spatial_shapes,
867
+ spatial_shapes_list,
868
+ level_start_index,
869
+ sampling_locations,
870
+ attention_weights,
871
+ self.im2col_step,
872
+ )
873
+
874
+ output = self.output_proj(output)
875
+
876
+ return output, attention_weights
877
+
878
+
879
+ class LwDetrMLP(nn.Module):
880
+ def __init__(self, config: LwDetrConfig):
881
+ super().__init__()
882
+ self.dropout = config.dropout
883
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
884
+ self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
885
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
886
+
887
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
888
+ residual = hidden_states
889
+ hidden_states = self.fc1(hidden_states)
890
+ hidden_states = self.activation_fn(hidden_states)
891
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
892
+ hidden_states = self.fc2(hidden_states)
893
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
894
+ hidden_states = residual + hidden_states
895
+ return hidden_states
896
+
897
+
898
+ class LwDetrDecoderLayer(GradientCheckpointingLayer):
899
+ def __init__(self, config: LwDetrConfig, layer_idx: int):
900
+ nn.Module.__init__(self)
901
+
902
+ # self-attention
903
+ self.self_attn = LwDetrAttention(config, layer_idx=layer_idx)
904
+ self.dropout = config.dropout
905
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
906
+ self.activation_dropout = config.activation_dropout
907
+ self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
908
+
909
+ # cross-attention
910
+ self.cross_attn = LwDetrMultiscaleDeformableAttention(
911
+ config,
912
+ num_heads=config.decoder_cross_attention_heads,
913
+ n_points=config.decoder_n_points,
914
+ )
915
+ self.cross_attn_layer_norm = nn.LayerNorm(config.d_model)
916
+
917
+ # mlp
918
+ self.mlp = LwDetrMLP(config)
919
+ self.layer_norm = nn.LayerNorm(config.d_model)
920
+
921
+ def forward(
922
+ self,
923
+ hidden_states: torch.Tensor,
924
+ position_embeddings: torch.Tensor | None = None,
925
+ reference_points=None,
926
+ spatial_shapes=None,
927
+ spatial_shapes_list=None,
928
+ level_start_index=None,
929
+ encoder_hidden_states: torch.Tensor | None = None,
930
+ encoder_attention_mask: torch.Tensor | None = None,
931
+ **kwargs: Unpack[TransformersKwargs],
932
+ ):
933
+ self_attention_output, self_attn_weights = self.self_attn(
934
+ hidden_states, position_embeddings=position_embeddings, **kwargs
935
+ )
936
+
937
+ self_attention_output = nn.functional.dropout(self_attention_output, p=self.dropout, training=self.training)
938
+ hidden_states = hidden_states + self_attention_output
939
+ hidden_states = self.self_attn_layer_norm(hidden_states)
940
+
941
+ cross_attention_output, cross_attn_weights = self.cross_attn(
942
+ hidden_states=hidden_states,
943
+ attention_mask=encoder_attention_mask,
944
+ encoder_hidden_states=encoder_hidden_states,
945
+ encoder_attention_mask=encoder_attention_mask,
946
+ position_embeddings=position_embeddings,
947
+ reference_points=reference_points,
948
+ spatial_shapes=spatial_shapes,
949
+ spatial_shapes_list=spatial_shapes_list,
950
+ level_start_index=level_start_index,
951
+ **kwargs,
952
+ )
953
+ cross_attention_output = nn.functional.dropout(cross_attention_output, p=self.dropout, training=self.training)
954
+ hidden_states = hidden_states + cross_attention_output
955
+ hidden_states = self.cross_attn_layer_norm(hidden_states)
956
+
957
+ hidden_states = self.mlp(hidden_states)
958
+ hidden_states = self.layer_norm(hidden_states)
959
+
960
+ return hidden_states
961
+
962
+
963
+ @auto_docstring
964
+ class LwDetrPreTrainedModel(PreTrainedModel):
965
+ config: LwDetrConfig
966
+ base_model_prefix = "model"
967
+ main_input_name = "pixel_values"
968
+ _no_split_modules = [
969
+ r"LwDetrConvEncoder",
970
+ r"LwDetrDecoderLayer",
971
+ ]
972
+ _supports_sdpa = True
973
+ _supports_flash_attn = True
974
+ _supports_flex_attn = True
975
+ _supports_attention_backend = True
976
+ _can_record_outputs = {
977
+ "attentions": [LwDetrAttention, LwDetrMultiscaleDeformableAttention],
978
+ "hidden_states": [LwDetrDecoderLayer],
979
+ }
980
+
981
+ @torch.no_grad()
982
+ def _init_weights(self, module):
983
+ super()._init_weights(module)
984
+
985
+ if isinstance(module, LwDetrMultiscaleDeformableAttention):
986
+ init.constant_(module.sampling_offsets.weight, 0.0)
987
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads)
988
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
989
+ grid_init = (
990
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
991
+ .view(module.n_heads, 1, 1, 2)
992
+ .repeat(1, module.n_levels, module.n_points, 1)
993
+ )
994
+ for i in range(module.n_points):
995
+ grid_init[:, :, i, :] *= i + 1
996
+
997
+ init.copy_(module.sampling_offsets.bias, grid_init.view(-1))
998
+ init.constant_(module.attention_weights.weight, 0.0)
999
+ init.constant_(module.attention_weights.bias, 0.0)
1000
+ init.xavier_uniform_(module.value_proj.weight)
1001
+ init.constant_(module.value_proj.bias, 0.0)
1002
+ init.xavier_uniform_(module.output_proj.weight)
1003
+ init.constant_(module.output_proj.bias, 0.0)
1004
+ if hasattr(module, "level_embed"):
1005
+ init.normal_(module.level_embed)
1006
+ if hasattr(module, "refpoint_embed") and module.refpoint_embed is not None:
1007
+ init.constant_(module.refpoint_embed.weight, 0)
1008
+ if hasattr(module, "class_embed") and module.class_embed is not None:
1009
+ prior_prob = 0.01
1010
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
1011
+ init.constant_(module.class_embed.bias, bias_value)
1012
+ if hasattr(module, "bbox_embed") and module.bbox_embed is not None:
1013
+ init.constant_(module.bbox_embed.layers[-1].weight, 0)
1014
+ init.constant_(module.bbox_embed.layers[-1].bias, 0)
1015
+
1016
+
1017
+ @dataclass
1018
+ @auto_docstring(
1019
+ custom_intro="""
1020
+ Base class for outputs of the LwDetrDecoder. This class adds two attributes to
1021
+ BaseModelOutputWithCrossAttentions, namely:
1022
+ - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
1023
+ - a stacked tensor of intermediate reference points.
1024
+ """
1025
+ )
1026
+ class LwDetrDecoderOutput(BaseModelOutputWithCrossAttentions):
1027
+ r"""
1028
+ cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
1029
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1030
+ sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
1031
+ used to compute the weighted average in the cross-attention heads.
1032
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(config.decoder_layers, batch_size, num_queries, hidden_size)`, *optional*, returned when `config.auxiliary_loss=True`):
1033
+ Intermediate decoder activations, i.e. the output of each decoder layer, each of them gone through a
1034
+ layernorm.
1035
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
1036
+ Stacked intermediate reference points (reference points of each layer of the decoder).
1037
+ """
1038
+
1039
+ intermediate_hidden_states: torch.FloatTensor | None = None
1040
+
1041
+ intermediate_reference_points: torch.FloatTensor | None = None
1042
+
1043
+
1044
+ # function to generate sine positional embedding for 4d coordinates
1045
+ def gen_sine_position_embeddings(pos_tensor, hidden_size=256):
1046
+ """
1047
+ This function computes position embeddings using sine and cosine functions from the input positional tensor,
1048
+ which has a shape of (batch_size, num_queries, 4).
1049
+ The last dimension of `pos_tensor` represents the following coordinates:
1050
+ - 0: x-coord
1051
+ - 1: y-coord
1052
+ - 2: width
1053
+ - 3: height
1054
+
1055
+ The output shape is (batch_size, num_queries, 512), where final dim (hidden_size*2 = 512) is the total embedding dimension
1056
+ achieved by concatenating the sine and cosine values for each coordinate.
1057
+ """
1058
+ scale = 2 * math.pi
1059
+ dim = hidden_size // 2
1060
+ dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
1061
+ dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
1062
+ x_embed = pos_tensor[:, :, 0] * scale
1063
+ y_embed = pos_tensor[:, :, 1] * scale
1064
+ pos_x = x_embed[:, :, None] / dim_t
1065
+ pos_y = y_embed[:, :, None] / dim_t
1066
+ pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
1067
+ pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
1068
+ if pos_tensor.size(-1) == 4:
1069
+ w_embed = pos_tensor[:, :, 2] * scale
1070
+ pos_w = w_embed[:, :, None] / dim_t
1071
+ pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
1072
+
1073
+ h_embed = pos_tensor[:, :, 3] * scale
1074
+ pos_h = h_embed[:, :, None] / dim_t
1075
+ pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
1076
+
1077
+ pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
1078
+ else:
1079
+ raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
1080
+ return pos.to(pos_tensor.dtype)
1081
+
1082
+
1083
+ class LwDetrDecoder(LwDetrPreTrainedModel):
1084
+ """
1085
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeformableDetrDecoderLayer`].
1086
+
1087
+ The decoder updates the query embeddings through multiple self-attention and deformable cross-attention layers.
1088
+
1089
+ Some tweaks for LwDetr:
1090
+
1091
+ - it uses group detr technique at training for faster convergence.
1092
+
1093
+ Args:
1094
+ config: LwDetrConfig
1095
+ """
1096
+
1097
+ def __init__(self, config: LwDetrConfig):
1098
+ super().__init__(config)
1099
+ self.dropout = config.dropout
1100
+ self.layers = nn.ModuleList([LwDetrDecoderLayer(config, i) for i in range(config.decoder_layers)])
1101
+ self.layernorm = nn.LayerNorm(config.d_model)
1102
+
1103
+ self.gradient_checkpointing = False
1104
+
1105
+ self.ref_point_head = LwDetrMLPPredictionHead(2 * config.d_model, config.d_model, config.d_model, num_layers=2)
1106
+
1107
+ self.post_init()
1108
+
1109
+ def get_reference(self, reference_points, valid_ratios):
1110
+ # batch_size, num_queries, batch_size, 4
1111
+ obj_center = reference_points[..., :4]
1112
+
1113
+ # batch_size, num_queries, num_levels, 4
1114
+ reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
1115
+
1116
+ # batch_size, num_queries, d_model * 2
1117
+ query_sine_embed = gen_sine_position_embeddings(reference_points_inputs[:, :, 0, :], self.config.d_model)
1118
+
1119
+ # batch_size, num_queries, d_model
1120
+ query_pos = self.ref_point_head(query_sine_embed)
1121
+ return reference_points_inputs, query_pos
1122
+
1123
+ def forward(
1124
+ self,
1125
+ inputs_embeds: torch.Tensor | None = None,
1126
+ reference_points: torch.Tensor | None = None,
1127
+ spatial_shapes: torch.Tensor | None = None,
1128
+ spatial_shapes_list: torch.Tensor | None = None,
1129
+ level_start_index: torch.Tensor | None = None,
1130
+ valid_ratios: torch.Tensor | None = None,
1131
+ encoder_hidden_states: torch.Tensor | None = None,
1132
+ encoder_attention_mask: torch.Tensor | None = None,
1133
+ **kwargs: Unpack[TransformersKwargs],
1134
+ ):
1135
+ intermediate = ()
1136
+ intermediate_reference_points = (reference_points,)
1137
+
1138
+ if inputs_embeds is not None:
1139
+ hidden_states = inputs_embeds
1140
+
1141
+ reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios)
1142
+
1143
+ for idx, decoder_layer in enumerate(self.layers):
1144
+ hidden_states = decoder_layer(
1145
+ hidden_states,
1146
+ encoder_hidden_states=encoder_hidden_states,
1147
+ encoder_attention_mask=encoder_attention_mask,
1148
+ position_embeddings=query_pos,
1149
+ reference_points=reference_points_inputs,
1150
+ spatial_shapes=spatial_shapes,
1151
+ spatial_shapes_list=spatial_shapes_list,
1152
+ level_start_index=level_start_index,
1153
+ **kwargs,
1154
+ )
1155
+ intermediate_hidden_states = self.layernorm(hidden_states)
1156
+ intermediate += (intermediate_hidden_states,)
1157
+
1158
+ intermediate = torch.stack(intermediate)
1159
+ last_hidden_state = intermediate[-1]
1160
+ intermediate_reference_points = torch.stack(intermediate_reference_points)
1161
+
1162
+ return LwDetrDecoderOutput(
1163
+ last_hidden_state=last_hidden_state,
1164
+ intermediate_hidden_states=intermediate,
1165
+ intermediate_reference_points=intermediate_reference_points,
1166
+ )
1167
+
1168
+
1169
+ @dataclass
1170
+ @auto_docstring(
1171
+ custom_intro="""
1172
+ Base class for outputs of the LwDetr backbone-decoder model.
1173
+ """
1174
+ )
1175
+ class LwDetrModelOutput(ModelOutput):
1176
+ r"""
1177
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1178
+ Initial reference points sent through the Transformer decoder.
1179
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
1180
+ Stacked intermediate hidden states (output of each layer of the decoder).
1181
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1182
+ Stacked intermediate reference points (reference points of each layer of the decoder).
1183
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1184
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
1185
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
1186
+ foreground and background).
1187
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1188
+ Logits of predicted bounding boxes coordinates in the first stage.
1189
+ """
1190
+
1191
+ init_reference_points: torch.FloatTensor | None = None
1192
+ last_hidden_state: torch.FloatTensor | None = None
1193
+ intermediate_hidden_states: torch.FloatTensor | None = None
1194
+ intermediate_reference_points: torch.FloatTensor | None = None
1195
+ enc_outputs_class: torch.FloatTensor | None = None
1196
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
1197
+
1198
+
1199
+ def refine_bboxes(reference_points, deltas):
1200
+ reference_points = reference_points.to(deltas.device)
1201
+ new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2]
1202
+ new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:]
1203
+ new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1)
1204
+ return new_reference_points
1205
+
1206
+
1207
+ @auto_docstring(
1208
+ custom_intro="""
1209
+ The bare LW Detr Model (consisting of a backbone and decoder Transformer) outputting raw
1210
+ hidden-states without any specific head on top.
1211
+ """
1212
+ )
1213
+ class LwDetrModel(LwDetrPreTrainedModel):
1214
+ def __init__(self, config: LwDetrConfig):
1215
+ super().__init__(config)
1216
+
1217
+ # Create backbone + positional encoding
1218
+ self.backbone = LwDetrConvEncoder(config)
1219
+
1220
+ self.group_detr = config.group_detr
1221
+ self.num_queries = config.num_queries
1222
+ hidden_dim = config.d_model
1223
+ self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4)
1224
+ self.query_feat = nn.Embedding(self.num_queries * self.group_detr, hidden_dim)
1225
+
1226
+ self.decoder = LwDetrDecoder(config)
1227
+
1228
+ self.enc_output = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(self.group_detr)])
1229
+ self.enc_output_norm = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(self.group_detr)])
1230
+ # Should normally be None and then instantiated in the ForObjectDetection class
1231
+ self.enc_out_bbox_embed = nn.ModuleList(
1232
+ [LwDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(self.group_detr)]
1233
+ )
1234
+ self.enc_out_class_embed = nn.ModuleList(
1235
+ [nn.Linear(config.d_model, config.num_labels) for _ in range(self.group_detr)]
1236
+ )
1237
+
1238
+ self.post_init()
1239
+
1240
+ def freeze_backbone(self):
1241
+ for name, param in self.backbone.model.named_parameters():
1242
+ param.requires_grad_(False)
1243
+
1244
+ def unfreeze_backbone(self):
1245
+ for name, param in self.backbone.model.named_parameters():
1246
+ param.requires_grad_(True)
1247
+
1248
+ def get_valid_ratio(self, mask, dtype=torch.float32):
1249
+ """Get the valid ratio of all feature maps."""
1250
+
1251
+ _, height, width = mask.shape
1252
+ valid_height = torch.sum(mask[:, :, 0], 1)
1253
+ valid_width = torch.sum(mask[:, 0, :], 1)
1254
+ valid_ratio_height = valid_height.to(dtype) / height
1255
+ valid_ratio_width = valid_width.to(dtype) / width
1256
+ valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
1257
+ return valid_ratio
1258
+
1259
+ def get_proposal_pos_embed(self, proposals):
1260
+ """Get the position embedding of the proposals."""
1261
+
1262
+ num_pos_feats = self.config.d_model // 2
1263
+ temperature = 10000
1264
+ scale = 2 * math.pi
1265
+
1266
+ # Compute position embeddings in float32 to avoid overflow with large temperature values in fp16
1267
+ proposals_dtype = proposals.dtype
1268
+ dim_t = torch.arange(num_pos_feats, dtype=torch.float32, device=proposals.device)
1269
+ dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
1270
+ # batch_size, num_queries, 4
1271
+ proposals = proposals.sigmoid().to(torch.float32) * scale
1272
+ # batch_size, num_queries, 4, 128
1273
+ pos = proposals[:, :, :, None] / dim_t
1274
+ # batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512
1275
+ pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
1276
+ # Convert back to target dtype after all computations are done
1277
+ return pos.to(proposals_dtype)
1278
+
1279
+ def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
1280
+ """Generate the encoder output proposals from encoded enc_output.
1281
+
1282
+ Args:
1283
+ enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
1284
+ padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
1285
+ spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps.
1286
+
1287
+ Returns:
1288
+ `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
1289
+ - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to
1290
+ directly predict a bounding box. (without the need of a decoder)
1291
+ - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse
1292
+ sigmoid.
1293
+ """
1294
+ batch_size = enc_output.shape[0]
1295
+ proposals = []
1296
+ _cur = 0
1297
+ for level, (height, width) in enumerate(spatial_shapes):
1298
+ mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
1299
+ valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
1300
+ valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
1301
+
1302
+ grid_y, grid_x = meshgrid(
1303
+ torch.linspace(
1304
+ 0,
1305
+ height - 1,
1306
+ height,
1307
+ dtype=enc_output.dtype,
1308
+ device=enc_output.device,
1309
+ ),
1310
+ torch.linspace(
1311
+ 0,
1312
+ width - 1,
1313
+ width,
1314
+ dtype=enc_output.dtype,
1315
+ device=enc_output.device,
1316
+ ),
1317
+ indexing="ij",
1318
+ )
1319
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
1320
+
1321
+ scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
1322
+ grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
1323
+ width_height = torch.ones_like(grid) * 0.05 * (2.0**level)
1324
+ proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4)
1325
+ proposals.append(proposal)
1326
+ _cur += height * width
1327
+ output_proposals = torch.cat(proposals, 1)
1328
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
1329
+ output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf"))
1330
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
1331
+
1332
+ # assign each pixel as an object query
1333
+ object_query = enc_output
1334
+ object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0))
1335
+ object_query = object_query.masked_fill(~output_proposals_valid, float(0))
1336
+ return object_query, output_proposals
1337
+
1338
+ @check_model_inputs
1339
+ @auto_docstring
1340
+ def forward(
1341
+ self,
1342
+ pixel_values: torch.FloatTensor = None,
1343
+ pixel_mask: torch.LongTensor | None = None,
1344
+ **kwargs: Unpack[TransformersKwargs],
1345
+ ) -> LwDetrModelOutput:
1346
+ r"""
1347
+ Examples:
1348
+
1349
+ ```python
1350
+ >>> from transformers import AutoImageProcessor, DeformableDetrModel
1351
+ >>> from PIL import Image
1352
+ >>> import httpx
1353
+ >>> from io import BytesIO
1354
+
1355
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1356
+ >>> with httpx.stream("GET", url) as response:
1357
+ ... image = Image.open(BytesIO(response.read()))
1358
+
1359
+ >>> image_processor = AutoImageProcessor.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1360
+ >>> model = DeformableDetrModel.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1361
+
1362
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1363
+
1364
+ >>> outputs = model(**inputs)
1365
+
1366
+ >>> last_hidden_states = outputs.last_hidden_state
1367
+ >>> list(last_hidden_states.shape)
1368
+ [1, 300, 256]
1369
+ ```"""
1370
+ batch_size, num_channels, height, width = pixel_values.shape
1371
+ device = pixel_values.device
1372
+
1373
+ if pixel_mask is None:
1374
+ pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
1375
+
1376
+ # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
1377
+ # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1378
+ # which is a list of tuples
1379
+ features = self.backbone(pixel_values, pixel_mask)
1380
+
1381
+ # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1382
+ sources = []
1383
+ masks = []
1384
+ for level, (source, mask) in enumerate(features):
1385
+ sources.append(source)
1386
+ masks.append(mask)
1387
+ if mask is None:
1388
+ raise ValueError("No attention mask was provided")
1389
+
1390
+ if self.training:
1391
+ reference_points = self.reference_point_embed.weight
1392
+ query_feat = self.query_feat.weight
1393
+ else:
1394
+ # only use one group in inference
1395
+ reference_points = self.reference_point_embed.weight[: self.num_queries]
1396
+ query_feat = self.query_feat.weight[: self.num_queries]
1397
+
1398
+ # Prepare encoder inputs (by flattening)
1399
+ source_flatten = []
1400
+ mask_flatten = []
1401
+ spatial_shapes_list = []
1402
+ for source, mask in zip(sources, masks):
1403
+ batch_size, num_channels, height, width = source.shape
1404
+ spatial_shape = (height, width)
1405
+ spatial_shapes_list.append(spatial_shape)
1406
+ source = source.flatten(2).transpose(1, 2)
1407
+ mask = mask.flatten(1)
1408
+ source_flatten.append(source)
1409
+ mask_flatten.append(mask)
1410
+ source_flatten = torch.cat(source_flatten, 1)
1411
+ mask_flatten = torch.cat(mask_flatten, 1)
1412
+ spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
1413
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
1414
+ valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
1415
+
1416
+ target = query_feat.unsqueeze(0).expand(batch_size, -1, -1)
1417
+ reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)
1418
+
1419
+ object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
1420
+ source_flatten, ~mask_flatten, spatial_shapes_list
1421
+ )
1422
+
1423
+ group_detr = self.group_detr if self.training else 1
1424
+ topk = self.num_queries
1425
+ topk_coords_logits = []
1426
+ topk_coords_logits_undetach = []
1427
+ object_query_undetach = []
1428
+
1429
+ for group_id in range(group_detr):
1430
+ group_object_query = self.enc_output[group_id](object_query_embedding)
1431
+ group_object_query = self.enc_output_norm[group_id](group_object_query)
1432
+
1433
+ group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query)
1434
+ group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query)
1435
+ group_enc_outputs_coord = refine_bboxes(output_proposals, group_delta_bbox)
1436
+
1437
+ group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1]
1438
+ group_topk_coords_logits_undetach = torch.gather(
1439
+ group_enc_outputs_coord,
1440
+ 1,
1441
+ group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
1442
+ )
1443
+ group_topk_coords_logits = group_topk_coords_logits_undetach.detach()
1444
+ group_object_query_undetach = torch.gather(
1445
+ group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.config.d_model)
1446
+ )
1447
+
1448
+ topk_coords_logits.append(group_topk_coords_logits)
1449
+ topk_coords_logits_undetach.append(group_topk_coords_logits_undetach)
1450
+ object_query_undetach.append(group_object_query_undetach)
1451
+
1452
+ topk_coords_logits = torch.cat(topk_coords_logits, 1)
1453
+ topk_coords_logits_undetach = torch.cat(topk_coords_logits_undetach, 1)
1454
+ object_query_undetach = torch.cat(object_query_undetach, 1)
1455
+
1456
+ enc_outputs_class = object_query_undetach
1457
+ enc_outputs_coord_logits = topk_coords_logits
1458
+
1459
+ reference_points = refine_bboxes(topk_coords_logits_undetach, reference_points)
1460
+
1461
+ init_reference_points = reference_points
1462
+ decoder_outputs = self.decoder(
1463
+ inputs_embeds=target,
1464
+ reference_points=reference_points,
1465
+ spatial_shapes=spatial_shapes,
1466
+ spatial_shapes_list=spatial_shapes_list,
1467
+ level_start_index=level_start_index,
1468
+ valid_ratios=valid_ratios,
1469
+ encoder_hidden_states=source_flatten,
1470
+ encoder_attention_mask=mask_flatten,
1471
+ **kwargs,
1472
+ )
1473
+
1474
+ return LwDetrModelOutput(
1475
+ init_reference_points=init_reference_points,
1476
+ last_hidden_state=decoder_outputs.last_hidden_state,
1477
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
1478
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
1479
+ enc_outputs_class=enc_outputs_class,
1480
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
1481
+ )
1482
+
1483
+
1484
+ class LwDetrMLPPredictionHead(nn.Module):
1485
+ """
1486
+ Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1487
+ height and width of a bounding box w.r.t. an image.
1488
+
1489
+ """
1490
+
1491
+ def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
1492
+ super().__init__()
1493
+ self.num_layers = num_layers
1494
+ h = [hidden_dim] * (num_layers - 1)
1495
+ self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1496
+
1497
+ def forward(self, x):
1498
+ for i, layer in enumerate(self.layers):
1499
+ x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1500
+ return x
1501
+
1502
+
1503
+ @dataclass
1504
+ @auto_docstring(
1505
+ custom_intro="""
1506
+ Output type of [`LwDetrForObjectDetection`].
1507
+ """
1508
+ )
1509
+ class LwDetrObjectDetectionOutput(ModelOutput):
1510
+ r"""
1511
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
1512
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
1513
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
1514
+ scale-invariant IoU loss.
1515
+ loss_dict (`Dict`, *optional*):
1516
+ A dictionary containing the individual losses. Useful for logging.
1517
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
1518
+ Classification logits (including no-object) for all queries.
1519
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1520
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
1521
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
1522
+ possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
1523
+ unnormalized bounding boxes.
1524
+ auxiliary_outputs (`list[Dict]`, *optional*):
1525
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
1526
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
1527
+ `pred_boxes`) for each decoder layer.
1528
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1529
+ Initial reference points sent through the Transformer decoder.
1530
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
1531
+ Stacked intermediate hidden states (output of each layer of the decoder).
1532
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1533
+ Stacked intermediate reference points (reference points of each layer of the decoder).
1534
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1535
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
1536
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
1537
+ foreground and background).
1538
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1539
+ Logits of predicted bounding boxes coordinates in the first stage.
1540
+ """
1541
+
1542
+ loss: torch.FloatTensor | None = None
1543
+ loss_dict: dict | None = None
1544
+ logits: torch.FloatTensor | None = None
1545
+ pred_boxes: torch.FloatTensor | None = None
1546
+ auxiliary_outputs: list[dict] | None = None
1547
+ init_reference_points: torch.FloatTensor | None = None
1548
+ last_hidden_state: torch.FloatTensor | None = None
1549
+ intermediate_hidden_states: torch.FloatTensor | None = None
1550
+ intermediate_reference_points: torch.FloatTensor | None = None
1551
+ enc_outputs_class: Any = None
1552
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
1553
+
1554
+
1555
+ @auto_docstring(
1556
+ custom_intro="""
1557
+ LW DETR Model (consisting of a backbone and decoder Transformer) with object detection heads on
1558
+ top, for tasks such as COCO detection.
1559
+ """
1560
+ )
1561
+ class LwDetrForObjectDetection(LwDetrPreTrainedModel):
1562
+ # When using clones, all layers > 0 will be clones, but layer 0 *is* required
1563
+ # We can't initialize the model on meta device as some weights are modified during the initialization
1564
+ _no_split_modules = None
1565
+ _tied_weights_keys = None
1566
+
1567
+ def __init__(self, config: LwDetrConfig):
1568
+ super().__init__(config)
1569
+ self.model = LwDetrModel(config)
1570
+ self.class_embed = nn.Linear(config.d_model, config.num_labels)
1571
+ self.bbox_embed = LwDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
1572
+
1573
+ self.post_init()
1574
+
1575
+ @check_model_inputs
1576
+ @auto_docstring
1577
+ def forward(
1578
+ self,
1579
+ pixel_values: torch.FloatTensor = None,
1580
+ pixel_mask: torch.LongTensor | None = None,
1581
+ labels: list[dict] | None = None,
1582
+ **kwargs: Unpack[TransformersKwargs],
1583
+ ) -> LwDetrObjectDetectionOutput:
1584
+ r"""
1585
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1586
+ Not used by default. Can be used to mask object queries.
1587
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1588
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1589
+ can choose to directly pass a flattened representation of an image.
1590
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1591
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1592
+ embedded representation.
1593
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1594
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1595
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1596
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1597
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1598
+
1599
+ Examples:
1600
+
1601
+ ```python
1602
+ >>> from transformers import AutoImageProcessor, LwDetrForObjectDetection
1603
+ >>> from PIL import Image
1604
+ >>> import httpx
1605
+ >>> from io import BytesIO
1606
+
1607
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1608
+ >>> with httpx.stream("GET", url) as response:
1609
+ ... image = Image.open(BytesIO(response.read()))
1610
+
1611
+ >>> image_processor = AutoImageProcessor.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1612
+ >>> model = LwDetrForObjectDetection.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1613
+
1614
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1615
+ >>> outputs = model(**inputs)
1616
+
1617
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
1618
+ >>> target_sizes = torch.tensor([image.size[::-1]])
1619
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[
1620
+ ... 0
1621
+ ... ]
1622
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
1623
+ ... box = [round(i, 2) for i in box.tolist()]
1624
+ ... print(
1625
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
1626
+ ... f"{round(score.item(), 3)} at location {box}"
1627
+ ... )
1628
+ Detected cat with confidence 0.8 at location [16.5, 52.84, 318.25, 470.78]
1629
+ Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
1630
+ Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
1631
+ ```"""
1632
+ outputs = self.model(
1633
+ pixel_values,
1634
+ pixel_mask=pixel_mask,
1635
+ **kwargs,
1636
+ )
1637
+
1638
+ last_hidden_states = outputs.last_hidden_state
1639
+ intermediate_reference_points = outputs.intermediate_reference_points
1640
+ enc_outputs_class_logits = outputs.enc_outputs_class
1641
+ enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits
1642
+
1643
+ logits = self.class_embed(last_hidden_states)
1644
+ pred_boxes_delta = self.bbox_embed(last_hidden_states)
1645
+ pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta)
1646
+
1647
+ enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.config.num_queries, dim=1)
1648
+ pred_class = []
1649
+ group_detr = self.config.group_detr if self.training else 1
1650
+ for group_index in range(group_detr):
1651
+ group_pred_class = self.model.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index])
1652
+ pred_class.append(group_pred_class)
1653
+ enc_outputs_class_logits = torch.cat(pred_class, dim=1)
1654
+
1655
+ loss, loss_dict, auxiliary_outputs = None, None, None
1656
+ if labels is not None:
1657
+ outputs_class, outputs_coord = None, None
1658
+ if self.config.auxiliary_loss:
1659
+ intermediate_hidden_states = outputs.intermediate_hidden_states
1660
+ outputs_coord_delta = self.bbox_embed(intermediate_hidden_states)
1661
+ outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta)
1662
+ outputs_class = self.class_embed(intermediate_hidden_states)
1663
+
1664
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
1665
+ logits,
1666
+ labels,
1667
+ self.device,
1668
+ pred_boxes,
1669
+ self.config,
1670
+ outputs_class,
1671
+ outputs_coord,
1672
+ enc_outputs_class_logits,
1673
+ enc_outputs_boxes_logits,
1674
+ )
1675
+
1676
+ return LwDetrObjectDetectionOutput(
1677
+ loss=loss,
1678
+ loss_dict=loss_dict,
1679
+ logits=logits,
1680
+ pred_boxes=pred_boxes,
1681
+ auxiliary_outputs=auxiliary_outputs,
1682
+ last_hidden_state=outputs.last_hidden_state,
1683
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
1684
+ intermediate_reference_points=outputs.intermediate_reference_points,
1685
+ init_reference_points=outputs.init_reference_points,
1686
+ enc_outputs_class=enc_outputs_class_logits,
1687
+ enc_outputs_coord_logits=enc_outputs_boxes_logits,
1688
+ )
1689
+
1690
+
1691
+ __all__ = [
1692
+ "LwDetrPreTrainedModel",
1693
+ "LwDetrModel",
1694
+ "LwDetrForObjectDetection",
1695
+ "LwDetrViTPreTrainedModel",
1696
+ "LwDetrViTBackbone",
1697
+ ]