transformers 5.0.0__py3-none-any.whl → 5.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1606) hide show
  1. transformers/__init__.py +36 -55
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +33 -32
  4. transformers/cache_utils.py +139 -32
  5. transformers/cli/chat.py +3 -3
  6. transformers/cli/serve.py +19 -49
  7. transformers/cli/transformers.py +1 -2
  8. transformers/configuration_utils.py +155 -129
  9. transformers/conversion_mapping.py +22 -158
  10. transformers/convert_slow_tokenizer.py +17 -227
  11. transformers/core_model_loading.py +185 -528
  12. transformers/data/data_collator.py +4 -12
  13. transformers/data/processors/glue.py +1 -0
  14. transformers/data/processors/utils.py +1 -0
  15. transformers/data/processors/xnli.py +1 -0
  16. transformers/dependency_versions_check.py +1 -0
  17. transformers/dependency_versions_table.py +7 -5
  18. transformers/distributed/configuration_utils.py +2 -1
  19. transformers/dynamic_module_utils.py +25 -24
  20. transformers/feature_extraction_sequence_utils.py +23 -19
  21. transformers/feature_extraction_utils.py +33 -64
  22. transformers/file_utils.py +1 -0
  23. transformers/generation/__init__.py +1 -11
  24. transformers/generation/candidate_generator.py +33 -80
  25. transformers/generation/configuration_utils.py +133 -189
  26. transformers/generation/continuous_batching/__init__.py +1 -4
  27. transformers/generation/continuous_batching/cache.py +25 -83
  28. transformers/generation/continuous_batching/cache_manager.py +45 -155
  29. transformers/generation/continuous_batching/continuous_api.py +147 -270
  30. transformers/generation/continuous_batching/requests.py +3 -51
  31. transformers/generation/continuous_batching/scheduler.py +105 -160
  32. transformers/generation/logits_process.py +128 -0
  33. transformers/generation/stopping_criteria.py +1 -1
  34. transformers/generation/streamers.py +1 -0
  35. transformers/generation/utils.py +123 -122
  36. transformers/generation/watermarking.py +6 -8
  37. transformers/hf_argparser.py +13 -9
  38. transformers/hyperparameter_search.py +2 -1
  39. transformers/image_processing_base.py +23 -12
  40. transformers/image_processing_utils.py +15 -11
  41. transformers/image_processing_utils_fast.py +75 -85
  42. transformers/image_transforms.py +42 -73
  43. transformers/image_utils.py +32 -30
  44. transformers/initialization.py +0 -37
  45. transformers/integrations/__init__.py +2 -16
  46. transformers/integrations/accelerate.py +113 -58
  47. transformers/integrations/aqlm.py +66 -36
  48. transformers/integrations/awq.py +516 -45
  49. transformers/integrations/bitnet.py +105 -47
  50. transformers/integrations/bitsandbytes.py +202 -91
  51. transformers/integrations/deepspeed.py +4 -161
  52. transformers/integrations/eetq.py +82 -84
  53. transformers/integrations/executorch.py +1 -1
  54. transformers/integrations/fbgemm_fp8.py +145 -190
  55. transformers/integrations/finegrained_fp8.py +215 -249
  56. transformers/integrations/flash_attention.py +3 -3
  57. transformers/integrations/flex_attention.py +1 -1
  58. transformers/integrations/fp_quant.py +0 -90
  59. transformers/integrations/ggml.py +2 -11
  60. transformers/integrations/higgs.py +62 -37
  61. transformers/integrations/hub_kernels.py +8 -65
  62. transformers/integrations/integration_utils.py +3 -47
  63. transformers/integrations/mistral.py +0 -12
  64. transformers/integrations/mxfp4.py +80 -33
  65. transformers/integrations/peft.py +191 -483
  66. transformers/integrations/quanto.py +56 -77
  67. transformers/integrations/spqr.py +90 -42
  68. transformers/integrations/tensor_parallel.py +221 -167
  69. transformers/integrations/torchao.py +43 -35
  70. transformers/integrations/vptq.py +59 -40
  71. transformers/kernels/__init__.py +0 -0
  72. transformers/{models/pe_audio_video/processing_pe_audio_video.py → kernels/falcon_mamba/__init__.py} +3 -12
  73. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +529 -0
  74. transformers/loss/loss_utils.py +0 -2
  75. transformers/masking_utils.py +55 -51
  76. transformers/model_debugging_utils.py +5 -4
  77. transformers/modelcard.py +194 -15
  78. transformers/modeling_attn_mask_utils.py +19 -19
  79. transformers/modeling_flash_attention_utils.py +27 -27
  80. transformers/modeling_gguf_pytorch_utils.py +24 -79
  81. transformers/modeling_layers.py +22 -21
  82. transformers/modeling_outputs.py +253 -242
  83. transformers/modeling_rope_utils.py +117 -138
  84. transformers/modeling_utils.py +739 -850
  85. transformers/models/__init__.py +0 -27
  86. transformers/models/afmoe/configuration_afmoe.py +33 -40
  87. transformers/models/afmoe/modeling_afmoe.py +54 -42
  88. transformers/models/afmoe/modular_afmoe.py +33 -23
  89. transformers/models/aimv2/configuration_aimv2.py +10 -2
  90. transformers/models/aimv2/modeling_aimv2.py +42 -47
  91. transformers/models/aimv2/modular_aimv2.py +19 -17
  92. transformers/models/albert/configuration_albert.py +2 -8
  93. transformers/models/albert/modeling_albert.py +69 -70
  94. transformers/models/albert/tokenization_albert.py +14 -5
  95. transformers/models/align/configuration_align.py +6 -8
  96. transformers/models/align/modeling_align.py +89 -94
  97. transformers/models/align/processing_align.py +30 -2
  98. transformers/models/altclip/configuration_altclip.py +7 -4
  99. transformers/models/altclip/modeling_altclip.py +103 -114
  100. transformers/models/altclip/processing_altclip.py +15 -2
  101. transformers/models/apertus/__init__.py +1 -0
  102. transformers/models/apertus/configuration_apertus.py +28 -23
  103. transformers/models/apertus/modeling_apertus.py +40 -39
  104. transformers/models/apertus/modular_apertus.py +38 -37
  105. transformers/models/arcee/configuration_arcee.py +30 -25
  106. transformers/models/arcee/modeling_arcee.py +39 -36
  107. transformers/models/arcee/modular_arcee.py +23 -20
  108. transformers/models/aria/configuration_aria.py +44 -31
  109. transformers/models/aria/image_processing_aria.py +27 -25
  110. transformers/models/aria/modeling_aria.py +106 -110
  111. transformers/models/aria/modular_aria.py +127 -118
  112. transformers/models/aria/processing_aria.py +35 -28
  113. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +1 -0
  114. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +6 -3
  115. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +8 -6
  116. transformers/models/audioflamingo3/__init__.py +1 -0
  117. transformers/models/audioflamingo3/configuration_audioflamingo3.py +1 -0
  118. transformers/models/audioflamingo3/modeling_audioflamingo3.py +49 -58
  119. transformers/models/audioflamingo3/modular_audioflamingo3.py +43 -53
  120. transformers/models/audioflamingo3/processing_audioflamingo3.py +30 -33
  121. transformers/models/auto/auto_factory.py +7 -6
  122. transformers/models/auto/configuration_auto.py +5 -66
  123. transformers/models/auto/feature_extraction_auto.py +10 -14
  124. transformers/models/auto/image_processing_auto.py +41 -32
  125. transformers/models/auto/modeling_auto.py +188 -46
  126. transformers/models/auto/processing_auto.py +11 -24
  127. transformers/models/auto/tokenization_auto.py +588 -171
  128. transformers/models/auto/video_processing_auto.py +10 -12
  129. transformers/models/autoformer/configuration_autoformer.py +7 -4
  130. transformers/models/autoformer/modeling_autoformer.py +101 -104
  131. transformers/models/aya_vision/configuration_aya_vision.py +1 -4
  132. transformers/models/aya_vision/modeling_aya_vision.py +102 -71
  133. transformers/models/aya_vision/modular_aya_vision.py +74 -46
  134. transformers/models/aya_vision/processing_aya_vision.py +53 -25
  135. transformers/models/bamba/configuration_bamba.py +39 -34
  136. transformers/models/bamba/modeling_bamba.py +86 -82
  137. transformers/models/bamba/modular_bamba.py +72 -70
  138. transformers/models/bark/configuration_bark.py +8 -6
  139. transformers/models/bark/generation_configuration_bark.py +5 -3
  140. transformers/models/bark/modeling_bark.py +57 -54
  141. transformers/models/bark/processing_bark.py +41 -19
  142. transformers/models/bart/configuration_bart.py +6 -9
  143. transformers/models/bart/modeling_bart.py +126 -135
  144. transformers/models/barthez/tokenization_barthez.py +11 -3
  145. transformers/models/bartpho/tokenization_bartpho.py +7 -6
  146. transformers/models/beit/configuration_beit.py +11 -0
  147. transformers/models/beit/image_processing_beit.py +56 -53
  148. transformers/models/beit/image_processing_beit_fast.py +12 -10
  149. transformers/models/beit/modeling_beit.py +60 -69
  150. transformers/models/bert/configuration_bert.py +2 -12
  151. transformers/models/bert/modeling_bert.py +122 -114
  152. transformers/models/bert/tokenization_bert.py +23 -8
  153. transformers/models/bert/tokenization_bert_legacy.py +5 -3
  154. transformers/models/bert_generation/configuration_bert_generation.py +2 -17
  155. transformers/models/bert_generation/modeling_bert_generation.py +49 -49
  156. transformers/models/bert_generation/tokenization_bert_generation.py +3 -2
  157. transformers/models/bert_japanese/tokenization_bert_japanese.py +6 -5
  158. transformers/models/bertweet/tokenization_bertweet.py +3 -1
  159. transformers/models/big_bird/configuration_big_bird.py +9 -12
  160. transformers/models/big_bird/modeling_big_bird.py +109 -116
  161. transformers/models/big_bird/tokenization_big_bird.py +43 -16
  162. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  163. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +117 -130
  164. transformers/models/biogpt/configuration_biogpt.py +2 -8
  165. transformers/models/biogpt/modeling_biogpt.py +76 -72
  166. transformers/models/biogpt/modular_biogpt.py +66 -62
  167. transformers/models/biogpt/tokenization_biogpt.py +5 -3
  168. transformers/models/bit/configuration_bit.py +1 -0
  169. transformers/models/bit/image_processing_bit.py +24 -21
  170. transformers/models/bit/image_processing_bit_fast.py +1 -0
  171. transformers/models/bit/modeling_bit.py +12 -25
  172. transformers/models/bitnet/configuration_bitnet.py +28 -23
  173. transformers/models/bitnet/modeling_bitnet.py +39 -36
  174. transformers/models/bitnet/modular_bitnet.py +6 -4
  175. transformers/models/blenderbot/configuration_blenderbot.py +5 -8
  176. transformers/models/blenderbot/modeling_blenderbot.py +96 -77
  177. transformers/models/blenderbot/tokenization_blenderbot.py +24 -18
  178. transformers/models/blenderbot_small/configuration_blenderbot_small.py +5 -8
  179. transformers/models/blenderbot_small/modeling_blenderbot_small.py +69 -79
  180. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +3 -1
  181. transformers/models/blip/configuration_blip.py +10 -9
  182. transformers/models/blip/image_processing_blip.py +20 -17
  183. transformers/models/blip/image_processing_blip_fast.py +1 -0
  184. transformers/models/blip/modeling_blip.py +108 -117
  185. transformers/models/blip/modeling_blip_text.py +65 -73
  186. transformers/models/blip/processing_blip.py +36 -5
  187. transformers/models/blip_2/configuration_blip_2.py +2 -2
  188. transformers/models/blip_2/modeling_blip_2.py +118 -146
  189. transformers/models/blip_2/processing_blip_2.py +38 -8
  190. transformers/models/bloom/configuration_bloom.py +2 -5
  191. transformers/models/bloom/modeling_bloom.py +104 -77
  192. transformers/models/blt/configuration_blt.py +86 -94
  193. transformers/models/blt/modeling_blt.py +81 -238
  194. transformers/models/blt/modular_blt.py +65 -228
  195. transformers/models/bridgetower/configuration_bridgetower.py +2 -7
  196. transformers/models/bridgetower/image_processing_bridgetower.py +35 -34
  197. transformers/models/bridgetower/image_processing_bridgetower_fast.py +16 -13
  198. transformers/models/bridgetower/modeling_bridgetower.py +119 -141
  199. transformers/models/bridgetower/processing_bridgetower.py +16 -2
  200. transformers/models/bros/configuration_bros.py +18 -24
  201. transformers/models/bros/modeling_bros.py +80 -90
  202. transformers/models/bros/processing_bros.py +12 -2
  203. transformers/models/byt5/tokenization_byt5.py +6 -4
  204. transformers/models/camembert/configuration_camembert.py +2 -8
  205. transformers/models/camembert/modeling_camembert.py +195 -196
  206. transformers/models/camembert/modular_camembert.py +54 -51
  207. transformers/models/camembert/tokenization_camembert.py +13 -6
  208. transformers/models/canine/configuration_canine.py +2 -4
  209. transformers/models/canine/modeling_canine.py +75 -84
  210. transformers/models/canine/tokenization_canine.py +1 -2
  211. transformers/models/chameleon/configuration_chameleon.py +34 -29
  212. transformers/models/chameleon/image_processing_chameleon.py +24 -21
  213. transformers/models/chameleon/image_processing_chameleon_fast.py +6 -5
  214. transformers/models/chameleon/modeling_chameleon.py +93 -142
  215. transformers/models/chameleon/processing_chameleon.py +41 -16
  216. transformers/models/chinese_clip/configuration_chinese_clip.py +8 -10
  217. transformers/models/chinese_clip/image_processing_chinese_clip.py +24 -21
  218. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +1 -0
  219. transformers/models/chinese_clip/modeling_chinese_clip.py +92 -96
  220. transformers/models/chinese_clip/processing_chinese_clip.py +15 -2
  221. transformers/models/clap/configuration_clap.py +9 -4
  222. transformers/models/clap/feature_extraction_clap.py +12 -11
  223. transformers/models/clap/modeling_clap.py +123 -136
  224. transformers/models/clap/processing_clap.py +15 -2
  225. transformers/models/clip/configuration_clip.py +2 -4
  226. transformers/models/clip/image_processing_clip.py +24 -21
  227. transformers/models/clip/image_processing_clip_fast.py +1 -9
  228. transformers/models/clip/modeling_clip.py +65 -65
  229. transformers/models/clip/processing_clip.py +14 -2
  230. transformers/models/clip/tokenization_clip.py +46 -21
  231. transformers/models/clipseg/configuration_clipseg.py +2 -4
  232. transformers/models/clipseg/modeling_clipseg.py +109 -119
  233. transformers/models/clipseg/processing_clipseg.py +42 -19
  234. transformers/models/clvp/configuration_clvp.py +5 -15
  235. transformers/models/clvp/feature_extraction_clvp.py +10 -7
  236. transformers/models/clvp/modeling_clvp.py +146 -155
  237. transformers/models/clvp/number_normalizer.py +2 -1
  238. transformers/models/clvp/processing_clvp.py +20 -3
  239. transformers/models/clvp/tokenization_clvp.py +64 -1
  240. transformers/models/code_llama/tokenization_code_llama.py +44 -18
  241. transformers/models/codegen/configuration_codegen.py +4 -4
  242. transformers/models/codegen/modeling_codegen.py +53 -63
  243. transformers/models/codegen/tokenization_codegen.py +47 -17
  244. transformers/models/cohere/configuration_cohere.py +30 -25
  245. transformers/models/cohere/modeling_cohere.py +42 -40
  246. transformers/models/cohere/modular_cohere.py +29 -26
  247. transformers/models/cohere/tokenization_cohere.py +46 -15
  248. transformers/models/cohere2/configuration_cohere2.py +32 -31
  249. transformers/models/cohere2/modeling_cohere2.py +44 -42
  250. transformers/models/cohere2/modular_cohere2.py +54 -54
  251. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +14 -13
  252. transformers/models/cohere2_vision/modeling_cohere2_vision.py +58 -59
  253. transformers/models/cohere2_vision/modular_cohere2_vision.py +46 -45
  254. transformers/models/cohere2_vision/processing_cohere2_vision.py +36 -6
  255. transformers/models/colpali/configuration_colpali.py +1 -0
  256. transformers/models/colpali/modeling_colpali.py +16 -14
  257. transformers/models/colpali/modular_colpali.py +51 -11
  258. transformers/models/colpali/processing_colpali.py +52 -14
  259. transformers/models/colqwen2/modeling_colqwen2.py +28 -28
  260. transformers/models/colqwen2/modular_colqwen2.py +74 -37
  261. transformers/models/colqwen2/processing_colqwen2.py +52 -16
  262. transformers/models/conditional_detr/configuration_conditional_detr.py +2 -1
  263. transformers/models/conditional_detr/image_processing_conditional_detr.py +70 -67
  264. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +36 -36
  265. transformers/models/conditional_detr/modeling_conditional_detr.py +87 -99
  266. transformers/models/conditional_detr/modular_conditional_detr.py +3 -49
  267. transformers/models/convbert/configuration_convbert.py +8 -11
  268. transformers/models/convbert/modeling_convbert.py +87 -94
  269. transformers/models/convbert/tokenization_convbert.py +1 -0
  270. transformers/models/convnext/configuration_convnext.py +1 -0
  271. transformers/models/convnext/image_processing_convnext.py +23 -20
  272. transformers/models/convnext/image_processing_convnext_fast.py +21 -16
  273. transformers/models/convnext/modeling_convnext.py +12 -9
  274. transformers/models/convnextv2/configuration_convnextv2.py +1 -0
  275. transformers/models/convnextv2/modeling_convnextv2.py +12 -9
  276. transformers/models/cpm/tokenization_cpm.py +7 -6
  277. transformers/models/cpm/tokenization_cpm_fast.py +5 -3
  278. transformers/models/cpmant/configuration_cpmant.py +1 -4
  279. transformers/models/cpmant/modeling_cpmant.py +40 -38
  280. transformers/models/cpmant/tokenization_cpmant.py +3 -1
  281. transformers/models/csm/configuration_csm.py +66 -58
  282. transformers/models/csm/generation_csm.py +35 -31
  283. transformers/models/csm/modeling_csm.py +85 -85
  284. transformers/models/csm/modular_csm.py +58 -58
  285. transformers/models/csm/processing_csm.py +68 -25
  286. transformers/models/ctrl/configuration_ctrl.py +1 -16
  287. transformers/models/ctrl/modeling_ctrl.py +44 -54
  288. transformers/models/ctrl/tokenization_ctrl.py +1 -0
  289. transformers/models/cvt/configuration_cvt.py +1 -0
  290. transformers/models/cvt/modeling_cvt.py +16 -20
  291. transformers/models/cwm/__init__.py +1 -0
  292. transformers/models/cwm/configuration_cwm.py +12 -8
  293. transformers/models/cwm/modeling_cwm.py +39 -37
  294. transformers/models/cwm/modular_cwm.py +12 -10
  295. transformers/models/d_fine/configuration_d_fine.py +5 -7
  296. transformers/models/d_fine/modeling_d_fine.py +128 -138
  297. transformers/models/d_fine/modular_d_fine.py +18 -33
  298. transformers/models/dab_detr/configuration_dab_detr.py +3 -6
  299. transformers/models/dab_detr/modeling_dab_detr.py +75 -81
  300. transformers/models/dac/configuration_dac.py +1 -0
  301. transformers/models/dac/feature_extraction_dac.py +9 -6
  302. transformers/models/dac/modeling_dac.py +26 -24
  303. transformers/models/data2vec/configuration_data2vec_audio.py +2 -4
  304. transformers/models/data2vec/configuration_data2vec_text.py +3 -11
  305. transformers/models/data2vec/configuration_data2vec_vision.py +1 -0
  306. transformers/models/data2vec/modeling_data2vec_audio.py +56 -57
  307. transformers/models/data2vec/modeling_data2vec_text.py +93 -98
  308. transformers/models/data2vec/modeling_data2vec_vision.py +45 -49
  309. transformers/models/data2vec/modular_data2vec_audio.py +1 -6
  310. transformers/models/data2vec/modular_data2vec_text.py +54 -58
  311. transformers/models/dbrx/configuration_dbrx.py +22 -36
  312. transformers/models/dbrx/modeling_dbrx.py +45 -42
  313. transformers/models/dbrx/modular_dbrx.py +33 -31
  314. transformers/models/deberta/configuration_deberta.py +1 -6
  315. transformers/models/deberta/modeling_deberta.py +60 -64
  316. transformers/models/deberta/tokenization_deberta.py +21 -9
  317. transformers/models/deberta_v2/configuration_deberta_v2.py +1 -6
  318. transformers/models/deberta_v2/modeling_deberta_v2.py +65 -71
  319. transformers/models/deberta_v2/tokenization_deberta_v2.py +29 -11
  320. transformers/models/decision_transformer/configuration_decision_transformer.py +2 -3
  321. transformers/models/decision_transformer/modeling_decision_transformer.py +56 -60
  322. transformers/models/deepseek_v2/configuration_deepseek_v2.py +44 -39
  323. transformers/models/deepseek_v2/modeling_deepseek_v2.py +43 -43
  324. transformers/models/deepseek_v2/modular_deepseek_v2.py +49 -48
  325. transformers/models/deepseek_v3/configuration_deepseek_v3.py +45 -40
  326. transformers/models/deepseek_v3/modeling_deepseek_v3.py +42 -45
  327. transformers/models/deepseek_v3/modular_deepseek_v3.py +9 -14
  328. transformers/models/deepseek_vl/configuration_deepseek_vl.py +3 -2
  329. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +26 -25
  330. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +10 -10
  331. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -57
  332. transformers/models/deepseek_vl/modular_deepseek_vl.py +43 -14
  333. transformers/models/deepseek_vl/processing_deepseek_vl.py +41 -10
  334. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +5 -3
  335. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +35 -35
  336. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +24 -20
  337. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +61 -109
  338. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +118 -146
  339. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +44 -12
  340. transformers/models/deformable_detr/configuration_deformable_detr.py +3 -2
  341. transformers/models/deformable_detr/image_processing_deformable_detr.py +61 -59
  342. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +28 -28
  343. transformers/models/deformable_detr/modeling_deformable_detr.py +82 -88
  344. transformers/models/deformable_detr/modular_deformable_detr.py +3 -1
  345. transformers/models/deit/configuration_deit.py +1 -0
  346. transformers/models/deit/image_processing_deit.py +21 -18
  347. transformers/models/deit/image_processing_deit_fast.py +1 -0
  348. transformers/models/deit/modeling_deit.py +22 -24
  349. transformers/models/depth_anything/configuration_depth_anything.py +4 -2
  350. transformers/models/depth_anything/modeling_depth_anything.py +10 -10
  351. transformers/models/depth_pro/configuration_depth_pro.py +1 -0
  352. transformers/models/depth_pro/image_processing_depth_pro.py +23 -22
  353. transformers/models/depth_pro/image_processing_depth_pro_fast.py +10 -8
  354. transformers/models/depth_pro/modeling_depth_pro.py +27 -31
  355. transformers/models/detr/configuration_detr.py +2 -1
  356. transformers/models/detr/image_processing_detr.py +66 -64
  357. transformers/models/detr/image_processing_detr_fast.py +34 -33
  358. transformers/models/detr/modeling_detr.py +79 -95
  359. transformers/models/dia/configuration_dia.py +15 -9
  360. transformers/models/dia/feature_extraction_dia.py +9 -6
  361. transformers/models/dia/generation_dia.py +50 -48
  362. transformers/models/dia/modeling_dia.py +69 -78
  363. transformers/models/dia/modular_dia.py +56 -64
  364. transformers/models/dia/processing_dia.py +29 -39
  365. transformers/models/dia/tokenization_dia.py +6 -3
  366. transformers/models/diffllama/configuration_diffllama.py +30 -25
  367. transformers/models/diffllama/modeling_diffllama.py +49 -46
  368. transformers/models/diffllama/modular_diffllama.py +19 -17
  369. transformers/models/dinat/configuration_dinat.py +1 -0
  370. transformers/models/dinat/modeling_dinat.py +44 -47
  371. transformers/models/dinov2/configuration_dinov2.py +1 -0
  372. transformers/models/dinov2/modeling_dinov2.py +15 -15
  373. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +1 -1
  374. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +15 -16
  375. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +9 -9
  376. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +7 -4
  377. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +6 -3
  378. transformers/models/dinov3_vit/configuration_dinov3_vit.py +8 -5
  379. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +9 -7
  380. transformers/models/dinov3_vit/modeling_dinov3_vit.py +18 -19
  381. transformers/models/dinov3_vit/modular_dinov3_vit.py +15 -16
  382. transformers/models/distilbert/configuration_distilbert.py +2 -8
  383. transformers/models/distilbert/modeling_distilbert.py +55 -55
  384. transformers/models/distilbert/tokenization_distilbert.py +1 -13
  385. transformers/models/doge/__init__.py +1 -0
  386. transformers/models/doge/configuration_doge.py +32 -39
  387. transformers/models/doge/modeling_doge.py +49 -45
  388. transformers/models/doge/modular_doge.py +63 -71
  389. transformers/models/donut/configuration_donut_swin.py +1 -0
  390. transformers/models/donut/image_processing_donut.py +29 -26
  391. transformers/models/donut/image_processing_donut_fast.py +15 -9
  392. transformers/models/donut/modeling_donut_swin.py +58 -62
  393. transformers/models/donut/processing_donut.py +26 -5
  394. transformers/models/dots1/configuration_dots1.py +33 -41
  395. transformers/models/dots1/modeling_dots1.py +45 -54
  396. transformers/models/dots1/modular_dots1.py +4 -5
  397. transformers/models/dpr/configuration_dpr.py +2 -19
  398. transformers/models/dpr/modeling_dpr.py +39 -42
  399. transformers/models/dpr/tokenization_dpr.py +9 -19
  400. transformers/models/dpr/tokenization_dpr_fast.py +9 -7
  401. transformers/models/dpt/configuration_dpt.py +2 -1
  402. transformers/models/dpt/image_processing_dpt.py +66 -65
  403. transformers/models/dpt/image_processing_dpt_fast.py +20 -18
  404. transformers/models/dpt/modeling_dpt.py +30 -32
  405. transformers/models/dpt/modular_dpt.py +17 -15
  406. transformers/models/edgetam/configuration_edgetam.py +3 -2
  407. transformers/models/edgetam/modeling_edgetam.py +86 -86
  408. transformers/models/edgetam/modular_edgetam.py +26 -21
  409. transformers/models/edgetam_video/__init__.py +1 -0
  410. transformers/models/edgetam_video/configuration_edgetam_video.py +1 -0
  411. transformers/models/edgetam_video/modeling_edgetam_video.py +158 -169
  412. transformers/models/edgetam_video/modular_edgetam_video.py +37 -30
  413. transformers/models/efficientloftr/configuration_efficientloftr.py +5 -4
  414. transformers/models/efficientloftr/image_processing_efficientloftr.py +16 -14
  415. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +9 -9
  416. transformers/models/efficientloftr/modeling_efficientloftr.py +38 -59
  417. transformers/models/efficientloftr/modular_efficientloftr.py +3 -1
  418. transformers/models/efficientnet/configuration_efficientnet.py +1 -0
  419. transformers/models/efficientnet/image_processing_efficientnet.py +32 -28
  420. transformers/models/efficientnet/image_processing_efficientnet_fast.py +19 -17
  421. transformers/models/efficientnet/modeling_efficientnet.py +15 -19
  422. transformers/models/electra/configuration_electra.py +3 -13
  423. transformers/models/electra/modeling_electra.py +103 -108
  424. transformers/models/emu3/configuration_emu3.py +17 -13
  425. transformers/models/emu3/image_processing_emu3.py +39 -44
  426. transformers/models/emu3/modeling_emu3.py +108 -148
  427. transformers/models/emu3/modular_emu3.py +73 -115
  428. transformers/models/emu3/processing_emu3.py +43 -18
  429. transformers/models/encodec/configuration_encodec.py +4 -2
  430. transformers/models/encodec/feature_extraction_encodec.py +13 -10
  431. transformers/models/encodec/modeling_encodec.py +29 -39
  432. transformers/models/encoder_decoder/configuration_encoder_decoder.py +2 -12
  433. transformers/models/encoder_decoder/modeling_encoder_decoder.py +43 -37
  434. transformers/models/eomt/configuration_eomt.py +1 -0
  435. transformers/models/eomt/image_processing_eomt.py +56 -66
  436. transformers/models/eomt/image_processing_eomt_fast.py +33 -76
  437. transformers/models/eomt/modeling_eomt.py +18 -23
  438. transformers/models/eomt/modular_eomt.py +13 -18
  439. transformers/models/ernie/configuration_ernie.py +3 -24
  440. transformers/models/ernie/modeling_ernie.py +132 -127
  441. transformers/models/ernie/modular_ernie.py +103 -97
  442. transformers/models/ernie4_5/configuration_ernie4_5.py +27 -23
  443. transformers/models/ernie4_5/modeling_ernie4_5.py +38 -36
  444. transformers/models/ernie4_5/modular_ernie4_5.py +4 -3
  445. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +36 -32
  446. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +55 -56
  447. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +46 -18
  448. transformers/models/esm/configuration_esm.py +15 -11
  449. transformers/models/esm/modeling_esm.py +34 -38
  450. transformers/models/esm/modeling_esmfold.py +49 -53
  451. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  452. transformers/models/esm/openfold_utils/loss.py +2 -1
  453. transformers/models/esm/openfold_utils/protein.py +16 -15
  454. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  455. transformers/models/esm/tokenization_esm.py +4 -2
  456. transformers/models/evolla/configuration_evolla.py +40 -50
  457. transformers/models/evolla/modeling_evolla.py +66 -71
  458. transformers/models/evolla/modular_evolla.py +47 -53
  459. transformers/models/evolla/processing_evolla.py +35 -23
  460. transformers/models/exaone4/configuration_exaone4.py +25 -23
  461. transformers/models/exaone4/modeling_exaone4.py +38 -35
  462. transformers/models/exaone4/modular_exaone4.py +46 -44
  463. transformers/models/falcon/configuration_falcon.py +26 -31
  464. transformers/models/falcon/modeling_falcon.py +80 -82
  465. transformers/models/falcon_h1/configuration_falcon_h1.py +51 -45
  466. transformers/models/falcon_h1/modeling_falcon_h1.py +82 -85
  467. transformers/models/falcon_h1/modular_falcon_h1.py +51 -56
  468. transformers/models/falcon_mamba/configuration_falcon_mamba.py +2 -1
  469. transformers/models/falcon_mamba/modeling_falcon_mamba.py +82 -75
  470. transformers/models/falcon_mamba/modular_falcon_mamba.py +45 -28
  471. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +6 -2
  472. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +60 -76
  473. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +3 -2
  474. transformers/models/flaubert/configuration_flaubert.py +5 -10
  475. transformers/models/flaubert/modeling_flaubert.py +143 -145
  476. transformers/models/flaubert/tokenization_flaubert.py +5 -3
  477. transformers/models/flava/configuration_flava.py +6 -5
  478. transformers/models/flava/image_processing_flava.py +67 -66
  479. transformers/models/flava/image_processing_flava_fast.py +49 -46
  480. transformers/models/flava/modeling_flava.py +136 -153
  481. transformers/models/flava/processing_flava.py +12 -2
  482. transformers/models/flex_olmo/__init__.py +1 -0
  483. transformers/models/flex_olmo/configuration_flex_olmo.py +32 -28
  484. transformers/models/flex_olmo/modeling_flex_olmo.py +47 -47
  485. transformers/models/flex_olmo/modular_flex_olmo.py +44 -40
  486. transformers/models/florence2/configuration_florence2.py +1 -0
  487. transformers/models/florence2/modeling_florence2.py +69 -111
  488. transformers/models/florence2/modular_florence2.py +101 -104
  489. transformers/models/florence2/processing_florence2.py +47 -18
  490. transformers/models/fnet/configuration_fnet.py +2 -6
  491. transformers/models/fnet/modeling_fnet.py +80 -83
  492. transformers/models/fnet/tokenization_fnet.py +1 -0
  493. transformers/models/focalnet/configuration_focalnet.py +1 -0
  494. transformers/models/focalnet/modeling_focalnet.py +45 -51
  495. transformers/models/fsmt/configuration_fsmt.py +17 -12
  496. transformers/models/fsmt/modeling_fsmt.py +48 -49
  497. transformers/models/fsmt/tokenization_fsmt.py +5 -3
  498. transformers/models/funnel/configuration_funnel.py +1 -8
  499. transformers/models/funnel/modeling_funnel.py +93 -99
  500. transformers/models/funnel/tokenization_funnel.py +27 -17
  501. transformers/models/fuyu/configuration_fuyu.py +34 -28
  502. transformers/models/fuyu/image_processing_fuyu.py +31 -29
  503. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  504. transformers/models/fuyu/modeling_fuyu.py +53 -53
  505. transformers/models/fuyu/processing_fuyu.py +34 -23
  506. transformers/models/gemma/configuration_gemma.py +30 -25
  507. transformers/models/gemma/modeling_gemma.py +50 -46
  508. transformers/models/gemma/modular_gemma.py +47 -42
  509. transformers/models/gemma/tokenization_gemma.py +30 -10
  510. transformers/models/gemma2/configuration_gemma2.py +35 -30
  511. transformers/models/gemma2/modeling_gemma2.py +42 -39
  512. transformers/models/gemma2/modular_gemma2.py +66 -63
  513. transformers/models/gemma3/configuration_gemma3.py +44 -44
  514. transformers/models/gemma3/image_processing_gemma3.py +31 -29
  515. transformers/models/gemma3/image_processing_gemma3_fast.py +13 -11
  516. transformers/models/gemma3/modeling_gemma3.py +207 -159
  517. transformers/models/gemma3/modular_gemma3.py +204 -153
  518. transformers/models/gemma3/processing_gemma3.py +5 -5
  519. transformers/models/gemma3n/configuration_gemma3n.py +26 -36
  520. transformers/models/gemma3n/feature_extraction_gemma3n.py +11 -9
  521. transformers/models/gemma3n/modeling_gemma3n.py +356 -222
  522. transformers/models/gemma3n/modular_gemma3n.py +207 -230
  523. transformers/models/gemma3n/processing_gemma3n.py +26 -12
  524. transformers/models/git/configuration_git.py +8 -5
  525. transformers/models/git/modeling_git.py +204 -266
  526. transformers/models/git/processing_git.py +14 -2
  527. transformers/models/glm/configuration_glm.py +28 -24
  528. transformers/models/glm/modeling_glm.py +40 -37
  529. transformers/models/glm/modular_glm.py +7 -4
  530. transformers/models/glm4/configuration_glm4.py +28 -24
  531. transformers/models/glm4/modeling_glm4.py +42 -40
  532. transformers/models/glm4/modular_glm4.py +10 -8
  533. transformers/models/glm46v/configuration_glm46v.py +1 -0
  534. transformers/models/glm46v/image_processing_glm46v.py +40 -35
  535. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  536. transformers/models/glm46v/modeling_glm46v.py +90 -137
  537. transformers/models/glm46v/modular_glm46v.py +3 -4
  538. transformers/models/glm46v/processing_glm46v.py +41 -7
  539. transformers/models/glm46v/video_processing_glm46v.py +11 -9
  540. transformers/models/glm4_moe/configuration_glm4_moe.py +32 -40
  541. transformers/models/glm4_moe/modeling_glm4_moe.py +42 -45
  542. transformers/models/glm4_moe/modular_glm4_moe.py +34 -42
  543. transformers/models/glm4v/configuration_glm4v.py +20 -18
  544. transformers/models/glm4v/image_processing_glm4v.py +40 -34
  545. transformers/models/glm4v/image_processing_glm4v_fast.py +9 -8
  546. transformers/models/glm4v/modeling_glm4v.py +205 -254
  547. transformers/models/glm4v/modular_glm4v.py +224 -210
  548. transformers/models/glm4v/processing_glm4v.py +41 -7
  549. transformers/models/glm4v/video_processing_glm4v.py +11 -9
  550. transformers/models/glm4v_moe/configuration_glm4v_moe.py +125 -136
  551. transformers/models/glm4v_moe/modeling_glm4v_moe.py +368 -377
  552. transformers/models/glm4v_moe/modular_glm4v_moe.py +169 -83
  553. transformers/models/glpn/configuration_glpn.py +1 -0
  554. transformers/models/glpn/image_processing_glpn.py +12 -11
  555. transformers/models/glpn/image_processing_glpn_fast.py +13 -11
  556. transformers/models/glpn/modeling_glpn.py +14 -16
  557. transformers/models/got_ocr2/configuration_got_ocr2.py +12 -4
  558. transformers/models/got_ocr2/image_processing_got_ocr2.py +24 -22
  559. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +11 -9
  560. transformers/models/got_ocr2/modeling_got_ocr2.py +80 -77
  561. transformers/models/got_ocr2/modular_got_ocr2.py +51 -54
  562. transformers/models/got_ocr2/processing_got_ocr2.py +63 -42
  563. transformers/models/gpt2/configuration_gpt2.py +2 -13
  564. transformers/models/gpt2/modeling_gpt2.py +115 -120
  565. transformers/models/gpt2/tokenization_gpt2.py +46 -15
  566. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +2 -5
  567. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +89 -79
  568. transformers/models/gpt_neo/configuration_gpt_neo.py +2 -9
  569. transformers/models/gpt_neo/modeling_gpt_neo.py +67 -83
  570. transformers/models/gpt_neox/configuration_gpt_neox.py +25 -25
  571. transformers/models/gpt_neox/modeling_gpt_neox.py +75 -76
  572. transformers/models/gpt_neox/modular_gpt_neox.py +66 -67
  573. transformers/models/gpt_neox/tokenization_gpt_neox.py +51 -9
  574. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +19 -24
  575. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +47 -46
  576. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +3 -1
  577. transformers/models/gpt_oss/configuration_gpt_oss.py +28 -46
  578. transformers/models/gpt_oss/modeling_gpt_oss.py +121 -83
  579. transformers/models/gpt_oss/modular_gpt_oss.py +103 -64
  580. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  581. transformers/models/gptj/configuration_gptj.py +4 -4
  582. transformers/models/gptj/modeling_gptj.py +87 -101
  583. transformers/models/granite/configuration_granite.py +33 -28
  584. transformers/models/granite/modeling_granite.py +46 -44
  585. transformers/models/granite/modular_granite.py +31 -29
  586. transformers/models/granite_speech/configuration_granite_speech.py +1 -0
  587. transformers/models/granite_speech/feature_extraction_granite_speech.py +3 -1
  588. transformers/models/granite_speech/modeling_granite_speech.py +52 -82
  589. transformers/models/granite_speech/processing_granite_speech.py +4 -11
  590. transformers/models/granitemoe/configuration_granitemoe.py +36 -31
  591. transformers/models/granitemoe/modeling_granitemoe.py +46 -41
  592. transformers/models/granitemoe/modular_granitemoe.py +27 -22
  593. transformers/models/granitemoehybrid/__init__.py +1 -0
  594. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +47 -46
  595. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +93 -97
  596. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +21 -54
  597. transformers/models/granitemoeshared/configuration_granitemoeshared.py +37 -33
  598. transformers/models/granitemoeshared/modeling_granitemoeshared.py +61 -54
  599. transformers/models/granitemoeshared/modular_granitemoeshared.py +21 -19
  600. transformers/models/grounding_dino/configuration_grounding_dino.py +4 -6
  601. transformers/models/grounding_dino/image_processing_grounding_dino.py +62 -60
  602. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +29 -28
  603. transformers/models/grounding_dino/modeling_grounding_dino.py +140 -155
  604. transformers/models/grounding_dino/modular_grounding_dino.py +3 -2
  605. transformers/models/grounding_dino/processing_grounding_dino.py +38 -10
  606. transformers/models/groupvit/configuration_groupvit.py +2 -4
  607. transformers/models/groupvit/modeling_groupvit.py +93 -107
  608. transformers/models/helium/configuration_helium.py +29 -25
  609. transformers/models/helium/modeling_helium.py +40 -38
  610. transformers/models/helium/modular_helium.py +7 -3
  611. transformers/models/herbert/tokenization_herbert.py +28 -10
  612. transformers/models/hgnet_v2/configuration_hgnet_v2.py +1 -0
  613. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -24
  614. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -24
  615. transformers/models/hiera/configuration_hiera.py +1 -0
  616. transformers/models/hiera/modeling_hiera.py +66 -72
  617. transformers/models/hubert/configuration_hubert.py +2 -4
  618. transformers/models/hubert/modeling_hubert.py +37 -42
  619. transformers/models/hubert/modular_hubert.py +11 -13
  620. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +31 -26
  621. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +38 -35
  622. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +6 -4
  623. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  624. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +36 -31
  625. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +42 -47
  626. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +9 -9
  627. transformers/models/ibert/configuration_ibert.py +2 -4
  628. transformers/models/ibert/modeling_ibert.py +62 -82
  629. transformers/models/ibert/quant_modules.py +1 -0
  630. transformers/models/idefics/configuration_idefics.py +8 -5
  631. transformers/models/idefics/image_processing_idefics.py +15 -13
  632. transformers/models/idefics/modeling_idefics.py +82 -75
  633. transformers/models/idefics/perceiver.py +3 -1
  634. transformers/models/idefics/processing_idefics.py +48 -32
  635. transformers/models/idefics/vision.py +25 -24
  636. transformers/models/idefics2/configuration_idefics2.py +3 -1
  637. transformers/models/idefics2/image_processing_idefics2.py +32 -31
  638. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  639. transformers/models/idefics2/modeling_idefics2.py +101 -127
  640. transformers/models/idefics2/processing_idefics2.py +68 -10
  641. transformers/models/idefics3/configuration_idefics3.py +4 -1
  642. transformers/models/idefics3/image_processing_idefics3.py +43 -42
  643. transformers/models/idefics3/image_processing_idefics3_fast.py +15 -40
  644. transformers/models/idefics3/modeling_idefics3.py +90 -115
  645. transformers/models/idefics3/processing_idefics3.py +69 -15
  646. transformers/models/ijepa/configuration_ijepa.py +1 -0
  647. transformers/models/ijepa/modeling_ijepa.py +11 -10
  648. transformers/models/ijepa/modular_ijepa.py +7 -5
  649. transformers/models/imagegpt/configuration_imagegpt.py +2 -9
  650. transformers/models/imagegpt/image_processing_imagegpt.py +18 -17
  651. transformers/models/imagegpt/image_processing_imagegpt_fast.py +16 -11
  652. transformers/models/imagegpt/modeling_imagegpt.py +65 -76
  653. transformers/models/informer/configuration_informer.py +9 -6
  654. transformers/models/informer/modeling_informer.py +86 -88
  655. transformers/models/informer/modular_informer.py +16 -14
  656. transformers/models/instructblip/configuration_instructblip.py +2 -2
  657. transformers/models/instructblip/modeling_instructblip.py +63 -103
  658. transformers/models/instructblip/processing_instructblip.py +36 -10
  659. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  660. transformers/models/instructblipvideo/modeling_instructblipvideo.py +139 -157
  661. transformers/models/instructblipvideo/modular_instructblipvideo.py +64 -73
  662. transformers/models/instructblipvideo/processing_instructblipvideo.py +33 -14
  663. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +8 -6
  664. transformers/models/internvl/configuration_internvl.py +1 -0
  665. transformers/models/internvl/modeling_internvl.py +106 -85
  666. transformers/models/internvl/modular_internvl.py +67 -47
  667. transformers/models/internvl/processing_internvl.py +45 -12
  668. transformers/models/internvl/video_processing_internvl.py +12 -10
  669. transformers/models/jamba/configuration_jamba.py +8 -5
  670. transformers/models/jamba/modeling_jamba.py +66 -68
  671. transformers/models/jamba/modular_jamba.py +55 -54
  672. transformers/models/janus/configuration_janus.py +1 -0
  673. transformers/models/janus/image_processing_janus.py +37 -35
  674. transformers/models/janus/image_processing_janus_fast.py +20 -18
  675. transformers/models/janus/modeling_janus.py +191 -115
  676. transformers/models/janus/modular_janus.py +84 -133
  677. transformers/models/janus/processing_janus.py +43 -17
  678. transformers/models/jetmoe/configuration_jetmoe.py +26 -24
  679. transformers/models/jetmoe/modeling_jetmoe.py +46 -43
  680. transformers/models/jetmoe/modular_jetmoe.py +33 -31
  681. transformers/models/kosmos2/configuration_kosmos2.py +9 -10
  682. transformers/models/kosmos2/modeling_kosmos2.py +173 -208
  683. transformers/models/kosmos2/processing_kosmos2.py +55 -40
  684. transformers/models/kosmos2_5/__init__.py +1 -0
  685. transformers/models/kosmos2_5/configuration_kosmos2_5.py +9 -8
  686. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +12 -10
  687. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +13 -4
  688. transformers/models/kosmos2_5/modeling_kosmos2_5.py +118 -132
  689. transformers/models/kosmos2_5/processing_kosmos2_5.py +29 -8
  690. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +28 -31
  691. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +14 -12
  692. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +100 -110
  693. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +22 -28
  694. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +8 -2
  695. transformers/models/layoutlm/configuration_layoutlm.py +2 -14
  696. transformers/models/layoutlm/modeling_layoutlm.py +72 -77
  697. transformers/models/layoutlmv2/configuration_layoutlmv2.py +17 -14
  698. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +21 -18
  699. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +9 -7
  700. transformers/models/layoutlmv2/modeling_layoutlmv2.py +50 -64
  701. transformers/models/layoutlmv2/processing_layoutlmv2.py +44 -14
  702. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +126 -73
  703. transformers/models/layoutlmv3/configuration_layoutlmv3.py +19 -16
  704. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +26 -24
  705. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +11 -9
  706. transformers/models/layoutlmv3/modeling_layoutlmv3.py +56 -82
  707. transformers/models/layoutlmv3/processing_layoutlmv3.py +46 -14
  708. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +134 -74
  709. transformers/models/layoutxlm/configuration_layoutxlm.py +17 -14
  710. transformers/models/layoutxlm/modular_layoutxlm.py +1 -0
  711. transformers/models/layoutxlm/processing_layoutxlm.py +44 -14
  712. transformers/models/layoutxlm/tokenization_layoutxlm.py +113 -77
  713. transformers/models/led/configuration_led.py +12 -8
  714. transformers/models/led/modeling_led.py +266 -124
  715. transformers/models/levit/configuration_levit.py +1 -0
  716. transformers/models/levit/image_processing_levit.py +21 -19
  717. transformers/models/levit/image_processing_levit_fast.py +5 -4
  718. transformers/models/levit/modeling_levit.py +19 -38
  719. transformers/models/lfm2/configuration_lfm2.py +30 -27
  720. transformers/models/lfm2/modeling_lfm2.py +50 -47
  721. transformers/models/lfm2/modular_lfm2.py +30 -29
  722. transformers/models/lfm2_moe/__init__.py +1 -0
  723. transformers/models/lfm2_moe/configuration_lfm2_moe.py +9 -6
  724. transformers/models/lfm2_moe/modeling_lfm2_moe.py +53 -61
  725. transformers/models/lfm2_moe/modular_lfm2_moe.py +37 -13
  726. transformers/models/lfm2_vl/configuration_lfm2_vl.py +1 -4
  727. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +12 -41
  728. transformers/models/lfm2_vl/modeling_lfm2_vl.py +66 -84
  729. transformers/models/lfm2_vl/modular_lfm2_vl.py +56 -70
  730. transformers/models/lfm2_vl/processing_lfm2_vl.py +76 -96
  731. transformers/models/lightglue/image_processing_lightglue.py +15 -16
  732. transformers/models/lightglue/image_processing_lightglue_fast.py +9 -9
  733. transformers/models/lightglue/modeling_lightglue.py +31 -31
  734. transformers/models/lightglue/modular_lightglue.py +28 -29
  735. transformers/models/lilt/configuration_lilt.py +2 -6
  736. transformers/models/lilt/modeling_lilt.py +70 -76
  737. transformers/models/llama/configuration_llama.py +31 -26
  738. transformers/models/llama/modeling_llama.py +39 -36
  739. transformers/models/llama/tokenization_llama.py +44 -14
  740. transformers/models/llama4/configuration_llama4.py +30 -27
  741. transformers/models/llama4/image_processing_llama4_fast.py +14 -12
  742. transformers/models/llama4/modeling_llama4.py +113 -120
  743. transformers/models/llama4/processing_llama4.py +57 -33
  744. transformers/models/llava/configuration_llava.py +1 -10
  745. transformers/models/llava/image_processing_llava.py +28 -25
  746. transformers/models/llava/image_processing_llava_fast.py +11 -9
  747. transformers/models/llava/modeling_llava.py +109 -85
  748. transformers/models/llava/processing_llava.py +51 -18
  749. transformers/models/llava_next/configuration_llava_next.py +2 -2
  750. transformers/models/llava_next/image_processing_llava_next.py +45 -43
  751. transformers/models/llava_next/image_processing_llava_next_fast.py +13 -11
  752. transformers/models/llava_next/modeling_llava_next.py +107 -110
  753. transformers/models/llava_next/processing_llava_next.py +47 -18
  754. transformers/models/llava_next_video/configuration_llava_next_video.py +7 -4
  755. transformers/models/llava_next_video/modeling_llava_next_video.py +158 -175
  756. transformers/models/llava_next_video/modular_llava_next_video.py +150 -155
  757. transformers/models/llava_next_video/processing_llava_next_video.py +63 -21
  758. transformers/models/llava_next_video/video_processing_llava_next_video.py +1 -0
  759. transformers/models/llava_onevision/configuration_llava_onevision.py +7 -4
  760. transformers/models/llava_onevision/image_processing_llava_onevision.py +42 -40
  761. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +15 -14
  762. transformers/models/llava_onevision/modeling_llava_onevision.py +169 -177
  763. transformers/models/llava_onevision/modular_llava_onevision.py +156 -163
  764. transformers/models/llava_onevision/processing_llava_onevision.py +53 -21
  765. transformers/models/llava_onevision/video_processing_llava_onevision.py +1 -0
  766. transformers/models/longcat_flash/__init__.py +1 -0
  767. transformers/models/longcat_flash/configuration_longcat_flash.py +42 -37
  768. transformers/models/longcat_flash/modeling_longcat_flash.py +36 -36
  769. transformers/models/longcat_flash/modular_longcat_flash.py +21 -21
  770. transformers/models/longformer/configuration_longformer.py +5 -5
  771. transformers/models/longformer/modeling_longformer.py +101 -105
  772. transformers/models/longt5/configuration_longt5.py +7 -9
  773. transformers/models/longt5/modeling_longt5.py +49 -49
  774. transformers/models/luke/configuration_luke.py +2 -8
  775. transformers/models/luke/modeling_luke.py +181 -188
  776. transformers/models/luke/tokenization_luke.py +140 -107
  777. transformers/models/lxmert/configuration_lxmert.py +1 -16
  778. transformers/models/lxmert/modeling_lxmert.py +74 -65
  779. transformers/models/m2m_100/configuration_m2m_100.py +9 -7
  780. transformers/models/m2m_100/modeling_m2m_100.py +71 -83
  781. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  782. transformers/models/mamba/configuration_mamba.py +2 -1
  783. transformers/models/mamba/modeling_mamba.py +66 -58
  784. transformers/models/mamba2/configuration_mamba2.py +8 -5
  785. transformers/models/mamba2/modeling_mamba2.py +69 -68
  786. transformers/models/marian/configuration_marian.py +5 -10
  787. transformers/models/marian/modeling_marian.py +87 -93
  788. transformers/models/marian/tokenization_marian.py +6 -6
  789. transformers/models/markuplm/configuration_markuplm.py +7 -4
  790. transformers/models/markuplm/feature_extraction_markuplm.py +2 -1
  791. transformers/models/markuplm/modeling_markuplm.py +70 -69
  792. transformers/models/markuplm/processing_markuplm.py +38 -31
  793. transformers/models/markuplm/tokenization_markuplm.py +136 -93
  794. transformers/models/mask2former/configuration_mask2former.py +8 -5
  795. transformers/models/mask2former/image_processing_mask2former.py +85 -84
  796. transformers/models/mask2former/image_processing_mask2former_fast.py +40 -37
  797. transformers/models/mask2former/modeling_mask2former.py +103 -118
  798. transformers/models/mask2former/modular_mask2former.py +8 -6
  799. transformers/models/maskformer/configuration_maskformer.py +9 -6
  800. transformers/models/maskformer/configuration_maskformer_swin.py +1 -0
  801. transformers/models/maskformer/image_processing_maskformer.py +85 -84
  802. transformers/models/maskformer/image_processing_maskformer_fast.py +40 -36
  803. transformers/models/maskformer/modeling_maskformer.py +65 -79
  804. transformers/models/maskformer/modeling_maskformer_swin.py +32 -36
  805. transformers/models/mbart/configuration_mbart.py +4 -9
  806. transformers/models/mbart/modeling_mbart.py +116 -131
  807. transformers/models/mbart/tokenization_mbart.py +54 -11
  808. transformers/models/mbart50/tokenization_mbart50.py +13 -8
  809. transformers/models/megatron_bert/configuration_megatron_bert.py +3 -13
  810. transformers/models/megatron_bert/modeling_megatron_bert.py +150 -148
  811. transformers/models/metaclip_2/configuration_metaclip_2.py +1 -4
  812. transformers/models/metaclip_2/modeling_metaclip_2.py +84 -91
  813. transformers/models/metaclip_2/modular_metaclip_2.py +45 -61
  814. transformers/models/mgp_str/configuration_mgp_str.py +1 -0
  815. transformers/models/mgp_str/modeling_mgp_str.py +18 -20
  816. transformers/models/mgp_str/processing_mgp_str.py +20 -3
  817. transformers/models/mgp_str/tokenization_mgp_str.py +3 -1
  818. transformers/models/mimi/configuration_mimi.py +40 -42
  819. transformers/models/mimi/modeling_mimi.py +113 -142
  820. transformers/models/minimax/__init__.py +1 -0
  821. transformers/models/minimax/configuration_minimax.py +43 -37
  822. transformers/models/minimax/modeling_minimax.py +51 -61
  823. transformers/models/minimax/modular_minimax.py +62 -68
  824. transformers/models/ministral/configuration_ministral.py +29 -25
  825. transformers/models/ministral/modeling_ministral.py +38 -36
  826. transformers/models/ministral/modular_ministral.py +37 -32
  827. transformers/models/ministral3/configuration_ministral3.py +27 -24
  828. transformers/models/ministral3/modeling_ministral3.py +37 -36
  829. transformers/models/ministral3/modular_ministral3.py +5 -4
  830. transformers/models/mistral/configuration_mistral.py +29 -24
  831. transformers/models/mistral/modeling_mistral.py +37 -36
  832. transformers/models/mistral/modular_mistral.py +12 -11
  833. transformers/models/mistral3/configuration_mistral3.py +1 -4
  834. transformers/models/mistral3/modeling_mistral3.py +86 -89
  835. transformers/models/mistral3/modular_mistral3.py +68 -69
  836. transformers/models/mixtral/configuration_mixtral.py +34 -29
  837. transformers/models/mixtral/modeling_mixtral.py +45 -50
  838. transformers/models/mixtral/modular_mixtral.py +31 -32
  839. transformers/models/mlcd/configuration_mlcd.py +1 -0
  840. transformers/models/mlcd/modeling_mlcd.py +14 -20
  841. transformers/models/mlcd/modular_mlcd.py +13 -17
  842. transformers/models/mllama/configuration_mllama.py +15 -10
  843. transformers/models/mllama/image_processing_mllama.py +25 -23
  844. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  845. transformers/models/mllama/modeling_mllama.py +94 -105
  846. transformers/models/mllama/processing_mllama.py +55 -6
  847. transformers/models/mluke/tokenization_mluke.py +107 -101
  848. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +3 -5
  849. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +140 -155
  850. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +3 -5
  851. transformers/models/mobilebert/configuration_mobilebert.py +2 -4
  852. transformers/models/mobilebert/modeling_mobilebert.py +85 -77
  853. transformers/models/mobilebert/tokenization_mobilebert.py +1 -0
  854. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +1 -0
  855. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +23 -20
  856. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +1 -0
  857. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +16 -15
  858. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +1 -0
  859. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +51 -48
  860. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +15 -13
  861. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +22 -24
  862. transformers/models/mobilevit/configuration_mobilevit.py +1 -0
  863. transformers/models/mobilevit/image_processing_mobilevit.py +49 -46
  864. transformers/models/mobilevit/image_processing_mobilevit_fast.py +14 -12
  865. transformers/models/mobilevit/modeling_mobilevit.py +21 -28
  866. transformers/models/mobilevitv2/configuration_mobilevitv2.py +1 -0
  867. transformers/models/mobilevitv2/modeling_mobilevitv2.py +22 -28
  868. transformers/models/modernbert/configuration_modernbert.py +42 -44
  869. transformers/models/modernbert/modeling_modernbert.py +133 -145
  870. transformers/models/modernbert/modular_modernbert.py +170 -186
  871. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +40 -40
  872. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +57 -62
  873. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +86 -94
  874. transformers/models/moonshine/configuration_moonshine.py +31 -34
  875. transformers/models/moonshine/modeling_moonshine.py +71 -71
  876. transformers/models/moonshine/modular_moonshine.py +83 -88
  877. transformers/models/moshi/configuration_moshi.py +23 -46
  878. transformers/models/moshi/modeling_moshi.py +187 -157
  879. transformers/models/mpnet/configuration_mpnet.py +2 -6
  880. transformers/models/mpnet/modeling_mpnet.py +57 -62
  881. transformers/models/mpnet/tokenization_mpnet.py +15 -4
  882. transformers/models/mpt/configuration_mpt.py +9 -5
  883. transformers/models/mpt/modeling_mpt.py +60 -60
  884. transformers/models/mra/configuration_mra.py +2 -8
  885. transformers/models/mra/modeling_mra.py +57 -64
  886. transformers/models/mt5/configuration_mt5.py +8 -10
  887. transformers/models/mt5/modeling_mt5.py +95 -87
  888. transformers/models/musicgen/configuration_musicgen.py +8 -12
  889. transformers/models/musicgen/modeling_musicgen.py +122 -118
  890. transformers/models/musicgen/processing_musicgen.py +21 -3
  891. transformers/models/musicgen_melody/configuration_musicgen_melody.py +8 -15
  892. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +9 -8
  893. transformers/models/musicgen_melody/modeling_musicgen_melody.py +123 -117
  894. transformers/models/musicgen_melody/processing_musicgen_melody.py +22 -3
  895. transformers/models/mvp/configuration_mvp.py +5 -8
  896. transformers/models/mvp/modeling_mvp.py +123 -135
  897. transformers/models/myt5/tokenization_myt5.py +10 -8
  898. transformers/models/nanochat/configuration_nanochat.py +8 -5
  899. transformers/models/nanochat/modeling_nanochat.py +40 -37
  900. transformers/models/nanochat/modular_nanochat.py +14 -12
  901. transformers/models/nemotron/configuration_nemotron.py +30 -25
  902. transformers/models/nemotron/modeling_nemotron.py +57 -56
  903. transformers/models/nllb/tokenization_nllb.py +28 -12
  904. transformers/models/nllb_moe/configuration_nllb_moe.py +9 -7
  905. transformers/models/nllb_moe/modeling_nllb_moe.py +69 -77
  906. transformers/models/nougat/image_processing_nougat.py +32 -29
  907. transformers/models/nougat/image_processing_nougat_fast.py +14 -12
  908. transformers/models/nougat/processing_nougat.py +39 -37
  909. transformers/models/nougat/tokenization_nougat.py +73 -18
  910. transformers/models/nystromformer/configuration_nystromformer.py +2 -8
  911. transformers/models/nystromformer/modeling_nystromformer.py +63 -74
  912. transformers/models/olmo/configuration_olmo.py +28 -23
  913. transformers/models/olmo/modeling_olmo.py +39 -36
  914. transformers/models/olmo/modular_olmo.py +11 -7
  915. transformers/models/olmo2/configuration_olmo2.py +28 -23
  916. transformers/models/olmo2/modeling_olmo2.py +41 -37
  917. transformers/models/olmo2/modular_olmo2.py +32 -29
  918. transformers/models/olmo3/__init__.py +1 -0
  919. transformers/models/olmo3/configuration_olmo3.py +30 -26
  920. transformers/models/olmo3/modeling_olmo3.py +39 -36
  921. transformers/models/olmo3/modular_olmo3.py +40 -37
  922. transformers/models/olmoe/configuration_olmoe.py +33 -29
  923. transformers/models/olmoe/modeling_olmoe.py +46 -52
  924. transformers/models/olmoe/modular_olmoe.py +15 -16
  925. transformers/models/omdet_turbo/configuration_omdet_turbo.py +4 -2
  926. transformers/models/omdet_turbo/modeling_omdet_turbo.py +47 -53
  927. transformers/models/omdet_turbo/processing_omdet_turbo.py +67 -19
  928. transformers/models/oneformer/configuration_oneformer.py +8 -5
  929. transformers/models/oneformer/image_processing_oneformer.py +84 -83
  930. transformers/models/oneformer/image_processing_oneformer_fast.py +42 -41
  931. transformers/models/oneformer/modeling_oneformer.py +171 -147
  932. transformers/models/oneformer/processing_oneformer.py +43 -28
  933. transformers/models/openai/configuration_openai.py +1 -16
  934. transformers/models/openai/modeling_openai.py +51 -65
  935. transformers/models/openai/tokenization_openai.py +47 -8
  936. transformers/models/opt/configuration_opt.py +7 -6
  937. transformers/models/opt/modeling_opt.py +76 -78
  938. transformers/models/ovis2/__init__.py +1 -0
  939. transformers/models/ovis2/configuration_ovis2.py +1 -0
  940. transformers/models/ovis2/image_processing_ovis2.py +24 -22
  941. transformers/models/ovis2/image_processing_ovis2_fast.py +11 -9
  942. transformers/models/ovis2/modeling_ovis2.py +142 -111
  943. transformers/models/ovis2/modular_ovis2.py +45 -90
  944. transformers/models/ovis2/processing_ovis2.py +40 -12
  945. transformers/models/owlv2/configuration_owlv2.py +2 -4
  946. transformers/models/owlv2/image_processing_owlv2.py +21 -20
  947. transformers/models/owlv2/image_processing_owlv2_fast.py +15 -12
  948. transformers/models/owlv2/modeling_owlv2.py +117 -133
  949. transformers/models/owlv2/modular_owlv2.py +14 -11
  950. transformers/models/owlv2/processing_owlv2.py +49 -20
  951. transformers/models/owlvit/configuration_owlvit.py +2 -4
  952. transformers/models/owlvit/image_processing_owlvit.py +22 -21
  953. transformers/models/owlvit/image_processing_owlvit_fast.py +3 -2
  954. transformers/models/owlvit/modeling_owlvit.py +116 -132
  955. transformers/models/owlvit/processing_owlvit.py +48 -20
  956. transformers/models/paligemma/configuration_paligemma.py +1 -4
  957. transformers/models/paligemma/modeling_paligemma.py +93 -103
  958. transformers/models/paligemma/processing_paligemma.py +66 -13
  959. transformers/models/parakeet/configuration_parakeet.py +14 -7
  960. transformers/models/parakeet/feature_extraction_parakeet.py +12 -10
  961. transformers/models/parakeet/modeling_parakeet.py +28 -32
  962. transformers/models/parakeet/modular_parakeet.py +20 -23
  963. transformers/models/parakeet/processing_parakeet.py +5 -13
  964. transformers/models/parakeet/{tokenization_parakeet.py → tokenization_parakeet_fast.py} +7 -5
  965. transformers/models/patchtsmixer/configuration_patchtsmixer.py +8 -5
  966. transformers/models/patchtsmixer/modeling_patchtsmixer.py +62 -70
  967. transformers/models/patchtst/configuration_patchtst.py +9 -6
  968. transformers/models/patchtst/modeling_patchtst.py +80 -97
  969. transformers/models/pegasus/configuration_pegasus.py +5 -8
  970. transformers/models/pegasus/modeling_pegasus.py +66 -72
  971. transformers/models/pegasus/tokenization_pegasus.py +45 -15
  972. transformers/models/pegasus_x/configuration_pegasus_x.py +4 -5
  973. transformers/models/pegasus_x/modeling_pegasus_x.py +52 -55
  974. transformers/models/perceiver/configuration_perceiver.py +1 -0
  975. transformers/models/perceiver/image_processing_perceiver.py +25 -22
  976. transformers/models/perceiver/image_processing_perceiver_fast.py +9 -7
  977. transformers/models/perceiver/modeling_perceiver.py +146 -165
  978. transformers/models/perceiver/tokenization_perceiver.py +6 -3
  979. transformers/models/perception_lm/configuration_perception_lm.py +1 -0
  980. transformers/models/perception_lm/image_processing_perception_lm_fast.py +10 -8
  981. transformers/models/perception_lm/modeling_perception_lm.py +70 -71
  982. transformers/models/perception_lm/modular_perception_lm.py +61 -65
  983. transformers/models/perception_lm/processing_perception_lm.py +47 -13
  984. transformers/models/perception_lm/video_processing_perception_lm.py +1 -0
  985. transformers/models/persimmon/configuration_persimmon.py +28 -23
  986. transformers/models/persimmon/modeling_persimmon.py +45 -43
  987. transformers/models/phi/configuration_phi.py +28 -23
  988. transformers/models/phi/modeling_phi.py +43 -40
  989. transformers/models/phi/modular_phi.py +24 -23
  990. transformers/models/phi3/configuration_phi3.py +33 -28
  991. transformers/models/phi3/modeling_phi3.py +38 -36
  992. transformers/models/phi3/modular_phi3.py +17 -13
  993. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +33 -30
  994. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +9 -7
  995. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  996. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +78 -95
  997. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +80 -98
  998. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +44 -7
  999. transformers/models/phimoe/configuration_phimoe.py +36 -31
  1000. transformers/models/phimoe/modeling_phimoe.py +45 -50
  1001. transformers/models/phimoe/modular_phimoe.py +4 -3
  1002. transformers/models/phobert/tokenization_phobert.py +6 -4
  1003. transformers/models/pix2struct/configuration_pix2struct.py +10 -12
  1004. transformers/models/pix2struct/image_processing_pix2struct.py +19 -15
  1005. transformers/models/pix2struct/image_processing_pix2struct_fast.py +15 -12
  1006. transformers/models/pix2struct/modeling_pix2struct.py +52 -58
  1007. transformers/models/pix2struct/processing_pix2struct.py +30 -5
  1008. transformers/models/pixtral/configuration_pixtral.py +14 -11
  1009. transformers/models/pixtral/image_processing_pixtral.py +28 -26
  1010. transformers/models/pixtral/image_processing_pixtral_fast.py +11 -10
  1011. transformers/models/pixtral/modeling_pixtral.py +34 -28
  1012. transformers/models/pixtral/processing_pixtral.py +53 -21
  1013. transformers/models/plbart/configuration_plbart.py +5 -8
  1014. transformers/models/plbart/modeling_plbart.py +106 -119
  1015. transformers/models/plbart/modular_plbart.py +33 -39
  1016. transformers/models/plbart/tokenization_plbart.py +7 -4
  1017. transformers/models/poolformer/configuration_poolformer.py +1 -0
  1018. transformers/models/poolformer/image_processing_poolformer.py +24 -21
  1019. transformers/models/poolformer/image_processing_poolformer_fast.py +15 -13
  1020. transformers/models/poolformer/modeling_poolformer.py +13 -23
  1021. transformers/models/pop2piano/configuration_pop2piano.py +8 -7
  1022. transformers/models/pop2piano/feature_extraction_pop2piano.py +9 -6
  1023. transformers/models/pop2piano/modeling_pop2piano.py +24 -26
  1024. transformers/models/pop2piano/processing_pop2piano.py +33 -25
  1025. transformers/models/pop2piano/tokenization_pop2piano.py +23 -15
  1026. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +3 -3
  1027. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1028. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +21 -20
  1029. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +13 -16
  1030. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +13 -16
  1031. transformers/models/prophetnet/configuration_prophetnet.py +38 -37
  1032. transformers/models/prophetnet/modeling_prophetnet.py +131 -114
  1033. transformers/models/prophetnet/tokenization_prophetnet.py +16 -14
  1034. transformers/models/pvt/configuration_pvt.py +1 -0
  1035. transformers/models/pvt/image_processing_pvt.py +27 -24
  1036. transformers/models/pvt/image_processing_pvt_fast.py +2 -1
  1037. transformers/models/pvt/modeling_pvt.py +21 -21
  1038. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -2
  1039. transformers/models/pvt_v2/modeling_pvt_v2.py +25 -28
  1040. transformers/models/qwen2/configuration_qwen2.py +25 -32
  1041. transformers/models/qwen2/modeling_qwen2.py +38 -36
  1042. transformers/models/qwen2/modular_qwen2.py +12 -11
  1043. transformers/models/qwen2/tokenization_qwen2.py +23 -12
  1044. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +26 -32
  1045. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +277 -340
  1046. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +211 -278
  1047. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +49 -41
  1048. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +35 -29
  1049. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +148 -203
  1050. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +118 -93
  1051. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +43 -7
  1052. transformers/models/qwen2_audio/configuration_qwen2_audio.py +1 -0
  1053. transformers/models/qwen2_audio/modeling_qwen2_audio.py +40 -40
  1054. transformers/models/qwen2_audio/processing_qwen2_audio.py +42 -13
  1055. transformers/models/qwen2_moe/configuration_qwen2_moe.py +35 -42
  1056. transformers/models/qwen2_moe/modeling_qwen2_moe.py +46 -51
  1057. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -7
  1058. transformers/models/qwen2_vl/configuration_qwen2_vl.py +34 -29
  1059. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +42 -41
  1060. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +15 -12
  1061. transformers/models/qwen2_vl/modeling_qwen2_vl.py +153 -199
  1062. transformers/models/qwen2_vl/processing_qwen2_vl.py +44 -7
  1063. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +18 -38
  1064. transformers/models/qwen3/configuration_qwen3.py +27 -34
  1065. transformers/models/qwen3/modeling_qwen3.py +39 -36
  1066. transformers/models/qwen3/modular_qwen3.py +6 -4
  1067. transformers/models/qwen3_moe/configuration_qwen3_moe.py +32 -39
  1068. transformers/models/qwen3_moe/modeling_qwen3_moe.py +46 -51
  1069. transformers/models/qwen3_moe/modular_qwen3_moe.py +13 -10
  1070. transformers/models/qwen3_next/configuration_qwen3_next.py +35 -45
  1071. transformers/models/qwen3_next/modeling_qwen3_next.py +51 -47
  1072. transformers/models/qwen3_next/modular_qwen3_next.py +35 -34
  1073. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +101 -135
  1074. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +252 -355
  1075. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +196 -250
  1076. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +48 -40
  1077. transformers/models/qwen3_vl/configuration_qwen3_vl.py +29 -27
  1078. transformers/models/qwen3_vl/modeling_qwen3_vl.py +155 -233
  1079. transformers/models/qwen3_vl/modular_qwen3_vl.py +179 -206
  1080. transformers/models/qwen3_vl/processing_qwen3_vl.py +42 -6
  1081. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +12 -10
  1082. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +30 -23
  1083. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +303 -358
  1084. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +124 -87
  1085. transformers/models/rag/configuration_rag.py +15 -6
  1086. transformers/models/rag/modeling_rag.py +130 -127
  1087. transformers/models/rag/retrieval_rag.py +5 -3
  1088. transformers/models/rag/tokenization_rag.py +50 -0
  1089. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +30 -29
  1090. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +42 -53
  1091. transformers/models/reformer/configuration_reformer.py +8 -7
  1092. transformers/models/reformer/modeling_reformer.py +69 -80
  1093. transformers/models/reformer/tokenization_reformer.py +31 -11
  1094. transformers/models/regnet/configuration_regnet.py +1 -0
  1095. transformers/models/regnet/modeling_regnet.py +8 -15
  1096. transformers/models/rembert/configuration_rembert.py +2 -8
  1097. transformers/models/rembert/modeling_rembert.py +111 -121
  1098. transformers/models/rembert/tokenization_rembert.py +12 -2
  1099. transformers/models/resnet/configuration_resnet.py +1 -0
  1100. transformers/models/resnet/modeling_resnet.py +13 -27
  1101. transformers/models/roberta/configuration_roberta.py +3 -11
  1102. transformers/models/roberta/modeling_roberta.py +93 -94
  1103. transformers/models/roberta/modular_roberta.py +58 -58
  1104. transformers/models/roberta/tokenization_roberta.py +29 -17
  1105. transformers/models/roberta/tokenization_roberta_old.py +4 -2
  1106. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +3 -11
  1107. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +93 -94
  1108. transformers/models/roc_bert/configuration_roc_bert.py +2 -8
  1109. transformers/models/roc_bert/modeling_roc_bert.py +121 -122
  1110. transformers/models/roc_bert/tokenization_roc_bert.py +94 -88
  1111. transformers/models/roformer/configuration_roformer.py +3 -13
  1112. transformers/models/roformer/modeling_roformer.py +81 -85
  1113. transformers/models/roformer/tokenization_roformer.py +412 -74
  1114. transformers/models/roformer/tokenization_roformer_fast.py +160 -0
  1115. transformers/models/roformer/tokenization_utils.py +1 -0
  1116. transformers/models/rt_detr/configuration_rt_detr.py +2 -1
  1117. transformers/models/rt_detr/configuration_rt_detr_resnet.py +1 -0
  1118. transformers/models/rt_detr/image_processing_rt_detr.py +55 -54
  1119. transformers/models/rt_detr/image_processing_rt_detr_fast.py +26 -26
  1120. transformers/models/rt_detr/modeling_rt_detr.py +90 -99
  1121. transformers/models/rt_detr/modeling_rt_detr_resnet.py +6 -13
  1122. transformers/models/rt_detr/modular_rt_detr.py +16 -16
  1123. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +4 -6
  1124. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +90 -101
  1125. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +12 -19
  1126. transformers/models/rwkv/configuration_rwkv.py +4 -2
  1127. transformers/models/rwkv/modeling_rwkv.py +32 -31
  1128. transformers/models/sam/configuration_sam.py +1 -3
  1129. transformers/models/sam/image_processing_sam.py +60 -59
  1130. transformers/models/sam/image_processing_sam_fast.py +27 -25
  1131. transformers/models/sam/modeling_sam.py +41 -47
  1132. transformers/models/sam/processing_sam.py +27 -39
  1133. transformers/models/sam2/configuration_sam2.py +3 -2
  1134. transformers/models/sam2/image_processing_sam2_fast.py +15 -14
  1135. transformers/models/sam2/modeling_sam2.py +90 -96
  1136. transformers/models/sam2/modular_sam2.py +91 -86
  1137. transformers/models/sam2/processing_sam2.py +47 -31
  1138. transformers/models/sam2_video/configuration_sam2_video.py +1 -0
  1139. transformers/models/sam2_video/modeling_sam2_video.py +144 -151
  1140. transformers/models/sam2_video/modular_sam2_video.py +104 -101
  1141. transformers/models/sam2_video/processing_sam2_video.py +66 -49
  1142. transformers/models/sam2_video/video_processing_sam2_video.py +4 -1
  1143. transformers/models/sam3/configuration_sam3.py +2 -21
  1144. transformers/models/sam3/image_processing_sam3_fast.py +20 -17
  1145. transformers/models/sam3/modeling_sam3.py +170 -184
  1146. transformers/models/sam3/modular_sam3.py +8 -3
  1147. transformers/models/sam3/processing_sam3.py +52 -37
  1148. transformers/models/sam3_tracker/__init__.py +1 -0
  1149. transformers/models/sam3_tracker/configuration_sam3_tracker.py +3 -1
  1150. transformers/models/sam3_tracker/modeling_sam3_tracker.py +77 -82
  1151. transformers/models/sam3_tracker/modular_sam3_tracker.py +3 -8
  1152. transformers/models/sam3_tracker/processing_sam3_tracker.py +48 -31
  1153. transformers/models/sam3_tracker_video/__init__.py +1 -0
  1154. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +1 -25
  1155. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +122 -135
  1156. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +26 -35
  1157. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +66 -50
  1158. transformers/models/sam3_video/configuration_sam3_video.py +1 -14
  1159. transformers/models/sam3_video/modeling_sam3_video.py +34 -33
  1160. transformers/models/sam3_video/processing_sam3_video.py +46 -26
  1161. transformers/models/sam_hq/__init__.py +1 -1
  1162. transformers/models/sam_hq/configuration_sam_hq.py +1 -3
  1163. transformers/models/sam_hq/modeling_sam_hq.py +69 -74
  1164. transformers/models/sam_hq/modular_sam_hq.py +25 -23
  1165. transformers/models/sam_hq/{processing_sam_hq.py → processing_samhq.py} +29 -41
  1166. transformers/models/seamless_m4t/configuration_seamless_m4t.py +10 -8
  1167. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +11 -8
  1168. transformers/models/seamless_m4t/modeling_seamless_m4t.py +194 -212
  1169. transformers/models/seamless_m4t/processing_seamless_m4t.py +39 -18
  1170. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +77 -40
  1171. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +10 -8
  1172. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +196 -204
  1173. transformers/models/seed_oss/configuration_seed_oss.py +32 -28
  1174. transformers/models/seed_oss/modeling_seed_oss.py +35 -33
  1175. transformers/models/seed_oss/modular_seed_oss.py +4 -3
  1176. transformers/models/segformer/configuration_segformer.py +10 -0
  1177. transformers/models/segformer/image_processing_segformer.py +42 -39
  1178. transformers/models/segformer/image_processing_segformer_fast.py +12 -10
  1179. transformers/models/segformer/modeling_segformer.py +31 -34
  1180. transformers/models/segformer/modular_segformer.py +10 -8
  1181. transformers/models/seggpt/configuration_seggpt.py +1 -0
  1182. transformers/models/seggpt/image_processing_seggpt.py +41 -38
  1183. transformers/models/seggpt/modeling_seggpt.py +38 -50
  1184. transformers/models/sew/configuration_sew.py +2 -4
  1185. transformers/models/sew/modeling_sew.py +36 -38
  1186. transformers/models/sew/modular_sew.py +13 -13
  1187. transformers/models/sew_d/configuration_sew_d.py +2 -4
  1188. transformers/models/sew_d/modeling_sew_d.py +30 -31
  1189. transformers/models/shieldgemma2/configuration_shieldgemma2.py +1 -0
  1190. transformers/models/shieldgemma2/modeling_shieldgemma2.py +17 -16
  1191. transformers/models/shieldgemma2/processing_shieldgemma2.py +5 -3
  1192. transformers/models/siglip/configuration_siglip.py +2 -4
  1193. transformers/models/siglip/image_processing_siglip.py +20 -17
  1194. transformers/models/siglip/image_processing_siglip_fast.py +1 -0
  1195. transformers/models/siglip/modeling_siglip.py +75 -84
  1196. transformers/models/siglip/processing_siglip.py +14 -2
  1197. transformers/models/siglip/tokenization_siglip.py +7 -6
  1198. transformers/models/siglip2/configuration_siglip2.py +2 -5
  1199. transformers/models/siglip2/image_processing_siglip2.py +16 -15
  1200. transformers/models/siglip2/image_processing_siglip2_fast.py +7 -6
  1201. transformers/models/siglip2/modeling_siglip2.py +129 -143
  1202. transformers/models/siglip2/modular_siglip2.py +46 -47
  1203. transformers/models/siglip2/processing_siglip2.py +14 -2
  1204. transformers/models/smollm3/configuration_smollm3.py +32 -29
  1205. transformers/models/smollm3/modeling_smollm3.py +39 -36
  1206. transformers/models/smollm3/modular_smollm3.py +35 -33
  1207. transformers/models/smolvlm/configuration_smolvlm.py +4 -2
  1208. transformers/models/smolvlm/image_processing_smolvlm.py +43 -42
  1209. transformers/models/smolvlm/image_processing_smolvlm_fast.py +15 -41
  1210. transformers/models/smolvlm/modeling_smolvlm.py +94 -126
  1211. transformers/models/smolvlm/modular_smolvlm.py +39 -50
  1212. transformers/models/smolvlm/processing_smolvlm.py +83 -15
  1213. transformers/models/smolvlm/video_processing_smolvlm.py +18 -16
  1214. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +1 -0
  1215. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +27 -26
  1216. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1217. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +13 -10
  1218. transformers/models/speech_to_text/modeling_speech_to_text.py +54 -66
  1219. transformers/models/speech_to_text/processing_speech_to_text.py +30 -4
  1220. transformers/models/speech_to_text/tokenization_speech_to_text.py +6 -5
  1221. transformers/models/speecht5/configuration_speecht5.py +9 -7
  1222. transformers/models/speecht5/feature_extraction_speecht5.py +37 -16
  1223. transformers/models/speecht5/modeling_speecht5.py +175 -213
  1224. transformers/models/speecht5/number_normalizer.py +1 -0
  1225. transformers/models/speecht5/processing_speecht5.py +37 -3
  1226. transformers/models/speecht5/tokenization_speecht5.py +5 -4
  1227. transformers/models/splinter/configuration_splinter.py +7 -6
  1228. transformers/models/splinter/modeling_splinter.py +59 -71
  1229. transformers/models/splinter/tokenization_splinter.py +30 -9
  1230. transformers/models/squeezebert/configuration_squeezebert.py +2 -14
  1231. transformers/models/squeezebert/modeling_squeezebert.py +62 -68
  1232. transformers/models/squeezebert/tokenization_squeezebert.py +1 -0
  1233. transformers/models/stablelm/configuration_stablelm.py +29 -24
  1234. transformers/models/stablelm/modeling_stablelm.py +45 -44
  1235. transformers/models/starcoder2/configuration_starcoder2.py +27 -30
  1236. transformers/models/starcoder2/modeling_starcoder2.py +41 -39
  1237. transformers/models/starcoder2/modular_starcoder2.py +16 -14
  1238. transformers/models/superglue/configuration_superglue.py +3 -7
  1239. transformers/models/superglue/image_processing_superglue.py +15 -15
  1240. transformers/models/superglue/image_processing_superglue_fast.py +10 -9
  1241. transformers/models/superglue/modeling_superglue.py +37 -42
  1242. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1243. transformers/models/superpoint/image_processing_superpoint_fast.py +11 -8
  1244. transformers/models/superpoint/modeling_superpoint.py +16 -18
  1245. transformers/models/swiftformer/configuration_swiftformer.py +1 -0
  1246. transformers/models/swiftformer/modeling_swiftformer.py +14 -18
  1247. transformers/models/swin/configuration_swin.py +1 -0
  1248. transformers/models/swin/modeling_swin.py +86 -86
  1249. transformers/models/swin2sr/configuration_swin2sr.py +1 -0
  1250. transformers/models/swin2sr/image_processing_swin2sr.py +13 -10
  1251. transformers/models/swin2sr/image_processing_swin2sr_fast.py +8 -4
  1252. transformers/models/swin2sr/modeling_swin2sr.py +63 -81
  1253. transformers/models/swinv2/configuration_swinv2.py +1 -0
  1254. transformers/models/swinv2/modeling_swinv2.py +104 -108
  1255. transformers/models/switch_transformers/configuration_switch_transformers.py +7 -11
  1256. transformers/models/switch_transformers/modeling_switch_transformers.py +44 -37
  1257. transformers/models/switch_transformers/modular_switch_transformers.py +41 -34
  1258. transformers/models/t5/configuration_t5.py +8 -14
  1259. transformers/models/t5/modeling_t5.py +92 -88
  1260. transformers/models/t5/tokenization_t5.py +9 -3
  1261. transformers/models/t5gemma/configuration_t5gemma.py +41 -43
  1262. transformers/models/t5gemma/modeling_t5gemma.py +107 -104
  1263. transformers/models/t5gemma/modular_t5gemma.py +120 -124
  1264. transformers/models/t5gemma2/configuration_t5gemma2.py +120 -80
  1265. transformers/models/t5gemma2/modeling_t5gemma2.py +125 -141
  1266. transformers/models/t5gemma2/modular_t5gemma2.py +104 -393
  1267. transformers/models/table_transformer/configuration_table_transformer.py +2 -1
  1268. transformers/models/table_transformer/modeling_table_transformer.py +49 -51
  1269. transformers/models/tapas/configuration_tapas.py +2 -12
  1270. transformers/models/tapas/modeling_tapas.py +67 -68
  1271. transformers/models/tapas/tokenization_tapas.py +153 -115
  1272. transformers/models/textnet/configuration_textnet.py +1 -0
  1273. transformers/models/textnet/image_processing_textnet.py +25 -22
  1274. transformers/models/textnet/image_processing_textnet_fast.py +10 -8
  1275. transformers/models/textnet/modeling_textnet.py +16 -28
  1276. transformers/models/time_series_transformer/configuration_time_series_transformer.py +8 -5
  1277. transformers/models/time_series_transformer/modeling_time_series_transformer.py +81 -83
  1278. transformers/models/timesfm/configuration_timesfm.py +1 -0
  1279. transformers/models/timesfm/modeling_timesfm.py +22 -33
  1280. transformers/models/timesfm/modular_timesfm.py +21 -32
  1281. transformers/models/timesformer/configuration_timesformer.py +1 -0
  1282. transformers/models/timesformer/modeling_timesformer.py +16 -15
  1283. transformers/models/timm_backbone/configuration_timm_backbone.py +1 -0
  1284. transformers/models/timm_backbone/modeling_timm_backbone.py +15 -17
  1285. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -5
  1286. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +5 -4
  1287. transformers/models/timm_wrapper/modeling_timm_wrapper.py +29 -34
  1288. transformers/models/trocr/configuration_trocr.py +8 -11
  1289. transformers/models/trocr/modeling_trocr.py +44 -45
  1290. transformers/models/trocr/processing_trocr.py +25 -5
  1291. transformers/models/tvp/configuration_tvp.py +2 -5
  1292. transformers/models/tvp/image_processing_tvp.py +52 -50
  1293. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1294. transformers/models/tvp/modeling_tvp.py +27 -27
  1295. transformers/models/tvp/processing_tvp.py +14 -2
  1296. transformers/models/udop/configuration_udop.py +7 -16
  1297. transformers/models/udop/modeling_udop.py +73 -71
  1298. transformers/models/udop/processing_udop.py +26 -7
  1299. transformers/models/udop/tokenization_udop.py +105 -84
  1300. transformers/models/umt5/configuration_umt5.py +7 -8
  1301. transformers/models/umt5/modeling_umt5.py +90 -94
  1302. transformers/models/unispeech/configuration_unispeech.py +2 -4
  1303. transformers/models/unispeech/modeling_unispeech.py +49 -51
  1304. transformers/models/unispeech/modular_unispeech.py +22 -22
  1305. transformers/models/unispeech_sat/configuration_unispeech_sat.py +2 -4
  1306. transformers/models/unispeech_sat/modeling_unispeech_sat.py +65 -69
  1307. transformers/models/unispeech_sat/modular_unispeech_sat.py +23 -23
  1308. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1309. transformers/models/univnet/modeling_univnet.py +8 -8
  1310. transformers/models/upernet/configuration_upernet.py +1 -0
  1311. transformers/models/upernet/modeling_upernet.py +13 -11
  1312. transformers/models/vaultgemma/__init__.py +1 -0
  1313. transformers/models/vaultgemma/configuration_vaultgemma.py +33 -29
  1314. transformers/models/vaultgemma/modeling_vaultgemma.py +41 -39
  1315. transformers/models/vaultgemma/modular_vaultgemma.py +31 -29
  1316. transformers/models/video_llama_3/configuration_video_llama_3.py +0 -4
  1317. transformers/models/video_llama_3/image_processing_video_llama_3.py +42 -43
  1318. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +14 -12
  1319. transformers/models/video_llama_3/modeling_video_llama_3.py +109 -157
  1320. transformers/models/video_llama_3/modular_video_llama_3.py +146 -155
  1321. transformers/models/video_llama_3/processing_video_llama_3.py +39 -5
  1322. transformers/models/video_llama_3/video_processing_video_llama_3.py +23 -42
  1323. transformers/models/video_llava/configuration_video_llava.py +1 -4
  1324. transformers/models/video_llava/image_processing_video_llava.py +38 -35
  1325. transformers/models/video_llava/modeling_video_llava.py +146 -146
  1326. transformers/models/video_llava/processing_video_llava.py +78 -38
  1327. transformers/models/video_llava/video_processing_video_llava.py +1 -0
  1328. transformers/models/videomae/configuration_videomae.py +1 -0
  1329. transformers/models/videomae/image_processing_videomae.py +34 -31
  1330. transformers/models/videomae/modeling_videomae.py +17 -14
  1331. transformers/models/videomae/video_processing_videomae.py +1 -0
  1332. transformers/models/vilt/configuration_vilt.py +4 -6
  1333. transformers/models/vilt/image_processing_vilt.py +30 -29
  1334. transformers/models/vilt/image_processing_vilt_fast.py +16 -15
  1335. transformers/models/vilt/modeling_vilt.py +90 -116
  1336. transformers/models/vilt/processing_vilt.py +14 -2
  1337. transformers/models/vipllava/configuration_vipllava.py +1 -4
  1338. transformers/models/vipllava/modeling_vipllava.py +70 -99
  1339. transformers/models/vipllava/modular_vipllava.py +54 -78
  1340. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +1 -0
  1341. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +27 -28
  1342. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +1 -0
  1343. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +41 -46
  1344. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +16 -2
  1345. transformers/models/visual_bert/configuration_visual_bert.py +2 -6
  1346. transformers/models/visual_bert/modeling_visual_bert.py +92 -98
  1347. transformers/models/vit/configuration_vit.py +1 -0
  1348. transformers/models/vit/image_processing_vit.py +22 -19
  1349. transformers/models/vit/image_processing_vit_fast.py +1 -0
  1350. transformers/models/vit/modeling_vit.py +17 -17
  1351. transformers/models/vit_mae/configuration_vit_mae.py +1 -0
  1352. transformers/models/vit_mae/modeling_vit_mae.py +27 -29
  1353. transformers/models/vit_msn/configuration_vit_msn.py +1 -0
  1354. transformers/models/vit_msn/modeling_vit_msn.py +16 -18
  1355. transformers/models/vitdet/configuration_vitdet.py +1 -0
  1356. transformers/models/vitdet/modeling_vitdet.py +14 -14
  1357. transformers/models/vitmatte/configuration_vitmatte.py +5 -2
  1358. transformers/models/vitmatte/image_processing_vitmatte.py +18 -15
  1359. transformers/models/vitmatte/image_processing_vitmatte_fast.py +18 -16
  1360. transformers/models/vitmatte/modeling_vitmatte.py +11 -14
  1361. transformers/models/vitpose/configuration_vitpose.py +7 -4
  1362. transformers/models/vitpose/image_processing_vitpose.py +25 -24
  1363. transformers/models/vitpose/image_processing_vitpose_fast.py +11 -9
  1364. transformers/models/vitpose/modeling_vitpose.py +14 -14
  1365. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +1 -0
  1366. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +10 -8
  1367. transformers/models/vits/configuration_vits.py +1 -4
  1368. transformers/models/vits/modeling_vits.py +42 -44
  1369. transformers/models/vits/tokenization_vits.py +4 -3
  1370. transformers/models/vivit/configuration_vivit.py +1 -0
  1371. transformers/models/vivit/image_processing_vivit.py +39 -36
  1372. transformers/models/vivit/modeling_vivit.py +8 -6
  1373. transformers/models/vjepa2/__init__.py +1 -0
  1374. transformers/models/vjepa2/configuration_vjepa2.py +1 -0
  1375. transformers/models/vjepa2/modeling_vjepa2.py +32 -31
  1376. transformers/models/vjepa2/video_processing_vjepa2.py +1 -0
  1377. transformers/models/voxtral/__init__.py +1 -0
  1378. transformers/models/voxtral/configuration_voxtral.py +2 -0
  1379. transformers/models/voxtral/modeling_voxtral.py +47 -40
  1380. transformers/models/voxtral/modular_voxtral.py +40 -37
  1381. transformers/models/voxtral/processing_voxtral.py +48 -25
  1382. transformers/models/wav2vec2/configuration_wav2vec2.py +2 -4
  1383. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +10 -7
  1384. transformers/models/wav2vec2/modeling_wav2vec2.py +121 -73
  1385. transformers/models/wav2vec2/processing_wav2vec2.py +35 -6
  1386. transformers/models/wav2vec2/tokenization_wav2vec2.py +332 -20
  1387. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +2 -4
  1388. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +62 -70
  1389. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +48 -57
  1390. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +35 -6
  1391. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +2 -4
  1392. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +77 -90
  1393. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +30 -37
  1394. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +17 -16
  1395. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +55 -36
  1396. transformers/models/wavlm/configuration_wavlm.py +2 -4
  1397. transformers/models/wavlm/modeling_wavlm.py +48 -50
  1398. transformers/models/wavlm/modular_wavlm.py +5 -4
  1399. transformers/models/whisper/configuration_whisper.py +5 -6
  1400. transformers/models/whisper/english_normalizer.py +4 -3
  1401. transformers/models/whisper/feature_extraction_whisper.py +24 -9
  1402. transformers/models/whisper/generation_whisper.py +48 -26
  1403. transformers/models/whisper/modeling_whisper.py +73 -79
  1404. transformers/models/whisper/processing_whisper.py +20 -3
  1405. transformers/models/whisper/tokenization_whisper.py +43 -11
  1406. transformers/models/x_clip/configuration_x_clip.py +2 -4
  1407. transformers/models/x_clip/modeling_x_clip.py +93 -96
  1408. transformers/models/x_clip/processing_x_clip.py +14 -2
  1409. transformers/models/xcodec/configuration_xcodec.py +6 -4
  1410. transformers/models/xcodec/modeling_xcodec.py +17 -20
  1411. transformers/models/xglm/configuration_xglm.py +8 -9
  1412. transformers/models/xglm/modeling_xglm.py +55 -60
  1413. transformers/models/xglm/tokenization_xglm.py +11 -3
  1414. transformers/models/xlm/configuration_xlm.py +8 -10
  1415. transformers/models/xlm/modeling_xlm.py +144 -144
  1416. transformers/models/xlm/tokenization_xlm.py +5 -3
  1417. transformers/models/xlm_roberta/configuration_xlm_roberta.py +3 -11
  1418. transformers/models/xlm_roberta/modeling_xlm_roberta.py +194 -195
  1419. transformers/models/xlm_roberta/modular_xlm_roberta.py +53 -50
  1420. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +18 -8
  1421. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +2 -10
  1422. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +93 -94
  1423. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +70 -67
  1424. transformers/models/xlnet/configuration_xlnet.py +12 -3
  1425. transformers/models/xlnet/modeling_xlnet.py +163 -152
  1426. transformers/models/xlnet/tokenization_xlnet.py +9 -2
  1427. transformers/models/xlstm/configuration_xlstm.py +12 -8
  1428. transformers/models/xlstm/modeling_xlstm.py +65 -62
  1429. transformers/models/xmod/configuration_xmod.py +3 -11
  1430. transformers/models/xmod/modeling_xmod.py +110 -108
  1431. transformers/models/yolos/configuration_yolos.py +1 -0
  1432. transformers/models/yolos/image_processing_yolos.py +62 -60
  1433. transformers/models/yolos/image_processing_yolos_fast.py +45 -42
  1434. transformers/models/yolos/modeling_yolos.py +16 -16
  1435. transformers/models/yolos/modular_yolos.py +19 -17
  1436. transformers/models/yoso/configuration_yoso.py +2 -8
  1437. transformers/models/yoso/modeling_yoso.py +63 -70
  1438. transformers/models/zamba/configuration_zamba.py +8 -5
  1439. transformers/models/zamba/modeling_zamba.py +78 -81
  1440. transformers/models/zamba2/configuration_zamba2.py +50 -44
  1441. transformers/models/zamba2/modeling_zamba2.py +97 -97
  1442. transformers/models/zamba2/modular_zamba2.py +48 -46
  1443. transformers/models/zoedepth/configuration_zoedepth.py +2 -1
  1444. transformers/models/zoedepth/image_processing_zoedepth.py +29 -28
  1445. transformers/models/zoedepth/image_processing_zoedepth_fast.py +24 -21
  1446. transformers/models/zoedepth/modeling_zoedepth.py +18 -26
  1447. transformers/pipelines/__init__.py +114 -57
  1448. transformers/pipelines/any_to_any.py +22 -14
  1449. transformers/pipelines/audio_utils.py +2 -1
  1450. transformers/pipelines/automatic_speech_recognition.py +12 -20
  1451. transformers/pipelines/base.py +27 -15
  1452. transformers/{models/pe_audio/processing_pe_audio.py → pipelines/deprecated/__init__.py} +3 -10
  1453. transformers/pipelines/deprecated/text2text_generation.py +408 -0
  1454. transformers/pipelines/document_question_answering.py +2 -4
  1455. transformers/pipelines/image_text_to_text.py +1 -0
  1456. transformers/pipelines/image_to_text.py +229 -0
  1457. transformers/pipelines/question_answering.py +44 -5
  1458. transformers/pipelines/text_classification.py +14 -1
  1459. transformers/pipelines/text_generation.py +1 -1
  1460. transformers/pipelines/text_to_audio.py +2 -2
  1461. transformers/pipelines/token_classification.py +22 -1
  1462. transformers/pipelines/video_classification.py +9 -1
  1463. transformers/pipelines/zero_shot_audio_classification.py +1 -0
  1464. transformers/pipelines/zero_shot_classification.py +6 -0
  1465. transformers/pipelines/zero_shot_image_classification.py +7 -0
  1466. transformers/processing_utils.py +145 -230
  1467. transformers/quantizers/auto.py +4 -2
  1468. transformers/quantizers/base.py +173 -53
  1469. transformers/quantizers/quantizer_aqlm.py +23 -2
  1470. transformers/quantizers/quantizer_auto_round.py +12 -2
  1471. transformers/quantizers/quantizer_awq.py +89 -20
  1472. transformers/quantizers/quantizer_bitnet.py +14 -4
  1473. transformers/quantizers/quantizer_bnb_4bit.py +155 -18
  1474. transformers/quantizers/quantizer_bnb_8bit.py +110 -24
  1475. transformers/quantizers/quantizer_compressed_tensors.py +9 -2
  1476. transformers/quantizers/quantizer_eetq.py +74 -16
  1477. transformers/quantizers/quantizer_fbgemm_fp8.py +138 -38
  1478. transformers/quantizers/quantizer_finegrained_fp8.py +113 -26
  1479. transformers/quantizers/quantizer_fp_quant.py +82 -52
  1480. transformers/quantizers/quantizer_gptq.py +28 -8
  1481. transformers/quantizers/quantizer_higgs.py +60 -42
  1482. transformers/quantizers/quantizer_hqq.py +153 -144
  1483. transformers/quantizers/quantizer_mxfp4.py +194 -14
  1484. transformers/quantizers/quantizer_quanto.py +79 -35
  1485. transformers/quantizers/quantizer_quark.py +18 -36
  1486. transformers/quantizers/quantizer_spqr.py +12 -4
  1487. transformers/quantizers/quantizer_torchao.py +325 -50
  1488. transformers/quantizers/quantizer_vptq.py +27 -4
  1489. transformers/quantizers/quantizers_utils.py +0 -20
  1490. transformers/safetensors_conversion.py +3 -9
  1491. transformers/testing_utils.py +82 -326
  1492. transformers/tokenization_mistral_common.py +903 -568
  1493. transformers/tokenization_utils_base.py +340 -220
  1494. transformers/tokenization_utils_sentencepiece.py +6 -5
  1495. transformers/tokenization_utils_tokenizers.py +113 -226
  1496. transformers/trainer.py +53 -60
  1497. transformers/trainer_callback.py +0 -8
  1498. transformers/trainer_seq2seq.py +1 -5
  1499. transformers/trainer_utils.py +1 -1
  1500. transformers/training_args.py +41 -77
  1501. transformers/utils/__init__.py +4 -8
  1502. transformers/utils/attention_visualizer.py +5 -5
  1503. transformers/utils/auto_docstring.py +37 -599
  1504. transformers/utils/doc.py +36 -4
  1505. transformers/utils/dummy_pt_objects.py +42 -0
  1506. transformers/utils/generic.py +28 -111
  1507. transformers/utils/hub.py +15 -5
  1508. transformers/utils/import_utils.py +32 -165
  1509. transformers/utils/kernel_config.py +19 -74
  1510. transformers/utils/loading_report.py +15 -25
  1511. transformers/utils/quantization_config.py +241 -72
  1512. transformers/video_processing_utils.py +39 -41
  1513. transformers/video_utils.py +22 -18
  1514. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/METADATA +236 -284
  1515. transformers-5.0.0rc0.dist-info/RECORD +1987 -0
  1516. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/WHEEL +1 -1
  1517. transformers/integrations/moe.py +0 -360
  1518. transformers/integrations/quark.py +0 -53
  1519. transformers/loss/loss_lw_detr.py +0 -356
  1520. transformers/models/ernie4_5_vl_moe/__init__.py +0 -31
  1521. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +0 -340
  1522. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +0 -455
  1523. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +0 -231
  1524. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +0 -1936
  1525. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +0 -1925
  1526. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +0 -249
  1527. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +0 -593
  1528. transformers/models/fast_vlm/__init__.py +0 -27
  1529. transformers/models/fast_vlm/configuration_fast_vlm.py +0 -137
  1530. transformers/models/fast_vlm/modeling_fast_vlm.py +0 -432
  1531. transformers/models/fast_vlm/modular_fast_vlm.py +0 -373
  1532. transformers/models/glm4_moe_lite/__init__.py +0 -28
  1533. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +0 -233
  1534. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +0 -740
  1535. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +0 -302
  1536. transformers/models/glm_image/__init__.py +0 -31
  1537. transformers/models/glm_image/configuration_glm_image.py +0 -351
  1538. transformers/models/glm_image/image_processing_glm_image.py +0 -503
  1539. transformers/models/glm_image/image_processing_glm_image_fast.py +0 -294
  1540. transformers/models/glm_image/modeling_glm_image.py +0 -1642
  1541. transformers/models/glm_image/modular_glm_image.py +0 -1531
  1542. transformers/models/glm_image/processing_glm_image.py +0 -217
  1543. transformers/models/glmasr/__init__.py +0 -29
  1544. transformers/models/glmasr/configuration_glmasr.py +0 -196
  1545. transformers/models/glmasr/modeling_glmasr.py +0 -517
  1546. transformers/models/glmasr/modular_glmasr.py +0 -443
  1547. transformers/models/glmasr/processing_glmasr.py +0 -331
  1548. transformers/models/jais2/__init__.py +0 -27
  1549. transformers/models/jais2/configuration_jais2.py +0 -148
  1550. transformers/models/jais2/modeling_jais2.py +0 -484
  1551. transformers/models/jais2/modular_jais2.py +0 -194
  1552. transformers/models/lasr/__init__.py +0 -29
  1553. transformers/models/lasr/configuration_lasr.py +0 -244
  1554. transformers/models/lasr/feature_extraction_lasr.py +0 -275
  1555. transformers/models/lasr/modeling_lasr.py +0 -727
  1556. transformers/models/lasr/modular_lasr.py +0 -574
  1557. transformers/models/lasr/processing_lasr.py +0 -100
  1558. transformers/models/lasr/tokenization_lasr.py +0 -184
  1559. transformers/models/lighton_ocr/__init__.py +0 -28
  1560. transformers/models/lighton_ocr/configuration_lighton_ocr.py +0 -128
  1561. transformers/models/lighton_ocr/modeling_lighton_ocr.py +0 -463
  1562. transformers/models/lighton_ocr/modular_lighton_ocr.py +0 -404
  1563. transformers/models/lighton_ocr/processing_lighton_ocr.py +0 -229
  1564. transformers/models/lw_detr/__init__.py +0 -27
  1565. transformers/models/lw_detr/configuration_lw_detr.py +0 -374
  1566. transformers/models/lw_detr/modeling_lw_detr.py +0 -1702
  1567. transformers/models/lw_detr/modular_lw_detr.py +0 -1615
  1568. transformers/models/minimax_m2/__init__.py +0 -28
  1569. transformers/models/minimax_m2/configuration_minimax_m2.py +0 -188
  1570. transformers/models/minimax_m2/modeling_minimax_m2.py +0 -704
  1571. transformers/models/minimax_m2/modular_minimax_m2.py +0 -346
  1572. transformers/models/paddleocr_vl/__init__.py +0 -31
  1573. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +0 -335
  1574. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +0 -503
  1575. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +0 -209
  1576. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +0 -1683
  1577. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +0 -1380
  1578. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +0 -133
  1579. transformers/models/pe_audio/__init__.py +0 -29
  1580. transformers/models/pe_audio/configuration_pe_audio.py +0 -204
  1581. transformers/models/pe_audio/feature_extraction_pe_audio.py +0 -160
  1582. transformers/models/pe_audio/modeling_pe_audio.py +0 -819
  1583. transformers/models/pe_audio/modular_pe_audio.py +0 -298
  1584. transformers/models/pe_audio_video/__init__.py +0 -28
  1585. transformers/models/pe_audio_video/configuration_pe_audio_video.py +0 -223
  1586. transformers/models/pe_audio_video/modeling_pe_audio_video.py +0 -971
  1587. transformers/models/pe_audio_video/modular_pe_audio_video.py +0 -763
  1588. transformers/models/pe_video/__init__.py +0 -29
  1589. transformers/models/pe_video/configuration_pe_video.py +0 -209
  1590. transformers/models/pe_video/modeling_pe_video.py +0 -647
  1591. transformers/models/pe_video/modular_pe_video.py +0 -231
  1592. transformers/models/pe_video/processing_pe_video.py +0 -10
  1593. transformers/models/pe_video/video_processing_pe_video.py +0 -64
  1594. transformers/models/pixio/__init__.py +0 -29
  1595. transformers/models/pixio/configuration_pixio.py +0 -150
  1596. transformers/models/pixio/modeling_pixio.py +0 -507
  1597. transformers/models/pixio/modular_pixio.py +0 -403
  1598. transformers/models/solar_open/__init__.py +0 -27
  1599. transformers/models/solar_open/configuration_solar_open.py +0 -184
  1600. transformers/models/solar_open/modeling_solar_open.py +0 -642
  1601. transformers/models/solar_open/modular_solar_open.py +0 -224
  1602. transformers/trainer_jit_checkpoint.py +0 -125
  1603. transformers-5.0.0.dist-info/RECORD +0 -2068
  1604. {transformers-5.0.0.dist-info/licenses → transformers-5.0.0rc0.dist-info}/LICENSE +0 -0
  1605. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/entry_points.txt +0 -0
  1606. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,1702 +0,0 @@
