transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1609 @@
1
+ # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import math
15
+ from collections.abc import Callable
16
+ from dataclasses import dataclass
17
+ from typing import Any
18
+
19
+ import torch
20
+ from torch import nn
21
+
22
+ from ... import initialization as init
23
+ from ...activations import ACT2FN
24
+ from ...backbone_utils import consolidate_backbone_kwargs_to_config
25
+ from ...configuration_utils import PreTrainedConfig
26
+ from ...modeling_layers import GradientCheckpointingLayer
27
+ from ...modeling_outputs import BackboneOutput
28
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
29
+ from ...processing_utils import Unpack
30
+ from ...pytorch_utils import meshgrid
31
+ from ...utils import ModelOutput, TransformersKwargs, auto_docstring, logging
32
+ from ...utils.generic import check_model_inputs
33
+ from ..auto import AutoConfig
34
+ from ..convnext.modeling_convnext import ConvNextLayerNorm
35
+ from ..dab_detr.modeling_dab_detr import gen_sine_position_embeddings
36
+ from ..deformable_detr.modeling_deformable_detr import (
37
+ DeformableDetrDecoderOutput,
38
+ DeformableDetrForObjectDetection,
39
+ DeformableDetrMLPPredictionHead,
40
+ DeformableDetrModel,
41
+ DeformableDetrMultiscaleDeformableAttention,
42
+ )
43
+ from ..llama.modeling_llama import eager_attention_forward
44
+ from ..rt_detr.modeling_rt_detr import RTDetrConvNormLayer
45
+ from ..vit.modeling_vit import ViTAttention, ViTEncoder, ViTSelfAttention
46
+ from ..vitdet.configuration_vitdet import VitDetConfig
47
+ from ..vitdet.modeling_vitdet import (
48
+ VitDetBackbone,
49
+ VitDetEmbeddings,
50
+ VitDetMlp,
51
+ VitDetPreTrainedModel,
52
+ )
53
+
54
+
55
+ logger = logging.get_logger(__name__)
56
+
57
+
58
+ class LwDetrViTConfig(VitDetConfig):
59
+ r"""
60
+ This is the configuration class to store the configuration of a [`LwDetrViTModel`]. It is used to instantiate an
61
+ LW-DETR ViT model according to the specified arguments, defining the model architecture. Instantiating a configuration
62
+ with the defaults will yield a similar configuration to that of the LW-DETR ViT
63
+ [AnnaZhang/lwdetr_small_60e_coco](https://huggingface.co/AnnaZhang/lwdetr_small_60e_coco) architecture.
64
+
65
+ LW-DETR ViT is the Vision Transformer backbone used in the LW-DETR model for real-time object detection. It features
66
+ interleaved window and global attention mechanisms to reduce computational complexity while maintaining high performance.
67
+ The model uses a window-major feature map organization for efficient attention computation.
68
+
69
+ Configuration objects inherit from [`VitDetConfig`] and can be used to control the model outputs. Read the
70
+ documentation from [`VitDetConfig`] for more information.
71
+
72
+ Args:
73
+ hidden_size (`int`, *optional*, defaults to 768):
74
+ Dimensionality of the encoder layers and the pooler layer.
75
+ num_hidden_layers (`int`, *optional*, defaults to 12):
76
+ Number of hidden layers in the Transformer encoder.
77
+ num_attention_heads (`int`, *optional*, defaults to 12):
78
+ Number of attention heads for each attention layer in the Transformer encoder.
79
+ mlp_ratio (`int`, *optional*, defaults to 4):
80
+ Ratio of mlp hidden dim to embedding dim.
81
+ hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
82
+ The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
83
+ `"relu"`, `"selu"` and `"gelu_new"` are supported.
84
+ dropout_prob (`float`, *optional*, defaults to 0.0):
85
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
86
+ initializer_range (`float`, *optional*, defaults to 0.02):
87
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
88
+ layer_norm_eps (`float`, *optional*, defaults to 1e-06):
89
+ The epsilon used by the layer normalization layers.
90
+ image_size (`int`, *optional*, defaults to 256):
91
+ The size (resolution) of each image.
92
+ pretrain_image_size (`int`, *optional*, defaults to 224):
93
+ The size (resolution) of each image during pretraining.
94
+ patch_size (`int`, *optional*, defaults to 16):
95
+ The size (resolution) of each patch.
96
+ num_channels (`int`, *optional*, defaults to 3):
97
+ The number of input channels.
98
+ qkv_bias (`bool`, *optional*, defaults to `True`):
99
+ Whether to add a bias to the queries, keys and values.
100
+ window_block_indices (`list[int]`, *optional*, defaults to `[]`):
101
+ List of indices of blocks that should have window attention instead of regular global self-attention.
102
+ use_absolute_position_embeddings (`bool`, *optional*, defaults to `True`):
103
+ Whether to add absolute position embeddings to the patch embeddings.
104
+ out_features (`list[str]`, *optional*):
105
+ If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
106
+ (depending on how many stages the model has). If unset and `out_indices` is set, will default to the
107
+ corresponding stages. If unset and `out_indices` is unset, will default to the last stage. Must be in the
108
+ same order as defined in the `stage_names` attribute.
109
+ out_indices (`list[int]`, *optional*):
110
+ If used as backbone, list of indices of features to output. Can be any of 0, 1, 2, etc. (depending on how
111
+ many stages the model has). If unset and `out_features` is set, will default to the corresponding stages.
112
+ If unset and `out_features` is unset, will default to the last stage. Must be in the
113
+ same order as defined in the `stage_names` attribute.
114
+ cae_init_values (`float`, *optional*, defaults to 0.1):
115
+ Initialization value for CAE parameters when `use_cae` is enabled.
116
+ num_windows (`int`, *optional*, defaults to 16):
117
+ Number of windows for window-based attention. Must be a perfect square and the image size must be
118
+ divisible by the square root of this value. This enables efficient window-major feature map organization.
119
+
120
+ Example:
121
+
122
+ ```python
123
+ >>> from transformers import LwDetrViTConfig, LwDetrViTModel
124
+
125
+ >>> # Initializing a LW-DETR ViT configuration
126
+ >>> configuration = LwDetrViTConfig()
127
+
128
+ >>> # Initializing a model (with random weights) from the configuration
129
+ >>> model = LwDetrViTModel(configuration)
130
+
131
+ >>> # Accessing the model configuration
132
+ >>> configuration = model.config
133
+ ```"""
134
+
135
+ model_type = "lw_detr_vit"
136
+
137
+ def __init__(
138
+ self,
139
+ hidden_size=768,
140
+ num_hidden_layers=12,
141
+ num_attention_heads=12,
142
+ mlp_ratio=4,
143
+ hidden_act="gelu",
144
+ dropout_prob=0.0,
145
+ initializer_range=0.02,
146
+ layer_norm_eps=1e-6,
147
+ image_size=256,
148
+ pretrain_image_size=224,
149
+ patch_size=16,
150
+ num_channels=3,
151
+ qkv_bias=True,
152
+ window_block_indices=[],
153
+ use_absolute_position_embeddings=True,
154
+ out_features=None,
155
+ out_indices=None,
156
+ cae_init_values: float = 0.1,
157
+ num_windows=16,
158
+ **kwargs,
159
+ ):
160
+ super().__init__(
161
+ hidden_size=hidden_size,
162
+ num_hidden_layers=num_hidden_layers,
163
+ num_attention_heads=num_attention_heads,
164
+ mlp_ratio=mlp_ratio,
165
+ hidden_act=hidden_act,
166
+ dropout_prob=dropout_prob,
167
+ initializer_range=initializer_range,
168
+ layer_norm_eps=layer_norm_eps,
169
+ image_size=image_size,
170
+ pretrain_image_size=pretrain_image_size,
171
+ patch_size=patch_size,
172
+ num_channels=num_channels,
173
+ qkv_bias=qkv_bias,
174
+ window_block_indices=window_block_indices,
175
+ use_absolute_position_embeddings=use_absolute_position_embeddings,
176
+ out_features=out_features,
177
+ out_indices=out_indices,
178
+ **kwargs,
179
+ )
180
+ del self.residual_block_indices
181
+ del self.use_relative_position_embeddings
182
+ del self.window_size
183
+ del self.drop_path_rate
184
+
185
+ self.cae_init_values = cae_init_values
186
+ if num_windows % math.sqrt(num_windows) != 0:
187
+ raise ValueError(
188
+ f"`num_windows` has to be a perfect square, where num_windows % math.sqrt(num_windows) != 0, but got {num_windows}."
189
+ )
190
+ if image_size / num_windows % math.sqrt(num_windows) != 0:
191
+ raise ValueError(
192
+ f"`image_size` has to be divisible by `num_windows`, where image_size / num_windows % math.sqrt(num_windows) != 0,but got {image_size} and {num_windows}."
193
+ )
194
+ self.num_windows = num_windows
195
+ self.num_windows_side = int(math.sqrt(num_windows))
196
+
197
+
198
+ class LwDetrConfig(PreTrainedConfig):
199
+ r"""
200
+ This is the configuration class to store the configuration of a [`LwDetrModel`]. It is used to instantiate
201
+ a LW-DETR model according to the specified arguments, defining the model architecture. Instantiating a
202
+ configuration with the defaults will yield a similar configuration to that of the LW-DETR
203
+ [AnnaZhang/lwdetr_small_60e_coco](https://huggingface.co/AnnaZhang/lwdetr_small_60e_coco) architecture.
204
+
205
+ LW-DETR (Lightweight Detection Transformer) is a transformer-based object detection model designed for real-time
206
+ detection tasks. It replaces traditional CNN-based detectors like YOLO with a more efficient transformer architecture
207
+ that achieves competitive performance while being computationally lightweight.
208
+
209
+ Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
210
+ documentation from [`PretrainedConfig`] for more information.
211
+
212
+ Args:
213
+ backbone_config (`PretrainedConfig` or `dict`, *optional*):
214
+ The configuration of the backbone model. If not provided, will default to `LwDetrViTConfig` with
215
+ a small ViT architecture optimized for detection tasks.
216
+ projector_scale_factors (`list[float]`, *optional*, defaults to `[]`):
217
+ Scale factors for the feature pyramid network. Each scale factor determines the resolution of features
218
+ at different levels. Supported values are 0.5, 1.0, and 2.0.
219
+ hidden_expansion (`float`, *optional*, defaults to 0.5):
220
+ Expansion factor for hidden dimensions in the projector layers.
221
+ c2f_num_blocks (`int`, *optional*, defaults to 3):
222
+ Number of blocks in the C2F layer.
223
+ activation_function (`str`, *optional*, defaults to `"silu"`):
224
+ The non-linear activation function in the projector. Supported values are `"silu"`, `"relu"`, `"gelu"`.
225
+ batch_norm_eps (`float`, *optional*, defaults to 1e-05):
226
+ The epsilon value for batch normalization layers.
227
+ d_model (`int`, *optional*, defaults to 256):
228
+ Dimension of the model layers and the number of expected features in the decoder inputs.
229
+ dropout (`float`, *optional*, defaults to 0.1):
230
+ The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
231
+ decoder_ffn_dim (`int`, *optional*, defaults to 2048):
232
+ Dimension of the "intermediate" (often named feed-forward) layer in decoder.
233
+ decoder_n_points (`int`, *optional*, defaults to 4):
234
+ The number of sampled keys in each feature level for each attention head in the decoder.
235
+ decoder_layers (`int`, *optional*, defaults to 3):
236
+ Number of decoder layers in the transformer.
237
+ decoder_self_attention_heads (`int`, *optional*, defaults to 8):
238
+ Number of attention heads for each attention layer in the decoder self-attention.
239
+ decoder_cross_attention_heads (`int`, *optional*, defaults to 16):
240
+ Number of attention heads for each attention layer in the decoder cross-attention.
241
+ decoder_activation_function (`str`, *optional*, defaults to `"relu"`):
242
+ The non-linear activation function in the decoder. Supported values are `"relu"`, `"silu"`, `"gelu"`.
243
+ num_queries (`int`, *optional*, defaults to 300):
244
+ Number of object queries, i.e. detection slots. This is the maximal number of objects
245
+ [`LwDetrModel`] can detect in a single image.
246
+ attention_bias (`bool`, *optional*, defaults to `True`):
247
+ Whether to add bias to the attention layers.
248
+ attention_dropout (`float`, *optional*, defaults to 0.0):
249
+ The dropout ratio for the attention probabilities.
250
+ activation_dropout (`float`, *optional*, defaults to 0.0):
251
+ The dropout ratio for activations inside the fully connected layer.
252
+ group_detr (`int`, *optional*, defaults to 13):
253
+ Number of groups for Group DETR attention mechanism, which helps reduce computational complexity.
254
+ init_std (`float`, *optional*, defaults to 0.02):
255
+ The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
256
+ disable_custom_kernels (`bool`, *optional*, defaults to `True`):
257
+ Disable the use of custom CUDA and CPU kernels. This option is necessary for the ONNX export, as custom
258
+ kernels are not supported by PyTorch ONNX export.
259
+ class_cost (`float`, *optional*, defaults to 2):
260
+ Relative weight of the classification error in the Hungarian matching cost.
261
+ bbox_cost (`float`, *optional*, defaults to 5):
262
+ Relative weight of the L1 error of the bounding box coordinates in the Hungarian matching cost.
263
+ giou_cost (`float`, *optional*, defaults to 2):
264
+ Relative weight of the generalized IoU loss of the bounding box in the Hungarian matching cost.
265
+ mask_loss_coefficient (`float`, *optional*, defaults to 1):
266
+ Relative weight of the Focal loss in the panoptic segmentation loss.
267
+ dice_loss_coefficient (`float`, *optional*, defaults to 1):
268
+ Relative weight of the DICE/F-1 loss in the panoptic segmentation loss.
269
+ bbox_loss_coefficient (`float`, *optional*, defaults to 5):
270
+ Relative weight of the L1 bounding box loss in the object detection loss.
271
+ giou_loss_coefficient (`float`, *optional*, defaults to 2):
272
+ Relative weight of the generalized IoU loss in the object detection loss.
273
+ eos_coefficient (`float`, *optional*, defaults to 0.1):
274
+ Relative classification weight of the 'no-object' class in the object detection loss.
275
+ focal_alpha (`float`, *optional*, defaults to 0.25):
276
+ Alpha parameter in the focal loss.
277
+ auxiliary_loss (`bool`, *optional*, defaults to `True`):
278
+ Whether auxiliary decoding losses (loss at each decoder layer) are to be used.
279
+
280
+ Examples:
281
+
282
+ ```python
283
+ >>> from transformers import LwDetrConfig, LwDetrModel
284
+
285
+ >>> # Initializing a LW-DETR AnnaZhang/lwdetr_small_60e_coco style configuration
286
+ >>> configuration = LwDetrConfig()
287
+
288
+ >>> # Initializing a model (with random weights) from the AnnaZhang/lwdetr_small_60e_coco style configuration
289
+ >>> model = LwDetrModel(configuration)
290
+
291
+ >>> # Accessing the model configuration
292
+ >>> configuration = model.config
293
+ ```"""
294
+
295
+ model_type = "lw_detr"
296
+ sub_configs = {"backbone_config": AutoConfig}
297
+
298
+ def __init__(
299
+ self,
300
+ # backbone
301
+ backbone_config=None,
302
+ # projector
303
+ projector_scale_factors: list[float] = [],
304
+ hidden_expansion=0.5,
305
+ c2f_num_blocks=3,
306
+ activation_function="silu",
307
+ batch_norm_eps=1e-5,
308
+ # decoder
309
+ d_model=256,
310
+ dropout=0.1,
311
+ decoder_ffn_dim=2048,
312
+ decoder_n_points=4,
313
+ decoder_layers: int = 3,
314
+ decoder_self_attention_heads: int = 8,
315
+ decoder_cross_attention_heads: int = 16,
316
+ decoder_activation_function="relu",
317
+ # model
318
+ num_queries=300,
319
+ attention_bias=True,
320
+ attention_dropout=0.0,
321
+ activation_dropout=0.0,
322
+ group_detr: int = 13,
323
+ init_std=0.02,
324
+ disable_custom_kernels=True,
325
+ # loss
326
+ class_cost=2,
327
+ bbox_cost=5,
328
+ giou_cost=2,
329
+ mask_loss_coefficient=1,
330
+ dice_loss_coefficient=1,
331
+ bbox_loss_coefficient=5,
332
+ giou_loss_coefficient=2,
333
+ eos_coefficient=0.1,
334
+ focal_alpha=0.25,
335
+ auxiliary_loss=True,
336
+ **kwargs,
337
+ ):
338
+ self.batch_norm_eps = batch_norm_eps
339
+
340
+ backbone_config, kwargs = consolidate_backbone_kwargs_to_config(
341
+ backbone_config=backbone_config,
342
+ default_config_type="lw_detr_vit",
343
+ default_config_kwargs={
344
+ "image_size": 1024,
345
+ "hidden_size": 192,
346
+ "num_hidden_layers": 10,
347
+ "window_block_indices": [0, 1, 3, 6, 7, 9],
348
+ "out_indices": [2, 4, 5, 9],
349
+ },
350
+ **kwargs,
351
+ )
352
+
353
+ self.backbone_config = backbone_config
354
+ # projector
355
+ self.projector_scale_factors = projector_scale_factors
356
+ for scale in projector_scale_factors:
357
+ if scale not in [0.5, 1.0, 2.0]:
358
+ raise ValueError(f"Unsupported scale factor: {scale}")
359
+ self.projector_in_channels = [d_model] * len(projector_scale_factors)
360
+ self.projector_out_channels = d_model
361
+ self.activation_function = activation_function
362
+ self.hidden_expansion = hidden_expansion
363
+ self.c2f_num_blocks = c2f_num_blocks
364
+ # decoder
365
+ self.d_model = d_model
366
+ self.dropout = dropout
367
+ self.num_queries = num_queries
368
+ self.decoder_ffn_dim = decoder_ffn_dim
369
+ self.num_feature_levels = len(self.projector_scale_factors)
370
+ self.decoder_n_points = decoder_n_points
371
+ self.decoder_layers = decoder_layers
372
+ self.decoder_activation_function = decoder_activation_function
373
+ self.decoder_self_attention_heads = decoder_self_attention_heads
374
+ self.decoder_cross_attention_heads = decoder_cross_attention_heads
375
+ self.attention_bias = attention_bias
376
+ self.attention_dropout = attention_dropout
377
+ self.activation_dropout = activation_dropout
378
+ # model
379
+ self.init_std = init_std
380
+ self.group_detr = group_detr
381
+ # Loss
382
+ self.auxiliary_loss = auxiliary_loss
383
+ # Hungarian matcher
384
+ self.class_cost = class_cost
385
+ self.bbox_cost = bbox_cost
386
+ self.giou_cost = giou_cost
387
+ # Loss coefficients
388
+ self.dice_loss_coefficient = dice_loss_coefficient
389
+ self.bbox_loss_coefficient = bbox_loss_coefficient
390
+ self.giou_loss_coefficient = giou_loss_coefficient
391
+ self.eos_coefficient = eos_coefficient
392
+ self.focal_alpha = focal_alpha
393
+ self.disable_custom_kernels = disable_custom_kernels
394
+ super().__init__(**kwargs)
395
+
396
+
397
+ class LwDetrViTSelfAttention(ViTSelfAttention):
398
+ def __init__(self, config: LwDetrViTConfig):
399
+ super().__init__(config)
400
+ del self.key
401
+ self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
402
+ self.num_key_value_groups = 1
403
+ self.dropout_prob = config.dropout_prob
404
+
405
+ def forward(
406
+ self,
407
+ hidden_states: torch.Tensor,
408
+ **kwargs: Unpack[TransformersKwargs],
409
+ ) -> tuple[torch.Tensor, torch.Tensor]:
410
+ batch_size = hidden_states.shape[0]
411
+ new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
412
+
413
+ key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
414
+ value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
415
+ query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
416
+
417
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
418
+ self.config._attn_implementation, eager_attention_forward
419
+ )
420
+
421
+ context_layer, attention_probs = attention_interface(
422
+ self,
423
+ query_layer,
424
+ key_layer,
425
+ value_layer,
426
+ None,
427
+ is_causal=self.is_causal,
428
+ scaling=self.scaling,
429
+ dropout=0.0 if not self.training else self.dropout_prob,
430
+ **kwargs,
431
+ )
432
+
433
+ new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
434
+ context_layer = context_layer.reshape(new_context_layer_shape)
435
+
436
+ return context_layer, attention_probs
437
+
438
+
439
+ class LwDetrViTAttention(ViTAttention):
440
+ def __init__(self, config: LwDetrViTConfig):
441
+ """
442
+ Args:
443
+ config (`LwDetrViTConfig`):
444
+ Model configuration.
445
+ """
446
+ super().__init__(config)
447
+ self.attention = LwDetrViTSelfAttention(config)
448
+ self.output = nn.Linear(config.hidden_size, config.hidden_size)
449
+
450
+ def forward(
451
+ self,
452
+ hidden_states: torch.Tensor,
453
+ **kwargs: Unpack[TransformersKwargs],
454
+ ) -> torch.Tensor:
455
+ self_attn_output, _ = self.attention(hidden_states, **kwargs)
456
+ output = self.output(self_attn_output)
457
+ return output
458
+
459
+
460
+ class LwDetrViTMlp(VitDetMlp):
461
+ pass
462
+
463
+
464
+ class LwDetrViTLayer(GradientCheckpointingLayer):
465
+ def __init__(
466
+ self,
467
+ config: LwDetrViTConfig,
468
+ layer_idx,
469
+ ) -> None:
470
+ super().__init__()
471
+
472
+ dim = config.hidden_size
473
+ self.attention = LwDetrViTAttention(config)
474
+ self.intermediate = LwDetrViTMlp(config=config, in_features=dim, hidden_features=int(dim * config.mlp_ratio))
475
+ self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
476
+ self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
477
+
478
+ self.gamma_1 = nn.Parameter(torch.Tensor(dim), requires_grad=True)
479
+ self.gamma_2 = nn.Parameter(torch.Tensor(dim), requires_grad=True)
480
+
481
+ self.window = layer_idx in config.window_block_indices
482
+ self.num_windows = config.num_windows
483
+
484
+ def forward(
485
+ self,
486
+ hidden_states: torch.Tensor,
487
+ **kwargs: Unpack[TransformersKwargs],
488
+ ) -> torch.Tensor:
489
+ batch_size, seq_len, channels = hidden_states.shape
490
+ hidden_states_norm = self.layernorm_before(hidden_states)
491
+
492
+ if not self.window:
493
+ hidden_states_norm = hidden_states_norm.reshape(
494
+ batch_size // self.num_windows, self.num_windows * seq_len, channels
495
+ )
496
+
497
+ attention_output = self.attention(hidden_states_norm, **kwargs)
498
+ attention_output = attention_output * self.gamma_1
499
+
500
+ if not self.window:
501
+ attention_output = attention_output.reshape(batch_size, seq_len, channels)
502
+
503
+ hidden_states = hidden_states + attention_output
504
+
505
+ layer_output = self.layernorm_after(hidden_states)
506
+ layer_output = self.intermediate(layer_output)
507
+ layer_output = layer_output * self.gamma_2
508
+
509
+ hidden_states = hidden_states + layer_output
510
+
511
+ return hidden_states
512
+
513
+
514
+ class LwDetrViTEncoder(ViTEncoder):
515
+ def __init__(self, config: LwDetrViTConfig) -> None:
516
+ super().__init__(config)
517
+ self.layer = nn.ModuleList([LwDetrViTLayer(config, i) for i in range(config.num_hidden_layers)])
518
+
519
+ def forward(
520
+ self,
521
+ hidden_states: torch.Tensor,
522
+ **kwargs: Unpack[TransformersKwargs],
523
+ ) -> list[torch.Tensor]:
524
+ list_hidden_states = [hidden_states]
525
+ for i, layer_module in enumerate(self.layer):
526
+ hidden_states = layer_module(hidden_states, **kwargs)
527
+ list_hidden_states.append(hidden_states)
528
+ return list_hidden_states
529
+
530
+
531
+ class LwDetrViTEmbeddings(VitDetEmbeddings):
532
+ pass
533
+
534
+
535
+ class LwDetrViTPreTrainedModel(VitDetPreTrainedModel):
536
+ config: LwDetrViTConfig
537
+ base_model_prefix = "lw_detr_vit"
538
+ main_input_name = "pixel_values"
539
+ supports_gradient_checkpointing = True
540
+ _no_split_modules = ["LwDetrViTEmbeddings", "LwDetrViTLayer"]
541
+ _supports_sdpa = True
542
+ _supports_flash_attn = True
543
+ _supports_flex_attn = True
544
+ _supports_attention_backend = True
545
+ _can_record_outputs = {
546
+ "hidden_states": LwDetrViTLayer,
547
+ "attentions": LwDetrViTSelfAttention,
548
+ }
549
+
550
+ def _init_weights(self, module) -> None:
551
+ """Initialize the weights"""
552
+ if isinstance(module, (nn.Linear, nn.Conv2d)):
553
+ init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
554
+ if module.bias is not None:
555
+ init.zeros_(module.bias)
556
+ elif isinstance(module, nn.LayerNorm):
557
+ init.zeros_(module.bias)
558
+ init.ones_(module.weight)
559
+ elif isinstance(module, LwDetrViTEmbeddings):
560
+ init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
561
+ if isinstance(module, LwDetrViTLayer):
562
+ nn.init.constant_(module.gamma_1, self.config.cae_init_values)
563
+ nn.init.constant_(module.gamma_2, self.config.cae_init_values)
564
+
565
+
566
+ @auto_docstring()
567
+ class LwDetrViTBackbone(VitDetBackbone):
568
+ @check_model_inputs
569
+ @auto_docstring
570
+ def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BackboneOutput:
571
+ r"""
572
+ Examples:
573
+
574
+ ```python
575
+ >>> from transformers import LwDetrViTConfig, LwDetrViTBackbone
576
+ >>> import torch
577
+
578
+ >>> config = LwDetrViTConfig()
579
+ >>> model = LwDetrViTBackbone(config)
580
+
581
+ >>> pixel_values = torch.randn(1, 3, 224, 224)
582
+
583
+ >>> with torch.no_grad():
584
+ ... outputs = model(pixel_values)
585
+
586
+ >>> feature_maps = outputs.feature_maps
587
+ >>> list(feature_maps[-1].shape)
588
+ [1, 768, 14, 14]
589
+ ```"""
590
+ embedding_output = self.embeddings(pixel_values)
591
+
592
+ batch_size, channels, height, width = embedding_output.shape
593
+ # (batch_size, channels, height, width) -> (batch_size, height, width, channels)
594
+ hidden_states = embedding_output.permute(0, 2, 3, 1)
595
+
596
+ window_height = height // self.config.num_windows_side
597
+ window_width = width // self.config.num_windows_side
598
+ # (batch_size, height, width, channels) -> (batch_size*num_windows_side**2, window_height*window_width, channels)
599
+ hidden_states = (
600
+ hidden_states.reshape(
601
+ batch_size,
602
+ self.config.num_windows_side,
603
+ window_height,
604
+ self.config.num_windows_side,
605
+ window_width,
606
+ channels,
607
+ )
608
+ .permute(0, 1, 3, 2, 4, 5)
609
+ .reshape(batch_size * self.config.num_windows_side**2, window_height * window_width, channels)
610
+ )
611
+
612
+ hidden_states = self.encoder(hidden_states, **kwargs)
613
+
614
+ feature_maps = ()
615
+ for stage, hidden_state in zip(self.stage_names, hidden_states):
616
+ if stage in self.out_features:
617
+ hidden_state = (
618
+ hidden_state.reshape(
619
+ batch_size,
620
+ self.config.num_windows_side,
621
+ self.config.num_windows_side,
622
+ window_height,
623
+ window_width,
624
+ channels,
625
+ )
626
+ .permute(0, 5, 1, 3, 2, 4)
627
+ .reshape(batch_size, channels, height, width)
628
+ )
629
+ feature_maps += (hidden_state,)
630
+
631
+ return BackboneOutput(feature_maps=feature_maps)
632
+
633
+
634
+ class LwDetrConvNormLayer(RTDetrConvNormLayer):
635
+ def __init__(
636
+ self,
637
+ config: LwDetrConfig,
638
+ in_channels: int,
639
+ out_channels: int,
640
+ kernel_size: int,
641
+ stride: int,
642
+ activation: str | None = None,
643
+ ):
644
+ super().__init__(config, in_channels, out_channels, kernel_size, stride, activation)
645
+ self.conv = nn.Conv2d(
646
+ in_channels,
647
+ out_channels,
648
+ kernel_size,
649
+ stride,
650
+ padding=kernel_size // 2,
651
+ bias=False,
652
+ )
653
+
654
+
655
+ class LwDetrRepVggBlock(nn.Module):
656
+ def __init__(self, config: LwDetrConfig):
657
+ super().__init__()
658
+ hidden_channels = int(config.d_model * config.hidden_expansion)
659
+ self.conv1 = LwDetrConvNormLayer(
660
+ config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
661
+ )
662
+ self.conv2 = LwDetrConvNormLayer(
663
+ config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
664
+ )
665
+
666
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
667
+ y = self.conv1(x)
668
+ y = self.conv2(y)
669
+ return y
670
+
671
+
672
+ class LwDetrC2FLayer(nn.Module):
673
+ # Inspired by RTDetrCSPRepLayer
674
+ def __init__(self, config: LwDetrConfig, in_channels: int):
675
+ super().__init__()
676
+ num_blocks = config.c2f_num_blocks
677
+ activation = config.activation_function
678
+ out_channels = config.d_model
679
+
680
+ self.hidden_channels = int(out_channels * config.hidden_expansion)
681
+
682
+ conv1_out_channels = 2 * self.hidden_channels
683
+ self.conv1 = LwDetrConvNormLayer(config, in_channels, conv1_out_channels, 1, 1, activation=activation)
684
+
685
+ conv2_in_channels = (2 + num_blocks) * self.hidden_channels
686
+ self.conv2 = LwDetrConvNormLayer(config, conv2_in_channels, out_channels, 1, 1, activation=activation)
687
+
688
+ self.bottlenecks = nn.ModuleList(LwDetrRepVggBlock(config) for _ in range(num_blocks))
689
+
690
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
691
+ hidden_states = self.conv1(hidden_states)
692
+ all_hidden_states = list(hidden_states.split(self.hidden_channels, 1))
693
+ hidden_states = all_hidden_states[-1]
694
+
695
+ for bottleneck in self.bottlenecks:
696
+ hidden_states = bottleneck(hidden_states)
697
+ all_hidden_states.append(hidden_states)
698
+
699
+ hidden_states = torch.cat(all_hidden_states, 1)
700
+ hidden_states = self.conv2(hidden_states)
701
+ return hidden_states
702
+
703
+
704
+ class LwDetrLayerNorm(ConvNextLayerNorm):
705
+ pass
706
+
707
+
708
+ class LwDetrSamplingLayer(nn.Module):
709
+ def __init__(self, config: LwDetrConfig, channel_size: int, scale: float):
710
+ super().__init__()
711
+
712
+ self.scale = scale
713
+ self.channel_size = channel_size
714
+
715
+ layers = []
716
+ if scale == 2.0:
717
+ if channel_size > 512:
718
+ layers.append(LwDetrConvNormLayer(config, channel_size, channel_size // 2, 1, 1, activation="relu"))
719
+ layers.append(nn.ConvTranspose2d(channel_size // 2, channel_size // 4, kernel_size=2, stride=2))
720
+ else:
721
+ layers.append(nn.ConvTranspose2d(channel_size, channel_size // 2, 2, 2))
722
+ elif scale == 0.5:
723
+ layers.append(LwDetrConvNormLayer(config, channel_size, channel_size, 3, 2, activation="relu"))
724
+ self.layers = nn.ModuleList(layers)
725
+
726
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
727
+ for layer in self.layers:
728
+ hidden_states = layer(hidden_states)
729
+ return hidden_states
730
+
731
+
732
+ class LwDetrScaleProjector(nn.Module):
733
+ def __init__(self, config: LwDetrConfig, scale: float):
734
+ super().__init__()
735
+
736
+ intermediate_dims = [config.backbone_config.hidden_size] * len(config.backbone_config.out_indices)
737
+ sampling_layers = []
738
+ for channel_size in intermediate_dims:
739
+ sampling_layers.append(LwDetrSamplingLayer(config, channel_size, scale))
740
+ self.sampling_layers = nn.ModuleList(sampling_layers)
741
+
742
+ intermediate_dim = intermediate_dims[-1]
743
+ if scale == 2.0:
744
+ if intermediate_dim > 512:
745
+ intermediate_dim = intermediate_dim // 4
746
+ else:
747
+ intermediate_dim = intermediate_dim // 2
748
+ projector_input_dim = intermediate_dim * len(intermediate_dims)
749
+
750
+ self.projector_layer = LwDetrC2FLayer(config, projector_input_dim)
751
+ self.layer_norm = LwDetrLayerNorm(config.d_model, data_format="channels_first")
752
+
753
+ def forward(self, hidden_states_tuple: tuple[torch.Tensor]) -> torch.Tensor:
754
+ sampled_hidden_states = []
755
+ for sampling_layer, hidden_states in zip(self.sampling_layers, hidden_states_tuple):
756
+ hidden_states = sampling_layer(hidden_states)
757
+ sampled_hidden_states.append(hidden_states)
758
+ hidden_states = torch.cat(sampled_hidden_states, dim=1)
759
+ hidden_states = self.projector_layer(hidden_states)
760
+ hidden_states = self.layer_norm(hidden_states)
761
+ return hidden_states
762
+
763
+
764
+ class LwDetrMultiScaleProjector(nn.Module):
765
+ def __init__(self, config: LwDetrConfig):
766
+ super().__init__()
767
+
768
+ self.config = config
769
+ scale_factors = config.projector_scale_factors
770
+
771
+ self.scale_layers = nn.ModuleList([LwDetrScaleProjector(config, scale) for scale in scale_factors])
772
+
773
+ def forward(self, hidden_states: tuple[torch.Tensor]) -> list[torch.Tensor]:
774
+ output_hidden_states = []
775
+ for scale_layer in self.scale_layers:
776
+ output_hidden_states.append(scale_layer(hidden_states))
777
+ return output_hidden_states
778
+
779
+
780
+ class LwDetrConvEncoder(nn.Module):
781
+ def __init__(self, config: LwDetrConfig):
782
+ super().__init__()
783
+ self.backbone = LwDetrViTBackbone(config.backbone_config)
784
+ self.projector = LwDetrMultiScaleProjector(config)
785
+
786
+ def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
787
+ # send pixel_values through the model to get list of feature maps
788
+ features = self.backbone(pixel_values).feature_maps
789
+ features = self.projector(features)
790
+ out = []
791
+ for feature_map in features:
792
+ # downsample pixel_mask to match shape of corresponding feature_map
793
+ mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
794
+ out.append((feature_map, mask))
795
+ return out
796
+
797
+
798
+ class LwDetrAttention(nn.Module):
799
+ def __init__(self, config: LwDetrConfig, layer_idx: int):
800
+ super().__init__()
801
+ self.config = config
802
+ self.layer_idx = layer_idx
803
+ self.head_dim = getattr(config, "head_dim", config.d_model // config.decoder_self_attention_heads)
804
+ self.scaling = self.head_dim**-0.5
805
+ self.attention_dropout = config.attention_dropout
806
+ self.is_causal = False
807
+ self.num_key_value_groups = 1
808
+
809
+ self.q_proj = nn.Linear(
810
+ config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
811
+ )
812
+ self.k_proj = nn.Linear(
813
+ config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
814
+ )
815
+ self.v_proj = nn.Linear(
816
+ config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
817
+ )
818
+ self.o_proj = nn.Linear(
819
+ config.decoder_self_attention_heads * self.head_dim, config.d_model, bias=config.attention_bias
820
+ )
821
+
822
+ def forward(
823
+ self,
824
+ hidden_states: torch.Tensor,
825
+ position_embeddings: torch.Tensor | None = None,
826
+ **kwargs: Unpack[TransformersKwargs],
827
+ ) -> tuple[torch.Tensor, torch.Tensor]:
828
+ batch_size, seq_len, _ = hidden_states.shape
829
+ input_shape = hidden_states.shape[:-1]
830
+ hidden_shape = (*input_shape, -1, self.head_dim)
831
+
832
+ hidden_states_original = hidden_states
833
+ if position_embeddings is not None:
834
+ hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
835
+
836
+ if self.training:
837
+ # at training, we use group detr technique to add more supervision by using multiple weight-sharing decoders at once for faster convergence
838
+ # at inference, we only use one decoder
839
+ hidden_states_original = torch.cat(
840
+ hidden_states_original.split(seq_len // self.config.group_detr, dim=1), dim=0
841
+ )
842
+ hidden_states = torch.cat(hidden_states.split(seq_len // self.config.group_detr, dim=1), dim=0)
843
+
844
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
845
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
846
+ value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2)
847
+
848
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
849
+ self.config._attn_implementation, eager_attention_forward
850
+ )
851
+
852
+ attn_output, attn_weights = attention_interface(
853
+ self,
854
+ query_states,
855
+ key_states,
856
+ value_states,
857
+ attention_mask=None,
858
+ dropout=0.0 if not self.training else self.attention_dropout,
859
+ scaling=self.scaling,
860
+ **kwargs,
861
+ )
862
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
863
+ attn_output = self.o_proj(attn_output)
864
+
865
+ if self.training:
866
+ attn_output = torch.cat(torch.split(attn_output, batch_size, dim=0), dim=1)
867
+
868
+ return attn_output, attn_weights
869
+
870
+
871
+ class LwDetrMultiscaleDeformableAttention(DeformableDetrMultiscaleDeformableAttention):
872
+ def forward(
873
+ self,
874
+ hidden_states: torch.Tensor,
875
+ attention_mask: torch.Tensor | None = None,
876
+ encoder_hidden_states=None,
877
+ encoder_attention_mask=None,
878
+ position_embeddings: torch.Tensor | None = None,
879
+ reference_points=None,
880
+ spatial_shapes=None,
881
+ spatial_shapes_list=None,
882
+ level_start_index=None,
883
+ **kwargs: Unpack[TransformersKwargs],
884
+ ):
885
+ return super().forward(
886
+ hidden_states=hidden_states,
887
+ attention_mask=attention_mask,
888
+ encoder_hidden_states=encoder_hidden_states,
889
+ encoder_attention_mask=encoder_attention_mask,
890
+ position_embeddings=position_embeddings,
891
+ reference_points=reference_points,
892
+ spatial_shapes=spatial_shapes,
893
+ spatial_shapes_list=spatial_shapes_list,
894
+ level_start_index=level_start_index,
895
+ **kwargs,
896
+ )
897
+
898
+
899
+ class LwDetrMLP(nn.Module):
900
+ def __init__(self, config: LwDetrConfig):
901
+ super().__init__()
902
+ self.dropout = config.dropout
903
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
904
+ self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
905
+ self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
906
+
907
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
908
+ residual = hidden_states
909
+ hidden_states = self.fc1(hidden_states)
910
+ hidden_states = self.activation_fn(hidden_states)
911
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
912
+ hidden_states = self.fc2(hidden_states)
913
+ hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
914
+ hidden_states = residual + hidden_states
915
+ return hidden_states
916
+
917
+
918
+ class LwDetrDecoderLayer(GradientCheckpointingLayer):
919
+ def __init__(self, config: LwDetrConfig, layer_idx: int):
920
+ nn.Module.__init__(self)
921
+
922
+ # self-attention
923
+ self.self_attn = LwDetrAttention(config, layer_idx=layer_idx)
924
+ self.dropout = config.dropout
925
+ self.activation_fn = ACT2FN[config.decoder_activation_function]
926
+ self.activation_dropout = config.activation_dropout
927
+ self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
928
+
929
+ # cross-attention
930
+ self.cross_attn = LwDetrMultiscaleDeformableAttention(
931
+ config,
932
+ num_heads=config.decoder_cross_attention_heads,
933
+ n_points=config.decoder_n_points,
934
+ )
935
+ self.cross_attn_layer_norm = nn.LayerNorm(config.d_model)
936
+
937
+ # mlp
938
+ self.mlp = LwDetrMLP(config)
939
+ self.layer_norm = nn.LayerNorm(config.d_model)
940
+
941
+ def forward(
942
+ self,
943
+ hidden_states: torch.Tensor,
944
+ position_embeddings: torch.Tensor | None = None,
945
+ reference_points=None,
946
+ spatial_shapes=None,
947
+ spatial_shapes_list=None,
948
+ level_start_index=None,
949
+ encoder_hidden_states: torch.Tensor | None = None,
950
+ encoder_attention_mask: torch.Tensor | None = None,
951
+ **kwargs: Unpack[TransformersKwargs],
952
+ ):
953
+ self_attention_output, self_attn_weights = self.self_attn(
954
+ hidden_states, position_embeddings=position_embeddings, **kwargs
955
+ )
956
+
957
+ self_attention_output = nn.functional.dropout(self_attention_output, p=self.dropout, training=self.training)
958
+ hidden_states = hidden_states + self_attention_output
959
+ hidden_states = self.self_attn_layer_norm(hidden_states)
960
+
961
+ cross_attention_output, cross_attn_weights = self.cross_attn(
962
+ hidden_states=hidden_states,
963
+ attention_mask=encoder_attention_mask,
964
+ encoder_hidden_states=encoder_hidden_states,
965
+ encoder_attention_mask=encoder_attention_mask,
966
+ position_embeddings=position_embeddings,
967
+ reference_points=reference_points,
968
+ spatial_shapes=spatial_shapes,
969
+ spatial_shapes_list=spatial_shapes_list,
970
+ level_start_index=level_start_index,
971
+ **kwargs,
972
+ )
973
+ cross_attention_output = nn.functional.dropout(cross_attention_output, p=self.dropout, training=self.training)
974
+ hidden_states = hidden_states + cross_attention_output
975
+ hidden_states = self.cross_attn_layer_norm(hidden_states)
976
+
977
+ hidden_states = self.mlp(hidden_states)
978
+ hidden_states = self.layer_norm(hidden_states)
979
+
980
+ return hidden_states
981
+
982
+
983
+ @auto_docstring
984
+ class LwDetrPreTrainedModel(PreTrainedModel):
985
+ config: LwDetrConfig
986
+ base_model_prefix = "model"
987
+ main_input_name = "pixel_values"
988
+ _no_split_modules = [
989
+ r"LwDetrConvEncoder",
990
+ r"LwDetrDecoderLayer",
991
+ ]
992
+ _supports_sdpa = True
993
+ _supports_flash_attn = True
994
+ _supports_flex_attn = True
995
+ _supports_attention_backend = True
996
+ _can_record_outputs = {
997
+ "attentions": [LwDetrAttention, LwDetrMultiscaleDeformableAttention],
998
+ "hidden_states": [LwDetrDecoderLayer],
999
+ }
1000
+
1001
+ @torch.no_grad()
1002
+ def _init_weights(self, module):
1003
+ super()._init_weights(module)
1004
+
1005
+ if isinstance(module, LwDetrMultiscaleDeformableAttention):
1006
+ init.constant_(module.sampling_offsets.weight, 0.0)
1007
+ thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads)
1008
+ grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
1009
+ grid_init = (
1010
+ (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
1011
+ .view(module.n_heads, 1, 1, 2)
1012
+ .repeat(1, module.n_levels, module.n_points, 1)
1013
+ )
1014
+ for i in range(module.n_points):
1015
+ grid_init[:, :, i, :] *= i + 1
1016
+
1017
+ init.copy_(module.sampling_offsets.bias, grid_init.view(-1))
1018
+ init.constant_(module.attention_weights.weight, 0.0)
1019
+ init.constant_(module.attention_weights.bias, 0.0)
1020
+ init.xavier_uniform_(module.value_proj.weight)
1021
+ init.constant_(module.value_proj.bias, 0.0)
1022
+ init.xavier_uniform_(module.output_proj.weight)
1023
+ init.constant_(module.output_proj.bias, 0.0)
1024
+ if hasattr(module, "level_embed"):
1025
+ init.normal_(module.level_embed)
1026
+ if hasattr(module, "refpoint_embed") and module.refpoint_embed is not None:
1027
+ init.constant_(module.refpoint_embed.weight, 0)
1028
+ if hasattr(module, "class_embed") and module.class_embed is not None:
1029
+ prior_prob = 0.01
1030
+ bias_value = -math.log((1 - prior_prob) / prior_prob)
1031
+ init.constant_(module.class_embed.bias, bias_value)
1032
+ if hasattr(module, "bbox_embed") and module.bbox_embed is not None:
1033
+ init.constant_(module.bbox_embed.layers[-1].weight, 0)
1034
+ init.constant_(module.bbox_embed.layers[-1].bias, 0)
1035
+
1036
+
1037
+ def refine_bboxes(reference_points, deltas):
1038
+ reference_points = reference_points.to(deltas.device)
1039
+ new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2]
1040
+ new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:]
1041
+ new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1)
1042
+ return new_reference_points
1043
+
1044
+
1045
+ @dataclass
1046
+ @auto_docstring(
1047
+ custom_intro="""
1048
+ Base class for outputs of the LwDetrDecoder. This class adds two attributes to
1049
+ BaseModelOutputWithCrossAttentions, namely:
1050
+ - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
1051
+ - a stacked tensor of intermediate reference points.
1052
+ """
1053
+ )
1054
+ class LwDetrDecoderOutput(DeformableDetrDecoderOutput):
1055
+ pass
1056
+
1057
+
1058
+ class LwDetrDecoder(LwDetrPreTrainedModel):
1059
+ """
1060
+ Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeformableDetrDecoderLayer`].
1061
+
1062
+ The decoder updates the query embeddings through multiple self-attention and deformable cross-attention layers.
1063
+
1064
+ Some tweaks for LwDetr:
1065
+
1066
+ - it uses group detr technique at training for faster convergence.
1067
+
1068
+ Args:
1069
+ config: LwDetrConfig
1070
+ """
1071
+
1072
+ def __init__(self, config: LwDetrConfig):
1073
+ super().__init__(config)
1074
+ self.dropout = config.dropout
1075
+ self.layers = nn.ModuleList([LwDetrDecoderLayer(config, i) for i in range(config.decoder_layers)])
1076
+ self.layernorm = nn.LayerNorm(config.d_model)
1077
+
1078
+ self.gradient_checkpointing = False
1079
+
1080
+ self.ref_point_head = LwDetrMLPPredictionHead(2 * config.d_model, config.d_model, config.d_model, num_layers=2)
1081
+
1082
+ self.post_init()
1083
+
1084
+ def get_reference(self, reference_points, valid_ratios):
1085
+ # batch_size, num_queries, batch_size, 4
1086
+ obj_center = reference_points[..., :4]
1087
+
1088
+ # batch_size, num_queries, num_levels, 4
1089
+ reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
1090
+
1091
+ # batch_size, num_queries, d_model * 2
1092
+ query_sine_embed = gen_sine_position_embeddings(reference_points_inputs[:, :, 0, :], self.config.d_model)
1093
+
1094
+ # batch_size, num_queries, d_model
1095
+ query_pos = self.ref_point_head(query_sine_embed)
1096
+ return reference_points_inputs, query_pos
1097
+
1098
+ def forward(
1099
+ self,
1100
+ inputs_embeds: torch.Tensor | None = None,
1101
+ reference_points: torch.Tensor | None = None,
1102
+ spatial_shapes: torch.Tensor | None = None,
1103
+ spatial_shapes_list: torch.Tensor | None = None,
1104
+ level_start_index: torch.Tensor | None = None,
1105
+ valid_ratios: torch.Tensor | None = None,
1106
+ encoder_hidden_states: torch.Tensor | None = None,
1107
+ encoder_attention_mask: torch.Tensor | None = None,
1108
+ **kwargs: Unpack[TransformersKwargs],
1109
+ ):
1110
+ intermediate = ()
1111
+ intermediate_reference_points = (reference_points,)
1112
+
1113
+ if inputs_embeds is not None:
1114
+ hidden_states = inputs_embeds
1115
+
1116
+ reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios)
1117
+
1118
+ for idx, decoder_layer in enumerate(self.layers):
1119
+ hidden_states = decoder_layer(
1120
+ hidden_states,
1121
+ encoder_hidden_states=encoder_hidden_states,
1122
+ encoder_attention_mask=encoder_attention_mask,
1123
+ position_embeddings=query_pos,
1124
+ reference_points=reference_points_inputs,
1125
+ spatial_shapes=spatial_shapes,
1126
+ spatial_shapes_list=spatial_shapes_list,
1127
+ level_start_index=level_start_index,
1128
+ **kwargs,
1129
+ )
1130
+ intermediate_hidden_states = self.layernorm(hidden_states)
1131
+ intermediate += (intermediate_hidden_states,)
1132
+
1133
+ intermediate = torch.stack(intermediate)
1134
+ last_hidden_state = intermediate[-1]
1135
+ intermediate_reference_points = torch.stack(intermediate_reference_points)
1136
+
1137
+ return LwDetrDecoderOutput(
1138
+ last_hidden_state=last_hidden_state,
1139
+ intermediate_hidden_states=intermediate,
1140
+ intermediate_reference_points=intermediate_reference_points,
1141
+ )
1142
+
1143
+
1144
+ @dataclass
1145
+ @auto_docstring(
1146
+ custom_intro="""
1147
+ Base class for outputs of the LwDetr backbone-decoder model.
1148
+ """
1149
+ )
1150
+ class LwDetrModelOutput(ModelOutput):
1151
+ r"""
1152
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1153
+ Initial reference points sent through the Transformer decoder.
1154
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
1155
+ Stacked intermediate hidden states (output of each layer of the decoder).
1156
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1157
+ Stacked intermediate reference points (reference points of each layer of the decoder).
1158
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1159
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
1160
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
1161
+ foreground and background).
1162
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1163
+ Logits of predicted bounding boxes coordinates in the first stage.
1164
+ """
1165
+
1166
+ init_reference_points: torch.FloatTensor | None = None
1167
+ last_hidden_state: torch.FloatTensor | None = None
1168
+ intermediate_hidden_states: torch.FloatTensor | None = None
1169
+ intermediate_reference_points: torch.FloatTensor | None = None
1170
+ enc_outputs_class: torch.FloatTensor | None = None
1171
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
1172
+
1173
+
1174
+ @auto_docstring(
1175
+ custom_intro="""
1176
+ The bare LW Detr Model (consisting of a backbone and decoder Transformer) outputting raw
1177
+ hidden-states without any specific head on top.
1178
+ """
1179
+ )
1180
+ class LwDetrModel(DeformableDetrModel):
1181
+ def __init__(self, config: LwDetrConfig):
1182
+ LwDetrPreTrainedModel.__init__(config)
1183
+
1184
+ # Create backbone + positional encoding
1185
+ self.backbone = LwDetrConvEncoder(config)
1186
+
1187
+ self.group_detr = config.group_detr
1188
+ self.num_queries = config.num_queries
1189
+ hidden_dim = config.d_model
1190
+ self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4)
1191
+ self.query_feat = nn.Embedding(self.num_queries * self.group_detr, hidden_dim)
1192
+
1193
+ self.decoder = LwDetrDecoder(config)
1194
+
1195
+ self.enc_output = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(self.group_detr)])
1196
+ self.enc_output_norm = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(self.group_detr)])
1197
+ # Should normally be None and then instantiated in the ForObjectDetection class
1198
+ self.enc_out_bbox_embed = nn.ModuleList(
1199
+ [LwDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(self.group_detr)]
1200
+ )
1201
+ self.enc_out_class_embed = nn.ModuleList(
1202
+ [nn.Linear(config.d_model, config.num_labels) for _ in range(self.group_detr)]
1203
+ )
1204
+
1205
+ self.post_init()
1206
+
1207
+ def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
1208
+ """Generate the encoder output proposals from encoded enc_output.
1209
+
1210
+ Args:
1211
+ enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
1212
+ padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
1213
+ spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps.
1214
+
1215
+ Returns:
1216
+ `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
1217
+ - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to
1218
+ directly predict a bounding box. (without the need of a decoder)
1219
+ - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse
1220
+ sigmoid.
1221
+ """
1222
+ batch_size = enc_output.shape[0]
1223
+ proposals = []
1224
+ _cur = 0
1225
+ for level, (height, width) in enumerate(spatial_shapes):
1226
+ mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
1227
+ valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
1228
+ valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
1229
+
1230
+ grid_y, grid_x = meshgrid(
1231
+ torch.linspace(
1232
+ 0,
1233
+ height - 1,
1234
+ height,
1235
+ dtype=enc_output.dtype,
1236
+ device=enc_output.device,
1237
+ ),
1238
+ torch.linspace(
1239
+ 0,
1240
+ width - 1,
1241
+ width,
1242
+ dtype=enc_output.dtype,
1243
+ device=enc_output.device,
1244
+ ),
1245
+ indexing="ij",
1246
+ )
1247
+ grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
1248
+
1249
+ scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
1250
+ grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
1251
+ width_height = torch.ones_like(grid) * 0.05 * (2.0**level)
1252
+ proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4)
1253
+ proposals.append(proposal)
1254
+ _cur += height * width
1255
+ output_proposals = torch.cat(proposals, 1)
1256
+ output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
1257
+ output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf"))
1258
+ output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
1259
+
1260
+ # assign each pixel as an object query
1261
+ object_query = enc_output
1262
+ object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0))
1263
+ object_query = object_query.masked_fill(~output_proposals_valid, float(0))
1264
+ return object_query, output_proposals
1265
+
1266
+ @check_model_inputs
1267
+ @auto_docstring
1268
+ def forward(
1269
+ self,
1270
+ pixel_values: torch.FloatTensor = None,
1271
+ pixel_mask: torch.LongTensor | None = None,
1272
+ **kwargs: Unpack[TransformersKwargs],
1273
+ ) -> LwDetrModelOutput:
1274
+ r"""
1275
+ Examples:
1276
+
1277
+ ```python
1278
+ >>> from transformers import AutoImageProcessor, DeformableDetrModel
1279
+ >>> from PIL import Image
1280
+ >>> import httpx
1281
+ >>> from io import BytesIO
1282
+
1283
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1284
+ >>> with httpx.stream("GET", url) as response:
1285
+ ... image = Image.open(BytesIO(response.read()))
1286
+
1287
+ >>> image_processor = AutoImageProcessor.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1288
+ >>> model = DeformableDetrModel.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1289
+
1290
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1291
+
1292
+ >>> outputs = model(**inputs)
1293
+
1294
+ >>> last_hidden_states = outputs.last_hidden_state
1295
+ >>> list(last_hidden_states.shape)
1296
+ [1, 300, 256]
1297
+ ```"""
1298
+ batch_size, num_channels, height, width = pixel_values.shape
1299
+ device = pixel_values.device
1300
+
1301
+ if pixel_mask is None:
1302
+ pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
1303
+
1304
+ # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
1305
+ # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1306
+ # which is a list of tuples
1307
+ features = self.backbone(pixel_values, pixel_mask)
1308
+
1309
+ # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1310
+ sources = []
1311
+ masks = []
1312
+ for level, (source, mask) in enumerate(features):
1313
+ sources.append(source)
1314
+ masks.append(mask)
1315
+ if mask is None:
1316
+ raise ValueError("No attention mask was provided")
1317
+
1318
+ if self.training:
1319
+ reference_points = self.reference_point_embed.weight
1320
+ query_feat = self.query_feat.weight
1321
+ else:
1322
+ # only use one group in inference
1323
+ reference_points = self.reference_point_embed.weight[: self.num_queries]
1324
+ query_feat = self.query_feat.weight[: self.num_queries]
1325
+
1326
+ # Prepare encoder inputs (by flattening)
1327
+ source_flatten = []
1328
+ mask_flatten = []
1329
+ spatial_shapes_list = []
1330
+ for source, mask in zip(sources, masks):
1331
+ batch_size, num_channels, height, width = source.shape
1332
+ spatial_shape = (height, width)
1333
+ spatial_shapes_list.append(spatial_shape)
1334
+ source = source.flatten(2).transpose(1, 2)
1335
+ mask = mask.flatten(1)
1336
+ source_flatten.append(source)
1337
+ mask_flatten.append(mask)
1338
+ source_flatten = torch.cat(source_flatten, 1)
1339
+ mask_flatten = torch.cat(mask_flatten, 1)
1340
+ spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
1341
+ level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
1342
+ valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
1343
+
1344
+ target = query_feat.unsqueeze(0).expand(batch_size, -1, -1)
1345
+ reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)
1346
+
1347
+ object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
1348
+ source_flatten, ~mask_flatten, spatial_shapes_list
1349
+ )
1350
+
1351
+ group_detr = self.group_detr if self.training else 1
1352
+ topk = self.num_queries
1353
+ topk_coords_logits = []
1354
+ topk_coords_logits_undetach = []
1355
+ object_query_undetach = []
1356
+
1357
+ for group_id in range(group_detr):
1358
+ group_object_query = self.enc_output[group_id](object_query_embedding)
1359
+ group_object_query = self.enc_output_norm[group_id](group_object_query)
1360
+
1361
+ group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query)
1362
+ group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query)
1363
+ group_enc_outputs_coord = refine_bboxes(output_proposals, group_delta_bbox)
1364
+
1365
+ group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1]
1366
+ group_topk_coords_logits_undetach = torch.gather(
1367
+ group_enc_outputs_coord,
1368
+ 1,
1369
+ group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
1370
+ )
1371
+ group_topk_coords_logits = group_topk_coords_logits_undetach.detach()
1372
+ group_object_query_undetach = torch.gather(
1373
+ group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.config.d_model)
1374
+ )
1375
+
1376
+ topk_coords_logits.append(group_topk_coords_logits)
1377
+ topk_coords_logits_undetach.append(group_topk_coords_logits_undetach)
1378
+ object_query_undetach.append(group_object_query_undetach)
1379
+
1380
+ topk_coords_logits = torch.cat(topk_coords_logits, 1)
1381
+ topk_coords_logits_undetach = torch.cat(topk_coords_logits_undetach, 1)
1382
+ object_query_undetach = torch.cat(object_query_undetach, 1)
1383
+
1384
+ enc_outputs_class = object_query_undetach
1385
+ enc_outputs_coord_logits = topk_coords_logits
1386
+
1387
+ reference_points = refine_bboxes(topk_coords_logits_undetach, reference_points)
1388
+
1389
+ init_reference_points = reference_points
1390
+ decoder_outputs = self.decoder(
1391
+ inputs_embeds=target,
1392
+ reference_points=reference_points,
1393
+ spatial_shapes=spatial_shapes,
1394
+ spatial_shapes_list=spatial_shapes_list,
1395
+ level_start_index=level_start_index,
1396
+ valid_ratios=valid_ratios,
1397
+ encoder_hidden_states=source_flatten,
1398
+ encoder_attention_mask=mask_flatten,
1399
+ **kwargs,
1400
+ )
1401
+
1402
+ return LwDetrModelOutput(
1403
+ init_reference_points=init_reference_points,
1404
+ last_hidden_state=decoder_outputs.last_hidden_state,
1405
+ intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
1406
+ intermediate_reference_points=decoder_outputs.intermediate_reference_points,
1407
+ enc_outputs_class=enc_outputs_class,
1408
+ enc_outputs_coord_logits=enc_outputs_coord_logits,
1409
+ )
1410
+
1411
+
1412
+ class LwDetrMLPPredictionHead(DeformableDetrMLPPredictionHead):
1413
+ pass
1414
+
1415
+
1416
+ @dataclass
1417
+ @auto_docstring(
1418
+ custom_intro="""
1419
+ Output type of [`LwDetrForObjectDetection`].
1420
+ """
1421
+ )
1422
+ class LwDetrObjectDetectionOutput(ModelOutput):
1423
+ r"""
1424
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
1425
+ Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
1426
+ bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
1427
+ scale-invariant IoU loss.
1428
+ loss_dict (`Dict`, *optional*):
1429
+ A dictionary containing the individual losses. Useful for logging.
1430
+ logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
1431
+ Classification logits (including no-object) for all queries.
1432
+ pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1433
+ Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
1434
+ values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
1435
+ possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
1436
+ unnormalized bounding boxes.
1437
+ auxiliary_outputs (`list[Dict]`, *optional*):
1438
+ Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
1439
+ and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
1440
+ `pred_boxes`) for each decoder layer.
1441
+ init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1442
+ Initial reference points sent through the Transformer decoder.
1443
+ intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
1444
+ Stacked intermediate hidden states (output of each layer of the decoder).
1445
+ intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1446
+ Stacked intermediate reference points (reference points of each layer of the decoder).
1447
+ enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1448
+ Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
1449
+ picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
1450
+ foreground and background).
1451
+ enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1452
+ Logits of predicted bounding boxes coordinates in the first stage.
1453
+ """
1454
+
1455
+ loss: torch.FloatTensor | None = None
1456
+ loss_dict: dict | None = None
1457
+ logits: torch.FloatTensor | None = None
1458
+ pred_boxes: torch.FloatTensor | None = None
1459
+ auxiliary_outputs: list[dict] | None = None
1460
+ init_reference_points: torch.FloatTensor | None = None
1461
+ last_hidden_state: torch.FloatTensor | None = None
1462
+ intermediate_hidden_states: torch.FloatTensor | None = None
1463
+ intermediate_reference_points: torch.FloatTensor | None = None
1464
+ enc_outputs_class: Any = None
1465
+ enc_outputs_coord_logits: torch.FloatTensor | None = None
1466
+
1467
+
1468
+ @auto_docstring(
1469
+ custom_intro="""
1470
+ LW DETR Model (consisting of a backbone and decoder Transformer) with object detection heads on
1471
+ top, for tasks such as COCO detection.
1472
+ """
1473
+ )
1474
+ class LwDetrForObjectDetection(DeformableDetrForObjectDetection):
1475
+ _tied_weights_keys = None
1476
+
1477
+ def __init__(self, config: LwDetrConfig):
1478
+ PreTrainedModel.__init__(self, config)
1479
+ self.model = LwDetrModel(config)
1480
+ self.class_embed = nn.Linear(config.d_model, config.num_labels)
1481
+ self.bbox_embed = LwDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
1482
+
1483
+ self.post_init()
1484
+
1485
+ @check_model_inputs
1486
+ @auto_docstring
1487
+ def forward(
1488
+ self,
1489
+ pixel_values: torch.FloatTensor = None,
1490
+ pixel_mask: torch.LongTensor | None = None,
1491
+ labels: list[dict] | None = None,
1492
+ **kwargs: Unpack[TransformersKwargs],
1493
+ ) -> LwDetrObjectDetectionOutput:
1494
+ r"""
1495
+ decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1496
+ Not used by default. Can be used to mask object queries.
1497
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1498
+ Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1499
+ can choose to directly pass a flattened representation of an image.
1500
+ decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1501
+ Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1502
+ embedded representation.
1503
+ labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1504
+ Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1505
+ following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1506
+ respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1507
+ in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1508
+
1509
+ Examples:
1510
+
1511
+ ```python
1512
+ >>> from transformers import AutoImageProcessor, LwDetrForObjectDetection
1513
+ >>> from PIL import Image
1514
+ >>> import httpx
1515
+ >>> from io import BytesIO
1516
+
1517
+ >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1518
+ >>> with httpx.stream("GET", url) as response:
1519
+ ... image = Image.open(BytesIO(response.read()))
1520
+
1521
+ >>> image_processor = AutoImageProcessor.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1522
+ >>> model = LwDetrForObjectDetection.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1523
+
1524
+ >>> inputs = image_processor(images=image, return_tensors="pt")
1525
+ >>> outputs = model(**inputs)
1526
+
1527
+ >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
1528
+ >>> target_sizes = torch.tensor([image.size[::-1]])
1529
+ >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[
1530
+ ... 0
1531
+ ... ]
1532
+ >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
1533
+ ... box = [round(i, 2) for i in box.tolist()]
1534
+ ... print(
1535
+ ... f"Detected {model.config.id2label[label.item()]} with confidence "
1536
+ ... f"{round(score.item(), 3)} at location {box}"
1537
+ ... )
1538
+ Detected cat with confidence 0.8 at location [16.5, 52.84, 318.25, 470.78]
1539
+ Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
1540
+ Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
1541
+ ```"""
1542
+ outputs = self.model(
1543
+ pixel_values,
1544
+ pixel_mask=pixel_mask,
1545
+ **kwargs,
1546
+ )
1547
+
1548
+ last_hidden_states = outputs.last_hidden_state
1549
+ intermediate_reference_points = outputs.intermediate_reference_points
1550
+ enc_outputs_class_logits = outputs.enc_outputs_class
1551
+ enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits
1552
+
1553
+ logits = self.class_embed(last_hidden_states)
1554
+ pred_boxes_delta = self.bbox_embed(last_hidden_states)
1555
+ pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta)
1556
+
1557
+ enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.config.num_queries, dim=1)
1558
+ pred_class = []
1559
+ group_detr = self.config.group_detr if self.training else 1
1560
+ for group_index in range(group_detr):
1561
+ group_pred_class = self.model.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index])
1562
+ pred_class.append(group_pred_class)
1563
+ enc_outputs_class_logits = torch.cat(pred_class, dim=1)
1564
+
1565
+ loss, loss_dict, auxiliary_outputs = None, None, None
1566
+ if labels is not None:
1567
+ outputs_class, outputs_coord = None, None
1568
+ if self.config.auxiliary_loss:
1569
+ intermediate_hidden_states = outputs.intermediate_hidden_states
1570
+ outputs_coord_delta = self.bbox_embed(intermediate_hidden_states)
1571
+ outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta)
1572
+ outputs_class = self.class_embed(intermediate_hidden_states)
1573
+
1574
+ loss, loss_dict, auxiliary_outputs = self.loss_function(
1575
+ logits,
1576
+ labels,
1577
+ self.device,
1578
+ pred_boxes,
1579
+ self.config,
1580
+ outputs_class,
1581
+ outputs_coord,
1582
+ enc_outputs_class_logits,
1583
+ enc_outputs_boxes_logits,
1584
+ )
1585
+
1586
+ return LwDetrObjectDetectionOutput(
1587
+ loss=loss,
1588
+ loss_dict=loss_dict,
1589
+ logits=logits,
1590
+ pred_boxes=pred_boxes,
1591
+ auxiliary_outputs=auxiliary_outputs,
1592
+ last_hidden_state=outputs.last_hidden_state,
1593
+ intermediate_hidden_states=outputs.intermediate_hidden_states,
1594
+ intermediate_reference_points=outputs.intermediate_reference_points,
1595
+ init_reference_points=outputs.init_reference_points,
1596
+ enc_outputs_class=enc_outputs_class_logits,
1597
+ enc_outputs_coord_logits=enc_outputs_boxes_logits,
1598
+ )
1599
+
1600
+
1601
+ __all__ = [
1602
+ "LwDetrConfig",
1603
+ "LwDetrPreTrainedModel",
1604
+ "LwDetrModel",
1605
+ "LwDetrForObjectDetection",
1606
+ "LwDetrViTConfig",
1607
+ "LwDetrViTPreTrainedModel",
1608
+ "LwDetrViTBackbone",
1609
+ ]