1
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
- # This file was automatically generated from src/transformers/models/lw_detr/modular_lw_detr.py.
3
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
- # the file from the modular. If any change should be done, please apply the change to the
5
- # modular_lw_detr.py file directly. One of our CI enforces this.
6
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- # Copyright 2026 The HuggingFace Inc. team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
- import collections.abc
21
- import math
22
- import warnings
23
- from collections.abc import Callable
24
- from dataclasses import dataclass
25
- from typing import Any
26
-
27
- import torch
28
- import torch.nn.functional as F
29
- from torch import Tensor, nn
30
-
31
- from ... import initialization as init
32
- from ...activations import ACT2CLS, ACT2FN
33
- from ...integrations import use_kernel_forward_from_hub
34
- from ...modeling_layers import GradientCheckpointingLayer
35
- from ...modeling_outputs import BackboneOutput
36
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
37
- from ...processing_utils import Unpack
38
- from ...pytorch_utils import meshgrid
39
- from ...utils import ModelOutput, TransformersKwargs, auto_docstring, torch_compilable_check
40
- from ...utils.backbone_utils import BackboneMixin
41
- from ...utils.generic import check_model_inputs
42
- from .configuration_lw_detr import LwDetrConfig, LwDetrViTConfig
43
-
44
-
45
- def eager_attention_forward(
46
- module: nn.Module,
47
- query: torch.Tensor,
48
- key: torch.Tensor,
49
- value: torch.Tensor,
50
- attention_mask: torch.Tensor | None,
51
- scaling: float,
52
- dropout: float = 0.0,
53
- **kwargs: Unpack[TransformersKwargs],
54
- ):
55
- key_states = repeat_kv(key, module.num_key_value_groups)
56
- value_states = repeat_kv(value, module.num_key_value_groups)
57
-
58
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
59
- if attention_mask is not None:
60
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
61
- attn_weights = attn_weights + causal_mask
62
-
63
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
64
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
65
- attn_output = torch.matmul(attn_weights, value_states)
66
- attn_output = attn_output.transpose(1, 2).contiguous()
67
-
68
- return attn_output, attn_weights
69
-
70
-
71
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
72
- """
73
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
74
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
75
- """
76
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
77
- if n_rep == 1:
78
- return hidden_states
79
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
80
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
81
-
82
-
83
- class LwDetrViTSelfAttention(nn.Module):
84
- def __init__(self, config: LwDetrViTConfig):
85
- super().__init__()
86
- if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
87
- raise ValueError(
88
- f"The hidden size {config.hidden_size} is not a multiple of the number of attention "
89
- f"heads {config.num_attention_heads}."
90
- )
91
-
92
- self.config = config
93
- self.num_attention_heads = config.num_attention_heads
94
- self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
95
- self.all_head_size = self.num_attention_heads * self.attention_head_size
96
- self.dropout_prob = config.dropout_prob
97
- self.scaling = self.attention_head_size**-0.5
98
- self.is_causal = False
99
-
100
- self.query = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
101
- self.value = nn.Linear(config.hidden_size, self.all_head_size, bias=config.qkv_bias)
102
- self.key = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
103
- self.num_key_value_groups = 1
104
-
105
- def forward(
106
- self,
107
- hidden_states: torch.Tensor,
108
- **kwargs: Unpack[TransformersKwargs],
109
- ) -> tuple[torch.Tensor, torch.Tensor]:
110
- batch_size = hidden_states.shape[0]
111
- new_shape = batch_size, -1, self.num_attention_heads, self.attention_head_size
112
-
113
- key_layer = self.key(hidden_states).view(*new_shape).transpose(1, 2)
114
- value_layer = self.value(hidden_states).view(*new_shape).transpose(1, 2)
115
- query_layer = self.query(hidden_states).view(*new_shape).transpose(1, 2)
116
-
117
- attention_interface: Callable = eager_attention_forward
118
- if self.config._attn_implementation != "eager":
119
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
120
-
121
- context_layer, attention_probs = attention_interface(
122
- self,
123
- query_layer,
124
- key_layer,
125
- value_layer,
126
- None,
127
- is_causal=self.is_causal,
128
- scaling=self.scaling,
129
- dropout=0.0 if not self.training else self.dropout_prob,
130
- **kwargs,
131
- )
132
-
133
- new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
134
- context_layer = context_layer.reshape(new_context_layer_shape)
135
-
136
- return context_layer, attention_probs
137
-
138
-
139
- class LwDetrViTAttention(nn.Module):
140
- def __init__(self, config: LwDetrViTConfig):
141
- """
142
- Args:
143
- config (`LwDetrViTConfig`):
144
- Model configuration.
145
- """
146
- super().__init__()
147
- self.attention = LwDetrViTSelfAttention(config)
148
- self.output = nn.Linear(config.hidden_size, config.hidden_size)
149
-
150
- def forward(
151
- self,
152
- hidden_states: torch.Tensor,
153
- **kwargs: Unpack[TransformersKwargs],
154
- ) -> torch.Tensor:
155
- self_attn_output, _ = self.attention(hidden_states, **kwargs)
156
- output = self.output(self_attn_output)
157
- return output
158
-
159
-
160
- class LwDetrViTMlp(nn.Module):
161
- def __init__(self, config, in_features: int, hidden_features: int) -> None:
162
- super().__init__()
163
- self.fc1 = nn.Linear(in_features, hidden_features)
164
- self.act = ACT2FN[config.hidden_act]
165
- self.fc2 = nn.Linear(hidden_features, in_features)
166
- self.drop = nn.Dropout(config.dropout_prob)
167
-
168
- def forward(self, x: torch.Tensor) -> torch.Tensor:
169
- x = self.fc1(x)
170
- x = self.act(x)
171
- x = self.drop(x)
172
- x = self.fc2(x)
173
- x = self.drop(x)
174
-
175
- return x
176
-
177
-
178
- class LwDetrViTLayer(GradientCheckpointingLayer):
179
- def __init__(
180
- self,
181
- config: LwDetrViTConfig,
182
- layer_idx,
183
- ) -> None:
184
- super().__init__()
185
-
186
- dim = config.hidden_size
187
- self.attention = LwDetrViTAttention(config)
188
- self.intermediate = LwDetrViTMlp(config=config, in_features=dim, hidden_features=int(dim * config.mlp_ratio))
189
- self.layernorm_before = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
190
- self.layernorm_after = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
191
-
192
- self.gamma_1 = nn.Parameter(torch.Tensor(dim), requires_grad=True)
193
- self.gamma_2 = nn.Parameter(torch.Tensor(dim), requires_grad=True)
194
-
195
- self.window = layer_idx in config.window_block_indices
196
- self.num_windows = config.num_windows
197
-
198
- def forward(
199
- self,
200
- hidden_states: torch.Tensor,
201
- **kwargs: Unpack[TransformersKwargs],
202
- ) -> torch.Tensor:
203
- batch_size, seq_len, channels = hidden_states.shape
204
- hidden_states_norm = self.layernorm_before(hidden_states)
205
-
206
- if not self.window:
207
- hidden_states_norm = hidden_states_norm.reshape(
208
- batch_size // self.num_windows, self.num_windows * seq_len, channels
209
- )
210
-
211
- attention_output = self.attention(hidden_states_norm, **kwargs)
212
- attention_output = attention_output * self.gamma_1
213
-
214
- if not self.window:
215
- attention_output = attention_output.reshape(batch_size, seq_len, channels)
216
-
217
- hidden_states = hidden_states + attention_output
218
-
219
- layer_output = self.layernorm_after(hidden_states)
220
- layer_output = self.intermediate(layer_output)
221
- layer_output = layer_output * self.gamma_2
222
-
223
- hidden_states = hidden_states + layer_output
224
-
225
- return hidden_states
226
-
227
-
228
- class LwDetrViTEncoder(nn.Module):
229
- def __init__(self, config: LwDetrViTConfig) -> None:
230
- super().__init__()
231
- self.config = config
232
- self.layer = nn.ModuleList([LwDetrViTLayer(config, i) for i in range(config.num_hidden_layers)])
233
- self.gradient_checkpointing = False
234
-
235
- def forward(
236
- self,
237
- hidden_states: torch.Tensor,
238
- **kwargs: Unpack[TransformersKwargs],
239
- ) -> list[torch.Tensor]:
240
- list_hidden_states = [hidden_states]
241
- for i, layer_module in enumerate(self.layer):
242
- hidden_states = layer_module(hidden_states, **kwargs)
243
- list_hidden_states.append(hidden_states)
244
- return list_hidden_states
245
-
246
-
247
- class LwDetrViTEmbeddings(nn.Module):
248
- """
249
- This class turns `pixel_values` of shape `(batch_size, num_channels, height, width)` into the initial
250
- `hidden_states` (patch embeddings) to be consumed by a Transformer.
251
- """
252
-
253
- def __init__(self, config):
254
- super().__init__()
255
- image_size, patch_size = config.pretrain_image_size, config.patch_size
256
- num_channels, hidden_size = config.num_channels, config.hidden_size
257
-
258
- image_size = image_size if isinstance(image_size, collections.abc.Iterable) else (image_size, image_size)
259
- patch_size = patch_size if isinstance(patch_size, collections.abc.Iterable) else (patch_size, patch_size)
260
- num_patches = (image_size[1] // patch_size[1]) * (image_size[0] // patch_size[0])
261
- self.image_size = image_size
262
- self.patch_size = patch_size
263
- self.num_channels = num_channels
264
- self.num_patches = num_patches
265
-
266
- if config.use_absolute_position_embeddings:
267
- # Initialize absolute positional embedding with pretrain image size.
268
- num_positions = num_patches + 1
269
- self.position_embeddings = nn.Parameter(torch.zeros(1, num_positions, config.hidden_size))
270
- else:
271
- self.position_embeddings = None
272
-
273
- self.projection = nn.Conv2d(num_channels, hidden_size, kernel_size=patch_size, stride=patch_size)
274
-
275
- def get_absolute_positions(self, abs_pos_embeddings, has_cls_token, height, width):
276
- """
277
- Calculate absolute positional embeddings. If needed, resize embeddings and remove cls_token dimension for the
278
- original embeddings.
279
-
280
- Args:
281
- abs_pos_embeddings (`torch.Tensor`):
282
- Absolute positional embeddings with (1, num_position, num_channels).
283
- has_cls_token (`bool`):
284
- If true, has 1 embedding in abs_pos_embeddings for cls token.
285
- height (`int`):
286
- Height of input image tokens.
287
- width (`int`):
288
- Width of input image tokens.
289
-
290
- Returns:
291
- Absolute positional embeddings after processing with shape (1, height, width, num_channels)
292
- """
293
- if has_cls_token:
294
- abs_pos_embeddings = abs_pos_embeddings[:, 1:]
295
- num_position = abs_pos_embeddings.shape[1]
296
- size = int(math.sqrt(num_position)) # This is a constant and can be recorded as such in the ONNX export.
297
- if size * size != num_position:
298
- raise ValueError("Absolute position embeddings must be a square number.")
299
-
300
- if torch.jit.is_tracing() or (size != height or size != width):
301
- # nn.functional.interpolate is a noop in case size == height and size == width - we need to always capture this path with jit.trace.
302
- new_abs_pos_embeddings = nn.functional.interpolate(
303
- abs_pos_embeddings.reshape(1, size, size, -1).permute(0, 3, 1, 2),
304
- size=(height, width),
305
- mode="bicubic",
306
- align_corners=False,
307
- )
308
-
309
- return new_abs_pos_embeddings.permute(0, 2, 3, 1)
310
- else:
311
- return abs_pos_embeddings.reshape(1, height, width, -1)
312
-
313
- def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
314
- num_channels = pixel_values.shape[1]
315
- if num_channels != self.num_channels:
316
- raise ValueError(
317
- "Make sure that the channel dimension of the pixel values match with the one set in the configuration."
318
- f" Expected {self.num_channels} but got {num_channels}."
319
- )
320
- embeddings = self.projection(pixel_values)
321
-
322
- if self.position_embeddings is not None:
323
- # (batch_size, num_channels, height, width) -> (batch_size, height, width, num_channels)
324
- embeddings = embeddings.permute(0, 2, 3, 1)
325
- # add position embeddings
326
- embeddings = embeddings + self.get_absolute_positions(
327
- self.position_embeddings, True, embeddings.shape[1], embeddings.shape[2]
328
- )
329
- # (batch_size, height, width, num_channels) -> (batch_size, num_channels, height, width)
330
- embeddings = embeddings.permute(0, 3, 1, 2)
331
-
332
- return embeddings
333
-
334
-
335
- @auto_docstring
336
- class LwDetrViTPreTrainedModel(PreTrainedModel):
337
- config: LwDetrViTConfig
338
- base_model_prefix = "lw_detr_vit"
339
- main_input_name = "pixel_values"
340
- input_modalities = ("image",)
341
- supports_gradient_checkpointing = True
342
- _no_split_modules = ["LwDetrViTEmbeddings", "LwDetrViTLayer"]
343
- _supports_sdpa = True
344
- _supports_flash_attn = True
345
- _supports_flex_attn = True
346
- _supports_attention_backend = True
347
- _can_record_outputs = {
348
- "hidden_states": LwDetrViTLayer,
349
- "attentions": LwDetrViTSelfAttention,
350
- }
351
-
352
- @torch.no_grad()
353
- def _init_weights(self, module) -> None:
354
- """Initialize the weights"""
355
- if isinstance(module, (nn.Linear, nn.Conv2d)):
356
- init.trunc_normal_(module.weight, mean=0.0, std=self.config.initializer_range)
357
- if module.bias is not None:
358
- init.zeros_(module.bias)
359
- elif isinstance(module, nn.LayerNorm):
360
- init.zeros_(module.bias)
361
- init.ones_(module.weight)
362
- elif isinstance(module, LwDetrViTEmbeddings):
363
- init.trunc_normal_(module.position_embeddings, mean=0.0, std=self.config.initializer_range)
364
- if isinstance(module, LwDetrViTLayer):
365
- nn.init.constant_(module.gamma_1, self.config.cae_init_values)
366
- nn.init.constant_(module.gamma_2, self.config.cae_init_values)
367
-
368
-
369
- @auto_docstring()
370
- class LwDetrViTBackbone(LwDetrViTPreTrainedModel, BackboneMixin):
371
- def __init__(self, config):
372
- super().__init__(config)
373
- super()._init_backbone(config)
374
-
375
- self.embeddings = LwDetrViTEmbeddings(config)
376
- self.encoder = LwDetrViTEncoder(config)
377
- self.num_features = [config.hidden_size for _ in range(config.num_hidden_layers + 1)]
378
-
379
- # initialize weights and apply final processing
380
- self.post_init()
381
-
382
- def get_input_embeddings(self) -> LwDetrViTEmbeddings:
383
- return self.embeddings.projection
384
-
385
- @check_model_inputs
386
- @auto_docstring
387
- def forward(self, pixel_values: torch.Tensor, **kwargs: Unpack[TransformersKwargs]) -> BackboneOutput:
388
- r"""
389
- Examples:
390
-
391
- ```python
392
- >>> from transformers import LwDetrViTConfig, LwDetrViTBackbone
393
- >>> import torch
394
-
395
- >>> config = LwDetrViTConfig()
396
- >>> model = LwDetrViTBackbone(config)
397
-
398
- >>> pixel_values = torch.randn(1, 3, 224, 224)
399
-
400
- >>> with torch.no_grad():
401
- ... outputs = model(pixel_values)
402
-
403
- >>> feature_maps = outputs.feature_maps
404
- >>> list(feature_maps[-1].shape)
405
- [1, 768, 14, 14]
406
- ```"""
407
- embedding_output = self.embeddings(pixel_values)
408
-
409
- batch_size, channels, height, width = embedding_output.shape
410
- # (batch_size, channels, height, width) -> (batch_size, height, width, channels)
411
- hidden_states = embedding_output.permute(0, 2, 3, 1)
412
-
413
- window_height = height // self.config.num_windows_side
414
- window_width = width // self.config.num_windows_side
415
- # (batch_size, height, width, channels) -> (batch_size*num_windows_side**2, window_height*window_width, channels)
416
- hidden_states = (
417
- hidden_states.reshape(
418
- batch_size,
419
- self.config.num_windows_side,
420
- window_height,
421
- self.config.num_windows_side,
422
- window_width,
423
- channels,
424
- )
425
- .permute(0, 1, 3, 2, 4, 5)
426
- .reshape(batch_size * self.config.num_windows_side**2, window_height * window_width, channels)
427
- )
428
-
429
- hidden_states = self.encoder(hidden_states, **kwargs)
430
-
431
- feature_maps = ()
432
- for stage, hidden_state in zip(self.stage_names, hidden_states):
433
- if stage in self.out_features:
434
- hidden_state = (
435
- hidden_state.reshape(
436
- batch_size,
437
- self.config.num_windows_side,
438
- self.config.num_windows_side,
439
- window_height,
440
- window_width,
441
- channels,
442
- )
443
- .permute(0, 5, 1, 3, 2, 4)
444
- .reshape(batch_size, channels, height, width)
445
- )
446
- feature_maps += (hidden_state,)
447
-
448
- return BackboneOutput(feature_maps=feature_maps)
449
-
450
-
451
- class LwDetrConvNormLayer(nn.Module):
452
- def __init__(
453
- self,
454
- config: LwDetrConfig,
455
- in_channels: int,
456
- out_channels: int,
457
- kernel_size: int,
458
- stride: int,
459
- activation: str | None = None,
460
- ):
461
- super().__init__()
462
- self.conv = nn.Conv2d(
463
- in_channels,
464
- out_channels,
465
- kernel_size,
466
- stride,
467
- padding=kernel_size // 2,
468
- bias=False,
469
- )
470
- self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
471
- self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
472
-
473
- def forward(self, hidden_state):
474
- hidden_state = self.conv(hidden_state)
475
- hidden_state = self.norm(hidden_state)
476
- hidden_state = self.activation(hidden_state)
477
- return hidden_state
478
-
479
-
480
- class LwDetrRepVggBlock(nn.Module):
481
- def __init__(self, config: LwDetrConfig):
482
- super().__init__()
483
- hidden_channels = int(config.d_model * config.hidden_expansion)
484
- self.conv1 = LwDetrConvNormLayer(
485
- config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
486
- )
487
- self.conv2 = LwDetrConvNormLayer(
488
- config, hidden_channels, hidden_channels, 3, 1, activation=config.activation_function
489
- )
490
-
491
- def forward(self, x: torch.Tensor) -> torch.Tensor:
492
- y = self.conv1(x)
493
- y = self.conv2(y)
494
- return y
495
-
496
-
497
- class LwDetrC2FLayer(nn.Module):
498
- # Inspired by RTDetrCSPRepLayer
499
- def __init__(self, config: LwDetrConfig, in_channels: int):
500
- super().__init__()
501
- num_blocks = config.c2f_num_blocks
502
- activation = config.activation_function
503
- out_channels = config.d_model
504
-
505
- self.hidden_channels = int(out_channels * config.hidden_expansion)
506
-
507
- conv1_out_channels = 2 * self.hidden_channels
508
- self.conv1 = LwDetrConvNormLayer(config, in_channels, conv1_out_channels, 1, 1, activation=activation)
509
-
510
- conv2_in_channels = (2 + num_blocks) * self.hidden_channels
511
- self.conv2 = LwDetrConvNormLayer(config, conv2_in_channels, out_channels, 1, 1, activation=activation)
512
-
513
- self.bottlenecks = nn.ModuleList(LwDetrRepVggBlock(config) for _ in range(num_blocks))
514
-
515
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
516
- hidden_states = self.conv1(hidden_states)
517
- all_hidden_states = list(hidden_states.split(self.hidden_channels, 1))
518
- hidden_states = all_hidden_states[-1]
519
-
520
- for bottleneck in self.bottlenecks:
521
- hidden_states = bottleneck(hidden_states)
522
- all_hidden_states.append(hidden_states)
523
-
524
- hidden_states = torch.cat(all_hidden_states, 1)
525
- hidden_states = self.conv2(hidden_states)
526
- return hidden_states
527
-
528
-
529
- class LwDetrLayerNorm(nn.LayerNorm):
530
- r"""LayerNorm that supports two data formats: channels_last (default) or channels_first.
531
- The ordering of the dimensions in the inputs. channels_last corresponds to inputs with shape (batch_size, height,
532
- width, channels) while channels_first corresponds to inputs with shape (batch_size, channels, height, width).
533
- """
534
-
535
- def __init__(self, normalized_shape, *, eps=1e-6, data_format="channels_last", **kwargs):
536
- super().__init__(normalized_shape, eps=eps, **kwargs)
537
- if data_format not in ["channels_last", "channels_first"]:
538
- raise NotImplementedError(f"Unsupported data format: {data_format}")
539
- self.data_format = data_format
540
-
541
- def forward(self, features: torch.Tensor) -> torch.Tensor:
542
- """
543
- Args:
544
- features: Tensor of shape (batch_size, channels, height, width) OR (batch_size, height, width, channels)
545
- """
546
- if self.data_format == "channels_first":
547
- features = features.permute(0, 2, 3, 1)
548
- features = super().forward(features)
549
- features = features.permute(0, 3, 1, 2)
550
- else:
551
- features = super().forward(features)
552
- return features
553
-
554
-
555
- class LwDetrSamplingLayer(nn.Module):
556
- def __init__(self, config: LwDetrConfig, channel_size: int, scale: float):
557
- super().__init__()
558
-
559
- self.scale = scale
560
- self.channel_size = channel_size
561
-
562
- layers = []
563
- if scale == 2.0:
564
- if channel_size > 512:
565
- layers.append(LwDetrConvNormLayer(config, channel_size, channel_size // 2, 1, 1, activation="relu"))
566
- layers.append(nn.ConvTranspose2d(channel_size // 2, channel_size // 4, kernel_size=2, stride=2))
567
- else:
568
- layers.append(nn.ConvTranspose2d(channel_size, channel_size // 2, 2, 2))
569
- elif scale == 0.5:
570
- layers.append(LwDetrConvNormLayer(config, channel_size, channel_size, 3, 2, activation="relu"))
571
- self.layers = nn.ModuleList(layers)
572
-
573
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
574
- for layer in self.layers:
575
- hidden_states = layer(hidden_states)
576
- return hidden_states
577
-
578
-
579
- class LwDetrScaleProjector(nn.Module):
580
- def __init__(self, config: LwDetrConfig, scale: float):
581
- super().__init__()
582
-
583
- intermediate_dims = [config.backbone_config.hidden_size] * len(config.backbone_config.out_indices)
584
- sampling_layers = []
585
- for channel_size in intermediate_dims:
586
- sampling_layers.append(LwDetrSamplingLayer(config, channel_size, scale))
587
- self.sampling_layers = nn.ModuleList(sampling_layers)
588
-
589
- intermediate_dim = intermediate_dims[-1]
590
- if scale == 2.0:
591
- if intermediate_dim > 512:
592
- intermediate_dim = intermediate_dim // 4
593
- else:
594
- intermediate_dim = intermediate_dim // 2
595
- projector_input_dim = intermediate_dim * len(intermediate_dims)
596
-
597
- self.projector_layer = LwDetrC2FLayer(config, projector_input_dim)
598
- self.layer_norm = LwDetrLayerNorm(config.d_model, data_format="channels_first")
599
-
600
- def forward(self, hidden_states_tuple: tuple[torch.Tensor]) -> torch.Tensor:
601
- sampled_hidden_states = []
602
- for sampling_layer, hidden_states in zip(self.sampling_layers, hidden_states_tuple):
603
- hidden_states = sampling_layer(hidden_states)
604
- sampled_hidden_states.append(hidden_states)
605
- hidden_states = torch.cat(sampled_hidden_states, dim=1)
606
- hidden_states = self.projector_layer(hidden_states)
607
- hidden_states = self.layer_norm(hidden_states)
608
- return hidden_states
609
-
610
-
611
- class LwDetrMultiScaleProjector(nn.Module):
612
- def __init__(self, config: LwDetrConfig):
613
- super().__init__()
614
-
615
- self.config = config
616
- scale_factors = config.projector_scale_factors
617
-
618
- self.scale_layers = nn.ModuleList([LwDetrScaleProjector(config, scale) for scale in scale_factors])
619
-
620
- def forward(self, hidden_states: tuple[torch.Tensor]) -> list[torch.Tensor]:
621
- output_hidden_states = []
622
- for scale_layer in self.scale_layers:
623
- output_hidden_states.append(scale_layer(hidden_states))
624
- return output_hidden_states
625
-
626
-
627
- class LwDetrConvEncoder(nn.Module):
628
- def __init__(self, config: LwDetrConfig):
629
- super().__init__()
630
- self.backbone = LwDetrViTBackbone(config.backbone_config)
631
- self.projector = LwDetrMultiScaleProjector(config)
632
-
633
- def forward(self, pixel_values: torch.Tensor, pixel_mask: torch.Tensor):
634
- # send pixel_values through the model to get list of feature maps
635
- features = self.backbone(pixel_values).feature_maps
636
- features = self.projector(features)
637
- out = []
638
- for feature_map in features:
639
- # downsample pixel_mask to match shape of corresponding feature_map
640
- mask = nn.functional.interpolate(pixel_mask[None].float(), size=feature_map.shape[-2:]).to(torch.bool)[0]
641
- out.append((feature_map, mask))
642
- return out
643
-
644
-
645
- class LwDetrAttention(nn.Module):
646
- def __init__(self, config: LwDetrConfig, layer_idx: int):
647
- super().__init__()
648
- self.config = config
649
- self.layer_idx = layer_idx
650
- self.head_dim = getattr(config, "head_dim", config.d_model // config.decoder_self_attention_heads)
651
- self.scaling = self.head_dim**-0.5
652
- self.attention_dropout = config.attention_dropout
653
- self.is_causal = False
654
- self.num_key_value_groups = 1
655
-
656
- self.q_proj = nn.Linear(
657
- config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
658
- )
659
- self.k_proj = nn.Linear(
660
- config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
661
- )
662
- self.v_proj = nn.Linear(
663
- config.d_model, config.decoder_self_attention_heads * self.head_dim, bias=config.attention_bias
664
- )
665
- self.o_proj = nn.Linear(
666
- config.decoder_self_attention_heads * self.head_dim, config.d_model, bias=config.attention_bias
667
- )
668
-
669
- def forward(
670
- self,
671
- hidden_states: torch.Tensor,
672
- position_embeddings: torch.Tensor | None = None,
673
- **kwargs: Unpack[TransformersKwargs],
674
- ) -> tuple[torch.Tensor, torch.Tensor]:
675
- batch_size, seq_len, _ = hidden_states.shape
676
- input_shape = hidden_states.shape[:-1]
677
- hidden_shape = (*input_shape, -1, self.head_dim)
678
-
679
- hidden_states_original = hidden_states
680
- if position_embeddings is not None:
681
- hidden_states = hidden_states if position_embeddings is None else hidden_states + position_embeddings
682
-
683
- if self.training:
684
- # at training, we use group detr technique to add more supervision by using multiple weight-sharing decoders at once for faster convergence
685
- # at inference, we only use one decoder
686
- hidden_states_original = torch.cat(
687
- hidden_states_original.split(seq_len // self.config.group_detr, dim=1), dim=0
688
- )
689
- hidden_states = torch.cat(hidden_states.split(seq_len // self.config.group_detr, dim=1), dim=0)
690
-
691
- query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
692
- key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
693
- value_states = self.v_proj(hidden_states_original).view(hidden_shape).transpose(1, 2)
694
-
695
- attention_interface: Callable = eager_attention_forward
696
- if self.config._attn_implementation != "eager":
697
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
698
-
699
- attn_output, attn_weights = attention_interface(
700
- self,
701
- query_states,
702
- key_states,
703
- value_states,
704
- attention_mask=None,
705
- dropout=0.0 if not self.training else self.attention_dropout,
706
- scaling=self.scaling,
707
- **kwargs,
708
- )
709
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
710
- attn_output = self.o_proj(attn_output)
711
-
712
- if self.training:
713
- attn_output = torch.cat(torch.split(attn_output, batch_size, dim=0), dim=1)
714
-
715
- return attn_output, attn_weights
716
-
717
-
718
- @use_kernel_forward_from_hub("MultiScaleDeformableAttention")
719
- class MultiScaleDeformableAttention(nn.Module):
720
- def forward(
721
- self,
722
- value: Tensor,
723
- value_spatial_shapes: Tensor,
724
- value_spatial_shapes_list: list[tuple],
725
- level_start_index: Tensor,
726
- sampling_locations: Tensor,
727
- attention_weights: Tensor,
728
- im2col_step: int,
729
- ):
730
- batch_size, _, num_heads, hidden_dim = value.shape
731
- _, num_queries, num_heads, num_levels, num_points, _ = sampling_locations.shape
732
- value_list = value.split([height * width for height, width in value_spatial_shapes_list], dim=1)
733
- sampling_grids = 2 * sampling_locations - 1
734
- sampling_value_list = []
735
- for level_id, (height, width) in enumerate(value_spatial_shapes_list):
736
- # batch_size, height*width, num_heads, hidden_dim
737
- # -> batch_size, height*width, num_heads*hidden_dim
738
- # -> batch_size, num_heads*hidden_dim, height*width
739
- # -> batch_size*num_heads, hidden_dim, height, width
740
- value_l_ = (
741
- value_list[level_id]
742
- .flatten(2)
743
- .transpose(1, 2)
744
- .reshape(batch_size * num_heads, hidden_dim, height, width)
745
- )
746
- # batch_size, num_queries, num_heads, num_points, 2
747
- # -> batch_size, num_heads, num_queries, num_points, 2
748
- # -> batch_size*num_heads, num_queries, num_points, 2
749
- sampling_grid_l_ = sampling_grids[:, :, :, level_id].transpose(1, 2).flatten(0, 1)
750
- # batch_size*num_heads, hidden_dim, num_queries, num_points
751
- sampling_value_l_ = nn.functional.grid_sample(
752
- value_l_,
753
- sampling_grid_l_,
754
- mode="bilinear",
755
- padding_mode="zeros",
756
- align_corners=False,
757
- )
758
- sampling_value_list.append(sampling_value_l_)
759
- # (batch_size, num_queries, num_heads, num_levels, num_points)
760
- # -> (batch_size, num_heads, num_queries, num_levels, num_points)
761
- # -> (batch_size, num_heads, 1, num_queries, num_levels*num_points)
762
- attention_weights = attention_weights.transpose(1, 2).reshape(
763
- batch_size * num_heads, 1, num_queries, num_levels * num_points
764
- )
765
- output = (
766
- (torch.stack(sampling_value_list, dim=-2).flatten(-2) * attention_weights)
767
- .sum(-1)
768
- .view(batch_size, num_heads * hidden_dim, num_queries)
769
- )
770
- return output.transpose(1, 2).contiguous()
771
-
772
-
773
- class LwDetrMultiscaleDeformableAttention(nn.Module):
774
- """
775
- Multiscale deformable attention as proposed in Deformable DETR.
776
- """
777
-
778
- def __init__(self, config: LwDetrConfig, num_heads: int, n_points: int):
779
- super().__init__()
780
-
781
- self.attn = MultiScaleDeformableAttention()
782
-
783
- if config.d_model % num_heads != 0:
784
- raise ValueError(
785
- f"embed_dim (d_model) must be divisible by num_heads, but got {config.d_model} and {num_heads}"
786
- )
787
- dim_per_head = config.d_model // num_heads
788
- # check if dim_per_head is power of 2
789
- if not ((dim_per_head & (dim_per_head - 1) == 0) and dim_per_head != 0):
790
- warnings.warn(
791
- "You'd better set embed_dim (d_model) in LwDetrMultiscaleDeformableAttention to make the"
792
- " dimension of each attention head a power of 2 which is more efficient in the authors' CUDA"
793
- " implementation."
794
- )
795
-
796
- self.im2col_step = 64
797
-
798
- self.d_model = config.d_model
799
- self.n_levels = config.num_feature_levels
800
- self.n_heads = num_heads
801
- self.n_points = n_points
802
-
803
- self.sampling_offsets = nn.Linear(config.d_model, num_heads * self.n_levels * n_points * 2)
804
- self.attention_weights = nn.Linear(config.d_model, num_heads * self.n_levels * n_points)
805
- self.value_proj = nn.Linear(config.d_model, config.d_model)
806
- self.output_proj = nn.Linear(config.d_model, config.d_model)
807
-
808
- self.disable_custom_kernels = config.disable_custom_kernels
809
-
810
- def with_pos_embed(self, tensor: torch.Tensor, position_embeddings: Tensor | None):
811
- return tensor if position_embeddings is None else tensor + position_embeddings
812
-
813
- def forward(
814
- self,
815
- hidden_states: torch.Tensor,
816
- attention_mask: torch.Tensor | None = None,
817
- encoder_hidden_states=None,
818
- encoder_attention_mask=None,
819
- position_embeddings: torch.Tensor | None = None,
820
- reference_points=None,
821
- spatial_shapes=None,
822
- spatial_shapes_list=None,
823
- level_start_index=None,
824
- **kwargs: Unpack[TransformersKwargs],
825
- ):
826
- # add position embeddings to the hidden states before projecting to queries and keys
827
- if position_embeddings is not None:
828
- hidden_states = self.with_pos_embed(hidden_states, position_embeddings)
829
-
830
- batch_size, num_queries, _ = hidden_states.shape
831
- batch_size, sequence_length, _ = encoder_hidden_states.shape
832
- total_elements = sum(height * width for height, width in spatial_shapes_list)
833
- torch_compilable_check(
834
- total_elements == sequence_length,
835
- "Make sure to align the spatial shapes with the sequence length of the encoder hidden states",
836
- )
837
-
838
- value = self.value_proj(encoder_hidden_states)
839
- if attention_mask is not None:
840
- # we invert the attention_mask
841
- value = value.masked_fill(~attention_mask[..., None], float(0))
842
- value = value.view(batch_size, sequence_length, self.n_heads, self.d_model // self.n_heads)
843
- sampling_offsets = self.sampling_offsets(hidden_states).view(
844
- batch_size, num_queries, self.n_heads, self.n_levels, self.n_points, 2
845
- )
846
- attention_weights = self.attention_weights(hidden_states).view(
847
- batch_size, num_queries, self.n_heads, self.n_levels * self.n_points
848
- )
849
- attention_weights = F.softmax(attention_weights, -1).view(
850
- batch_size, num_queries, self.n_heads, self.n_levels, self.n_points
851
- )
852
- # batch_size, num_queries, n_heads, n_levels, n_points, 2
853
- num_coordinates = reference_points.shape[-1]
854
- if num_coordinates == 2:
855
- offset_normalizer = torch.stack([spatial_shapes[..., 1], spatial_shapes[..., 0]], -1)
856
- sampling_locations = (
857
- reference_points[:, :, None, :, None, :]
858
- + sampling_offsets / offset_normalizer[None, None, None, :, None, :]
859
- )
860
- elif num_coordinates == 4:
861
- sampling_locations = (
862
- reference_points[:, :, None, :, None, :2]
863
- + sampling_offsets / self.n_points * reference_points[:, :, None, :, None, 2:] * 0.5
864
- )
865
- else:
866
- raise ValueError(f"Last dim of reference_points must be 2 or 4, but got {reference_points.shape[-1]}")
867
-
868
- output = self.attn(
869
- value,
870
- spatial_shapes,
871
- spatial_shapes_list,
872
- level_start_index,
873
- sampling_locations,
874
- attention_weights,
875
- self.im2col_step,
876
- )
877
-
878
- output = self.output_proj(output)
879
-
880
- return output, attention_weights
881
-
882
-
883
- class LwDetrMLP(nn.Module):
884
- def __init__(self, config: LwDetrConfig):
885
- super().__init__()
886
- self.dropout = config.dropout
887
- self.activation_fn = ACT2FN[config.decoder_activation_function]
888
- self.fc1 = nn.Linear(config.d_model, config.decoder_ffn_dim)
889
- self.fc2 = nn.Linear(config.decoder_ffn_dim, config.d_model)
890
-
891
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
892
- residual = hidden_states
893
- hidden_states = self.fc1(hidden_states)
894
- hidden_states = self.activation_fn(hidden_states)
895
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
896
- hidden_states = self.fc2(hidden_states)
897
- hidden_states = nn.functional.dropout(hidden_states, p=self.dropout, training=self.training)
898
- hidden_states = residual + hidden_states
899
- return hidden_states
900
-
901
-
902
- class LwDetrDecoderLayer(GradientCheckpointingLayer):
903
- def __init__(self, config: LwDetrConfig, layer_idx: int):
904
- nn.Module.__init__(self)
905
-
906
- # self-attention
907
- self.self_attn = LwDetrAttention(config, layer_idx=layer_idx)
908
- self.dropout = config.dropout
909
- self.activation_fn = ACT2FN[config.decoder_activation_function]
910
- self.activation_dropout = config.activation_dropout
911
- self.self_attn_layer_norm = nn.LayerNorm(config.d_model)
912
-
913
- # cross-attention
914
- self.cross_attn = LwDetrMultiscaleDeformableAttention(
915
- config,
916
- num_heads=config.decoder_cross_attention_heads,
917
- n_points=config.decoder_n_points,
918
- )
919
- self.cross_attn_layer_norm = nn.LayerNorm(config.d_model)
920
-
921
- # mlp
922
- self.mlp = LwDetrMLP(config)
923
- self.layer_norm = nn.LayerNorm(config.d_model)
924
-
925
- def forward(
926
- self,
927
- hidden_states: torch.Tensor,
928
- position_embeddings: torch.Tensor | None = None,
929
- reference_points=None,
930
- spatial_shapes=None,
931
- spatial_shapes_list=None,
932
- level_start_index=None,
933
- encoder_hidden_states: torch.Tensor | None = None,
934
- encoder_attention_mask: torch.Tensor | None = None,
935
- **kwargs: Unpack[TransformersKwargs],
936
- ):
937
- self_attention_output, self_attn_weights = self.self_attn(
938
- hidden_states, position_embeddings=position_embeddings, **kwargs
939
- )
940
-
941
- self_attention_output = nn.functional.dropout(self_attention_output, p=self.dropout, training=self.training)
942
- hidden_states = hidden_states + self_attention_output
943
- hidden_states = self.self_attn_layer_norm(hidden_states)
944
-
945
- cross_attention_output, cross_attn_weights = self.cross_attn(
946
- hidden_states=hidden_states,
947
- attention_mask=encoder_attention_mask,
948
- encoder_hidden_states=encoder_hidden_states,
949
- encoder_attention_mask=encoder_attention_mask,
950
- position_embeddings=position_embeddings,
951
- reference_points=reference_points,
952
- spatial_shapes=spatial_shapes,
953
- spatial_shapes_list=spatial_shapes_list,
954
- level_start_index=level_start_index,
955
- **kwargs,
956
- )
957
- cross_attention_output = nn.functional.dropout(cross_attention_output, p=self.dropout, training=self.training)
958
- hidden_states = hidden_states + cross_attention_output
959
- hidden_states = self.cross_attn_layer_norm(hidden_states)
960
-
961
- hidden_states = self.mlp(hidden_states)
962
- hidden_states = self.layer_norm(hidden_states)
963
-
964
- return hidden_states
965
-
966
-
967
- @auto_docstring
968
- class LwDetrPreTrainedModel(PreTrainedModel):
969
- config: LwDetrConfig
970
- base_model_prefix = "model"
971
- main_input_name = "pixel_values"
972
- _no_split_modules = [
973
- r"LwDetrConvEncoder",
974
- r"LwDetrDecoderLayer",
975
- ]
976
- _supports_sdpa = True
977
- _supports_flash_attn = True
978
- _supports_flex_attn = True
979
- _supports_attention_backend = True
980
- _can_record_outputs = {
981
- "attentions": [LwDetrAttention, LwDetrMultiscaleDeformableAttention],
982
- "hidden_states": [LwDetrDecoderLayer],
983
- }
984
-
985
- @torch.no_grad()
986
- def _init_weights(self, module):
987
- super()._init_weights(module)
988
-
989
- if isinstance(module, LwDetrMultiscaleDeformableAttention):
990
- init.constant_(module.sampling_offsets.weight, 0.0)
991
- thetas = torch.arange(module.n_heads, dtype=torch.int64).float() * (2.0 * math.pi / module.n_heads)
992
- grid_init = torch.stack([thetas.cos(), thetas.sin()], -1)
993
- grid_init = (
994
- (grid_init / grid_init.abs().max(-1, keepdim=True)[0])
995
- .view(module.n_heads, 1, 1, 2)
996
- .repeat(1, module.n_levels, module.n_points, 1)
997
- )
998
- for i in range(module.n_points):
999
- grid_init[:, :, i, :] *= i + 1
1000
-
1001
- init.copy_(module.sampling_offsets.bias, grid_init.view(-1))
1002
- init.constant_(module.attention_weights.weight, 0.0)
1003
- init.constant_(module.attention_weights.bias, 0.0)
1004
- init.xavier_uniform_(module.value_proj.weight)
1005
- init.constant_(module.value_proj.bias, 0.0)
1006
- init.xavier_uniform_(module.output_proj.weight)
1007
- init.constant_(module.output_proj.bias, 0.0)
1008
- if hasattr(module, "level_embed"):
1009
- init.normal_(module.level_embed)
1010
- if hasattr(module, "refpoint_embed") and module.refpoint_embed is not None:
1011
- init.constant_(module.refpoint_embed.weight, 0)
1012
- if hasattr(module, "class_embed") and module.class_embed is not None:
1013
- prior_prob = 0.01
1014
- bias_value = -math.log((1 - prior_prob) / prior_prob)
1015
- init.constant_(module.class_embed.bias, bias_value)
1016
- if hasattr(module, "bbox_embed") and module.bbox_embed is not None:
1017
- init.constant_(module.bbox_embed.layers[-1].weight, 0)
1018
- init.constant_(module.bbox_embed.layers[-1].bias, 0)
1019
-
1020
-
1021
- @dataclass
1022
- @auto_docstring(
1023
- custom_intro="""
1024
- Base class for outputs of the LwDetrDecoder. This class adds two attributes to
1025
- BaseModelOutputWithCrossAttentions, namely:
1026
- - a stacked tensor of intermediate decoder hidden states (i.e. the output of each decoder layer)
1027
- - a stacked tensor of intermediate reference points.
1028
- """
1029
- )
1030
- class LwDetrDecoderOutput(ModelOutput):
1031
- r"""
1032
- intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
1033
- Stacked intermediate hidden states (output of each layer of the decoder).
1034
- intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, sequence_length, hidden_size)`):
1035
- Stacked intermediate reference points (reference points of each layer of the decoder).
1036
- cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` and `config.add_cross_attention=True` is passed or when `config.output_attentions=True`):
1037
- Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
1038
- sequence_length)`. Attentions weights of the decoder's cross-attention layer, after the attention softmax,
1039
- used to compute the weighted average in the cross-attention heads.
1040
- """
1041
-
1042
- last_hidden_state: torch.FloatTensor | None = None
1043
- intermediate_hidden_states: torch.FloatTensor | None = None
1044
- intermediate_reference_points: torch.FloatTensor | None = None
1045
- hidden_states: tuple[torch.FloatTensor] | None = None
1046
- attentions: tuple[torch.FloatTensor] | None = None
1047
- cross_attentions: tuple[torch.FloatTensor] | None = None
1048
-
1049
-
1050
- # function to generate sine positional embedding for 4d coordinates
1051
- def gen_sine_position_embeddings(pos_tensor, hidden_size=256):
1052
- """
1053
- This function computes position embeddings using sine and cosine functions from the input positional tensor,
1054
- which has a shape of (batch_size, num_queries, 4).
1055
- The last dimension of `pos_tensor` represents the following coordinates:
1056
- - 0: x-coord
1057
- - 1: y-coord
1058
- - 2: width
1059
- - 3: height
1060
-
1061
- The output shape is (batch_size, num_queries, 512), where final dim (hidden_size*2 = 512) is the total embedding dimension
1062
- achieved by concatenating the sine and cosine values for each coordinate.
1063
- """
1064
- scale = 2 * math.pi
1065
- dim = hidden_size // 2
1066
- dim_t = torch.arange(dim, dtype=torch.float32, device=pos_tensor.device)
1067
- dim_t = 10000 ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / dim)
1068
- x_embed = pos_tensor[:, :, 0] * scale
1069
- y_embed = pos_tensor[:, :, 1] * scale
1070
- pos_x = x_embed[:, :, None] / dim_t
1071
- pos_y = y_embed[:, :, None] / dim_t
1072
- pos_x = torch.stack((pos_x[:, :, 0::2].sin(), pos_x[:, :, 1::2].cos()), dim=3).flatten(2)
1073
- pos_y = torch.stack((pos_y[:, :, 0::2].sin(), pos_y[:, :, 1::2].cos()), dim=3).flatten(2)
1074
- if pos_tensor.size(-1) == 4:
1075
- w_embed = pos_tensor[:, :, 2] * scale
1076
- pos_w = w_embed[:, :, None] / dim_t
1077
- pos_w = torch.stack((pos_w[:, :, 0::2].sin(), pos_w[:, :, 1::2].cos()), dim=3).flatten(2)
1078
-
1079
- h_embed = pos_tensor[:, :, 3] * scale
1080
- pos_h = h_embed[:, :, None] / dim_t
1081
- pos_h = torch.stack((pos_h[:, :, 0::2].sin(), pos_h[:, :, 1::2].cos()), dim=3).flatten(2)
1082
-
1083
- pos = torch.cat((pos_y, pos_x, pos_w, pos_h), dim=2)
1084
- else:
1085
- raise ValueError(f"Unknown pos_tensor shape(-1):{pos_tensor.size(-1)}")
1086
- return pos.to(pos_tensor.dtype)
1087
-
1088
-
1089
- class LwDetrDecoder(LwDetrPreTrainedModel):
1090
- """
1091
- Transformer decoder consisting of *config.decoder_layers* layers. Each layer is a [`DeformableDetrDecoderLayer`].
1092
-
1093
- The decoder updates the query embeddings through multiple self-attention and deformable cross-attention layers.
1094
-
1095
- Some tweaks for LwDetr:
1096
-
1097
- - it uses group detr technique at training for faster convergence.
1098
-
1099
- Args:
1100
- config: LwDetrConfig
1101
- """
1102
-
1103
- def __init__(self, config: LwDetrConfig):
1104
- super().__init__(config)
1105
- self.dropout = config.dropout
1106
- self.layers = nn.ModuleList([LwDetrDecoderLayer(config, i) for i in range(config.decoder_layers)])
1107
- self.layernorm = nn.LayerNorm(config.d_model)
1108
-
1109
- self.gradient_checkpointing = False
1110
-
1111
- self.ref_point_head = LwDetrMLPPredictionHead(2 * config.d_model, config.d_model, config.d_model, num_layers=2)
1112
-
1113
- self.post_init()
1114
-
1115
- def get_reference(self, reference_points, valid_ratios):
1116
- # batch_size, num_queries, batch_size, 4
1117
- obj_center = reference_points[..., :4]
1118
-
1119
- # batch_size, num_queries, num_levels, 4
1120
- reference_points_inputs = obj_center[:, :, None] * torch.cat([valid_ratios, valid_ratios], -1)[:, None]
1121
-
1122
- # batch_size, num_queries, d_model * 2
1123
- query_sine_embed = gen_sine_position_embeddings(reference_points_inputs[:, :, 0, :], self.config.d_model)
1124
-
1125
- # batch_size, num_queries, d_model
1126
- query_pos = self.ref_point_head(query_sine_embed)
1127
- return reference_points_inputs, query_pos
1128
-
1129
- def forward(
1130
- self,
1131
- inputs_embeds: torch.Tensor | None = None,
1132
- reference_points: torch.Tensor | None = None,
1133
- spatial_shapes: torch.Tensor | None = None,
1134
- spatial_shapes_list: torch.Tensor | None = None,
1135
- level_start_index: torch.Tensor | None = None,
1136
- valid_ratios: torch.Tensor | None = None,
1137
- encoder_hidden_states: torch.Tensor | None = None,
1138
- encoder_attention_mask: torch.Tensor | None = None,
1139
- **kwargs: Unpack[TransformersKwargs],
1140
- ):
1141
- intermediate = ()
1142
- intermediate_reference_points = (reference_points,)
1143
-
1144
- if inputs_embeds is not None:
1145
- hidden_states = inputs_embeds
1146
-
1147
- reference_points_inputs, query_pos = self.get_reference(reference_points, valid_ratios)
1148
-
1149
- for idx, decoder_layer in enumerate(self.layers):
1150
- hidden_states = decoder_layer(
1151
- hidden_states,
1152
- encoder_hidden_states=encoder_hidden_states,
1153
- encoder_attention_mask=encoder_attention_mask,
1154
- position_embeddings=query_pos,
1155
- reference_points=reference_points_inputs,
1156
- spatial_shapes=spatial_shapes,
1157
- spatial_shapes_list=spatial_shapes_list,
1158
- level_start_index=level_start_index,
1159
- **kwargs,
1160
- )
1161
- intermediate_hidden_states = self.layernorm(hidden_states)
1162
- intermediate += (intermediate_hidden_states,)
1163
-
1164
- intermediate = torch.stack(intermediate)
1165
- last_hidden_state = intermediate[-1]
1166
- intermediate_reference_points = torch.stack(intermediate_reference_points)
1167
-
1168
- return LwDetrDecoderOutput(
1169
- last_hidden_state=last_hidden_state,
1170
- intermediate_hidden_states=intermediate,
1171
- intermediate_reference_points=intermediate_reference_points,
1172
- )
1173
-
1174
-
1175
- @dataclass
1176
- @auto_docstring(
1177
- custom_intro="""
1178
- Base class for outputs of the LwDetr backbone-decoder model.
1179
- """
1180
- )
1181
- class LwDetrModelOutput(ModelOutput):
1182
- r"""
1183
- init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1184
- Initial reference points sent through the Transformer decoder.
1185
- intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
1186
- Stacked intermediate hidden states (output of each layer of the decoder).
1187
- intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1188
- Stacked intermediate reference points (reference points of each layer of the decoder).
1189
- enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1190
- Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
1191
- picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
1192
- foreground and background).
1193
- enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1194
- Logits of predicted bounding boxes coordinates in the first stage.
1195
- """
1196
-
1197
- init_reference_points: torch.FloatTensor | None = None
1198
- last_hidden_state: torch.FloatTensor | None = None
1199
- intermediate_hidden_states: torch.FloatTensor | None = None
1200
- intermediate_reference_points: torch.FloatTensor | None = None
1201
- enc_outputs_class: torch.FloatTensor | None = None
1202
- enc_outputs_coord_logits: torch.FloatTensor | None = None
1203
-
1204
-
1205
- def refine_bboxes(reference_points, deltas):
1206
- reference_points = reference_points.to(deltas.device)
1207
- new_reference_points_cxcy = deltas[..., :2] * reference_points[..., 2:] + reference_points[..., :2]
1208
- new_reference_points_wh = deltas[..., 2:].exp() * reference_points[..., 2:]
1209
- new_reference_points = torch.cat((new_reference_points_cxcy, new_reference_points_wh), -1)
1210
- return new_reference_points
1211
-
1212
-
1213
- @auto_docstring(
1214
- custom_intro="""
1215
- The bare LW Detr Model (consisting of a backbone and decoder Transformer) outputting raw
1216
- hidden-states without any specific head on top.
1217
- """
1218
- )
1219
- class LwDetrModel(LwDetrPreTrainedModel):
1220
- def __init__(self, config: LwDetrConfig):
1221
- super().__init__(config)
1222
-
1223
- # Create backbone + positional encoding
1224
- self.backbone = LwDetrConvEncoder(config)
1225
-
1226
- self.group_detr = config.group_detr
1227
- self.num_queries = config.num_queries
1228
- hidden_dim = config.d_model
1229
- self.reference_point_embed = nn.Embedding(self.num_queries * self.group_detr, 4)
1230
- self.query_feat = nn.Embedding(self.num_queries * self.group_detr, hidden_dim)
1231
-
1232
- self.decoder = LwDetrDecoder(config)
1233
-
1234
- self.enc_output = nn.ModuleList([nn.Linear(hidden_dim, hidden_dim) for _ in range(self.group_detr)])
1235
- self.enc_output_norm = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(self.group_detr)])
1236
- # Should normally be None and then instantiated in the ForObjectDetection class
1237
- self.enc_out_bbox_embed = nn.ModuleList(
1238
- [LwDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3) for _ in range(self.group_detr)]
1239
- )
1240
- self.enc_out_class_embed = nn.ModuleList(
1241
- [nn.Linear(config.d_model, config.num_labels) for _ in range(self.group_detr)]
1242
- )
1243
-
1244
- self.post_init()
1245
-
1246
- def freeze_backbone(self):
1247
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1248
- param.requires_grad_(False)
1249
-
1250
- def unfreeze_backbone(self):
1251
- for name, param in self.backbone.conv_encoder.model.named_parameters():
1252
- param.requires_grad_(True)
1253
-
1254
- def get_valid_ratio(self, mask, dtype=torch.float32):
1255
- """Get the valid ratio of all feature maps."""
1256
-
1257
- _, height, width = mask.shape
1258
- valid_height = torch.sum(mask[:, :, 0], 1)
1259
- valid_width = torch.sum(mask[:, 0, :], 1)
1260
- valid_ratio_height = valid_height.to(dtype) / height
1261
- valid_ratio_width = valid_width.to(dtype) / width
1262
- valid_ratio = torch.stack([valid_ratio_width, valid_ratio_height], -1)
1263
- return valid_ratio
1264
-
1265
- def get_proposal_pos_embed(self, proposals):
1266
- """Get the position embedding of the proposals."""
1267
-
1268
- num_pos_feats = self.config.d_model // 2
1269
- temperature = 10000
1270
- scale = 2 * math.pi
1271
-
1272
- dim_t = torch.arange(num_pos_feats, dtype=proposals.dtype, device=proposals.device)
1273
- dim_t = temperature ** (2 * torch.div(dim_t, 2, rounding_mode="floor") / num_pos_feats)
1274
- # batch_size, num_queries, 4
1275
- proposals = proposals.sigmoid() * scale
1276
- # batch_size, num_queries, 4, 128
1277
- pos = proposals[:, :, :, None] / dim_t
1278
- # batch_size, num_queries, 4, 64, 2 -> batch_size, num_queries, 512
1279
- pos = torch.stack((pos[:, :, :, 0::2].sin(), pos[:, :, :, 1::2].cos()), dim=4).flatten(2)
1280
- return pos
1281
-
1282
- def gen_encoder_output_proposals(self, enc_output, padding_mask, spatial_shapes):
1283
- """Generate the encoder output proposals from encoded enc_output.
1284
-
1285
- Args:
1286
- enc_output (Tensor[batch_size, sequence_length, hidden_size]): Output of the encoder.
1287
- padding_mask (Tensor[batch_size, sequence_length]): Padding mask for `enc_output`.
1288
- spatial_shapes (list[tuple[int, int]]): Spatial shapes of the feature maps.
1289
-
1290
- Returns:
1291
- `tuple(torch.FloatTensor)`: A tuple of feature map and bbox prediction.
1292
- - object_query (Tensor[batch_size, sequence_length, hidden_size]): Object query features. Later used to
1293
- directly predict a bounding box. (without the need of a decoder)
1294
- - output_proposals (Tensor[batch_size, sequence_length, 4]): Normalized proposals, after an inverse
1295
- sigmoid.
1296
- """
1297
- batch_size = enc_output.shape[0]
1298
- proposals = []
1299
- _cur = 0
1300
- for level, (height, width) in enumerate(spatial_shapes):
1301
- mask_flatten_ = padding_mask[:, _cur : (_cur + height * width)].view(batch_size, height, width, 1)
1302
- valid_height = torch.sum(~mask_flatten_[:, :, 0, 0], 1)
1303
- valid_width = torch.sum(~mask_flatten_[:, 0, :, 0], 1)
1304
-
1305
- grid_y, grid_x = meshgrid(
1306
- torch.linspace(
1307
- 0,
1308
- height - 1,
1309
- height,
1310
- dtype=enc_output.dtype,
1311
- device=enc_output.device,
1312
- ),
1313
- torch.linspace(
1314
- 0,
1315
- width - 1,
1316
- width,
1317
- dtype=enc_output.dtype,
1318
- device=enc_output.device,
1319
- ),
1320
- indexing="ij",
1321
- )
1322
- grid = torch.cat([grid_x.unsqueeze(-1), grid_y.unsqueeze(-1)], -1)
1323
-
1324
- scale = torch.cat([valid_width.unsqueeze(-1), valid_height.unsqueeze(-1)], 1).view(batch_size, 1, 1, 2)
1325
- grid = (grid.unsqueeze(0).expand(batch_size, -1, -1, -1) + 0.5) / scale
1326
- width_height = torch.ones_like(grid) * 0.05 * (2.0**level)
1327
- proposal = torch.cat((grid, width_height), -1).view(batch_size, -1, 4)
1328
- proposals.append(proposal)
1329
- _cur += height * width
1330
- output_proposals = torch.cat(proposals, 1)
1331
- output_proposals_valid = ((output_proposals > 0.01) & (output_proposals < 0.99)).all(-1, keepdim=True)
1332
- output_proposals = output_proposals.masked_fill(padding_mask.unsqueeze(-1), float("inf"))
1333
- output_proposals = output_proposals.masked_fill(~output_proposals_valid, float("inf"))
1334
-
1335
- # assign each pixel as an object query
1336
- object_query = enc_output
1337
- object_query = object_query.masked_fill(padding_mask.unsqueeze(-1), float(0))
1338
- object_query = object_query.masked_fill(~output_proposals_valid, float(0))
1339
- return object_query, output_proposals
1340
-
1341
- @check_model_inputs
1342
- @auto_docstring
1343
- def forward(
1344
- self,
1345
- pixel_values: torch.FloatTensor = None,
1346
- pixel_mask: torch.LongTensor | None = None,
1347
- **kwargs: Unpack[TransformersKwargs],
1348
- ) -> LwDetrModelOutput:
1349
- r"""
1350
- Examples:
1351
-
1352
- ```python
1353
- >>> from transformers import AutoImageProcessor, DeformableDetrModel
1354
- >>> from PIL import Image
1355
- >>> import httpx
1356
- >>> from io import BytesIO
1357
-
1358
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1359
- >>> with httpx.stream("GET", url) as response:
1360
- ... image = Image.open(BytesIO(response.read()))
1361
-
1362
- >>> image_processor = AutoImageProcessor.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1363
- >>> model = DeformableDetrModel.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1364
-
1365
- >>> inputs = image_processor(images=image, return_tensors="pt")
1366
-
1367
- >>> outputs = model(**inputs)
1368
-
1369
- >>> last_hidden_states = outputs.last_hidden_state
1370
- >>> list(last_hidden_states.shape)
1371
- [1, 300, 256]
1372
- ```"""
1373
- batch_size, num_channels, height, width = pixel_values.shape
1374
- device = pixel_values.device
1375
-
1376
- if pixel_mask is None:
1377
- pixel_mask = torch.ones(((batch_size, height, width)), dtype=torch.long, device=device)
1378
-
1379
- # Extract multi-scale feature maps of same resolution `config.d_model` (cf Figure 4 in paper)
1380
- # First, sent pixel_values + pixel_mask through Backbone to obtain the features
1381
- # which is a list of tuples
1382
- features = self.backbone(pixel_values, pixel_mask)
1383
-
1384
- # Then, apply 1x1 convolution to reduce the channel dimension to d_model (256 by default)
1385
- sources = []
1386
- masks = []
1387
- for level, (source, mask) in enumerate(features):
1388
- sources.append(source)
1389
- masks.append(mask)
1390
- if mask is None:
1391
- raise ValueError("No attention mask was provided")
1392
-
1393
- if self.training:
1394
- reference_points = self.reference_point_embed.weight
1395
- query_feat = self.query_feat.weight
1396
- else:
1397
- # only use one group in inference
1398
- reference_points = self.reference_point_embed.weight[: self.num_queries]
1399
- query_feat = self.query_feat.weight[: self.num_queries]
1400
-
1401
- # Prepare encoder inputs (by flattening)
1402
- source_flatten = []
1403
- mask_flatten = []
1404
- spatial_shapes_list = []
1405
- for source, mask in zip(sources, masks):
1406
- batch_size, num_channels, height, width = source.shape
1407
- spatial_shape = (height, width)
1408
- spatial_shapes_list.append(spatial_shape)
1409
- source = source.flatten(2).transpose(1, 2)
1410
- mask = mask.flatten(1)
1411
- source_flatten.append(source)
1412
- mask_flatten.append(mask)
1413
- source_flatten = torch.cat(source_flatten, 1)
1414
- mask_flatten = torch.cat(mask_flatten, 1)
1415
- spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
1416
- level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))
1417
- valid_ratios = torch.stack([self.get_valid_ratio(m, dtype=source_flatten.dtype) for m in masks], 1)
1418
-
1419
- target = query_feat.unsqueeze(0).expand(batch_size, -1, -1)
1420
- reference_points = reference_points.unsqueeze(0).expand(batch_size, -1, -1)
1421
-
1422
- object_query_embedding, output_proposals = self.gen_encoder_output_proposals(
1423
- source_flatten, ~mask_flatten, spatial_shapes_list
1424
- )
1425
-
1426
- group_detr = self.group_detr if self.training else 1
1427
- topk = self.num_queries
1428
- topk_coords_logits = []
1429
- topk_coords_logits_undetach = []
1430
- object_query_undetach = []
1431
-
1432
- for group_id in range(group_detr):
1433
- group_object_query = self.enc_output[group_id](object_query_embedding)
1434
- group_object_query = self.enc_output_norm[group_id](group_object_query)
1435
-
1436
- group_enc_outputs_class = self.enc_out_class_embed[group_id](group_object_query)
1437
- group_delta_bbox = self.enc_out_bbox_embed[group_id](group_object_query)
1438
- group_enc_outputs_coord = refine_bboxes(output_proposals, group_delta_bbox)
1439
-
1440
- group_topk_proposals = torch.topk(group_enc_outputs_class.max(-1)[0], topk, dim=1)[1]
1441
- group_topk_coords_logits_undetach = torch.gather(
1442
- group_enc_outputs_coord,
1443
- 1,
1444
- group_topk_proposals.unsqueeze(-1).repeat(1, 1, 4),
1445
- )
1446
- group_topk_coords_logits = group_topk_coords_logits_undetach.detach()
1447
- group_object_query_undetach = torch.gather(
1448
- group_object_query, 1, group_topk_proposals.unsqueeze(-1).repeat(1, 1, self.config.d_model)
1449
- )
1450
-
1451
- topk_coords_logits.append(group_topk_coords_logits)
1452
- topk_coords_logits_undetach.append(group_topk_coords_logits_undetach)
1453
- object_query_undetach.append(group_object_query_undetach)
1454
-
1455
- topk_coords_logits = torch.cat(topk_coords_logits, 1)
1456
- topk_coords_logits_undetach = torch.cat(topk_coords_logits_undetach, 1)
1457
- object_query_undetach = torch.cat(object_query_undetach, 1)
1458
-
1459
- enc_outputs_class = object_query_undetach
1460
- enc_outputs_coord_logits = topk_coords_logits
1461
-
1462
- reference_points = refine_bboxes(topk_coords_logits_undetach, reference_points)
1463
-
1464
- init_reference_points = reference_points
1465
- decoder_outputs = self.decoder(
1466
- inputs_embeds=target,
1467
- reference_points=reference_points,
1468
- spatial_shapes=spatial_shapes,
1469
- spatial_shapes_list=spatial_shapes_list,
1470
- level_start_index=level_start_index,
1471
- valid_ratios=valid_ratios,
1472
- encoder_hidden_states=source_flatten,
1473
- encoder_attention_mask=mask_flatten,
1474
- **kwargs,
1475
- )
1476
-
1477
- return LwDetrModelOutput(
1478
- init_reference_points=init_reference_points,
1479
- last_hidden_state=decoder_outputs.last_hidden_state,
1480
- intermediate_hidden_states=decoder_outputs.intermediate_hidden_states,
1481
- intermediate_reference_points=decoder_outputs.intermediate_reference_points,
1482
- enc_outputs_class=enc_outputs_class,
1483
- enc_outputs_coord_logits=enc_outputs_coord_logits,
1484
- )
1485
-
1486
-
1487
- class LwDetrMLPPredictionHead(nn.Module):
1488
- """
1489
- Very simple multi-layer perceptron (MLP, also called FFN), used to predict the normalized center coordinates,
1490
- height and width of a bounding box w.r.t. an image.
1491
-
1492
- Copied from https://github.com/facebookresearch/detr/blob/master/models/detr.py
1493
-
1494
- """
1495
-
1496
- def __init__(self, input_dim, hidden_dim, output_dim, num_layers):
1497
- super().__init__()
1498
- self.num_layers = num_layers
1499
- h = [hidden_dim] * (num_layers - 1)
1500
- self.layers = nn.ModuleList(nn.Linear(n, k) for n, k in zip([input_dim] + h, h + [output_dim]))
1501
-
1502
- def forward(self, x):
1503
- for i, layer in enumerate(self.layers):
1504
- x = nn.functional.relu(layer(x)) if i < self.num_layers - 1 else layer(x)
1505
- return x
1506
-
1507
-
1508
- @dataclass
1509
- @auto_docstring(
1510
- custom_intro="""
1511
- Output type of [`LwDetrForObjectDetection`].
1512
- """
1513
- )
1514
- class LwDetrObjectDetectionOutput(ModelOutput):
1515
- r"""
1516
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` are provided)):
1517
- Total loss as a linear combination of a negative log-likehood (cross-entropy) for class prediction and a
1518
- bounding box loss. The latter is defined as a linear combination of the L1 loss and the generalized
1519
- scale-invariant IoU loss.
1520
- loss_dict (`Dict`, *optional*):
1521
- A dictionary containing the individual losses. Useful for logging.
1522
- logits (`torch.FloatTensor` of shape `(batch_size, num_queries, num_classes + 1)`):
1523
- Classification logits (including no-object) for all queries.
1524
- pred_boxes (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1525
- Normalized boxes coordinates for all queries, represented as (center_x, center_y, width, height). These
1526
- values are normalized in [0, 1], relative to the size of each individual image in the batch (disregarding
1527
- possible padding). You can use [`~DeformableDetrProcessor.post_process_object_detection`] to retrieve the
1528
- unnormalized bounding boxes.
1529
- auxiliary_outputs (`list[Dict]`, *optional*):
1530
- Optional, only returned when auxiliary losses are activated (i.e. `config.auxiliary_loss` is set to `True`)
1531
- and labels are provided. It is a list of dictionaries containing the two above keys (`logits` and
1532
- `pred_boxes`) for each decoder layer.
1533
- init_reference_points (`torch.FloatTensor` of shape `(batch_size, num_queries, 4)`):
1534
- Initial reference points sent through the Transformer decoder.
1535
- intermediate_hidden_states (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, hidden_size)`):
1536
- Stacked intermediate hidden states (output of each layer of the decoder).
1537
- intermediate_reference_points (`torch.FloatTensor` of shape `(batch_size, config.decoder_layers, num_queries, 4)`):
1538
- Stacked intermediate reference points (reference points of each layer of the decoder).
1539
- enc_outputs_class (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.num_labels)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1540
- Predicted bounding boxes scores where the top `config.two_stage_num_proposals` scoring bounding boxes are
1541
- picked as region proposals in the first stage. Output of bounding box binary classification (i.e.
1542
- foreground and background).
1543
- enc_outputs_coord_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, 4)`, *optional*, returned when `config.with_box_refine=True` and `config.two_stage=True`):
1544
- Logits of predicted bounding boxes coordinates in the first stage.
1545
- """
1546
-
1547
- loss: torch.FloatTensor | None = None
1548
- loss_dict: dict | None = None
1549
- logits: torch.FloatTensor | None = None
1550
- pred_boxes: torch.FloatTensor | None = None
1551
- auxiliary_outputs: list[dict] | None = None
1552
- init_reference_points: torch.FloatTensor | None = None
1553
- last_hidden_state: torch.FloatTensor | None = None
1554
- intermediate_hidden_states: torch.FloatTensor | None = None
1555
- intermediate_reference_points: torch.FloatTensor | None = None
1556
- enc_outputs_class: Any = None
1557
- enc_outputs_coord_logits: torch.FloatTensor | None = None
1558
-
1559
-
1560
- @auto_docstring(
1561
- custom_intro="""
1562
- LW DETR Model (consisting of a backbone and decoder Transformer) with object detection heads on
1563
- top, for tasks such as COCO detection.
1564
- """
1565
- )
1566
- class LwDetrForObjectDetection(LwDetrPreTrainedModel):
1567
- # When using clones, all layers > 0 will be clones, but layer 0 *is* required
1568
- # We can't initialize the model on meta device as some weights are modified during the initialization
1569
- _no_split_modules = None
1570
- _tied_weights_keys = None
1571
-
1572
- def __init__(self, config: LwDetrConfig):
1573
- super().__init__(config)
1574
- self.model = LwDetrModel(config)
1575
- self.class_embed = nn.Linear(config.d_model, config.num_labels)
1576
- self.bbox_embed = LwDetrMLPPredictionHead(config.d_model, config.d_model, 4, num_layers=3)
1577
-
1578
- self.post_init()
1579
-
1580
- @check_model_inputs
1581
- @auto_docstring
1582
- def forward(
1583
- self,
1584
- pixel_values: torch.FloatTensor = None,
1585
- pixel_mask: torch.LongTensor | None = None,
1586
- labels: list[dict] | None = None,
1587
- **kwargs: Unpack[TransformersKwargs],
1588
- ) -> LwDetrObjectDetectionOutput:
1589
- r"""
1590
- decoder_attention_mask (`torch.FloatTensor` of shape `(batch_size, num_queries)`, *optional*):
1591
- Not used by default. Can be used to mask object queries.
1592
- inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
1593
- Optionally, instead of passing the flattened feature map (output of the backbone + projection layer), you
1594
- can choose to directly pass a flattened representation of an image.
1595
- decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_queries, hidden_size)`, *optional*):
1596
- Optionally, instead of initializing the queries with a tensor of zeros, you can choose to directly pass an
1597
- embedded representation.
1598
- labels (`list[Dict]` of len `(batch_size,)`, *optional*):
1599
- Labels for computing the bipartite matching loss. List of dicts, each dictionary containing at least the
1600
- following 2 keys: 'class_labels' and 'boxes' (the class labels and bounding boxes of an image in the batch
1601
- respectively). The class labels themselves should be a `torch.LongTensor` of len `(number of bounding boxes
1602
- in the image,)` and the boxes a `torch.FloatTensor` of shape `(number of bounding boxes in the image, 4)`.
1603
-
1604
- Examples:
1605
-
1606
- ```python
1607
- >>> from transformers import AutoImageProcessor, LwDetrForObjectDetection
1608
- >>> from PIL import Image
1609
- >>> import httpx
1610
- >>> from io import BytesIO
1611
-
1612
- >>> url = "http://images.cocodataset.org/val2017/000000039769.jpg"
1613
- >>> with httpx.stream("GET", url) as response:
1614
- ... image = Image.open(BytesIO(response.read()))
1615
-
1616
- >>> image_processor = AutoImageProcessor.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1617
- >>> model = LwDetrForObjectDetection.from_pretrained("AnnaZhang/lwdetr_small_60e_coco")
1618
-
1619
- >>> inputs = image_processor(images=image, return_tensors="pt")
1620
- >>> outputs = model(**inputs)
1621
-
1622
- >>> # convert outputs (bounding boxes and class logits) to Pascal VOC format (xmin, ymin, xmax, ymax)
1623
- >>> target_sizes = torch.tensor([image.size[::-1]])
1624
- >>> results = image_processor.post_process_object_detection(outputs, threshold=0.5, target_sizes=target_sizes)[
1625
- ... 0
1626
- ... ]
1627
- >>> for score, label, box in zip(results["scores"], results["labels"], results["boxes"]):
1628
- ... box = [round(i, 2) for i in box.tolist()]
1629
- ... print(
1630
- ... f"Detected {model.config.id2label[label.item()]} with confidence "
1631
- ... f"{round(score.item(), 3)} at location {box}"
1632
- ... )
1633
- Detected cat with confidence 0.8 at location [16.5, 52.84, 318.25, 470.78]
1634
- Detected cat with confidence 0.789 at location [342.19, 24.3, 640.02, 372.25]
1635
- Detected remote with confidence 0.633 at location [40.79, 72.78, 176.76, 117.25]
1636
- ```"""
1637
- outputs = self.model(
1638
- pixel_values,
1639
- pixel_mask=pixel_mask,
1640
- **kwargs,
1641
- )
1642
-
1643
- last_hidden_states = outputs.last_hidden_state
1644
- intermediate_reference_points = outputs.intermediate_reference_points
1645
- enc_outputs_class_logits = outputs.enc_outputs_class
1646
- enc_outputs_boxes_logits = outputs.enc_outputs_coord_logits
1647
-
1648
- logits = self.class_embed(last_hidden_states)
1649
- pred_boxes_delta = self.bbox_embed(last_hidden_states)
1650
- pred_boxes = refine_bboxes(intermediate_reference_points[-1], pred_boxes_delta)
1651
-
1652
- enc_outputs_class_logits_list = enc_outputs_class_logits.split(self.config.num_queries, dim=1)
1653
- pred_class = []
1654
- group_detr = self.config.group_detr if self.training else 1
1655
- for group_index in range(group_detr):
1656
- group_pred_class = self.model.enc_out_class_embed[group_index](enc_outputs_class_logits_list[group_index])
1657
- pred_class.append(group_pred_class)
1658
- enc_outputs_class_logits = torch.cat(pred_class, dim=1)
1659
-
1660
- loss, loss_dict, auxiliary_outputs = None, None, None
1661
- if labels is not None:
1662
- outputs_class, outputs_coord = None, None
1663
- if self.config.auxiliary_loss:
1664
- intermediate_hidden_states = outputs.intermediate_hidden_states
1665
- outputs_coord_delta = self.bbox_embed(intermediate_hidden_states)
1666
- outputs_coord = refine_bboxes(intermediate_reference_points, outputs_coord_delta)
1667
- outputs_class = self.class_embed(intermediate_hidden_states)
1668
-
1669
- loss, loss_dict, auxiliary_outputs = self.loss_function(
1670
- logits,
1671
- labels,
1672
- self.device,
1673
- pred_boxes,
1674
- self.config,
1675
- outputs_class,
1676
- outputs_coord,
1677
- enc_outputs_class_logits,
1678
- enc_outputs_boxes_logits,
1679
- )
1680
-
1681
- return LwDetrObjectDetectionOutput(
1682
- loss=loss,
1683
- loss_dict=loss_dict,
1684
- logits=logits,
1685
- pred_boxes=pred_boxes,
1686
- auxiliary_outputs=auxiliary_outputs,
1687
- last_hidden_state=outputs.last_hidden_state,
1688
- intermediate_hidden_states=outputs.intermediate_hidden_states,
1689
- intermediate_reference_points=outputs.intermediate_reference_points,
1690
- init_reference_points=outputs.init_reference_points,
1691
- enc_outputs_class=enc_outputs_class_logits,
1692
- enc_outputs_coord_logits=enc_outputs_boxes_logits,
1693
- )
1694
-
1695
-
1696
- __all__ = [
1697
- "LwDetrPreTrainedModel",
1698
- "LwDetrModel",
1699
- "LwDetrForObjectDetection",
1700
- "LwDetrViTPreTrainedModel",
1701
- "LwDetrViTBackbone",
1702
- ]