transformers 5.0.0__py3-none-any.whl → 5.0.0rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1606) hide show
  1. transformers/__init__.py +36 -55
  2. transformers/activations.py +1 -1
  3. transformers/audio_utils.py +33 -32
  4. transformers/cache_utils.py +139 -32
  5. transformers/cli/chat.py +3 -3
  6. transformers/cli/serve.py +19 -49
  7. transformers/cli/transformers.py +1 -2
  8. transformers/configuration_utils.py +155 -129
  9. transformers/conversion_mapping.py +22 -158
  10. transformers/convert_slow_tokenizer.py +17 -227
  11. transformers/core_model_loading.py +185 -528
  12. transformers/data/data_collator.py +4 -12
  13. transformers/data/processors/glue.py +1 -0
  14. transformers/data/processors/utils.py +1 -0
  15. transformers/data/processors/xnli.py +1 -0
  16. transformers/dependency_versions_check.py +1 -0
  17. transformers/dependency_versions_table.py +7 -5
  18. transformers/distributed/configuration_utils.py +2 -1
  19. transformers/dynamic_module_utils.py +25 -24
  20. transformers/feature_extraction_sequence_utils.py +23 -19
  21. transformers/feature_extraction_utils.py +33 -64
  22. transformers/file_utils.py +1 -0
  23. transformers/generation/__init__.py +1 -11
  24. transformers/generation/candidate_generator.py +33 -80
  25. transformers/generation/configuration_utils.py +133 -189
  26. transformers/generation/continuous_batching/__init__.py +1 -4
  27. transformers/generation/continuous_batching/cache.py +25 -83
  28. transformers/generation/continuous_batching/cache_manager.py +45 -155
  29. transformers/generation/continuous_batching/continuous_api.py +147 -270
  30. transformers/generation/continuous_batching/requests.py +3 -51
  31. transformers/generation/continuous_batching/scheduler.py +105 -160
  32. transformers/generation/logits_process.py +128 -0
  33. transformers/generation/stopping_criteria.py +1 -1
  34. transformers/generation/streamers.py +1 -0
  35. transformers/generation/utils.py +123 -122
  36. transformers/generation/watermarking.py +6 -8
  37. transformers/hf_argparser.py +13 -9
  38. transformers/hyperparameter_search.py +2 -1
  39. transformers/image_processing_base.py +23 -12
  40. transformers/image_processing_utils.py +15 -11
  41. transformers/image_processing_utils_fast.py +75 -85
  42. transformers/image_transforms.py +42 -73
  43. transformers/image_utils.py +32 -30
  44. transformers/initialization.py +0 -37
  45. transformers/integrations/__init__.py +2 -16
  46. transformers/integrations/accelerate.py +113 -58
  47. transformers/integrations/aqlm.py +66 -36
  48. transformers/integrations/awq.py +516 -45
  49. transformers/integrations/bitnet.py +105 -47
  50. transformers/integrations/bitsandbytes.py +202 -91
  51. transformers/integrations/deepspeed.py +4 -161
  52. transformers/integrations/eetq.py +82 -84
  53. transformers/integrations/executorch.py +1 -1
  54. transformers/integrations/fbgemm_fp8.py +145 -190
  55. transformers/integrations/finegrained_fp8.py +215 -249
  56. transformers/integrations/flash_attention.py +3 -3
  57. transformers/integrations/flex_attention.py +1 -1
  58. transformers/integrations/fp_quant.py +0 -90
  59. transformers/integrations/ggml.py +2 -11
  60. transformers/integrations/higgs.py +62 -37
  61. transformers/integrations/hub_kernels.py +8 -65
  62. transformers/integrations/integration_utils.py +3 -47
  63. transformers/integrations/mistral.py +0 -12
  64. transformers/integrations/mxfp4.py +80 -33
  65. transformers/integrations/peft.py +191 -483
  66. transformers/integrations/quanto.py +56 -77
  67. transformers/integrations/spqr.py +90 -42
  68. transformers/integrations/tensor_parallel.py +221 -167
  69. transformers/integrations/torchao.py +43 -35
  70. transformers/integrations/vptq.py +59 -40
  71. transformers/kernels/__init__.py +0 -0
  72. transformers/{models/pe_audio_video/processing_pe_audio_video.py → kernels/falcon_mamba/__init__.py} +3 -12
  73. transformers/kernels/falcon_mamba/selective_scan_with_ln_interface.py +529 -0
  74. transformers/loss/loss_utils.py +0 -2
  75. transformers/masking_utils.py +55 -51
  76. transformers/model_debugging_utils.py +5 -4
  77. transformers/modelcard.py +194 -15
  78. transformers/modeling_attn_mask_utils.py +19 -19
  79. transformers/modeling_flash_attention_utils.py +27 -27
  80. transformers/modeling_gguf_pytorch_utils.py +24 -79
  81. transformers/modeling_layers.py +22 -21
  82. transformers/modeling_outputs.py +253 -242
  83. transformers/modeling_rope_utils.py +117 -138
  84. transformers/modeling_utils.py +739 -850
  85. transformers/models/__init__.py +0 -27
  86. transformers/models/afmoe/configuration_afmoe.py +33 -40
  87. transformers/models/afmoe/modeling_afmoe.py +54 -42
  88. transformers/models/afmoe/modular_afmoe.py +33 -23
  89. transformers/models/aimv2/configuration_aimv2.py +10 -2
  90. transformers/models/aimv2/modeling_aimv2.py +42 -47
  91. transformers/models/aimv2/modular_aimv2.py +19 -17
  92. transformers/models/albert/configuration_albert.py +2 -8
  93. transformers/models/albert/modeling_albert.py +69 -70
  94. transformers/models/albert/tokenization_albert.py +14 -5
  95. transformers/models/align/configuration_align.py +6 -8
  96. transformers/models/align/modeling_align.py +89 -94
  97. transformers/models/align/processing_align.py +30 -2
  98. transformers/models/altclip/configuration_altclip.py +7 -4
  99. transformers/models/altclip/modeling_altclip.py +103 -114
  100. transformers/models/altclip/processing_altclip.py +15 -2
  101. transformers/models/apertus/__init__.py +1 -0
  102. transformers/models/apertus/configuration_apertus.py +28 -23
  103. transformers/models/apertus/modeling_apertus.py +40 -39
  104. transformers/models/apertus/modular_apertus.py +38 -37
  105. transformers/models/arcee/configuration_arcee.py +30 -25
  106. transformers/models/arcee/modeling_arcee.py +39 -36
  107. transformers/models/arcee/modular_arcee.py +23 -20
  108. transformers/models/aria/configuration_aria.py +44 -31
  109. transformers/models/aria/image_processing_aria.py +27 -25
  110. transformers/models/aria/modeling_aria.py +106 -110
  111. transformers/models/aria/modular_aria.py +127 -118
  112. transformers/models/aria/processing_aria.py +35 -28
  113. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +1 -0
  114. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +6 -3
  115. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +8 -6
  116. transformers/models/audioflamingo3/__init__.py +1 -0
  117. transformers/models/audioflamingo3/configuration_audioflamingo3.py +1 -0
  118. transformers/models/audioflamingo3/modeling_audioflamingo3.py +49 -58
  119. transformers/models/audioflamingo3/modular_audioflamingo3.py +43 -53
  120. transformers/models/audioflamingo3/processing_audioflamingo3.py +30 -33
  121. transformers/models/auto/auto_factory.py +7 -6
  122. transformers/models/auto/configuration_auto.py +5 -66
  123. transformers/models/auto/feature_extraction_auto.py +10 -14
  124. transformers/models/auto/image_processing_auto.py +41 -32
  125. transformers/models/auto/modeling_auto.py +188 -46
  126. transformers/models/auto/processing_auto.py +11 -24
  127. transformers/models/auto/tokenization_auto.py +588 -171
  128. transformers/models/auto/video_processing_auto.py +10 -12
  129. transformers/models/autoformer/configuration_autoformer.py +7 -4
  130. transformers/models/autoformer/modeling_autoformer.py +101 -104
  131. transformers/models/aya_vision/configuration_aya_vision.py +1 -4
  132. transformers/models/aya_vision/modeling_aya_vision.py +102 -71
  133. transformers/models/aya_vision/modular_aya_vision.py +74 -46
  134. transformers/models/aya_vision/processing_aya_vision.py +53 -25
  135. transformers/models/bamba/configuration_bamba.py +39 -34
  136. transformers/models/bamba/modeling_bamba.py +86 -82
  137. transformers/models/bamba/modular_bamba.py +72 -70
  138. transformers/models/bark/configuration_bark.py +8 -6
  139. transformers/models/bark/generation_configuration_bark.py +5 -3
  140. transformers/models/bark/modeling_bark.py +57 -54
  141. transformers/models/bark/processing_bark.py +41 -19
  142. transformers/models/bart/configuration_bart.py +6 -9
  143. transformers/models/bart/modeling_bart.py +126 -135
  144. transformers/models/barthez/tokenization_barthez.py +11 -3
  145. transformers/models/bartpho/tokenization_bartpho.py +7 -6
  146. transformers/models/beit/configuration_beit.py +11 -0
  147. transformers/models/beit/image_processing_beit.py +56 -53
  148. transformers/models/beit/image_processing_beit_fast.py +12 -10
  149. transformers/models/beit/modeling_beit.py +60 -69
  150. transformers/models/bert/configuration_bert.py +2 -12
  151. transformers/models/bert/modeling_bert.py +122 -114
  152. transformers/models/bert/tokenization_bert.py +23 -8
  153. transformers/models/bert/tokenization_bert_legacy.py +5 -3
  154. transformers/models/bert_generation/configuration_bert_generation.py +2 -17
  155. transformers/models/bert_generation/modeling_bert_generation.py +49 -49
  156. transformers/models/bert_generation/tokenization_bert_generation.py +3 -2
  157. transformers/models/bert_japanese/tokenization_bert_japanese.py +6 -5
  158. transformers/models/bertweet/tokenization_bertweet.py +3 -1
  159. transformers/models/big_bird/configuration_big_bird.py +9 -12
  160. transformers/models/big_bird/modeling_big_bird.py +109 -116
  161. transformers/models/big_bird/tokenization_big_bird.py +43 -16
  162. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  163. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +117 -130
  164. transformers/models/biogpt/configuration_biogpt.py +2 -8
  165. transformers/models/biogpt/modeling_biogpt.py +76 -72
  166. transformers/models/biogpt/modular_biogpt.py +66 -62
  167. transformers/models/biogpt/tokenization_biogpt.py +5 -3
  168. transformers/models/bit/configuration_bit.py +1 -0
  169. transformers/models/bit/image_processing_bit.py +24 -21
  170. transformers/models/bit/image_processing_bit_fast.py +1 -0
  171. transformers/models/bit/modeling_bit.py +12 -25
  172. transformers/models/bitnet/configuration_bitnet.py +28 -23
  173. transformers/models/bitnet/modeling_bitnet.py +39 -36
  174. transformers/models/bitnet/modular_bitnet.py +6 -4
  175. transformers/models/blenderbot/configuration_blenderbot.py +5 -8
  176. transformers/models/blenderbot/modeling_blenderbot.py +96 -77
  177. transformers/models/blenderbot/tokenization_blenderbot.py +24 -18
  178. transformers/models/blenderbot_small/configuration_blenderbot_small.py +5 -8
  179. transformers/models/blenderbot_small/modeling_blenderbot_small.py +69 -79
  180. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +3 -1
  181. transformers/models/blip/configuration_blip.py +10 -9
  182. transformers/models/blip/image_processing_blip.py +20 -17
  183. transformers/models/blip/image_processing_blip_fast.py +1 -0
  184. transformers/models/blip/modeling_blip.py +108 -117
  185. transformers/models/blip/modeling_blip_text.py +65 -73
  186. transformers/models/blip/processing_blip.py +36 -5
  187. transformers/models/blip_2/configuration_blip_2.py +2 -2
  188. transformers/models/blip_2/modeling_blip_2.py +118 -146
  189. transformers/models/blip_2/processing_blip_2.py +38 -8
  190. transformers/models/bloom/configuration_bloom.py +2 -5
  191. transformers/models/bloom/modeling_bloom.py +104 -77
  192. transformers/models/blt/configuration_blt.py +86 -94
  193. transformers/models/blt/modeling_blt.py +81 -238
  194. transformers/models/blt/modular_blt.py +65 -228
  195. transformers/models/bridgetower/configuration_bridgetower.py +2 -7
  196. transformers/models/bridgetower/image_processing_bridgetower.py +35 -34
  197. transformers/models/bridgetower/image_processing_bridgetower_fast.py +16 -13
  198. transformers/models/bridgetower/modeling_bridgetower.py +119 -141
  199. transformers/models/bridgetower/processing_bridgetower.py +16 -2
  200. transformers/models/bros/configuration_bros.py +18 -24
  201. transformers/models/bros/modeling_bros.py +80 -90
  202. transformers/models/bros/processing_bros.py +12 -2
  203. transformers/models/byt5/tokenization_byt5.py +6 -4
  204. transformers/models/camembert/configuration_camembert.py +2 -8
  205. transformers/models/camembert/modeling_camembert.py +195 -196
  206. transformers/models/camembert/modular_camembert.py +54 -51
  207. transformers/models/camembert/tokenization_camembert.py +13 -6
  208. transformers/models/canine/configuration_canine.py +2 -4
  209. transformers/models/canine/modeling_canine.py +75 -84
  210. transformers/models/canine/tokenization_canine.py +1 -2
  211. transformers/models/chameleon/configuration_chameleon.py +34 -29
  212. transformers/models/chameleon/image_processing_chameleon.py +24 -21
  213. transformers/models/chameleon/image_processing_chameleon_fast.py +6 -5
  214. transformers/models/chameleon/modeling_chameleon.py +93 -142
  215. transformers/models/chameleon/processing_chameleon.py +41 -16
  216. transformers/models/chinese_clip/configuration_chinese_clip.py +8 -10
  217. transformers/models/chinese_clip/image_processing_chinese_clip.py +24 -21
  218. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +1 -0
  219. transformers/models/chinese_clip/modeling_chinese_clip.py +92 -96
  220. transformers/models/chinese_clip/processing_chinese_clip.py +15 -2
  221. transformers/models/clap/configuration_clap.py +9 -4
  222. transformers/models/clap/feature_extraction_clap.py +12 -11
  223. transformers/models/clap/modeling_clap.py +123 -136
  224. transformers/models/clap/processing_clap.py +15 -2
  225. transformers/models/clip/configuration_clip.py +2 -4
  226. transformers/models/clip/image_processing_clip.py +24 -21
  227. transformers/models/clip/image_processing_clip_fast.py +1 -9
  228. transformers/models/clip/modeling_clip.py +65 -65
  229. transformers/models/clip/processing_clip.py +14 -2
  230. transformers/models/clip/tokenization_clip.py +46 -21
  231. transformers/models/clipseg/configuration_clipseg.py +2 -4
  232. transformers/models/clipseg/modeling_clipseg.py +109 -119
  233. transformers/models/clipseg/processing_clipseg.py +42 -19
  234. transformers/models/clvp/configuration_clvp.py +5 -15
  235. transformers/models/clvp/feature_extraction_clvp.py +10 -7
  236. transformers/models/clvp/modeling_clvp.py +146 -155
  237. transformers/models/clvp/number_normalizer.py +2 -1
  238. transformers/models/clvp/processing_clvp.py +20 -3
  239. transformers/models/clvp/tokenization_clvp.py +64 -1
  240. transformers/models/code_llama/tokenization_code_llama.py +44 -18
  241. transformers/models/codegen/configuration_codegen.py +4 -4
  242. transformers/models/codegen/modeling_codegen.py +53 -63
  243. transformers/models/codegen/tokenization_codegen.py +47 -17
  244. transformers/models/cohere/configuration_cohere.py +30 -25
  245. transformers/models/cohere/modeling_cohere.py +42 -40
  246. transformers/models/cohere/modular_cohere.py +29 -26
  247. transformers/models/cohere/tokenization_cohere.py +46 -15
  248. transformers/models/cohere2/configuration_cohere2.py +32 -31
  249. transformers/models/cohere2/modeling_cohere2.py +44 -42
  250. transformers/models/cohere2/modular_cohere2.py +54 -54
  251. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +14 -13
  252. transformers/models/cohere2_vision/modeling_cohere2_vision.py +58 -59
  253. transformers/models/cohere2_vision/modular_cohere2_vision.py +46 -45
  254. transformers/models/cohere2_vision/processing_cohere2_vision.py +36 -6
  255. transformers/models/colpali/configuration_colpali.py +1 -0
  256. transformers/models/colpali/modeling_colpali.py +16 -14
  257. transformers/models/colpali/modular_colpali.py +51 -11
  258. transformers/models/colpali/processing_colpali.py +52 -14
  259. transformers/models/colqwen2/modeling_colqwen2.py +28 -28
  260. transformers/models/colqwen2/modular_colqwen2.py +74 -37
  261. transformers/models/colqwen2/processing_colqwen2.py +52 -16
  262. transformers/models/conditional_detr/configuration_conditional_detr.py +2 -1
  263. transformers/models/conditional_detr/image_processing_conditional_detr.py +70 -67
  264. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +36 -36
  265. transformers/models/conditional_detr/modeling_conditional_detr.py +87 -99
  266. transformers/models/conditional_detr/modular_conditional_detr.py +3 -49
  267. transformers/models/convbert/configuration_convbert.py +8 -11
  268. transformers/models/convbert/modeling_convbert.py +87 -94
  269. transformers/models/convbert/tokenization_convbert.py +1 -0
  270. transformers/models/convnext/configuration_convnext.py +1 -0
  271. transformers/models/convnext/image_processing_convnext.py +23 -20
  272. transformers/models/convnext/image_processing_convnext_fast.py +21 -16
  273. transformers/models/convnext/modeling_convnext.py +12 -9
  274. transformers/models/convnextv2/configuration_convnextv2.py +1 -0
  275. transformers/models/convnextv2/modeling_convnextv2.py +12 -9
  276. transformers/models/cpm/tokenization_cpm.py +7 -6
  277. transformers/models/cpm/tokenization_cpm_fast.py +5 -3
  278. transformers/models/cpmant/configuration_cpmant.py +1 -4
  279. transformers/models/cpmant/modeling_cpmant.py +40 -38
  280. transformers/models/cpmant/tokenization_cpmant.py +3 -1
  281. transformers/models/csm/configuration_csm.py +66 -58
  282. transformers/models/csm/generation_csm.py +35 -31
  283. transformers/models/csm/modeling_csm.py +85 -85
  284. transformers/models/csm/modular_csm.py +58 -58
  285. transformers/models/csm/processing_csm.py +68 -25
  286. transformers/models/ctrl/configuration_ctrl.py +1 -16
  287. transformers/models/ctrl/modeling_ctrl.py +44 -54
  288. transformers/models/ctrl/tokenization_ctrl.py +1 -0
  289. transformers/models/cvt/configuration_cvt.py +1 -0
  290. transformers/models/cvt/modeling_cvt.py +16 -20
  291. transformers/models/cwm/__init__.py +1 -0
  292. transformers/models/cwm/configuration_cwm.py +12 -8
  293. transformers/models/cwm/modeling_cwm.py +39 -37
  294. transformers/models/cwm/modular_cwm.py +12 -10
  295. transformers/models/d_fine/configuration_d_fine.py +5 -7
  296. transformers/models/d_fine/modeling_d_fine.py +128 -138
  297. transformers/models/d_fine/modular_d_fine.py +18 -33
  298. transformers/models/dab_detr/configuration_dab_detr.py +3 -6
  299. transformers/models/dab_detr/modeling_dab_detr.py +75 -81
  300. transformers/models/dac/configuration_dac.py +1 -0
  301. transformers/models/dac/feature_extraction_dac.py +9 -6
  302. transformers/models/dac/modeling_dac.py +26 -24
  303. transformers/models/data2vec/configuration_data2vec_audio.py +2 -4
  304. transformers/models/data2vec/configuration_data2vec_text.py +3 -11
  305. transformers/models/data2vec/configuration_data2vec_vision.py +1 -0
  306. transformers/models/data2vec/modeling_data2vec_audio.py +56 -57
  307. transformers/models/data2vec/modeling_data2vec_text.py +93 -98
  308. transformers/models/data2vec/modeling_data2vec_vision.py +45 -49
  309. transformers/models/data2vec/modular_data2vec_audio.py +1 -6
  310. transformers/models/data2vec/modular_data2vec_text.py +54 -58
  311. transformers/models/dbrx/configuration_dbrx.py +22 -36
  312. transformers/models/dbrx/modeling_dbrx.py +45 -42
  313. transformers/models/dbrx/modular_dbrx.py +33 -31
  314. transformers/models/deberta/configuration_deberta.py +1 -6
  315. transformers/models/deberta/modeling_deberta.py +60 -64
  316. transformers/models/deberta/tokenization_deberta.py +21 -9
  317. transformers/models/deberta_v2/configuration_deberta_v2.py +1 -6
  318. transformers/models/deberta_v2/modeling_deberta_v2.py +65 -71
  319. transformers/models/deberta_v2/tokenization_deberta_v2.py +29 -11
  320. transformers/models/decision_transformer/configuration_decision_transformer.py +2 -3
  321. transformers/models/decision_transformer/modeling_decision_transformer.py +56 -60
  322. transformers/models/deepseek_v2/configuration_deepseek_v2.py +44 -39
  323. transformers/models/deepseek_v2/modeling_deepseek_v2.py +43 -43
  324. transformers/models/deepseek_v2/modular_deepseek_v2.py +49 -48
  325. transformers/models/deepseek_v3/configuration_deepseek_v3.py +45 -40
  326. transformers/models/deepseek_v3/modeling_deepseek_v3.py +42 -45
  327. transformers/models/deepseek_v3/modular_deepseek_v3.py +9 -14
  328. transformers/models/deepseek_vl/configuration_deepseek_vl.py +3 -2
  329. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +26 -25
  330. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +10 -10
  331. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -57
  332. transformers/models/deepseek_vl/modular_deepseek_vl.py +43 -14
  333. transformers/models/deepseek_vl/processing_deepseek_vl.py +41 -10
  334. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +5 -3
  335. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +35 -35
  336. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +24 -20
  337. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +61 -109
  338. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +118 -146
  339. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +44 -12
  340. transformers/models/deformable_detr/configuration_deformable_detr.py +3 -2
  341. transformers/models/deformable_detr/image_processing_deformable_detr.py +61 -59
  342. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +28 -28
  343. transformers/models/deformable_detr/modeling_deformable_detr.py +82 -88
  344. transformers/models/deformable_detr/modular_deformable_detr.py +3 -1
  345. transformers/models/deit/configuration_deit.py +1 -0
  346. transformers/models/deit/image_processing_deit.py +21 -18
  347. transformers/models/deit/image_processing_deit_fast.py +1 -0
  348. transformers/models/deit/modeling_deit.py +22 -24
  349. transformers/models/depth_anything/configuration_depth_anything.py +4 -2
  350. transformers/models/depth_anything/modeling_depth_anything.py +10 -10
  351. transformers/models/depth_pro/configuration_depth_pro.py +1 -0
  352. transformers/models/depth_pro/image_processing_depth_pro.py +23 -22
  353. transformers/models/depth_pro/image_processing_depth_pro_fast.py +10 -8
  354. transformers/models/depth_pro/modeling_depth_pro.py +27 -31
  355. transformers/models/detr/configuration_detr.py +2 -1
  356. transformers/models/detr/image_processing_detr.py +66 -64
  357. transformers/models/detr/image_processing_detr_fast.py +34 -33
  358. transformers/models/detr/modeling_detr.py +79 -95
  359. transformers/models/dia/configuration_dia.py +15 -9
  360. transformers/models/dia/feature_extraction_dia.py +9 -6
  361. transformers/models/dia/generation_dia.py +50 -48
  362. transformers/models/dia/modeling_dia.py +69 -78
  363. transformers/models/dia/modular_dia.py +56 -64
  364. transformers/models/dia/processing_dia.py +29 -39
  365. transformers/models/dia/tokenization_dia.py +6 -3
  366. transformers/models/diffllama/configuration_diffllama.py +30 -25
  367. transformers/models/diffllama/modeling_diffllama.py +49 -46
  368. transformers/models/diffllama/modular_diffllama.py +19 -17
  369. transformers/models/dinat/configuration_dinat.py +1 -0
  370. transformers/models/dinat/modeling_dinat.py +44 -47
  371. transformers/models/dinov2/configuration_dinov2.py +1 -0
  372. transformers/models/dinov2/modeling_dinov2.py +15 -15
  373. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +1 -1
  374. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +15 -16
  375. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +9 -9
  376. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +7 -4
  377. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +6 -3
  378. transformers/models/dinov3_vit/configuration_dinov3_vit.py +8 -5
  379. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +9 -7
  380. transformers/models/dinov3_vit/modeling_dinov3_vit.py +18 -19
  381. transformers/models/dinov3_vit/modular_dinov3_vit.py +15 -16
  382. transformers/models/distilbert/configuration_distilbert.py +2 -8
  383. transformers/models/distilbert/modeling_distilbert.py +55 -55
  384. transformers/models/distilbert/tokenization_distilbert.py +1 -13
  385. transformers/models/doge/__init__.py +1 -0
  386. transformers/models/doge/configuration_doge.py +32 -39
  387. transformers/models/doge/modeling_doge.py +49 -45
  388. transformers/models/doge/modular_doge.py +63 -71
  389. transformers/models/donut/configuration_donut_swin.py +1 -0
  390. transformers/models/donut/image_processing_donut.py +29 -26
  391. transformers/models/donut/image_processing_donut_fast.py +15 -9
  392. transformers/models/donut/modeling_donut_swin.py +58 -62
  393. transformers/models/donut/processing_donut.py +26 -5
  394. transformers/models/dots1/configuration_dots1.py +33 -41
  395. transformers/models/dots1/modeling_dots1.py +45 -54
  396. transformers/models/dots1/modular_dots1.py +4 -5
  397. transformers/models/dpr/configuration_dpr.py +2 -19
  398. transformers/models/dpr/modeling_dpr.py +39 -42
  399. transformers/models/dpr/tokenization_dpr.py +9 -19
  400. transformers/models/dpr/tokenization_dpr_fast.py +9 -7
  401. transformers/models/dpt/configuration_dpt.py +2 -1
  402. transformers/models/dpt/image_processing_dpt.py +66 -65
  403. transformers/models/dpt/image_processing_dpt_fast.py +20 -18
  404. transformers/models/dpt/modeling_dpt.py +30 -32
  405. transformers/models/dpt/modular_dpt.py +17 -15
  406. transformers/models/edgetam/configuration_edgetam.py +3 -2
  407. transformers/models/edgetam/modeling_edgetam.py +86 -86
  408. transformers/models/edgetam/modular_edgetam.py +26 -21
  409. transformers/models/edgetam_video/__init__.py +1 -0
  410. transformers/models/edgetam_video/configuration_edgetam_video.py +1 -0
  411. transformers/models/edgetam_video/modeling_edgetam_video.py +158 -169
  412. transformers/models/edgetam_video/modular_edgetam_video.py +37 -30
  413. transformers/models/efficientloftr/configuration_efficientloftr.py +5 -4
  414. transformers/models/efficientloftr/image_processing_efficientloftr.py +16 -14
  415. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +9 -9
  416. transformers/models/efficientloftr/modeling_efficientloftr.py +38 -59
  417. transformers/models/efficientloftr/modular_efficientloftr.py +3 -1
  418. transformers/models/efficientnet/configuration_efficientnet.py +1 -0
  419. transformers/models/efficientnet/image_processing_efficientnet.py +32 -28
  420. transformers/models/efficientnet/image_processing_efficientnet_fast.py +19 -17
  421. transformers/models/efficientnet/modeling_efficientnet.py +15 -19
  422. transformers/models/electra/configuration_electra.py +3 -13
  423. transformers/models/electra/modeling_electra.py +103 -108
  424. transformers/models/emu3/configuration_emu3.py +17 -13
  425. transformers/models/emu3/image_processing_emu3.py +39 -44
  426. transformers/models/emu3/modeling_emu3.py +108 -148
  427. transformers/models/emu3/modular_emu3.py +73 -115
  428. transformers/models/emu3/processing_emu3.py +43 -18
  429. transformers/models/encodec/configuration_encodec.py +4 -2
  430. transformers/models/encodec/feature_extraction_encodec.py +13 -10
  431. transformers/models/encodec/modeling_encodec.py +29 -39
  432. transformers/models/encoder_decoder/configuration_encoder_decoder.py +2 -12
  433. transformers/models/encoder_decoder/modeling_encoder_decoder.py +43 -37
  434. transformers/models/eomt/configuration_eomt.py +1 -0
  435. transformers/models/eomt/image_processing_eomt.py +56 -66
  436. transformers/models/eomt/image_processing_eomt_fast.py +33 -76
  437. transformers/models/eomt/modeling_eomt.py +18 -23
  438. transformers/models/eomt/modular_eomt.py +13 -18
  439. transformers/models/ernie/configuration_ernie.py +3 -24
  440. transformers/models/ernie/modeling_ernie.py +132 -127
  441. transformers/models/ernie/modular_ernie.py +103 -97
  442. transformers/models/ernie4_5/configuration_ernie4_5.py +27 -23
  443. transformers/models/ernie4_5/modeling_ernie4_5.py +38 -36
  444. transformers/models/ernie4_5/modular_ernie4_5.py +4 -3
  445. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +36 -32
  446. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +55 -56
  447. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +46 -18
  448. transformers/models/esm/configuration_esm.py +15 -11
  449. transformers/models/esm/modeling_esm.py +34 -38
  450. transformers/models/esm/modeling_esmfold.py +49 -53
  451. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  452. transformers/models/esm/openfold_utils/loss.py +2 -1
  453. transformers/models/esm/openfold_utils/protein.py +16 -15
  454. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  455. transformers/models/esm/tokenization_esm.py +4 -2
  456. transformers/models/evolla/configuration_evolla.py +40 -50
  457. transformers/models/evolla/modeling_evolla.py +66 -71
  458. transformers/models/evolla/modular_evolla.py +47 -53
  459. transformers/models/evolla/processing_evolla.py +35 -23
  460. transformers/models/exaone4/configuration_exaone4.py +25 -23
  461. transformers/models/exaone4/modeling_exaone4.py +38 -35
  462. transformers/models/exaone4/modular_exaone4.py +46 -44
  463. transformers/models/falcon/configuration_falcon.py +26 -31
  464. transformers/models/falcon/modeling_falcon.py +80 -82
  465. transformers/models/falcon_h1/configuration_falcon_h1.py +51 -45
  466. transformers/models/falcon_h1/modeling_falcon_h1.py +82 -85
  467. transformers/models/falcon_h1/modular_falcon_h1.py +51 -56
  468. transformers/models/falcon_mamba/configuration_falcon_mamba.py +2 -1
  469. transformers/models/falcon_mamba/modeling_falcon_mamba.py +82 -75
  470. transformers/models/falcon_mamba/modular_falcon_mamba.py +45 -28
  471. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +6 -2
  472. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +60 -76
  473. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +3 -2
  474. transformers/models/flaubert/configuration_flaubert.py +5 -10
  475. transformers/models/flaubert/modeling_flaubert.py +143 -145
  476. transformers/models/flaubert/tokenization_flaubert.py +5 -3
  477. transformers/models/flava/configuration_flava.py +6 -5
  478. transformers/models/flava/image_processing_flava.py +67 -66
  479. transformers/models/flava/image_processing_flava_fast.py +49 -46
  480. transformers/models/flava/modeling_flava.py +136 -153
  481. transformers/models/flava/processing_flava.py +12 -2
  482. transformers/models/flex_olmo/__init__.py +1 -0
  483. transformers/models/flex_olmo/configuration_flex_olmo.py +32 -28
  484. transformers/models/flex_olmo/modeling_flex_olmo.py +47 -47
  485. transformers/models/flex_olmo/modular_flex_olmo.py +44 -40
  486. transformers/models/florence2/configuration_florence2.py +1 -0
  487. transformers/models/florence2/modeling_florence2.py +69 -111
  488. transformers/models/florence2/modular_florence2.py +101 -104
  489. transformers/models/florence2/processing_florence2.py +47 -18
  490. transformers/models/fnet/configuration_fnet.py +2 -6
  491. transformers/models/fnet/modeling_fnet.py +80 -83
  492. transformers/models/fnet/tokenization_fnet.py +1 -0
  493. transformers/models/focalnet/configuration_focalnet.py +1 -0
  494. transformers/models/focalnet/modeling_focalnet.py +45 -51
  495. transformers/models/fsmt/configuration_fsmt.py +17 -12
  496. transformers/models/fsmt/modeling_fsmt.py +48 -49
  497. transformers/models/fsmt/tokenization_fsmt.py +5 -3
  498. transformers/models/funnel/configuration_funnel.py +1 -8
  499. transformers/models/funnel/modeling_funnel.py +93 -99
  500. transformers/models/funnel/tokenization_funnel.py +27 -17
  501. transformers/models/fuyu/configuration_fuyu.py +34 -28
  502. transformers/models/fuyu/image_processing_fuyu.py +31 -29
  503. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  504. transformers/models/fuyu/modeling_fuyu.py +53 -53
  505. transformers/models/fuyu/processing_fuyu.py +34 -23
  506. transformers/models/gemma/configuration_gemma.py +30 -25
  507. transformers/models/gemma/modeling_gemma.py +50 -46
  508. transformers/models/gemma/modular_gemma.py +47 -42
  509. transformers/models/gemma/tokenization_gemma.py +30 -10
  510. transformers/models/gemma2/configuration_gemma2.py +35 -30
  511. transformers/models/gemma2/modeling_gemma2.py +42 -39
  512. transformers/models/gemma2/modular_gemma2.py +66 -63
  513. transformers/models/gemma3/configuration_gemma3.py +44 -44
  514. transformers/models/gemma3/image_processing_gemma3.py +31 -29
  515. transformers/models/gemma3/image_processing_gemma3_fast.py +13 -11
  516. transformers/models/gemma3/modeling_gemma3.py +207 -159
  517. transformers/models/gemma3/modular_gemma3.py +204 -153
  518. transformers/models/gemma3/processing_gemma3.py +5 -5
  519. transformers/models/gemma3n/configuration_gemma3n.py +26 -36
  520. transformers/models/gemma3n/feature_extraction_gemma3n.py +11 -9
  521. transformers/models/gemma3n/modeling_gemma3n.py +356 -222
  522. transformers/models/gemma3n/modular_gemma3n.py +207 -230
  523. transformers/models/gemma3n/processing_gemma3n.py +26 -12
  524. transformers/models/git/configuration_git.py +8 -5
  525. transformers/models/git/modeling_git.py +204 -266
  526. transformers/models/git/processing_git.py +14 -2
  527. transformers/models/glm/configuration_glm.py +28 -24
  528. transformers/models/glm/modeling_glm.py +40 -37
  529. transformers/models/glm/modular_glm.py +7 -4
  530. transformers/models/glm4/configuration_glm4.py +28 -24
  531. transformers/models/glm4/modeling_glm4.py +42 -40
  532. transformers/models/glm4/modular_glm4.py +10 -8
  533. transformers/models/glm46v/configuration_glm46v.py +1 -0
  534. transformers/models/glm46v/image_processing_glm46v.py +40 -35
  535. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  536. transformers/models/glm46v/modeling_glm46v.py +90 -137
  537. transformers/models/glm46v/modular_glm46v.py +3 -4
  538. transformers/models/glm46v/processing_glm46v.py +41 -7
  539. transformers/models/glm46v/video_processing_glm46v.py +11 -9
  540. transformers/models/glm4_moe/configuration_glm4_moe.py +32 -40
  541. transformers/models/glm4_moe/modeling_glm4_moe.py +42 -45
  542. transformers/models/glm4_moe/modular_glm4_moe.py +34 -42
  543. transformers/models/glm4v/configuration_glm4v.py +20 -18
  544. transformers/models/glm4v/image_processing_glm4v.py +40 -34
  545. transformers/models/glm4v/image_processing_glm4v_fast.py +9 -8
  546. transformers/models/glm4v/modeling_glm4v.py +205 -254
  547. transformers/models/glm4v/modular_glm4v.py +224 -210
  548. transformers/models/glm4v/processing_glm4v.py +41 -7
  549. transformers/models/glm4v/video_processing_glm4v.py +11 -9
  550. transformers/models/glm4v_moe/configuration_glm4v_moe.py +125 -136
  551. transformers/models/glm4v_moe/modeling_glm4v_moe.py +368 -377
  552. transformers/models/glm4v_moe/modular_glm4v_moe.py +169 -83
  553. transformers/models/glpn/configuration_glpn.py +1 -0
  554. transformers/models/glpn/image_processing_glpn.py +12 -11
  555. transformers/models/glpn/image_processing_glpn_fast.py +13 -11
  556. transformers/models/glpn/modeling_glpn.py +14 -16
  557. transformers/models/got_ocr2/configuration_got_ocr2.py +12 -4
  558. transformers/models/got_ocr2/image_processing_got_ocr2.py +24 -22
  559. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +11 -9
  560. transformers/models/got_ocr2/modeling_got_ocr2.py +80 -77
  561. transformers/models/got_ocr2/modular_got_ocr2.py +51 -54
  562. transformers/models/got_ocr2/processing_got_ocr2.py +63 -42
  563. transformers/models/gpt2/configuration_gpt2.py +2 -13
  564. transformers/models/gpt2/modeling_gpt2.py +115 -120
  565. transformers/models/gpt2/tokenization_gpt2.py +46 -15
  566. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +2 -5
  567. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +89 -79
  568. transformers/models/gpt_neo/configuration_gpt_neo.py +2 -9
  569. transformers/models/gpt_neo/modeling_gpt_neo.py +67 -83
  570. transformers/models/gpt_neox/configuration_gpt_neox.py +25 -25
  571. transformers/models/gpt_neox/modeling_gpt_neox.py +75 -76
  572. transformers/models/gpt_neox/modular_gpt_neox.py +66 -67
  573. transformers/models/gpt_neox/tokenization_gpt_neox.py +51 -9
  574. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +19 -24
  575. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +47 -46
  576. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +3 -1
  577. transformers/models/gpt_oss/configuration_gpt_oss.py +28 -46
  578. transformers/models/gpt_oss/modeling_gpt_oss.py +121 -83
  579. transformers/models/gpt_oss/modular_gpt_oss.py +103 -64
  580. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  581. transformers/models/gptj/configuration_gptj.py +4 -4
  582. transformers/models/gptj/modeling_gptj.py +87 -101
  583. transformers/models/granite/configuration_granite.py +33 -28
  584. transformers/models/granite/modeling_granite.py +46 -44
  585. transformers/models/granite/modular_granite.py +31 -29
  586. transformers/models/granite_speech/configuration_granite_speech.py +1 -0
  587. transformers/models/granite_speech/feature_extraction_granite_speech.py +3 -1
  588. transformers/models/granite_speech/modeling_granite_speech.py +52 -82
  589. transformers/models/granite_speech/processing_granite_speech.py +4 -11
  590. transformers/models/granitemoe/configuration_granitemoe.py +36 -31
  591. transformers/models/granitemoe/modeling_granitemoe.py +46 -41
  592. transformers/models/granitemoe/modular_granitemoe.py +27 -22
  593. transformers/models/granitemoehybrid/__init__.py +1 -0
  594. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +47 -46
  595. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +93 -97
  596. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +21 -54
  597. transformers/models/granitemoeshared/configuration_granitemoeshared.py +37 -33
  598. transformers/models/granitemoeshared/modeling_granitemoeshared.py +61 -54
  599. transformers/models/granitemoeshared/modular_granitemoeshared.py +21 -19
  600. transformers/models/grounding_dino/configuration_grounding_dino.py +4 -6
  601. transformers/models/grounding_dino/image_processing_grounding_dino.py +62 -60
  602. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +29 -28
  603. transformers/models/grounding_dino/modeling_grounding_dino.py +140 -155
  604. transformers/models/grounding_dino/modular_grounding_dino.py +3 -2
  605. transformers/models/grounding_dino/processing_grounding_dino.py +38 -10
  606. transformers/models/groupvit/configuration_groupvit.py +2 -4
  607. transformers/models/groupvit/modeling_groupvit.py +93 -107
  608. transformers/models/helium/configuration_helium.py +29 -25
  609. transformers/models/helium/modeling_helium.py +40 -38
  610. transformers/models/helium/modular_helium.py +7 -3
  611. transformers/models/herbert/tokenization_herbert.py +28 -10
  612. transformers/models/hgnet_v2/configuration_hgnet_v2.py +1 -0
  613. transformers/models/hgnet_v2/modeling_hgnet_v2.py +10 -24
  614. transformers/models/hgnet_v2/modular_hgnet_v2.py +10 -24
  615. transformers/models/hiera/configuration_hiera.py +1 -0
  616. transformers/models/hiera/modeling_hiera.py +66 -72
  617. transformers/models/hubert/configuration_hubert.py +2 -4
  618. transformers/models/hubert/modeling_hubert.py +37 -42
  619. transformers/models/hubert/modular_hubert.py +11 -13
  620. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +31 -26
  621. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +38 -35
  622. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +6 -4
  623. transformers/models/hunyuan_v1_moe/__init__.py +1 -1
  624. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +36 -31
  625. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +42 -47
  626. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +9 -9
  627. transformers/models/ibert/configuration_ibert.py +2 -4
  628. transformers/models/ibert/modeling_ibert.py +62 -82
  629. transformers/models/ibert/quant_modules.py +1 -0
  630. transformers/models/idefics/configuration_idefics.py +8 -5
  631. transformers/models/idefics/image_processing_idefics.py +15 -13
  632. transformers/models/idefics/modeling_idefics.py +82 -75
  633. transformers/models/idefics/perceiver.py +3 -1
  634. transformers/models/idefics/processing_idefics.py +48 -32
  635. transformers/models/idefics/vision.py +25 -24
  636. transformers/models/idefics2/configuration_idefics2.py +3 -1
  637. transformers/models/idefics2/image_processing_idefics2.py +32 -31
  638. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  639. transformers/models/idefics2/modeling_idefics2.py +101 -127
  640. transformers/models/idefics2/processing_idefics2.py +68 -10
  641. transformers/models/idefics3/configuration_idefics3.py +4 -1
  642. transformers/models/idefics3/image_processing_idefics3.py +43 -42
  643. transformers/models/idefics3/image_processing_idefics3_fast.py +15 -40
  644. transformers/models/idefics3/modeling_idefics3.py +90 -115
  645. transformers/models/idefics3/processing_idefics3.py +69 -15
  646. transformers/models/ijepa/configuration_ijepa.py +1 -0
  647. transformers/models/ijepa/modeling_ijepa.py +11 -10
  648. transformers/models/ijepa/modular_ijepa.py +7 -5
  649. transformers/models/imagegpt/configuration_imagegpt.py +2 -9
  650. transformers/models/imagegpt/image_processing_imagegpt.py +18 -17
  651. transformers/models/imagegpt/image_processing_imagegpt_fast.py +16 -11
  652. transformers/models/imagegpt/modeling_imagegpt.py +65 -76
  653. transformers/models/informer/configuration_informer.py +9 -6
  654. transformers/models/informer/modeling_informer.py +86 -88
  655. transformers/models/informer/modular_informer.py +16 -14
  656. transformers/models/instructblip/configuration_instructblip.py +2 -2
  657. transformers/models/instructblip/modeling_instructblip.py +63 -103
  658. transformers/models/instructblip/processing_instructblip.py +36 -10
  659. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  660. transformers/models/instructblipvideo/modeling_instructblipvideo.py +139 -157
  661. transformers/models/instructblipvideo/modular_instructblipvideo.py +64 -73
  662. transformers/models/instructblipvideo/processing_instructblipvideo.py +33 -14
  663. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +8 -6
  664. transformers/models/internvl/configuration_internvl.py +1 -0
  665. transformers/models/internvl/modeling_internvl.py +106 -85
  666. transformers/models/internvl/modular_internvl.py +67 -47
  667. transformers/models/internvl/processing_internvl.py +45 -12
  668. transformers/models/internvl/video_processing_internvl.py +12 -10
  669. transformers/models/jamba/configuration_jamba.py +8 -5
  670. transformers/models/jamba/modeling_jamba.py +66 -68
  671. transformers/models/jamba/modular_jamba.py +55 -54
  672. transformers/models/janus/configuration_janus.py +1 -0
  673. transformers/models/janus/image_processing_janus.py +37 -35
  674. transformers/models/janus/image_processing_janus_fast.py +20 -18
  675. transformers/models/janus/modeling_janus.py +191 -115
  676. transformers/models/janus/modular_janus.py +84 -133
  677. transformers/models/janus/processing_janus.py +43 -17
  678. transformers/models/jetmoe/configuration_jetmoe.py +26 -24
  679. transformers/models/jetmoe/modeling_jetmoe.py +46 -43
  680. transformers/models/jetmoe/modular_jetmoe.py +33 -31
  681. transformers/models/kosmos2/configuration_kosmos2.py +9 -10
  682. transformers/models/kosmos2/modeling_kosmos2.py +173 -208
  683. transformers/models/kosmos2/processing_kosmos2.py +55 -40
  684. transformers/models/kosmos2_5/__init__.py +1 -0
  685. transformers/models/kosmos2_5/configuration_kosmos2_5.py +9 -8
  686. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +12 -10
  687. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +13 -4
  688. transformers/models/kosmos2_5/modeling_kosmos2_5.py +118 -132
  689. transformers/models/kosmos2_5/processing_kosmos2_5.py +29 -8
  690. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +28 -31
  691. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +14 -12
  692. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +100 -110
  693. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +22 -28
  694. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +8 -2
  695. transformers/models/layoutlm/configuration_layoutlm.py +2 -14
  696. transformers/models/layoutlm/modeling_layoutlm.py +72 -77
  697. transformers/models/layoutlmv2/configuration_layoutlmv2.py +17 -14
  698. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +21 -18
  699. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +9 -7
  700. transformers/models/layoutlmv2/modeling_layoutlmv2.py +50 -64
  701. transformers/models/layoutlmv2/processing_layoutlmv2.py +44 -14
  702. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +126 -73
  703. transformers/models/layoutlmv3/configuration_layoutlmv3.py +19 -16
  704. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +26 -24
  705. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +11 -9
  706. transformers/models/layoutlmv3/modeling_layoutlmv3.py +56 -82
  707. transformers/models/layoutlmv3/processing_layoutlmv3.py +46 -14
  708. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +134 -74
  709. transformers/models/layoutxlm/configuration_layoutxlm.py +17 -14
  710. transformers/models/layoutxlm/modular_layoutxlm.py +1 -0
  711. transformers/models/layoutxlm/processing_layoutxlm.py +44 -14
  712. transformers/models/layoutxlm/tokenization_layoutxlm.py +113 -77
  713. transformers/models/led/configuration_led.py +12 -8
  714. transformers/models/led/modeling_led.py +266 -124
  715. transformers/models/levit/configuration_levit.py +1 -0
  716. transformers/models/levit/image_processing_levit.py +21 -19
  717. transformers/models/levit/image_processing_levit_fast.py +5 -4
  718. transformers/models/levit/modeling_levit.py +19 -38
  719. transformers/models/lfm2/configuration_lfm2.py +30 -27
  720. transformers/models/lfm2/modeling_lfm2.py +50 -47
  721. transformers/models/lfm2/modular_lfm2.py +30 -29
  722. transformers/models/lfm2_moe/__init__.py +1 -0
  723. transformers/models/lfm2_moe/configuration_lfm2_moe.py +9 -6
  724. transformers/models/lfm2_moe/modeling_lfm2_moe.py +53 -61
  725. transformers/models/lfm2_moe/modular_lfm2_moe.py +37 -13
  726. transformers/models/lfm2_vl/configuration_lfm2_vl.py +1 -4
  727. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +12 -41
  728. transformers/models/lfm2_vl/modeling_lfm2_vl.py +66 -84
  729. transformers/models/lfm2_vl/modular_lfm2_vl.py +56 -70
  730. transformers/models/lfm2_vl/processing_lfm2_vl.py +76 -96
  731. transformers/models/lightglue/image_processing_lightglue.py +15 -16
  732. transformers/models/lightglue/image_processing_lightglue_fast.py +9 -9
  733. transformers/models/lightglue/modeling_lightglue.py +31 -31
  734. transformers/models/lightglue/modular_lightglue.py +28 -29
  735. transformers/models/lilt/configuration_lilt.py +2 -6
  736. transformers/models/lilt/modeling_lilt.py +70 -76
  737. transformers/models/llama/configuration_llama.py +31 -26
  738. transformers/models/llama/modeling_llama.py +39 -36
  739. transformers/models/llama/tokenization_llama.py +44 -14
  740. transformers/models/llama4/configuration_llama4.py +30 -27
  741. transformers/models/llama4/image_processing_llama4_fast.py +14 -12
  742. transformers/models/llama4/modeling_llama4.py +113 -120
  743. transformers/models/llama4/processing_llama4.py +57 -33
  744. transformers/models/llava/configuration_llava.py +1 -10
  745. transformers/models/llava/image_processing_llava.py +28 -25
  746. transformers/models/llava/image_processing_llava_fast.py +11 -9
  747. transformers/models/llava/modeling_llava.py +109 -85
  748. transformers/models/llava/processing_llava.py +51 -18
  749. transformers/models/llava_next/configuration_llava_next.py +2 -2
  750. transformers/models/llava_next/image_processing_llava_next.py +45 -43
  751. transformers/models/llava_next/image_processing_llava_next_fast.py +13 -11
  752. transformers/models/llava_next/modeling_llava_next.py +107 -110
  753. transformers/models/llava_next/processing_llava_next.py +47 -18
  754. transformers/models/llava_next_video/configuration_llava_next_video.py +7 -4
  755. transformers/models/llava_next_video/modeling_llava_next_video.py +158 -175
  756. transformers/models/llava_next_video/modular_llava_next_video.py +150 -155
  757. transformers/models/llava_next_video/processing_llava_next_video.py +63 -21
  758. transformers/models/llava_next_video/video_processing_llava_next_video.py +1 -0
  759. transformers/models/llava_onevision/configuration_llava_onevision.py +7 -4
  760. transformers/models/llava_onevision/image_processing_llava_onevision.py +42 -40
  761. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +15 -14
  762. transformers/models/llava_onevision/modeling_llava_onevision.py +169 -177
  763. transformers/models/llava_onevision/modular_llava_onevision.py +156 -163
  764. transformers/models/llava_onevision/processing_llava_onevision.py +53 -21
  765. transformers/models/llava_onevision/video_processing_llava_onevision.py +1 -0
  766. transformers/models/longcat_flash/__init__.py +1 -0
  767. transformers/models/longcat_flash/configuration_longcat_flash.py +42 -37
  768. transformers/models/longcat_flash/modeling_longcat_flash.py +36 -36
  769. transformers/models/longcat_flash/modular_longcat_flash.py +21 -21
  770. transformers/models/longformer/configuration_longformer.py +5 -5
  771. transformers/models/longformer/modeling_longformer.py +101 -105
  772. transformers/models/longt5/configuration_longt5.py +7 -9
  773. transformers/models/longt5/modeling_longt5.py +49 -49
  774. transformers/models/luke/configuration_luke.py +2 -8
  775. transformers/models/luke/modeling_luke.py +181 -188
  776. transformers/models/luke/tokenization_luke.py +140 -107
  777. transformers/models/lxmert/configuration_lxmert.py +1 -16
  778. transformers/models/lxmert/modeling_lxmert.py +74 -65
  779. transformers/models/m2m_100/configuration_m2m_100.py +9 -7
  780. transformers/models/m2m_100/modeling_m2m_100.py +71 -83
  781. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  782. transformers/models/mamba/configuration_mamba.py +2 -1
  783. transformers/models/mamba/modeling_mamba.py +66 -58
  784. transformers/models/mamba2/configuration_mamba2.py +8 -5
  785. transformers/models/mamba2/modeling_mamba2.py +69 -68
  786. transformers/models/marian/configuration_marian.py +5 -10
  787. transformers/models/marian/modeling_marian.py +87 -93
  788. transformers/models/marian/tokenization_marian.py +6 -6
  789. transformers/models/markuplm/configuration_markuplm.py +7 -4
  790. transformers/models/markuplm/feature_extraction_markuplm.py +2 -1
  791. transformers/models/markuplm/modeling_markuplm.py +70 -69
  792. transformers/models/markuplm/processing_markuplm.py +38 -31
  793. transformers/models/markuplm/tokenization_markuplm.py +136 -93
  794. transformers/models/mask2former/configuration_mask2former.py +8 -5
  795. transformers/models/mask2former/image_processing_mask2former.py +85 -84
  796. transformers/models/mask2former/image_processing_mask2former_fast.py +40 -37
  797. transformers/models/mask2former/modeling_mask2former.py +103 -118
  798. transformers/models/mask2former/modular_mask2former.py +8 -6
  799. transformers/models/maskformer/configuration_maskformer.py +9 -6
  800. transformers/models/maskformer/configuration_maskformer_swin.py +1 -0
  801. transformers/models/maskformer/image_processing_maskformer.py +85 -84
  802. transformers/models/maskformer/image_processing_maskformer_fast.py +40 -36
  803. transformers/models/maskformer/modeling_maskformer.py +65 -79
  804. transformers/models/maskformer/modeling_maskformer_swin.py +32 -36
  805. transformers/models/mbart/configuration_mbart.py +4 -9
  806. transformers/models/mbart/modeling_mbart.py +116 -131
  807. transformers/models/mbart/tokenization_mbart.py +54 -11
  808. transformers/models/mbart50/tokenization_mbart50.py +13 -8
  809. transformers/models/megatron_bert/configuration_megatron_bert.py +3 -13
  810. transformers/models/megatron_bert/modeling_megatron_bert.py +150 -148
  811. transformers/models/metaclip_2/configuration_metaclip_2.py +1 -4
  812. transformers/models/metaclip_2/modeling_metaclip_2.py +84 -91
  813. transformers/models/metaclip_2/modular_metaclip_2.py +45 -61
  814. transformers/models/mgp_str/configuration_mgp_str.py +1 -0
  815. transformers/models/mgp_str/modeling_mgp_str.py +18 -20
  816. transformers/models/mgp_str/processing_mgp_str.py +20 -3
  817. transformers/models/mgp_str/tokenization_mgp_str.py +3 -1
  818. transformers/models/mimi/configuration_mimi.py +40 -42
  819. transformers/models/mimi/modeling_mimi.py +113 -142
  820. transformers/models/minimax/__init__.py +1 -0
  821. transformers/models/minimax/configuration_minimax.py +43 -37
  822. transformers/models/minimax/modeling_minimax.py +51 -61
  823. transformers/models/minimax/modular_minimax.py +62 -68
  824. transformers/models/ministral/configuration_ministral.py +29 -25
  825. transformers/models/ministral/modeling_ministral.py +38 -36
  826. transformers/models/ministral/modular_ministral.py +37 -32
  827. transformers/models/ministral3/configuration_ministral3.py +27 -24
  828. transformers/models/ministral3/modeling_ministral3.py +37 -36
  829. transformers/models/ministral3/modular_ministral3.py +5 -4
  830. transformers/models/mistral/configuration_mistral.py +29 -24
  831. transformers/models/mistral/modeling_mistral.py +37 -36
  832. transformers/models/mistral/modular_mistral.py +12 -11
  833. transformers/models/mistral3/configuration_mistral3.py +1 -4
  834. transformers/models/mistral3/modeling_mistral3.py +86 -89
  835. transformers/models/mistral3/modular_mistral3.py +68 -69
  836. transformers/models/mixtral/configuration_mixtral.py +34 -29
  837. transformers/models/mixtral/modeling_mixtral.py +45 -50
  838. transformers/models/mixtral/modular_mixtral.py +31 -32
  839. transformers/models/mlcd/configuration_mlcd.py +1 -0
  840. transformers/models/mlcd/modeling_mlcd.py +14 -20
  841. transformers/models/mlcd/modular_mlcd.py +13 -17
  842. transformers/models/mllama/configuration_mllama.py +15 -10
  843. transformers/models/mllama/image_processing_mllama.py +25 -23
  844. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  845. transformers/models/mllama/modeling_mllama.py +94 -105
  846. transformers/models/mllama/processing_mllama.py +55 -6
  847. transformers/models/mluke/tokenization_mluke.py +107 -101
  848. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +3 -5
  849. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +140 -155
  850. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +3 -5
  851. transformers/models/mobilebert/configuration_mobilebert.py +2 -4
  852. transformers/models/mobilebert/modeling_mobilebert.py +85 -77
  853. transformers/models/mobilebert/tokenization_mobilebert.py +1 -0
  854. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +1 -0
  855. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +23 -20
  856. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +1 -0
  857. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +16 -15
  858. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +1 -0
  859. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +51 -48
  860. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +15 -13
  861. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +22 -24
  862. transformers/models/mobilevit/configuration_mobilevit.py +1 -0
  863. transformers/models/mobilevit/image_processing_mobilevit.py +49 -46
  864. transformers/models/mobilevit/image_processing_mobilevit_fast.py +14 -12
  865. transformers/models/mobilevit/modeling_mobilevit.py +21 -28
  866. transformers/models/mobilevitv2/configuration_mobilevitv2.py +1 -0
  867. transformers/models/mobilevitv2/modeling_mobilevitv2.py +22 -28
  868. transformers/models/modernbert/configuration_modernbert.py +42 -44
  869. transformers/models/modernbert/modeling_modernbert.py +133 -145
  870. transformers/models/modernbert/modular_modernbert.py +170 -186
  871. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +40 -40
  872. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +57 -62
  873. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +86 -94
  874. transformers/models/moonshine/configuration_moonshine.py +31 -34
  875. transformers/models/moonshine/modeling_moonshine.py +71 -71
  876. transformers/models/moonshine/modular_moonshine.py +83 -88
  877. transformers/models/moshi/configuration_moshi.py +23 -46
  878. transformers/models/moshi/modeling_moshi.py +187 -157
  879. transformers/models/mpnet/configuration_mpnet.py +2 -6
  880. transformers/models/mpnet/modeling_mpnet.py +57 -62
  881. transformers/models/mpnet/tokenization_mpnet.py +15 -4
  882. transformers/models/mpt/configuration_mpt.py +9 -5
  883. transformers/models/mpt/modeling_mpt.py +60 -60
  884. transformers/models/mra/configuration_mra.py +2 -8
  885. transformers/models/mra/modeling_mra.py +57 -64
  886. transformers/models/mt5/configuration_mt5.py +8 -10
  887. transformers/models/mt5/modeling_mt5.py +95 -87
  888. transformers/models/musicgen/configuration_musicgen.py +8 -12
  889. transformers/models/musicgen/modeling_musicgen.py +122 -118
  890. transformers/models/musicgen/processing_musicgen.py +21 -3
  891. transformers/models/musicgen_melody/configuration_musicgen_melody.py +8 -15
  892. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +9 -8
  893. transformers/models/musicgen_melody/modeling_musicgen_melody.py +123 -117
  894. transformers/models/musicgen_melody/processing_musicgen_melody.py +22 -3
  895. transformers/models/mvp/configuration_mvp.py +5 -8
  896. transformers/models/mvp/modeling_mvp.py +123 -135
  897. transformers/models/myt5/tokenization_myt5.py +10 -8
  898. transformers/models/nanochat/configuration_nanochat.py +8 -5
  899. transformers/models/nanochat/modeling_nanochat.py +40 -37
  900. transformers/models/nanochat/modular_nanochat.py +14 -12
  901. transformers/models/nemotron/configuration_nemotron.py +30 -25
  902. transformers/models/nemotron/modeling_nemotron.py +57 -56
  903. transformers/models/nllb/tokenization_nllb.py +28 -12
  904. transformers/models/nllb_moe/configuration_nllb_moe.py +9 -7
  905. transformers/models/nllb_moe/modeling_nllb_moe.py +69 -77
  906. transformers/models/nougat/image_processing_nougat.py +32 -29
  907. transformers/models/nougat/image_processing_nougat_fast.py +14 -12
  908. transformers/models/nougat/processing_nougat.py +39 -37
  909. transformers/models/nougat/tokenization_nougat.py +73 -18
  910. transformers/models/nystromformer/configuration_nystromformer.py +2 -8
  911. transformers/models/nystromformer/modeling_nystromformer.py +63 -74
  912. transformers/models/olmo/configuration_olmo.py +28 -23
  913. transformers/models/olmo/modeling_olmo.py +39 -36
  914. transformers/models/olmo/modular_olmo.py +11 -7
  915. transformers/models/olmo2/configuration_olmo2.py +28 -23
  916. transformers/models/olmo2/modeling_olmo2.py +41 -37
  917. transformers/models/olmo2/modular_olmo2.py +32 -29
  918. transformers/models/olmo3/__init__.py +1 -0
  919. transformers/models/olmo3/configuration_olmo3.py +30 -26
  920. transformers/models/olmo3/modeling_olmo3.py +39 -36
  921. transformers/models/olmo3/modular_olmo3.py +40 -37
  922. transformers/models/olmoe/configuration_olmoe.py +33 -29
  923. transformers/models/olmoe/modeling_olmoe.py +46 -52
  924. transformers/models/olmoe/modular_olmoe.py +15 -16
  925. transformers/models/omdet_turbo/configuration_omdet_turbo.py +4 -2
  926. transformers/models/omdet_turbo/modeling_omdet_turbo.py +47 -53
  927. transformers/models/omdet_turbo/processing_omdet_turbo.py +67 -19
  928. transformers/models/oneformer/configuration_oneformer.py +8 -5
  929. transformers/models/oneformer/image_processing_oneformer.py +84 -83
  930. transformers/models/oneformer/image_processing_oneformer_fast.py +42 -41
  931. transformers/models/oneformer/modeling_oneformer.py +171 -147
  932. transformers/models/oneformer/processing_oneformer.py +43 -28
  933. transformers/models/openai/configuration_openai.py +1 -16
  934. transformers/models/openai/modeling_openai.py +51 -65
  935. transformers/models/openai/tokenization_openai.py +47 -8
  936. transformers/models/opt/configuration_opt.py +7 -6
  937. transformers/models/opt/modeling_opt.py +76 -78
  938. transformers/models/ovis2/__init__.py +1 -0
  939. transformers/models/ovis2/configuration_ovis2.py +1 -0
  940. transformers/models/ovis2/image_processing_ovis2.py +24 -22
  941. transformers/models/ovis2/image_processing_ovis2_fast.py +11 -9
  942. transformers/models/ovis2/modeling_ovis2.py +142 -111
  943. transformers/models/ovis2/modular_ovis2.py +45 -90
  944. transformers/models/ovis2/processing_ovis2.py +40 -12
  945. transformers/models/owlv2/configuration_owlv2.py +2 -4
  946. transformers/models/owlv2/image_processing_owlv2.py +21 -20
  947. transformers/models/owlv2/image_processing_owlv2_fast.py +15 -12
  948. transformers/models/owlv2/modeling_owlv2.py +117 -133
  949. transformers/models/owlv2/modular_owlv2.py +14 -11
  950. transformers/models/owlv2/processing_owlv2.py +49 -20
  951. transformers/models/owlvit/configuration_owlvit.py +2 -4
  952. transformers/models/owlvit/image_processing_owlvit.py +22 -21
  953. transformers/models/owlvit/image_processing_owlvit_fast.py +3 -2
  954. transformers/models/owlvit/modeling_owlvit.py +116 -132
  955. transformers/models/owlvit/processing_owlvit.py +48 -20
  956. transformers/models/paligemma/configuration_paligemma.py +1 -4
  957. transformers/models/paligemma/modeling_paligemma.py +93 -103
  958. transformers/models/paligemma/processing_paligemma.py +66 -13
  959. transformers/models/parakeet/configuration_parakeet.py +14 -7
  960. transformers/models/parakeet/feature_extraction_parakeet.py +12 -10
  961. transformers/models/parakeet/modeling_parakeet.py +28 -32
  962. transformers/models/parakeet/modular_parakeet.py +20 -23
  963. transformers/models/parakeet/processing_parakeet.py +5 -13
  964. transformers/models/parakeet/{tokenization_parakeet.py → tokenization_parakeet_fast.py} +7 -5
  965. transformers/models/patchtsmixer/configuration_patchtsmixer.py +8 -5
  966. transformers/models/patchtsmixer/modeling_patchtsmixer.py +62 -70
  967. transformers/models/patchtst/configuration_patchtst.py +9 -6
  968. transformers/models/patchtst/modeling_patchtst.py +80 -97
  969. transformers/models/pegasus/configuration_pegasus.py +5 -8
  970. transformers/models/pegasus/modeling_pegasus.py +66 -72
  971. transformers/models/pegasus/tokenization_pegasus.py +45 -15
  972. transformers/models/pegasus_x/configuration_pegasus_x.py +4 -5
  973. transformers/models/pegasus_x/modeling_pegasus_x.py +52 -55
  974. transformers/models/perceiver/configuration_perceiver.py +1 -0
  975. transformers/models/perceiver/image_processing_perceiver.py +25 -22
  976. transformers/models/perceiver/image_processing_perceiver_fast.py +9 -7
  977. transformers/models/perceiver/modeling_perceiver.py +146 -165
  978. transformers/models/perceiver/tokenization_perceiver.py +6 -3
  979. transformers/models/perception_lm/configuration_perception_lm.py +1 -0
  980. transformers/models/perception_lm/image_processing_perception_lm_fast.py +10 -8
  981. transformers/models/perception_lm/modeling_perception_lm.py +70 -71
  982. transformers/models/perception_lm/modular_perception_lm.py +61 -65
  983. transformers/models/perception_lm/processing_perception_lm.py +47 -13
  984. transformers/models/perception_lm/video_processing_perception_lm.py +1 -0
  985. transformers/models/persimmon/configuration_persimmon.py +28 -23
  986. transformers/models/persimmon/modeling_persimmon.py +45 -43
  987. transformers/models/phi/configuration_phi.py +28 -23
  988. transformers/models/phi/modeling_phi.py +43 -40
  989. transformers/models/phi/modular_phi.py +24 -23
  990. transformers/models/phi3/configuration_phi3.py +33 -28
  991. transformers/models/phi3/modeling_phi3.py +38 -36
  992. transformers/models/phi3/modular_phi3.py +17 -13
  993. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +33 -30
  994. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +9 -7
  995. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  996. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +78 -95
  997. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +80 -98
  998. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +44 -7
  999. transformers/models/phimoe/configuration_phimoe.py +36 -31
  1000. transformers/models/phimoe/modeling_phimoe.py +45 -50
  1001. transformers/models/phimoe/modular_phimoe.py +4 -3
  1002. transformers/models/phobert/tokenization_phobert.py +6 -4
  1003. transformers/models/pix2struct/configuration_pix2struct.py +10 -12
  1004. transformers/models/pix2struct/image_processing_pix2struct.py +19 -15
  1005. transformers/models/pix2struct/image_processing_pix2struct_fast.py +15 -12
  1006. transformers/models/pix2struct/modeling_pix2struct.py +52 -58
  1007. transformers/models/pix2struct/processing_pix2struct.py +30 -5
  1008. transformers/models/pixtral/configuration_pixtral.py +14 -11
  1009. transformers/models/pixtral/image_processing_pixtral.py +28 -26
  1010. transformers/models/pixtral/image_processing_pixtral_fast.py +11 -10
  1011. transformers/models/pixtral/modeling_pixtral.py +34 -28
  1012. transformers/models/pixtral/processing_pixtral.py +53 -21
  1013. transformers/models/plbart/configuration_plbart.py +5 -8
  1014. transformers/models/plbart/modeling_plbart.py +106 -119
  1015. transformers/models/plbart/modular_plbart.py +33 -39
  1016. transformers/models/plbart/tokenization_plbart.py +7 -4
  1017. transformers/models/poolformer/configuration_poolformer.py +1 -0
  1018. transformers/models/poolformer/image_processing_poolformer.py +24 -21
  1019. transformers/models/poolformer/image_processing_poolformer_fast.py +15 -13
  1020. transformers/models/poolformer/modeling_poolformer.py +13 -23
  1021. transformers/models/pop2piano/configuration_pop2piano.py +8 -7
  1022. transformers/models/pop2piano/feature_extraction_pop2piano.py +9 -6
  1023. transformers/models/pop2piano/modeling_pop2piano.py +24 -26
  1024. transformers/models/pop2piano/processing_pop2piano.py +33 -25
  1025. transformers/models/pop2piano/tokenization_pop2piano.py +23 -15
  1026. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +3 -3
  1027. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1028. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +21 -20
  1029. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +13 -16
  1030. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +13 -16
  1031. transformers/models/prophetnet/configuration_prophetnet.py +38 -37
  1032. transformers/models/prophetnet/modeling_prophetnet.py +131 -114
  1033. transformers/models/prophetnet/tokenization_prophetnet.py +16 -14
  1034. transformers/models/pvt/configuration_pvt.py +1 -0
  1035. transformers/models/pvt/image_processing_pvt.py +27 -24
  1036. transformers/models/pvt/image_processing_pvt_fast.py +2 -1
  1037. transformers/models/pvt/modeling_pvt.py +21 -21
  1038. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -2
  1039. transformers/models/pvt_v2/modeling_pvt_v2.py +25 -28
  1040. transformers/models/qwen2/configuration_qwen2.py +25 -32
  1041. transformers/models/qwen2/modeling_qwen2.py +38 -36
  1042. transformers/models/qwen2/modular_qwen2.py +12 -11
  1043. transformers/models/qwen2/tokenization_qwen2.py +23 -12
  1044. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +26 -32
  1045. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +277 -340
  1046. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +211 -278
  1047. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +49 -41
  1048. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +35 -29
  1049. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +148 -203
  1050. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +118 -93
  1051. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +43 -7
  1052. transformers/models/qwen2_audio/configuration_qwen2_audio.py +1 -0
  1053. transformers/models/qwen2_audio/modeling_qwen2_audio.py +40 -40
  1054. transformers/models/qwen2_audio/processing_qwen2_audio.py +42 -13
  1055. transformers/models/qwen2_moe/configuration_qwen2_moe.py +35 -42
  1056. transformers/models/qwen2_moe/modeling_qwen2_moe.py +46 -51
  1057. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -7
  1058. transformers/models/qwen2_vl/configuration_qwen2_vl.py +34 -29
  1059. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +42 -41
  1060. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +15 -12
  1061. transformers/models/qwen2_vl/modeling_qwen2_vl.py +153 -199
  1062. transformers/models/qwen2_vl/processing_qwen2_vl.py +44 -7
  1063. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +18 -38
  1064. transformers/models/qwen3/configuration_qwen3.py +27 -34
  1065. transformers/models/qwen3/modeling_qwen3.py +39 -36
  1066. transformers/models/qwen3/modular_qwen3.py +6 -4
  1067. transformers/models/qwen3_moe/configuration_qwen3_moe.py +32 -39
  1068. transformers/models/qwen3_moe/modeling_qwen3_moe.py +46 -51
  1069. transformers/models/qwen3_moe/modular_qwen3_moe.py +13 -10
  1070. transformers/models/qwen3_next/configuration_qwen3_next.py +35 -45
  1071. transformers/models/qwen3_next/modeling_qwen3_next.py +51 -47
  1072. transformers/models/qwen3_next/modular_qwen3_next.py +35 -34
  1073. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +101 -135
  1074. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +252 -355
  1075. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +196 -250
  1076. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +48 -40
  1077. transformers/models/qwen3_vl/configuration_qwen3_vl.py +29 -27
  1078. transformers/models/qwen3_vl/modeling_qwen3_vl.py +155 -233
  1079. transformers/models/qwen3_vl/modular_qwen3_vl.py +179 -206
  1080. transformers/models/qwen3_vl/processing_qwen3_vl.py +42 -6
  1081. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +12 -10
  1082. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +30 -23
  1083. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +303 -358
  1084. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +124 -87
  1085. transformers/models/rag/configuration_rag.py +15 -6
  1086. transformers/models/rag/modeling_rag.py +130 -127
  1087. transformers/models/rag/retrieval_rag.py +5 -3
  1088. transformers/models/rag/tokenization_rag.py +50 -0
  1089. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +30 -29
  1090. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +42 -53
  1091. transformers/models/reformer/configuration_reformer.py +8 -7
  1092. transformers/models/reformer/modeling_reformer.py +69 -80
  1093. transformers/models/reformer/tokenization_reformer.py +31 -11
  1094. transformers/models/regnet/configuration_regnet.py +1 -0
  1095. transformers/models/regnet/modeling_regnet.py +8 -15
  1096. transformers/models/rembert/configuration_rembert.py +2 -8
  1097. transformers/models/rembert/modeling_rembert.py +111 -121
  1098. transformers/models/rembert/tokenization_rembert.py +12 -2
  1099. transformers/models/resnet/configuration_resnet.py +1 -0
  1100. transformers/models/resnet/modeling_resnet.py +13 -27
  1101. transformers/models/roberta/configuration_roberta.py +3 -11
  1102. transformers/models/roberta/modeling_roberta.py +93 -94
  1103. transformers/models/roberta/modular_roberta.py +58 -58
  1104. transformers/models/roberta/tokenization_roberta.py +29 -17
  1105. transformers/models/roberta/tokenization_roberta_old.py +4 -2
  1106. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +3 -11
  1107. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +93 -94
  1108. transformers/models/roc_bert/configuration_roc_bert.py +2 -8
  1109. transformers/models/roc_bert/modeling_roc_bert.py +121 -122
  1110. transformers/models/roc_bert/tokenization_roc_bert.py +94 -88
  1111. transformers/models/roformer/configuration_roformer.py +3 -13
  1112. transformers/models/roformer/modeling_roformer.py +81 -85
  1113. transformers/models/roformer/tokenization_roformer.py +412 -74
  1114. transformers/models/roformer/tokenization_roformer_fast.py +160 -0
  1115. transformers/models/roformer/tokenization_utils.py +1 -0
  1116. transformers/models/rt_detr/configuration_rt_detr.py +2 -1
  1117. transformers/models/rt_detr/configuration_rt_detr_resnet.py +1 -0
  1118. transformers/models/rt_detr/image_processing_rt_detr.py +55 -54
  1119. transformers/models/rt_detr/image_processing_rt_detr_fast.py +26 -26
  1120. transformers/models/rt_detr/modeling_rt_detr.py +90 -99
  1121. transformers/models/rt_detr/modeling_rt_detr_resnet.py +6 -13
  1122. transformers/models/rt_detr/modular_rt_detr.py +16 -16
  1123. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +4 -6
  1124. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +90 -101
  1125. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +12 -19
  1126. transformers/models/rwkv/configuration_rwkv.py +4 -2
  1127. transformers/models/rwkv/modeling_rwkv.py +32 -31
  1128. transformers/models/sam/configuration_sam.py +1 -3
  1129. transformers/models/sam/image_processing_sam.py +60 -59
  1130. transformers/models/sam/image_processing_sam_fast.py +27 -25
  1131. transformers/models/sam/modeling_sam.py +41 -47
  1132. transformers/models/sam/processing_sam.py +27 -39
  1133. transformers/models/sam2/configuration_sam2.py +3 -2
  1134. transformers/models/sam2/image_processing_sam2_fast.py +15 -14
  1135. transformers/models/sam2/modeling_sam2.py +90 -96
  1136. transformers/models/sam2/modular_sam2.py +91 -86
  1137. transformers/models/sam2/processing_sam2.py +47 -31
  1138. transformers/models/sam2_video/configuration_sam2_video.py +1 -0
  1139. transformers/models/sam2_video/modeling_sam2_video.py +144 -151
  1140. transformers/models/sam2_video/modular_sam2_video.py +104 -101
  1141. transformers/models/sam2_video/processing_sam2_video.py +66 -49
  1142. transformers/models/sam2_video/video_processing_sam2_video.py +4 -1
  1143. transformers/models/sam3/configuration_sam3.py +2 -21
  1144. transformers/models/sam3/image_processing_sam3_fast.py +20 -17
  1145. transformers/models/sam3/modeling_sam3.py +170 -184
  1146. transformers/models/sam3/modular_sam3.py +8 -3
  1147. transformers/models/sam3/processing_sam3.py +52 -37
  1148. transformers/models/sam3_tracker/__init__.py +1 -0
  1149. transformers/models/sam3_tracker/configuration_sam3_tracker.py +3 -1
  1150. transformers/models/sam3_tracker/modeling_sam3_tracker.py +77 -82
  1151. transformers/models/sam3_tracker/modular_sam3_tracker.py +3 -8
  1152. transformers/models/sam3_tracker/processing_sam3_tracker.py +48 -31
  1153. transformers/models/sam3_tracker_video/__init__.py +1 -0
  1154. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +1 -25
  1155. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +122 -135
  1156. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +26 -35
  1157. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +66 -50
  1158. transformers/models/sam3_video/configuration_sam3_video.py +1 -14
  1159. transformers/models/sam3_video/modeling_sam3_video.py +34 -33
  1160. transformers/models/sam3_video/processing_sam3_video.py +46 -26
  1161. transformers/models/sam_hq/__init__.py +1 -1
  1162. transformers/models/sam_hq/configuration_sam_hq.py +1 -3
  1163. transformers/models/sam_hq/modeling_sam_hq.py +69 -74
  1164. transformers/models/sam_hq/modular_sam_hq.py +25 -23
  1165. transformers/models/sam_hq/{processing_sam_hq.py → processing_samhq.py} +29 -41
  1166. transformers/models/seamless_m4t/configuration_seamless_m4t.py +10 -8
  1167. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +11 -8
  1168. transformers/models/seamless_m4t/modeling_seamless_m4t.py +194 -212
  1169. transformers/models/seamless_m4t/processing_seamless_m4t.py +39 -18
  1170. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +77 -40
  1171. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +10 -8
  1172. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +196 -204
  1173. transformers/models/seed_oss/configuration_seed_oss.py +32 -28
  1174. transformers/models/seed_oss/modeling_seed_oss.py +35 -33
  1175. transformers/models/seed_oss/modular_seed_oss.py +4 -3
  1176. transformers/models/segformer/configuration_segformer.py +10 -0
  1177. transformers/models/segformer/image_processing_segformer.py +42 -39
  1178. transformers/models/segformer/image_processing_segformer_fast.py +12 -10
  1179. transformers/models/segformer/modeling_segformer.py +31 -34
  1180. transformers/models/segformer/modular_segformer.py +10 -8
  1181. transformers/models/seggpt/configuration_seggpt.py +1 -0
  1182. transformers/models/seggpt/image_processing_seggpt.py +41 -38
  1183. transformers/models/seggpt/modeling_seggpt.py +38 -50
  1184. transformers/models/sew/configuration_sew.py +2 -4
  1185. transformers/models/sew/modeling_sew.py +36 -38
  1186. transformers/models/sew/modular_sew.py +13 -13
  1187. transformers/models/sew_d/configuration_sew_d.py +2 -4
  1188. transformers/models/sew_d/modeling_sew_d.py +30 -31
  1189. transformers/models/shieldgemma2/configuration_shieldgemma2.py +1 -0
  1190. transformers/models/shieldgemma2/modeling_shieldgemma2.py +17 -16
  1191. transformers/models/shieldgemma2/processing_shieldgemma2.py +5 -3
  1192. transformers/models/siglip/configuration_siglip.py +2 -4
  1193. transformers/models/siglip/image_processing_siglip.py +20 -17
  1194. transformers/models/siglip/image_processing_siglip_fast.py +1 -0
  1195. transformers/models/siglip/modeling_siglip.py +75 -84
  1196. transformers/models/siglip/processing_siglip.py +14 -2
  1197. transformers/models/siglip/tokenization_siglip.py +7 -6
  1198. transformers/models/siglip2/configuration_siglip2.py +2 -5
  1199. transformers/models/siglip2/image_processing_siglip2.py +16 -15
  1200. transformers/models/siglip2/image_processing_siglip2_fast.py +7 -6
  1201. transformers/models/siglip2/modeling_siglip2.py +129 -143
  1202. transformers/models/siglip2/modular_siglip2.py +46 -47
  1203. transformers/models/siglip2/processing_siglip2.py +14 -2
  1204. transformers/models/smollm3/configuration_smollm3.py +32 -29
  1205. transformers/models/smollm3/modeling_smollm3.py +39 -36
  1206. transformers/models/smollm3/modular_smollm3.py +35 -33
  1207. transformers/models/smolvlm/configuration_smolvlm.py +4 -2
  1208. transformers/models/smolvlm/image_processing_smolvlm.py +43 -42
  1209. transformers/models/smolvlm/image_processing_smolvlm_fast.py +15 -41
  1210. transformers/models/smolvlm/modeling_smolvlm.py +94 -126
  1211. transformers/models/smolvlm/modular_smolvlm.py +39 -50
  1212. transformers/models/smolvlm/processing_smolvlm.py +83 -15
  1213. transformers/models/smolvlm/video_processing_smolvlm.py +18 -16
  1214. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +1 -0
  1215. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +27 -26
  1216. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1217. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +13 -10
  1218. transformers/models/speech_to_text/modeling_speech_to_text.py +54 -66
  1219. transformers/models/speech_to_text/processing_speech_to_text.py +30 -4
  1220. transformers/models/speech_to_text/tokenization_speech_to_text.py +6 -5
  1221. transformers/models/speecht5/configuration_speecht5.py +9 -7
  1222. transformers/models/speecht5/feature_extraction_speecht5.py +37 -16
  1223. transformers/models/speecht5/modeling_speecht5.py +175 -213
  1224. transformers/models/speecht5/number_normalizer.py +1 -0
  1225. transformers/models/speecht5/processing_speecht5.py +37 -3
  1226. transformers/models/speecht5/tokenization_speecht5.py +5 -4
  1227. transformers/models/splinter/configuration_splinter.py +7 -6
  1228. transformers/models/splinter/modeling_splinter.py +59 -71
  1229. transformers/models/splinter/tokenization_splinter.py +30 -9
  1230. transformers/models/squeezebert/configuration_squeezebert.py +2 -14
  1231. transformers/models/squeezebert/modeling_squeezebert.py +62 -68
  1232. transformers/models/squeezebert/tokenization_squeezebert.py +1 -0
  1233. transformers/models/stablelm/configuration_stablelm.py +29 -24
  1234. transformers/models/stablelm/modeling_stablelm.py +45 -44
  1235. transformers/models/starcoder2/configuration_starcoder2.py +27 -30
  1236. transformers/models/starcoder2/modeling_starcoder2.py +41 -39
  1237. transformers/models/starcoder2/modular_starcoder2.py +16 -14
  1238. transformers/models/superglue/configuration_superglue.py +3 -7
  1239. transformers/models/superglue/image_processing_superglue.py +15 -15
  1240. transformers/models/superglue/image_processing_superglue_fast.py +10 -9
  1241. transformers/models/superglue/modeling_superglue.py +37 -42
  1242. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1243. transformers/models/superpoint/image_processing_superpoint_fast.py +11 -8
  1244. transformers/models/superpoint/modeling_superpoint.py +16 -18
  1245. transformers/models/swiftformer/configuration_swiftformer.py +1 -0
  1246. transformers/models/swiftformer/modeling_swiftformer.py +14 -18
  1247. transformers/models/swin/configuration_swin.py +1 -0
  1248. transformers/models/swin/modeling_swin.py +86 -86
  1249. transformers/models/swin2sr/configuration_swin2sr.py +1 -0
  1250. transformers/models/swin2sr/image_processing_swin2sr.py +13 -10
  1251. transformers/models/swin2sr/image_processing_swin2sr_fast.py +8 -4
  1252. transformers/models/swin2sr/modeling_swin2sr.py +63 -81
  1253. transformers/models/swinv2/configuration_swinv2.py +1 -0
  1254. transformers/models/swinv2/modeling_swinv2.py +104 -108
  1255. transformers/models/switch_transformers/configuration_switch_transformers.py +7 -11
  1256. transformers/models/switch_transformers/modeling_switch_transformers.py +44 -37
  1257. transformers/models/switch_transformers/modular_switch_transformers.py +41 -34
  1258. transformers/models/t5/configuration_t5.py +8 -14
  1259. transformers/models/t5/modeling_t5.py +92 -88
  1260. transformers/models/t5/tokenization_t5.py +9 -3
  1261. transformers/models/t5gemma/configuration_t5gemma.py +41 -43
  1262. transformers/models/t5gemma/modeling_t5gemma.py +107 -104
  1263. transformers/models/t5gemma/modular_t5gemma.py +120 -124
  1264. transformers/models/t5gemma2/configuration_t5gemma2.py +120 -80
  1265. transformers/models/t5gemma2/modeling_t5gemma2.py +125 -141
  1266. transformers/models/t5gemma2/modular_t5gemma2.py +104 -393
  1267. transformers/models/table_transformer/configuration_table_transformer.py +2 -1
  1268. transformers/models/table_transformer/modeling_table_transformer.py +49 -51
  1269. transformers/models/tapas/configuration_tapas.py +2 -12
  1270. transformers/models/tapas/modeling_tapas.py +67 -68
  1271. transformers/models/tapas/tokenization_tapas.py +153 -115
  1272. transformers/models/textnet/configuration_textnet.py +1 -0
  1273. transformers/models/textnet/image_processing_textnet.py +25 -22
  1274. transformers/models/textnet/image_processing_textnet_fast.py +10 -8
  1275. transformers/models/textnet/modeling_textnet.py +16 -28
  1276. transformers/models/time_series_transformer/configuration_time_series_transformer.py +8 -5
  1277. transformers/models/time_series_transformer/modeling_time_series_transformer.py +81 -83
  1278. transformers/models/timesfm/configuration_timesfm.py +1 -0
  1279. transformers/models/timesfm/modeling_timesfm.py +22 -33
  1280. transformers/models/timesfm/modular_timesfm.py +21 -32
  1281. transformers/models/timesformer/configuration_timesformer.py +1 -0
  1282. transformers/models/timesformer/modeling_timesformer.py +16 -15
  1283. transformers/models/timm_backbone/configuration_timm_backbone.py +1 -0
  1284. transformers/models/timm_backbone/modeling_timm_backbone.py +15 -17
  1285. transformers/models/timm_wrapper/configuration_timm_wrapper.py +3 -5
  1286. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +5 -4
  1287. transformers/models/timm_wrapper/modeling_timm_wrapper.py +29 -34
  1288. transformers/models/trocr/configuration_trocr.py +8 -11
  1289. transformers/models/trocr/modeling_trocr.py +44 -45
  1290. transformers/models/trocr/processing_trocr.py +25 -5
  1291. transformers/models/tvp/configuration_tvp.py +2 -5
  1292. transformers/models/tvp/image_processing_tvp.py +52 -50
  1293. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1294. transformers/models/tvp/modeling_tvp.py +27 -27
  1295. transformers/models/tvp/processing_tvp.py +14 -2
  1296. transformers/models/udop/configuration_udop.py +7 -16
  1297. transformers/models/udop/modeling_udop.py +73 -71
  1298. transformers/models/udop/processing_udop.py +26 -7
  1299. transformers/models/udop/tokenization_udop.py +105 -84
  1300. transformers/models/umt5/configuration_umt5.py +7 -8
  1301. transformers/models/umt5/modeling_umt5.py +90 -94
  1302. transformers/models/unispeech/configuration_unispeech.py +2 -4
  1303. transformers/models/unispeech/modeling_unispeech.py +49 -51
  1304. transformers/models/unispeech/modular_unispeech.py +22 -22
  1305. transformers/models/unispeech_sat/configuration_unispeech_sat.py +2 -4
  1306. transformers/models/unispeech_sat/modeling_unispeech_sat.py +65 -69
  1307. transformers/models/unispeech_sat/modular_unispeech_sat.py +23 -23
  1308. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1309. transformers/models/univnet/modeling_univnet.py +8 -8
  1310. transformers/models/upernet/configuration_upernet.py +1 -0
  1311. transformers/models/upernet/modeling_upernet.py +13 -11
  1312. transformers/models/vaultgemma/__init__.py +1 -0
  1313. transformers/models/vaultgemma/configuration_vaultgemma.py +33 -29
  1314. transformers/models/vaultgemma/modeling_vaultgemma.py +41 -39
  1315. transformers/models/vaultgemma/modular_vaultgemma.py +31 -29
  1316. transformers/models/video_llama_3/configuration_video_llama_3.py +0 -4
  1317. transformers/models/video_llama_3/image_processing_video_llama_3.py +42 -43
  1318. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +14 -12
  1319. transformers/models/video_llama_3/modeling_video_llama_3.py +109 -157
  1320. transformers/models/video_llama_3/modular_video_llama_3.py +146 -155
  1321. transformers/models/video_llama_3/processing_video_llama_3.py +39 -5
  1322. transformers/models/video_llama_3/video_processing_video_llama_3.py +23 -42
  1323. transformers/models/video_llava/configuration_video_llava.py +1 -4
  1324. transformers/models/video_llava/image_processing_video_llava.py +38 -35
  1325. transformers/models/video_llava/modeling_video_llava.py +146 -146
  1326. transformers/models/video_llava/processing_video_llava.py +78 -38
  1327. transformers/models/video_llava/video_processing_video_llava.py +1 -0
  1328. transformers/models/videomae/configuration_videomae.py +1 -0
  1329. transformers/models/videomae/image_processing_videomae.py +34 -31
  1330. transformers/models/videomae/modeling_videomae.py +17 -14
  1331. transformers/models/videomae/video_processing_videomae.py +1 -0
  1332. transformers/models/vilt/configuration_vilt.py +4 -6
  1333. transformers/models/vilt/image_processing_vilt.py +30 -29
  1334. transformers/models/vilt/image_processing_vilt_fast.py +16 -15
  1335. transformers/models/vilt/modeling_vilt.py +90 -116
  1336. transformers/models/vilt/processing_vilt.py +14 -2
  1337. transformers/models/vipllava/configuration_vipllava.py +1 -4
  1338. transformers/models/vipllava/modeling_vipllava.py +70 -99
  1339. transformers/models/vipllava/modular_vipllava.py +54 -78
  1340. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +1 -0
  1341. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +27 -28
  1342. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +1 -0
  1343. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +41 -46
  1344. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +16 -2
  1345. transformers/models/visual_bert/configuration_visual_bert.py +2 -6
  1346. transformers/models/visual_bert/modeling_visual_bert.py +92 -98
  1347. transformers/models/vit/configuration_vit.py +1 -0
  1348. transformers/models/vit/image_processing_vit.py +22 -19
  1349. transformers/models/vit/image_processing_vit_fast.py +1 -0
  1350. transformers/models/vit/modeling_vit.py +17 -17
  1351. transformers/models/vit_mae/configuration_vit_mae.py +1 -0
  1352. transformers/models/vit_mae/modeling_vit_mae.py +27 -29
  1353. transformers/models/vit_msn/configuration_vit_msn.py +1 -0
  1354. transformers/models/vit_msn/modeling_vit_msn.py +16 -18
  1355. transformers/models/vitdet/configuration_vitdet.py +1 -0
  1356. transformers/models/vitdet/modeling_vitdet.py +14 -14
  1357. transformers/models/vitmatte/configuration_vitmatte.py +5 -2
  1358. transformers/models/vitmatte/image_processing_vitmatte.py +18 -15
  1359. transformers/models/vitmatte/image_processing_vitmatte_fast.py +18 -16
  1360. transformers/models/vitmatte/modeling_vitmatte.py +11 -14
  1361. transformers/models/vitpose/configuration_vitpose.py +7 -4
  1362. transformers/models/vitpose/image_processing_vitpose.py +25 -24
  1363. transformers/models/vitpose/image_processing_vitpose_fast.py +11 -9
  1364. transformers/models/vitpose/modeling_vitpose.py +14 -14
  1365. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +1 -0
  1366. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +10 -8
  1367. transformers/models/vits/configuration_vits.py +1 -4
  1368. transformers/models/vits/modeling_vits.py +42 -44
  1369. transformers/models/vits/tokenization_vits.py +4 -3
  1370. transformers/models/vivit/configuration_vivit.py +1 -0
  1371. transformers/models/vivit/image_processing_vivit.py +39 -36
  1372. transformers/models/vivit/modeling_vivit.py +8 -6
  1373. transformers/models/vjepa2/__init__.py +1 -0
  1374. transformers/models/vjepa2/configuration_vjepa2.py +1 -0
  1375. transformers/models/vjepa2/modeling_vjepa2.py +32 -31
  1376. transformers/models/vjepa2/video_processing_vjepa2.py +1 -0
  1377. transformers/models/voxtral/__init__.py +1 -0
  1378. transformers/models/voxtral/configuration_voxtral.py +2 -0
  1379. transformers/models/voxtral/modeling_voxtral.py +47 -40
  1380. transformers/models/voxtral/modular_voxtral.py +40 -37
  1381. transformers/models/voxtral/processing_voxtral.py +48 -25
  1382. transformers/models/wav2vec2/configuration_wav2vec2.py +2 -4
  1383. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +10 -7
  1384. transformers/models/wav2vec2/modeling_wav2vec2.py +121 -73
  1385. transformers/models/wav2vec2/processing_wav2vec2.py +35 -6
  1386. transformers/models/wav2vec2/tokenization_wav2vec2.py +332 -20
  1387. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +2 -4
  1388. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +62 -70
  1389. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +48 -57
  1390. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +35 -6
  1391. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +2 -4
  1392. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +77 -90
  1393. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +30 -37
  1394. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +17 -16
  1395. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +55 -36
  1396. transformers/models/wavlm/configuration_wavlm.py +2 -4
  1397. transformers/models/wavlm/modeling_wavlm.py +48 -50
  1398. transformers/models/wavlm/modular_wavlm.py +5 -4
  1399. transformers/models/whisper/configuration_whisper.py +5 -6
  1400. transformers/models/whisper/english_normalizer.py +4 -3
  1401. transformers/models/whisper/feature_extraction_whisper.py +24 -9
  1402. transformers/models/whisper/generation_whisper.py +48 -26
  1403. transformers/models/whisper/modeling_whisper.py +73 -79
  1404. transformers/models/whisper/processing_whisper.py +20 -3
  1405. transformers/models/whisper/tokenization_whisper.py +43 -11
  1406. transformers/models/x_clip/configuration_x_clip.py +2 -4
  1407. transformers/models/x_clip/modeling_x_clip.py +93 -96
  1408. transformers/models/x_clip/processing_x_clip.py +14 -2
  1409. transformers/models/xcodec/configuration_xcodec.py +6 -4
  1410. transformers/models/xcodec/modeling_xcodec.py +17 -20
  1411. transformers/models/xglm/configuration_xglm.py +8 -9
  1412. transformers/models/xglm/modeling_xglm.py +55 -60
  1413. transformers/models/xglm/tokenization_xglm.py +11 -3
  1414. transformers/models/xlm/configuration_xlm.py +8 -10
  1415. transformers/models/xlm/modeling_xlm.py +144 -144
  1416. transformers/models/xlm/tokenization_xlm.py +5 -3
  1417. transformers/models/xlm_roberta/configuration_xlm_roberta.py +3 -11
  1418. transformers/models/xlm_roberta/modeling_xlm_roberta.py +194 -195
  1419. transformers/models/xlm_roberta/modular_xlm_roberta.py +53 -50
  1420. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +18 -8
  1421. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +2 -10
  1422. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +93 -94
  1423. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +70 -67
  1424. transformers/models/xlnet/configuration_xlnet.py +12 -3
  1425. transformers/models/xlnet/modeling_xlnet.py +163 -152
  1426. transformers/models/xlnet/tokenization_xlnet.py +9 -2
  1427. transformers/models/xlstm/configuration_xlstm.py +12 -8
  1428. transformers/models/xlstm/modeling_xlstm.py +65 -62
  1429. transformers/models/xmod/configuration_xmod.py +3 -11
  1430. transformers/models/xmod/modeling_xmod.py +110 -108
  1431. transformers/models/yolos/configuration_yolos.py +1 -0
  1432. transformers/models/yolos/image_processing_yolos.py +62 -60
  1433. transformers/models/yolos/image_processing_yolos_fast.py +45 -42
  1434. transformers/models/yolos/modeling_yolos.py +16 -16
  1435. transformers/models/yolos/modular_yolos.py +19 -17
  1436. transformers/models/yoso/configuration_yoso.py +2 -8
  1437. transformers/models/yoso/modeling_yoso.py +63 -70
  1438. transformers/models/zamba/configuration_zamba.py +8 -5
  1439. transformers/models/zamba/modeling_zamba.py +78 -81
  1440. transformers/models/zamba2/configuration_zamba2.py +50 -44
  1441. transformers/models/zamba2/modeling_zamba2.py +97 -97
  1442. transformers/models/zamba2/modular_zamba2.py +48 -46
  1443. transformers/models/zoedepth/configuration_zoedepth.py +2 -1
  1444. transformers/models/zoedepth/image_processing_zoedepth.py +29 -28
  1445. transformers/models/zoedepth/image_processing_zoedepth_fast.py +24 -21
  1446. transformers/models/zoedepth/modeling_zoedepth.py +18 -26
  1447. transformers/pipelines/__init__.py +114 -57
  1448. transformers/pipelines/any_to_any.py +22 -14
  1449. transformers/pipelines/audio_utils.py +2 -1
  1450. transformers/pipelines/automatic_speech_recognition.py +12 -20
  1451. transformers/pipelines/base.py +27 -15
  1452. transformers/{models/pe_audio/processing_pe_audio.py → pipelines/deprecated/__init__.py} +3 -10
  1453. transformers/pipelines/deprecated/text2text_generation.py +408 -0
  1454. transformers/pipelines/document_question_answering.py +2 -4
  1455. transformers/pipelines/image_text_to_text.py +1 -0
  1456. transformers/pipelines/image_to_text.py +229 -0
  1457. transformers/pipelines/question_answering.py +44 -5
  1458. transformers/pipelines/text_classification.py +14 -1
  1459. transformers/pipelines/text_generation.py +1 -1
  1460. transformers/pipelines/text_to_audio.py +2 -2
  1461. transformers/pipelines/token_classification.py +22 -1
  1462. transformers/pipelines/video_classification.py +9 -1
  1463. transformers/pipelines/zero_shot_audio_classification.py +1 -0
  1464. transformers/pipelines/zero_shot_classification.py +6 -0
  1465. transformers/pipelines/zero_shot_image_classification.py +7 -0
  1466. transformers/processing_utils.py +145 -230
  1467. transformers/quantizers/auto.py +4 -2
  1468. transformers/quantizers/base.py +173 -53
  1469. transformers/quantizers/quantizer_aqlm.py +23 -2
  1470. transformers/quantizers/quantizer_auto_round.py +12 -2
  1471. transformers/quantizers/quantizer_awq.py +89 -20
  1472. transformers/quantizers/quantizer_bitnet.py +14 -4
  1473. transformers/quantizers/quantizer_bnb_4bit.py +155 -18
  1474. transformers/quantizers/quantizer_bnb_8bit.py +110 -24
  1475. transformers/quantizers/quantizer_compressed_tensors.py +9 -2
  1476. transformers/quantizers/quantizer_eetq.py +74 -16
  1477. transformers/quantizers/quantizer_fbgemm_fp8.py +138 -38
  1478. transformers/quantizers/quantizer_finegrained_fp8.py +113 -26
  1479. transformers/quantizers/quantizer_fp_quant.py +82 -52
  1480. transformers/quantizers/quantizer_gptq.py +28 -8
  1481. transformers/quantizers/quantizer_higgs.py +60 -42
  1482. transformers/quantizers/quantizer_hqq.py +153 -144
  1483. transformers/quantizers/quantizer_mxfp4.py +194 -14
  1484. transformers/quantizers/quantizer_quanto.py +79 -35
  1485. transformers/quantizers/quantizer_quark.py +18 -36
  1486. transformers/quantizers/quantizer_spqr.py +12 -4
  1487. transformers/quantizers/quantizer_torchao.py +325 -50
  1488. transformers/quantizers/quantizer_vptq.py +27 -4
  1489. transformers/quantizers/quantizers_utils.py +0 -20
  1490. transformers/safetensors_conversion.py +3 -9
  1491. transformers/testing_utils.py +82 -326
  1492. transformers/tokenization_mistral_common.py +903 -568
  1493. transformers/tokenization_utils_base.py +340 -220
  1494. transformers/tokenization_utils_sentencepiece.py +6 -5
  1495. transformers/tokenization_utils_tokenizers.py +113 -226
  1496. transformers/trainer.py +53 -60
  1497. transformers/trainer_callback.py +0 -8
  1498. transformers/trainer_seq2seq.py +1 -5
  1499. transformers/trainer_utils.py +1 -1
  1500. transformers/training_args.py +41 -77
  1501. transformers/utils/__init__.py +4 -8
  1502. transformers/utils/attention_visualizer.py +5 -5
  1503. transformers/utils/auto_docstring.py +37 -599
  1504. transformers/utils/doc.py +36 -4
  1505. transformers/utils/dummy_pt_objects.py +42 -0
  1506. transformers/utils/generic.py +28 -111
  1507. transformers/utils/hub.py +15 -5
  1508. transformers/utils/import_utils.py +32 -165
  1509. transformers/utils/kernel_config.py +19 -74
  1510. transformers/utils/loading_report.py +15 -25
  1511. transformers/utils/quantization_config.py +241 -72
  1512. transformers/video_processing_utils.py +39 -41
  1513. transformers/video_utils.py +22 -18
  1514. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/METADATA +236 -284
  1515. transformers-5.0.0rc0.dist-info/RECORD +1987 -0
  1516. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/WHEEL +1 -1
  1517. transformers/integrations/moe.py +0 -360
  1518. transformers/integrations/quark.py +0 -53
  1519. transformers/loss/loss_lw_detr.py +0 -356
  1520. transformers/models/ernie4_5_vl_moe/__init__.py +0 -31
  1521. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +0 -340
  1522. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +0 -455
  1523. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +0 -231
  1524. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +0 -1936
  1525. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +0 -1925
  1526. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +0 -249
  1527. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +0 -593
  1528. transformers/models/fast_vlm/__init__.py +0 -27
  1529. transformers/models/fast_vlm/configuration_fast_vlm.py +0 -137
  1530. transformers/models/fast_vlm/modeling_fast_vlm.py +0 -432
  1531. transformers/models/fast_vlm/modular_fast_vlm.py +0 -373
  1532. transformers/models/glm4_moe_lite/__init__.py +0 -28
  1533. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +0 -233
  1534. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +0 -740
  1535. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +0 -302
  1536. transformers/models/glm_image/__init__.py +0 -31
  1537. transformers/models/glm_image/configuration_glm_image.py +0 -351
  1538. transformers/models/glm_image/image_processing_glm_image.py +0 -503
  1539. transformers/models/glm_image/image_processing_glm_image_fast.py +0 -294
  1540. transformers/models/glm_image/modeling_glm_image.py +0 -1642
  1541. transformers/models/glm_image/modular_glm_image.py +0 -1531
  1542. transformers/models/glm_image/processing_glm_image.py +0 -217
  1543. transformers/models/glmasr/__init__.py +0 -29
  1544. transformers/models/glmasr/configuration_glmasr.py +0 -196
  1545. transformers/models/glmasr/modeling_glmasr.py +0 -517
  1546. transformers/models/glmasr/modular_glmasr.py +0 -443
  1547. transformers/models/glmasr/processing_glmasr.py +0 -331
  1548. transformers/models/jais2/__init__.py +0 -27
  1549. transformers/models/jais2/configuration_jais2.py +0 -148
  1550. transformers/models/jais2/modeling_jais2.py +0 -484
  1551. transformers/models/jais2/modular_jais2.py +0 -194
  1552. transformers/models/lasr/__init__.py +0 -29
  1553. transformers/models/lasr/configuration_lasr.py +0 -244
  1554. transformers/models/lasr/feature_extraction_lasr.py +0 -275
  1555. transformers/models/lasr/modeling_lasr.py +0 -727
  1556. transformers/models/lasr/modular_lasr.py +0 -574
  1557. transformers/models/lasr/processing_lasr.py +0 -100
  1558. transformers/models/lasr/tokenization_lasr.py +0 -184
  1559. transformers/models/lighton_ocr/__init__.py +0 -28
  1560. transformers/models/lighton_ocr/configuration_lighton_ocr.py +0 -128
  1561. transformers/models/lighton_ocr/modeling_lighton_ocr.py +0 -463
  1562. transformers/models/lighton_ocr/modular_lighton_ocr.py +0 -404
  1563. transformers/models/lighton_ocr/processing_lighton_ocr.py +0 -229
  1564. transformers/models/lw_detr/__init__.py +0 -27
  1565. transformers/models/lw_detr/configuration_lw_detr.py +0 -374
  1566. transformers/models/lw_detr/modeling_lw_detr.py +0 -1702
  1567. transformers/models/lw_detr/modular_lw_detr.py +0 -1615
  1568. transformers/models/minimax_m2/__init__.py +0 -28
  1569. transformers/models/minimax_m2/configuration_minimax_m2.py +0 -188
  1570. transformers/models/minimax_m2/modeling_minimax_m2.py +0 -704
  1571. transformers/models/minimax_m2/modular_minimax_m2.py +0 -346
  1572. transformers/models/paddleocr_vl/__init__.py +0 -31
  1573. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +0 -335
  1574. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +0 -503
  1575. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +0 -209
  1576. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +0 -1683
  1577. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +0 -1380
  1578. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +0 -133
  1579. transformers/models/pe_audio/__init__.py +0 -29
  1580. transformers/models/pe_audio/configuration_pe_audio.py +0 -204
  1581. transformers/models/pe_audio/feature_extraction_pe_audio.py +0 -160
  1582. transformers/models/pe_audio/modeling_pe_audio.py +0 -819
  1583. transformers/models/pe_audio/modular_pe_audio.py +0 -298
  1584. transformers/models/pe_audio_video/__init__.py +0 -28
  1585. transformers/models/pe_audio_video/configuration_pe_audio_video.py +0 -223
  1586. transformers/models/pe_audio_video/modeling_pe_audio_video.py +0 -971
  1587. transformers/models/pe_audio_video/modular_pe_audio_video.py +0 -763
  1588. transformers/models/pe_video/__init__.py +0 -29
  1589. transformers/models/pe_video/configuration_pe_video.py +0 -209
  1590. transformers/models/pe_video/modeling_pe_video.py +0 -647
  1591. transformers/models/pe_video/modular_pe_video.py +0 -231
  1592. transformers/models/pe_video/processing_pe_video.py +0 -10
  1593. transformers/models/pe_video/video_processing_pe_video.py +0 -64
  1594. transformers/models/pixio/__init__.py +0 -29
  1595. transformers/models/pixio/configuration_pixio.py +0 -150
  1596. transformers/models/pixio/modeling_pixio.py +0 -507
  1597. transformers/models/pixio/modular_pixio.py +0 -403
  1598. transformers/models/solar_open/__init__.py +0 -27
  1599. transformers/models/solar_open/configuration_solar_open.py +0 -184
  1600. transformers/models/solar_open/modeling_solar_open.py +0 -642
  1601. transformers/models/solar_open/modular_solar_open.py +0 -224
  1602. transformers/trainer_jit_checkpoint.py +0 -125
  1603. transformers-5.0.0.dist-info/RECORD +0 -2068
  1604. {transformers-5.0.0.dist-info/licenses → transformers-5.0.0rc0.dist-info}/LICENSE +0 -0
  1605. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/entry_points.txt +0 -0
  1606. {transformers-5.0.0.dist-info → transformers-5.0.0rc0.dist-info}/top_level.txt +0 -0
@@ -1,1642 +0,0 @@
1
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
- # This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py.
3
- # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
- # the file from the modular. If any change should be done, please apply the change to the
5
- # modular_glm_image.py file directly. One of our CI enforces this.
6
- # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
- # Copyright 2025 the HuggingFace Team. All rights reserved.
8
- #
9
- # Licensed under the Apache License, Version 2.0 (the "License");
10
- # you may not use this file except in compliance with the License.
11
- # You may obtain a copy of the License at
12
- #
13
- # http://www.apache.org/licenses/LICENSE-2.0
14
- #
15
- # Unless required by applicable law or agreed to in writing, software
16
- # distributed under the License is distributed on an "AS IS" BASIS,
17
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
- # See the License for the specific language governing permissions and
19
- # limitations under the License.
20
-
21
- from collections.abc import Callable
22
- from dataclasses import dataclass
23
- from typing import Any, Optional
24
-
25
- import torch.nn as nn
26
- import torch.nn.functional as F
27
-
28
- from ...activations import ACT2FN
29
- from ...cache_utils import Cache, DynamicCache
30
- from ...generation import GenerationMixin
31
- from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
32
- from ...masking_utils import create_causal_mask
33
- from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
- from ...modeling_layers import GradientCheckpointingLayer
35
- from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
36
- from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
- from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
- from ...processing_utils import Unpack
39
- from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available
40
- from ...utils.generic import check_model_inputs, maybe_autocast
41
- from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig
42
-
43
-
44
- if is_torch_available():
45
- import torch
46
-
47
-
48
- class GlmImageVisionMLP(nn.Module):
49
- def __init__(self, config):
50
- super().__init__()
51
- self.config = config
52
- self.activation_fn = ACT2FN[config.hidden_act]
53
- self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
54
- self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
55
-
56
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
57
- hidden_states = self.fc1(hidden_states)
58
- hidden_states = self.activation_fn(hidden_states)
59
- hidden_states = self.fc2(hidden_states)
60
- return hidden_states
61
-
62
-
63
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
64
- """
65
- This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
66
- num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
67
- """
68
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
69
- if n_rep == 1:
70
- return hidden_states
71
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
72
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
73
-
74
-
75
- def eager_attention_forward(
76
- module: nn.Module,
77
- query: torch.Tensor,
78
- key: torch.Tensor,
79
- value: torch.Tensor,
80
- attention_mask: torch.Tensor | None,
81
- scaling: float,
82
- dropout: float = 0.0,
83
- **kwargs: Unpack[TransformersKwargs],
84
- ):
85
- key_states = repeat_kv(key, module.num_key_value_groups)
86
- value_states = repeat_kv(value, module.num_key_value_groups)
87
-
88
- attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
89
- if attention_mask is not None:
90
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
91
- attn_weights = attn_weights + causal_mask
92
-
93
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
94
- attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
95
- attn_output = torch.matmul(attn_weights, value_states)
96
- attn_output = attn_output.transpose(1, 2).contiguous()
97
-
98
- return attn_output, attn_weights
99
-
100
-
101
- class GlmImageVisionAttention(nn.Module):
102
- def __init__(self, config: GlmImageVisionConfig) -> None:
103
- super().__init__()
104
- self.dim = config.hidden_size
105
- self.num_heads = config.num_heads
106
- self.head_dim = self.dim // self.num_heads
107
- self.num_key_value_groups = 1 # needed for eager attention
108
- self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
109
- self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
110
- self.scaling = self.head_dim**-0.5
111
- self.config = config
112
- self.attention_dropout = config.attention_dropout
113
- self.is_causal = False
114
-
115
- def forward(
116
- self,
117
- hidden_states: torch.Tensor,
118
- cu_seqlens: torch.Tensor,
119
- **kwargs,
120
- ) -> torch.Tensor:
121
- seq_length = hidden_states.shape[0]
122
- query_states, key_states, value_states = (
123
- self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
124
- )
125
- query_states = query_states.transpose(0, 1).unsqueeze(0)
126
- key_states = key_states.transpose(0, 1).unsqueeze(0)
127
- value_states = value_states.transpose(0, 1).unsqueeze(0)
128
-
129
- attention_interface: Callable = eager_attention_forward
130
- if self.config._attn_implementation != "eager":
131
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
132
-
133
- if "flash" in self.config._attn_implementation:
134
- # Flash Attention: Use cu_seqlens for variable length attention
135
- max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
136
- attn_output, _ = attention_interface(
137
- self,
138
- query_states,
139
- key_states,
140
- value_states,
141
- attention_mask=None,
142
- scaling=self.scaling,
143
- dropout=0.0 if not self.training else self.attention_dropout,
144
- cu_seq_lens_q=cu_seqlens,
145
- cu_seq_lens_k=cu_seqlens,
146
- max_length_q=max_seqlen,
147
- max_length_k=max_seqlen,
148
- is_causal=False,
149
- **kwargs,
150
- )
151
- else:
152
- # Other implementations: Process each chunk separately
153
- lengths = cu_seqlens[1:] - cu_seqlens[:-1]
154
- splits = [
155
- torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
156
- ]
157
-
158
- attn_outputs = [
159
- attention_interface(
160
- self,
161
- q,
162
- k,
163
- v,
164
- attention_mask=None,
165
- scaling=self.scaling,
166
- dropout=0.0 if not self.training else self.attention_dropout,
167
- is_causal=False,
168
- **kwargs,
169
- )[0]
170
- for q, k, v in zip(*splits)
171
- ]
172
- attn_output = torch.cat(attn_outputs, dim=1)
173
-
174
- attn_output = attn_output.reshape(seq_length, -1).contiguous()
175
- attn_output = self.proj(attn_output)
176
- return attn_output
177
-
178
-
179
- class GlmImageVisionPatchEmbed(nn.Module):
180
- def __init__(self, config: GlmImageVisionConfig) -> None:
181
- super().__init__()
182
- self.patch_size = config.patch_size
183
- self.in_channels = config.in_channels
184
- self.embed_dim = config.hidden_size
185
- kernel_size = [self.patch_size, self.patch_size]
186
- self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
187
-
188
- def forward(self, hidden_states) -> torch.Tensor:
189
- target_dtype = self.proj.weight.dtype
190
- hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size)
191
- hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
192
- return hidden_states
193
-
194
-
195
- class GlmImageVisionEmbeddings(nn.Module):
196
- def __init__(self, config: GlmImageVisionConfig) -> None:
197
- super().__init__()
198
- self.config = config
199
- self.embed_dim = config.hidden_size
200
- self.image_size = config.image_size
201
- self.patch_size = config.patch_size
202
-
203
- self.num_patches = (self.image_size // self.patch_size) ** 2
204
- self.num_positions = self.num_patches
205
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
206
- self.interpolated_method = "bilinear"
207
-
208
- def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
209
- """
210
- Forward pass with integrated position encoding adaptation using 2D interpolation.
211
-
212
- Args:
213
- embeddings: Input embeddings tensor
214
- lengths (torch.Tensor): Sequence lengths for each image in the batch.
215
- image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
216
- h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
217
- w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
218
-
219
- Returns:
220
- torch.Tensor: Embeddings with adapted position encoding added.
221
- """
222
- # Get position embedding parameters
223
- pos_embed_weight = self.position_embedding.weight
224
- hidden_size = pos_embed_weight.shape[1]
225
- device = pos_embed_weight.device
226
-
227
- # Convert inputs to tensors if needed
228
- if isinstance(lengths, list):
229
- lengths = torch.tensor(lengths, device=device, dtype=torch.long)
230
-
231
- # Prepare 2D position embedding
232
- orig_size_sq = pos_embed_weight.shape[0]
233
- orig_size = int(orig_size_sq**0.5)
234
- pos_embed_2d = (
235
- pos_embed_weight.view(orig_size, orig_size, hidden_size)
236
- .permute(2, 0, 1)
237
- .unsqueeze(0)
238
- .to(device=device, dtype=torch.float32)
239
- )
240
-
241
- # Calculate target dimensions for each patch
242
- target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
243
- device=device, dtype=torch.float32
244
- )
245
- target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
246
- device=device, dtype=torch.float32
247
- )
248
-
249
- # Normalize coordinates to [-1, 1] range for grid_sample
250
- norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
251
- norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
252
-
253
- # Create sampling grid
254
- grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
255
-
256
- # Perform bicubic interpolation
257
- interpolated_embed_fp32 = F.grid_sample(
258
- pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border"
259
- )
260
-
261
- # Reshape and convert back to original dtype
262
- adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
263
- adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
264
-
265
- # Add adapted position encoding to embeddings
266
- embeddings = embeddings + adapted_pos_embed
267
- return embeddings
268
-
269
-
270
- class GlmImageVisionBlock(GradientCheckpointingLayer):
271
- def __init__(self, config: GlmImageVisionConfig) -> None:
272
- super().__init__()
273
- self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
274
- self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
275
- self.attn = GlmImageVisionAttention(config)
276
- self.mlp = GlmImageVisionMLP(config)
277
-
278
- def forward(
279
- self,
280
- hidden_states: torch.Tensor,
281
- cu_seqlens: torch.Tensor,
282
- **kwargs: Unpack[TransformersKwargs],
283
- ) -> torch.Tensor:
284
- r"""
285
- cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
286
- The cumulative sequence lengths of each image or video feature.
287
- position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
288
- The cosine and sine position embeddings for vision attention.
289
- """
290
- residual = hidden_states
291
-
292
- hidden_states = self.norm1(hidden_states)
293
- hidden_states = self.attn(
294
- hidden_states,
295
- cu_seqlens=cu_seqlens,
296
- **kwargs,
297
- )
298
- hidden_states = residual + hidden_states
299
-
300
- residual = hidden_states
301
- hidden_states = self.norm2(hidden_states)
302
- hidden_states = self.mlp(hidden_states)
303
- hidden_states = residual + hidden_states
304
-
305
- return hidden_states
306
-
307
-
308
- def rotate_half(x):
309
- """Rotates half the hidden dims of the input."""
310
- x1 = x[..., : x.shape[-1] // 2]
311
- x2 = x[..., x.shape[-1] // 2 :]
312
- return torch.cat((-x2, x1), dim=-1)
313
-
314
-
315
- def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
316
- """Applies Rotary Position Embedding to the query and key tensors.
317
-
318
- Args:
319
- q (`torch.Tensor`): The query tensor.
320
- k (`torch.Tensor`): The key tensor.
321
- cos (`torch.Tensor`): The cosine part of the rotary embedding.
322
- sin (`torch.Tensor`): The sine part of the rotary embedding.
323
- unsqueeze_dim (`int`, *optional*, defaults to 1):
324
- The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
325
- sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
326
- that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
327
- k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
328
- cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
329
- the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
330
- Returns:
331
- `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
332
- """
333
- cos = cos.unsqueeze(unsqueeze_dim)
334
- sin = sin.unsqueeze(unsqueeze_dim)
335
-
336
- # Keep half or full tensor for later concatenation
337
- rotary_dim = cos.shape[-1]
338
- q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
339
- k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
340
-
341
- # Apply rotary embeddings on the first half or full tensor
342
- q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
343
- k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
344
-
345
- # Concatenate back to full shape
346
- q_embed = torch.cat([q_embed, q_pass], dim=-1)
347
- k_embed = torch.cat([k_embed, k_pass], dim=-1)
348
- return q_embed, k_embed
349
-
350
-
351
- @use_kernelized_func(apply_rotary_pos_emb)
352
- class GlmImageTextAttention(nn.Module):
353
- """Multi-headed attention from 'Attention Is All You Need' paper"""
354
-
355
- def __init__(self, config: GlmImageTextConfig, layer_idx: int | None = None):
356
- super().__init__()
357
- self.config = config
358
- self.layer_idx = layer_idx
359
- self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
360
- self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
361
- self.scaling = self.head_dim**-0.5
362
- self.attention_dropout = config.attention_dropout
363
- self.is_causal = True
364
-
365
- self.q_proj = nn.Linear(
366
- config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
367
- )
368
- self.k_proj = nn.Linear(
369
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
370
- )
371
- self.v_proj = nn.Linear(
372
- config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
373
- )
374
- self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
375
- self.rope_parameters = config.rope_parameters
376
-
377
- def forward(
378
- self,
379
- hidden_states: torch.Tensor,
380
- position_embeddings: tuple[torch.Tensor, torch.Tensor],
381
- attention_mask: torch.Tensor | None,
382
- past_key_values: Cache | None = None,
383
- cache_position: torch.LongTensor | None = None,
384
- **kwargs: Unpack[FlashAttentionKwargs],
385
- ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
386
- input_shape = hidden_states.shape[:-1]
387
- hidden_shape = (*input_shape, -1, self.head_dim)
388
-
389
- query_states = self.q_proj(hidden_states).view(hidden_shape)
390
- key_states = self.k_proj(hidden_states).view(hidden_shape)
391
- value_states = self.v_proj(hidden_states).view(hidden_shape)
392
-
393
- query_states = query_states.transpose(1, 2)
394
- key_states = key_states.transpose(1, 2)
395
- value_states = value_states.transpose(1, 2)
396
-
397
- cos, sin = position_embeddings
398
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
399
-
400
- if past_key_values is not None:
401
- # sin and cos are specific to RoPE models; position_ids needed for the static cache
402
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
403
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
404
-
405
- attention_interface: Callable = eager_attention_forward
406
- if self.config._attn_implementation != "eager":
407
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
408
-
409
- attn_output, attn_weights = attention_interface(
410
- self,
411
- query_states,
412
- key_states,
413
- value_states,
414
- attention_mask,
415
- dropout=0.0 if not self.training else self.attention_dropout,
416
- scaling=self.scaling,
417
- **kwargs,
418
- )
419
-
420
- attn_output = attn_output.reshape(*input_shape, -1).contiguous()
421
- attn_output = self.o_proj(attn_output)
422
- return attn_output, attn_weights
423
-
424
-
425
- @use_kernel_forward_from_hub("RMSNorm")
426
- class GlmImageRMSNorm(nn.Module):
427
- def __init__(self, hidden_size, eps=1e-6):
428
- """
429
- GlmImageRMSNorm is equivalent to T5LayerNorm
430
- """
431
- super().__init__()
432
- self.weight = nn.Parameter(torch.ones(hidden_size))
433
- self.variance_epsilon = eps
434
-
435
- def forward(self, hidden_states):
436
- input_dtype = hidden_states.dtype
437
- hidden_states = hidden_states.to(torch.float32)
438
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
439
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
440
- return self.weight * hidden_states.to(input_dtype)
441
-
442
- def extra_repr(self):
443
- return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
444
-
445
-
446
- class GlmImageTextMLP(nn.Module):
447
- def __init__(self, config):
448
- super().__init__()
449
-
450
- self.config = config
451
- self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
452
- self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
453
- self.activation_fn = ACT2FN[config.hidden_act]
454
-
455
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
456
- up_states = self.gate_up_proj(hidden_states)
457
-
458
- gate, up_states = up_states.chunk(2, dim=-1)
459
- up_states = up_states * self.activation_fn(gate)
460
-
461
- return self.down_proj(up_states)
462
-
463
-
464
- class GlmImageTextDecoderLayer(GradientCheckpointingLayer):
465
- def __init__(self, config: GlmImageTextConfig, layer_idx: int):
466
- super().__init__()
467
- self.hidden_size = config.hidden_size
468
- self.self_attn = GlmImageTextAttention(config, layer_idx)
469
- self.mlp = GlmImageTextMLP(config)
470
- self.input_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
- self.post_attention_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
472
- self.post_self_attn_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
473
- self.post_mlp_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
474
-
475
- @auto_docstring
476
- def forward(
477
- self,
478
- hidden_states: torch.Tensor,
479
- position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
480
- attention_mask: torch.Tensor | None = None,
481
- position_ids: torch.LongTensor | None = None,
482
- past_key_values: Cache | None = None,
483
- use_cache: bool | None = False,
484
- cache_position: torch.LongTensor | None = None,
485
- **kwargs,
486
- ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
487
- residual = hidden_states
488
-
489
- hidden_states = self.input_layernorm(hidden_states)
490
-
491
- # Self Attention
492
- hidden_states, _ = self.self_attn(
493
- hidden_states=hidden_states,
494
- position_embeddings=position_embeddings,
495
- attention_mask=attention_mask,
496
- position_ids=position_ids,
497
- past_key_values=past_key_values,
498
- use_cache=use_cache,
499
- cache_position=cache_position,
500
- **kwargs,
501
- )
502
-
503
- hidden_states = self.post_self_attn_layernorm(hidden_states)
504
- hidden_states = residual + hidden_states
505
-
506
- # Fully Connected
507
- residual = hidden_states
508
- hidden_states = self.post_attention_layernorm(hidden_states)
509
- hidden_states = self.mlp(hidden_states)
510
- hidden_states = self.post_mlp_layernorm(hidden_states)
511
- hidden_states = residual + hidden_states
512
-
513
- return hidden_states
514
-
515
-
516
- @auto_docstring
517
- class GlmImagePreTrainedModel(PreTrainedModel):
518
- config: GlmImageConfig
519
- base_model_prefix = "model"
520
- input_modalities = ("image", "text")
521
- supports_gradient_checkpointing = True
522
- _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]
523
- _skip_keys_device_placement = "past_key_values"
524
- _supports_flash_attn = True
525
- _supports_sdpa = True
526
-
527
- _can_compile_fullgraph = True
528
- _supports_attention_backend = True
529
- _can_record_outputs = {
530
- "hidden_states": GlmImageTextDecoderLayer,
531
- "attentions": GlmImageTextAttention,
532
- }
533
-
534
- @torch.no_grad()
535
- def _init_weights(self, module):
536
- super()._init_weights(module)
537
-
538
-
539
- @dataclass
540
- @auto_docstring(
541
- custom_intro="""
542
- Base class for Llava outputs, with hidden states and attentions.
543
- """
544
- )
545
- class GlmImageModelOutputWithPast(ModelOutput):
546
- r"""
547
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
548
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
549
-
550
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
551
- `past_key_values` input) to speed up sequential decoding.
552
- rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
553
- The rope index difference between sequence length and multimodal rope.
554
- """
555
-
556
- last_hidden_state: torch.FloatTensor | None = None
557
- past_key_values: Cache | None = None
558
- hidden_states: tuple[torch.FloatTensor] | None = None
559
- attentions: tuple[torch.FloatTensor] | None = None
560
- rope_deltas: torch.LongTensor | None = None
561
-
562
-
563
- class GlmImageVQVAEVectorQuantizer(nn.Module):
564
- """
565
- A module for vector quantization using learned embedding vectors.
566
-
567
- This module implements the quantization process similar to te one described in
568
- the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
569
- input vectors into discrete codebook vectors, which are learned during training.
570
- Current implementation improves over previous ones by avoiding costly matrix multiplications
571
- and allowing for post-hoc remapping of indices.
572
- """
573
-
574
- def __init__(self, config: GlmImageVQVAEConfig):
575
- super().__init__()
576
- self.num_embeddings = config.num_embeddings
577
- self.embedding_dim = config.embed_dim
578
- self.beta = getattr(config, "beta", 0.25)
579
-
580
- self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
581
-
582
- def forward(self, hidden_state: torch.Tensor):
583
- hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
584
- hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
585
-
586
- # L2 normalize
587
- hidden_state = F.normalize(hidden_state, p=2, dim=-1)
588
- hidden_state_flattened = F.normalize(hidden_state_flattened, p=2, dim=-1)
589
- embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
590
-
591
- # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
592
- distances = (
593
- torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
594
- + torch.sum(embedding**2, dim=1)
595
- - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, embedding.transpose(0, 1))
596
- )
597
-
598
- min_encoding_indices = torch.argmin(distances, dim=1)
599
- hidden_state_quant = embedding[min_encoding_indices].view(hidden_state.shape)
600
-
601
- # compute loss for embedding
602
- loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
603
- (hidden_state_quant - hidden_state.detach()) ** 2
604
- )
605
-
606
- # preserve gradients
607
- hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
608
-
609
- # reshape back to match original input shape
610
- hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
611
-
612
- return hidden_state_quant, loss, min_encoding_indices
613
-
614
-
615
- @dataclass
616
- @auto_docstring
617
- class GlmImageVQVAEModelOutput(BaseModelOutputWithPooling):
618
- r"""
619
- quantized_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
620
- Quantized last hidden state from the VQ-VAE model.
621
- image_tokens (`torch.FloatTensor` of shape `(batch_size, config.vocab_size`):
622
- Indices of the image tokens predicted by the VQ-VAE model.
623
- embedding_loss (`torch.FloatTensor`):
624
- The embedding loss computed during quantization.
625
- """
626
-
627
- quantized_last_hidden_state: torch.FloatTensor | None = None
628
- image_tokens: torch.FloatTensor | None = None
629
- embedding_loss: torch.FloatTensor | None = None
630
-
631
-
632
- @auto_docstring(
633
- custom_intro="""
634
- The VQ-VAE model used in GlmImage for encoding/decoding images into discrete tokens.
635
- This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
636
- [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
637
- Taigman](https://huggingface.co/papers/2203.13131).
638
- """
639
- )
640
- class GlmImageVQVAE(GlmImagePreTrainedModel):
641
- config: GlmImageVQVAEConfig
642
- _no_split_modules = [
643
- "GlmImageVQVAEVectorQuantizer",
644
- ]
645
- _can_record_outputs = {}
646
-
647
- def __init__(self, config: GlmImageVQVAEConfig):
648
- super().__init__(config)
649
- self.quantize = GlmImageVQVAEVectorQuantizer(config)
650
- self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
651
- self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
652
- self.eval() # GlmImage's VQ model is frozen
653
- self.post_init()
654
-
655
- @check_model_inputs
656
- def encode(self, hidden_states) -> GlmImageVQVAEModelOutput:
657
- conv_hidden_states = self.quant_conv(hidden_states)
658
- quantized_last_hidden_state, emb_loss, indices = self.quantize(conv_hidden_states)
659
- return GlmImageVQVAEModelOutput(
660
- last_hidden_state=hidden_states,
661
- quantized_last_hidden_state=quantized_last_hidden_state,
662
- image_tokens=indices,
663
- embedding_loss=emb_loss,
664
- )
665
-
666
-
667
- class GlmImageVisionModel(GlmImagePreTrainedModel):
668
- config: GlmImageVisionConfig
669
- input_modalities = ("image",)
670
- _no_split_modules = ["GlmImageVisionBlock"]
671
- _can_record_outputs = {
672
- "hidden_states": GlmImageVisionBlock,
673
- "attentions": GlmImageVisionAttention,
674
- }
675
- main_input_name = "pixel_values"
676
-
677
- def __init__(self, config: GlmImageVisionConfig) -> None:
678
- super().__init__(config)
679
- self.spatial_merge_size = config.spatial_merge_size
680
- self.patch_size = config.patch_size
681
-
682
- self.embeddings = GlmImageVisionEmbeddings(config)
683
- self.patch_embed = GlmImageVisionPatchEmbed(config)
684
-
685
- head_dim = config.hidden_size // config.num_heads
686
-
687
- self.blocks = nn.ModuleList([GlmImageVisionBlock(config) for _ in range(config.depth)])
688
-
689
- self.gradient_checkpointing = False
690
- self.head_dim = head_dim
691
- self.post_init()
692
-
693
- def rot_pos_emb(self, grid_thw):
694
- pos_ids = []
695
- for t, h, w in grid_thw:
696
- hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
697
- hpos_ids = hpos_ids.reshape(
698
- h // self.spatial_merge_size,
699
- self.spatial_merge_size,
700
- w // self.spatial_merge_size,
701
- self.spatial_merge_size,
702
- )
703
- hpos_ids = hpos_ids.permute(0, 2, 1, 3)
704
- hpos_ids = hpos_ids.flatten()
705
-
706
- wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
707
- wpos_ids = wpos_ids.reshape(
708
- h // self.spatial_merge_size,
709
- self.spatial_merge_size,
710
- w // self.spatial_merge_size,
711
- self.spatial_merge_size,
712
- )
713
- wpos_ids = wpos_ids.permute(0, 2, 1, 3)
714
- wpos_ids = wpos_ids.flatten()
715
- pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
716
- pos_ids = torch.cat(pos_ids, dim=0)
717
- return pos_ids
718
-
719
- @check_model_inputs
720
- @auto_docstring
721
- def forward(
722
- self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
723
- ) -> tuple | BaseModelOutputWithPooling:
724
- r"""
725
- pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`):
726
- Packed pixel values.
727
- grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
728
- The temporal, height and width of feature shape of each image.
729
-
730
- Returns:
731
- `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states.
732
- """
733
-
734
- hidden_states = self.patch_embed(pixel_values)
735
- image_type_ids = self.rot_pos_emb(grid_thw)
736
-
737
- cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
738
- dim=0,
739
- dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
740
- )
741
- cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
742
- seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
743
- hidden_states = self.embeddings(
744
- hidden_states,
745
- seqlens,
746
- grid_thw,
747
- image_type_ids[:, 0].to(hidden_states.device),
748
- image_type_ids[:, 1].to(hidden_states.device),
749
- )
750
-
751
- # Transformer blocks (no position_embeddings needed, already added above)
752
- for blk in self.blocks:
753
- hidden_states = blk(
754
- hidden_states,
755
- cu_seqlens=cu_seqlens,
756
- )
757
-
758
- return BaseModelOutputWithPooling(last_hidden_state=hidden_states)
759
-
760
-
761
- class GlmImageTextRotaryEmbedding(nn.Module):
762
- inv_freq: torch.Tensor # fix linting for `register_buffer`
763
-
764
- def __init__(self, config: GlmImageTextConfig, device=None):
765
- super().__init__()
766
- self.max_seq_len_cached = config.max_position_embeddings
767
- self.original_max_seq_len = config.max_position_embeddings
768
-
769
- self.config = config
770
-
771
- self.rope_type = self.config.rope_parameters["rope_type"]
772
- rope_init_fn: Callable = self.compute_default_rope_parameters
773
- if self.rope_type != "default":
774
- rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
775
- inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
776
-
777
- self.register_buffer("inv_freq", inv_freq, persistent=False)
778
- self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
779
- self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])
780
-
781
- @staticmethod
782
- def compute_default_rope_parameters(
783
- config: GlmImageTextConfig | None = None,
784
- device: Optional["torch.device"] = None,
785
- seq_len: int | None = None,
786
- ) -> tuple["torch.Tensor", float]:
787
- """
788
- Computes the inverse frequencies according to the original RoPE implementation
789
- Args:
790
- config ([`~transformers.PreTrainedConfig`]):
791
- The model configuration.
792
- device (`torch.device`):
793
- The device to use for initialization of the inverse frequencies.
794
- seq_len (`int`, *optional*):
795
- The current sequence length. Unused for this type of RoPE.
796
- Returns:
797
- Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
798
- post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
799
- """
800
- base = config.rope_parameters["rope_theta"]
801
- partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
802
- head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
803
- dim = int(head_dim * partial_rotary_factor)
804
-
805
- attention_factor = 1.0 # Unused in this type of RoPE
806
-
807
- # Compute the inverse frequencies
808
- inv_freq = 1.0 / (
809
- base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
810
- )
811
- return inv_freq, attention_factor
812
-
813
- @torch.no_grad()
814
- @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
815
- def forward(self, x, position_ids):
816
- # In contrast to other models, GLM-V has different position ids for the grids
817
- # So we expand the inv_freq to shape (3, ...)
818
- inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
819
- position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
820
-
821
- device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
822
- with maybe_autocast(device_type=device_type, enabled=False): # Force float32
823
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
824
- freqs = self.apply_mrope(freqs, self.mrope_section)
825
- emb = torch.cat((freqs, freqs), dim=-1)
826
- cos = emb.cos() * self.attention_scaling
827
- sin = emb.sin() * self.attention_scaling
828
-
829
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
830
-
831
- def apply_mrope(self, freqs, mrope_section):
832
- section = mrope_section
833
- chunks = freqs.split(section, dim=-1)
834
- result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
835
- return result
836
-
837
-
838
- @auto_docstring
839
- class GlmImageTextModel(GlmImagePreTrainedModel):
840
- config: GlmImageTextConfig
841
- input_modalities = ("text",)
842
-
843
- def __init__(self, config: GlmImageTextConfig):
844
- super().__init__(config)
845
- self.padding_idx = config.pad_token_id
846
- self.vocab_size = config.vocab_size
847
-
848
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
849
- self.layers = nn.ModuleList(
850
- [GlmImageTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
851
- )
852
- self.norm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
853
- self.rotary_emb = GlmImageTextRotaryEmbedding(config=config)
854
-
855
- self.gradient_checkpointing = False
856
- # Initialize weights and apply final processing
857
- self.post_init()
858
-
859
- @auto_docstring
860
- @check_model_inputs
861
- def forward(
862
- self,
863
- input_ids: torch.LongTensor | None = None,
864
- attention_mask: torch.Tensor | None = None,
865
- position_ids: torch.LongTensor | None = None,
866
- past_key_values: Cache | None = None,
867
- inputs_embeds: torch.FloatTensor | None = None,
868
- use_cache: bool | None = None,
869
- cache_position: torch.LongTensor | None = None,
870
- **kwargs: Unpack[FlashAttentionKwargs],
871
- ) -> tuple | BaseModelOutputWithPast:
872
- if (input_ids is None) ^ (inputs_embeds is not None):
873
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
874
-
875
- # torch.jit.trace() doesn't support cache objects in the output
876
- if use_cache and past_key_values is None and not torch.jit.is_tracing():
877
- past_key_values = DynamicCache(config=self.config)
878
-
879
- if inputs_embeds is None:
880
- inputs_embeds = self.embed_tokens(input_ids)
881
-
882
- if cache_position is None:
883
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
884
- cache_position = torch.arange(
885
- past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
886
- )
887
-
888
- # the hard coded `3` is for temporal, height and width.
889
- if position_ids is None:
890
- position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
891
- elif position_ids.ndim == 2:
892
- position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
893
-
894
- # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
895
- # where each dim indicates visual spatial positions for temporal/height/width grids.
896
- # There are two scenarios when FA2-like packed masking might be activated.
897
- # 1. User specifically passed packed `position_ids` and no attention mask.
898
- # In this case we expect the useer to create correct position ids for all 3 grids
899
- # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
900
- # 2. User runs forward with no attention mask and no position ids. In this case, position ids
901
- # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
902
- # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
903
- # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
904
- if position_ids.ndim == 3 and position_ids.shape[0] == 4:
905
- text_position_ids = position_ids[0]
906
- position_ids = position_ids[1:]
907
- else:
908
- # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
909
- text_position_ids = None
910
-
911
- mask_kwargs = {
912
- "config": self.config,
913
- "input_embeds": inputs_embeds,
914
- "attention_mask": attention_mask,
915
- "cache_position": cache_position,
916
- "past_key_values": past_key_values,
917
- "position_ids": text_position_ids,
918
- }
919
- # Create the masks
920
- causal_mask = create_causal_mask(**mask_kwargs)
921
-
922
- hidden_states = inputs_embeds
923
- position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
924
-
925
- for decoder_layer in self.layers:
926
- layer_outputs = decoder_layer(
927
- hidden_states,
928
- attention_mask=causal_mask,
929
- position_ids=text_position_ids,
930
- past_key_values=past_key_values,
931
- cache_position=cache_position,
932
- position_embeddings=position_embeddings,
933
- **kwargs,
934
- )
935
- hidden_states = layer_outputs
936
-
937
- hidden_states = self.norm(hidden_states)
938
-
939
- return BaseModelOutputWithPast(
940
- last_hidden_state=hidden_states,
941
- past_key_values=past_key_values,
942
- )
943
-
944
-
945
- @auto_docstring
946
- class GlmImageModel(GlmImagePreTrainedModel):
947
- base_model_prefix = "model"
948
- _checkpoint_conversion_mapping = {}
949
- # Reference: fix gemma3 grad acc #37208
950
- accepts_loss_kwargs = False
951
- config: GlmImageConfig
952
- _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]
953
-
954
- def __init__(self, config):
955
- super().__init__(config)
956
- self.visual = GlmImageVisionModel._from_config(config.vision_config)
957
- self.language_model = GlmImageTextModel._from_config(config.text_config)
958
-
959
- self.rope_deltas = None # cache rope_deltas here
960
- self.vqmodel = GlmImageVQVAE._from_config(config.vq_config)
961
-
962
- # Initialize weights and apply final processing
963
- self.post_init()
964
-
965
- def get_input_embeddings(self):
966
- return self.language_model.get_input_embeddings()
967
-
968
- def set_input_embeddings(self, value):
969
- self.language_model.set_input_embeddings(value)
970
-
971
- def get_rope_index(
972
- self,
973
- input_ids: torch.LongTensor | None = None,
974
- image_grid_thw: torch.LongTensor | None = None,
975
- attention_mask: torch.LongTensor | None = None,
976
- ) -> tuple[torch.Tensor, torch.Tensor]:
977
- """
978
- Calculate the 3D rope index for image generation task.
979
-
980
- Explanation:
981
- Each embedding sequence may contain image tokens (for generation) and text tokens,
982
- or just text tokens.
983
-
984
- Input format:
985
- - Text-to-Image: [text tokens] + <|dit_token_16384|>
986
- - Image-to-Image: <|dit_token_16384|> [image tokens] <|dit_token_16385|> + [text tokens] + <|dit_token_16384|>
987
-
988
- For pure text embedding sequence, the rotary position embedding is the same across all 3 dimensions.
989
- Examples:
990
- input_ids: [T T T T T], here T is for text.
991
- temporal position_ids: [0, 1, 2, 3, 4]
992
- height position_ids: [0, 1, 2, 3, 4]
993
- width position_ids: [0, 1, 2, 3, 4]
994
-
995
- For sequences with image tokens, we use special markers to denote image regions:
996
- - <|dit_token_16384|>: image start marker
997
- - <|dit_token_16385|>: image end marker
998
- - Image tokens between these markers use 2D spatial position encoding.
999
-
1000
- For image tokens:
1001
- - temporal: stays constant at (image_start_pos + 1)
1002
- - height: increments every w tokens, representing row position
1003
- - width: cycles from 0 to w-1, representing column position
1004
-
1005
- After each image region, the next position jumps to: image_start_pos + 1 + max(h, w)
1006
- This ensures sufficient positional separation between images and subsequent tokens.
1007
-
1008
- Examples:
1009
- === Case 1: Image-to-Image Generation ===
1010
-
1011
- Source image with grid [1, 3, 2], followed by text, then generation.
1012
- input_ids: [<|dit_token_16384|> V V V V V V <|dit_token_16385|> T T T T <|dit_token_16384|>]
1013
- image_grid_thw: [[1, 3, 2], [1, 4, 4]] # first is source, second is target
1014
-
1015
- For source image (h=3, w=2, 6 tokens):
1016
- Start marker at position 0
1017
- Image tokens at temporal=1, height=[1,1,2,2,3,3], width=[1,2,1,2,1,2]
1018
- End marker at position 4 (= 0 + 1 + max(3,2))
1019
-
1020
- Text tokens and trailing start marker continue from position 5.
1021
-
1022
- Full prefill position_ids:
1023
- temporal: [0, 1,1,1,1,1,1, 4, 5,6,7,8, 9]
1024
- height: [0, 1,1,2,2,3,3, 4, 5,6,7,8, 9]
1025
- width: [0, 1,2,1,2,1,2, 4, 5,6,7,8, 9]
1026
-
1027
- Decode stage: use image_grid_thw[-1] = [1, 4, 4] to build cached position_ids,
1028
- starting from gen_st_idx = 10.
1029
-
1030
- === Case 2: Text-to-Image Generation (multi-resolution) ===
1031
-
1032
- Pure text input with two image_grids for progressive generation.
1033
- input_ids: [hello<sop>3 3<eop><sop>3 2<eop><|dit_token_16384|>]
1034
- Assume "hello<sop>3 3<eop><sop>3 2<eop>" = 4 tokens (positions 0-3)
1035
- <|dit_token_16384|> at position 4
1036
- image_grid_thw: [[1, 3, 3], [1, 3, 2]]
1037
- - image_grid_thw[-1] = [1, 3, 2]: first generated image (smaller/draft)
1038
- - image_grid_thw[-2] = [1, 3, 3]: second generated image (larger/final)
1039
-
1040
- Prefill position_ids (5 tokens: 4 text + 1 start marker):
1041
- temporal: [0, 1, 2, 3, 4]
1042
- height: [0, 1, 2, 3, 4]
1043
- width: [0, 1, 2, 3, 4]
1044
-
1045
- Decode stage builds position_ids in reverse order of image_grid_thw:
1046
-
1047
- First: image_grid_thw[-1] = [1, 3, 2] (6 tokens), starting at position 5:
1048
- temporal: [5, 5, 5, 5, 5, 5]
1049
- height: [5, 5, 6, 6, 7, 7]
1050
- width: [5, 6, 5, 6, 5, 6]
1051
- next_pos = 5 + max(3, 2) = 8
1052
-
1053
- Then: image_grid_thw[-2] = [1, 3, 3] (9 tokens), starting at position 8:
1054
- temporal: [8, 8, 8, 8, 8, 8, 8, 8, 8]
1055
- height: [8, 8, 8, 9, 9, 9, 10, 10, 10]
1056
- width: [8, 9, 10, 8, 9, 10, 8, 9, 10]
1057
- next_pos = 8 + max(3, 3) = 11
1058
-
1059
- Finally: <|dit_token_16385|> end marker at position 11
1060
-
1061
- Full sequence position_ids (prefill + decode):
1062
- temporal: [0,1,2,3, 4, 5,5,5,5,5,5, 8,8,8,8,8,8,8,8,8, 11]
1063
- height: [0,1,2,3, 4, 5,5,6,6,7,7, 8,8,8,9,9,9,10,10,10, 11]
1064
- width: [0,1,2,3, 4, 5,6,5,6,5,6, 8,9,10,8,9,10,8,9,10, 11]
1065
-
1066
- _cached_decode_position_ids shape: [3, 6 + 9 + 1] = [3, 16]
1067
- (includes all generated image tokens + end marker)
1068
-
1069
- Args:
1070
- input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1071
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default
1072
- should you provide it.
1073
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1074
- The temporal, height and width of feature shape of each image. For image generation,
1075
- temporal is typically 1.
1076
- - For image-to-image: includes source image grids + target image grid(s)
1077
- - For text-to-image with multi-resolution: includes multiple target grids,
1078
- processed in reverse order (last grid first, second-to-last grid second, etc.)
1079
- attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1080
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1081
- - 1 for tokens that are **not masked**,
1082
- - 0 for tokens that are **masked**.
1083
-
1084
- Returns:
1085
- position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`):
1086
- Position IDs for temporal, height, and width dimensions.
1087
- mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`):
1088
- Position deltas for multi-modal rotary position embedding (zeros for this task).
1089
- """
1090
-
1091
- batch_size, seq_len = input_ids.shape
1092
- device = input_ids.device
1093
- dtype = input_ids.dtype
1094
-
1095
- image_start_token_id = self.config.image_start_token_id
1096
- image_end_token_id = self.config.image_end_token_id
1097
- num_complete_images = (input_ids == image_end_token_id).sum().item()
1098
-
1099
- position_ids = torch.ones(
1100
- 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
1101
- )
1102
- text_positions = torch.arange(seq_len)[None, :].repeat(3, 1)
1103
- for batch_idx in range(batch_size):
1104
- curr_input_ids = input_ids[batch_idx]
1105
- if attention_mask is not None:
1106
- curr_input_ids = curr_input_ids[attention_mask[batch_idx] == 1]
1107
-
1108
- image_end = torch.where(curr_input_ids == image_end_token_id)[0]
1109
- image_start = torch.where(curr_input_ids == image_start_token_id)[0] + 1
1110
- current_pos = 0 # track the current position value
1111
- prev_image_end = 0
1112
- curr_position_ids = []
1113
- for start, end, grid in zip(image_start, image_end, image_grid_thw):
1114
- _, num_width_grid, num_height_grid = grid
1115
-
1116
- # Create text position ids first if there are text tokens before image
1117
- llm_pos_length = start - prev_image_end
1118
- llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(
1119
- device=input_ids.device
1120
- )
1121
- current_pos += llm_position_ids.shape[-1]
1122
-
1123
- # Now create image position ids for each grid
1124
- image_seq_length = num_height_grid * num_width_grid
1125
- h_grids = image_seq_length // num_height_grid + current_pos
1126
- w_grids = image_seq_length // num_width_grid + current_pos
1127
- position_width = torch.arange(current_pos, w_grids, device=input_ids.device).repeat(num_width_grid)
1128
- position_height = torch.arange(current_pos, h_grids, device=input_ids.device).repeat_interleave(
1129
- num_height_grid
1130
- )
1131
- position_temporal = torch.full(
1132
- (image_seq_length,), current_pos, device=input_ids.device, dtype=torch.long
1133
- )
1134
- vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0)
1135
- current_pos += max(num_height_grid, num_width_grid)
1136
-
1137
- prev_image_end = end
1138
- curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1))
1139
-
1140
- # Add position ids for the last text tokens if any
1141
- end_position = len(curr_input_ids) - prev_image_end
1142
- llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=input_ids.device)
1143
- current_pos += llm_position_ids.shape[-1]
1144
- curr_position_ids.append(llm_position_ids)
1145
- curr_position_ids = torch.cat(curr_position_ids, dim=-1)
1146
- if attention_mask is not None:
1147
- position_ids[:, batch_idx, attention_mask[batch_idx] == 1] = curr_position_ids.to(position_ids.device)
1148
- else:
1149
- position_ids[:, batch_idx, :] = curr_position_ids.to(position_ids.device)
1150
-
1151
- # Build and store position ids for tokens that will be generated. Later we will just
1152
- # slice these instead of computing each decoding step
1153
- self._prefill_len = seq_len
1154
- if image_grid_thw is not None and len(image_grid_thw) > 0:
1155
- num_decode_grids = len(image_grid_thw) - num_complete_images
1156
- num_decode_grids = max(num_decode_grids, 0)
1157
- decode_pos = current_pos
1158
-
1159
- decode_temporal_list = []
1160
- decode_height_list = []
1161
- decode_width_list = []
1162
-
1163
- for i in range(1, num_decode_grids + 1):
1164
- grid_idx = -i
1165
- h = image_grid_thw[grid_idx, 1].item()
1166
- w = image_grid_thw[grid_idx, 2].item()
1167
- total_tokens = h * w
1168
-
1169
- h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
1170
- w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()
1171
-
1172
- decode_temporal_list.append(torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long))
1173
- decode_height_list.append(decode_pos + h_indices)
1174
- decode_width_list.append(decode_pos + w_indices)
1175
- decode_pos = decode_pos + max(h, w)
1176
-
1177
- decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1178
- decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1179
- decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1180
-
1181
- self._cached_decode_position_ids = torch.stack(
1182
- [
1183
- torch.cat(decode_temporal_list, dim=0),
1184
- torch.cat(decode_height_list, dim=0),
1185
- torch.cat(decode_width_list, dim=0),
1186
- ],
1187
- dim=0,
1188
- )
1189
- else:
1190
- self._cached_decode_position_ids = None
1191
-
1192
- mrope_position_deltas = torch.zeros([batch_size, 1], dtype=dtype, device=device)
1193
-
1194
- return position_ids, mrope_position_deltas
1195
-
1196
- @can_return_tuple
1197
- @auto_docstring
1198
- def get_image_features(
1199
- self,
1200
- pixel_values: torch.FloatTensor,
1201
- image_grid_thw: torch.LongTensor | None = None,
1202
- **kwargs: Unpack[TransformersKwargs],
1203
- ) -> tuple | BaseModelOutputWithPooling:
1204
- r"""
1205
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1206
- The tensors corresponding to the input images.
1207
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1208
- The temporal, height and width of feature shape of each image in LLM.
1209
- """
1210
- pixel_values = pixel_values.type(self.visual.dtype)
1211
- vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs)
1212
- split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1213
- image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes)
1214
- vision_outputs.pooler_output = image_embeds
1215
-
1216
- return vision_outputs
1217
-
1218
- def get_placeholder_mask(
1219
- self,
1220
- input_ids: torch.LongTensor,
1221
- image_ids: torch.LongTensor,
1222
- ):
1223
- """
1224
- Replace image placeholder tokens in input_ids with actual image token ids from VQVAE.
1225
-
1226
- Args:
1227
- input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`):
1228
- Input token ids with image placeholders.
1229
- image_ids (`torch.LongTensor` of shape `(num_images, num_tokens_per_image)` or flattened):
1230
- Discrete token indices from the VQVAE codebook.
1231
-
1232
- Returns:
1233
- special_image_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
1234
- Mask indicating positions in input ids that will be replaced by actual image tokens.
1235
- """
1236
-
1237
- special_image_mask = input_ids == self.config.image_token_id
1238
- n_placeholder_tokens = special_image_mask.sum().item()
1239
- n_image_tokens = image_ids.shape[0]
1240
-
1241
- if n_placeholder_tokens != n_image_tokens:
1242
- raise ValueError(
1243
- f"Number of image placeholder tokens ({n_placeholder_tokens}) does not match "
1244
- f"number of image tokens from VQVAE ({n_image_tokens})"
1245
- )
1246
-
1247
- return special_image_mask
1248
-
1249
- @auto_docstring
1250
- @can_return_tuple
1251
- def forward(
1252
- self,
1253
- input_ids: torch.LongTensor | None = None,
1254
- attention_mask: torch.Tensor | None = None,
1255
- position_ids: torch.LongTensor | None = None,
1256
- past_key_values: Cache | None = None,
1257
- inputs_embeds: torch.FloatTensor | None = None,
1258
- pixel_values: torch.Tensor | None = None,
1259
- image_grid_thw: torch.LongTensor | None = None,
1260
- rope_deltas: torch.LongTensor | None = None,
1261
- cache_position: torch.LongTensor | None = None,
1262
- **kwargs: Unpack[TransformersKwargs],
1263
- ) -> tuple | GlmImageModelOutputWithPast:
1264
- r"""
1265
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1266
- The temporal, height and width of feature shape of each image in LLM.
1267
- rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1268
- The rope index difference between sequence length and multimodal rope.
1269
- """
1270
- if (input_ids is None) ^ (inputs_embeds is not None):
1271
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1272
-
1273
- if pixel_values is not None:
1274
- image_embeds = self.get_image_features(pixel_values, image_grid_thw[:-1], return_dict=True).pooler_output
1275
- image_embeds = torch.cat(image_embeds, dim=0)
1276
- image_ids = self.get_image_tokens(image_embeds, image_grid_thw[:-1])
1277
- image_ids = image_ids.view(-1).to(input_ids.device)
1278
- special_image_mask = self.get_placeholder_mask(input_ids, image_ids)
1279
- input_ids = input_ids.masked_scatter(special_image_mask, image_ids)
1280
-
1281
- if inputs_embeds is None:
1282
- inputs_embeds = self.get_input_embeddings()(input_ids)
1283
-
1284
- if position_ids is None:
1285
- attention_mask_2d = attention_mask
1286
- if attention_mask is not None and attention_mask.ndim == 4:
1287
- attention_mask_2d = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
1288
- # Only apply conversion for floating point tensors (inverted masks)
1289
- if attention_mask_2d.dtype.is_floating_point:
1290
- attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
1291
- attention_mask_2d = (1.0 - attention_mask_2d).int()
1292
-
1293
- # Calculate RoPE index once per generation in the pre-fill stage only.
1294
- # It is safe to assume that `length!=1` means we're in pre-fill because the
1295
- # model is used only by DiT pipeline without assisted decoding, etc. techniques
1296
- is_prefill_stage = (input_ids is not None and input_ids.shape[1] != 1) or (
1297
- inputs_embeds is not None and inputs_embeds.shape[1] != 1
1298
- )
1299
- if is_prefill_stage or self.rope_deltas is None:
1300
- position_ids, rope_deltas = self.get_rope_index(
1301
- input_ids,
1302
- image_grid_thw,
1303
- attention_mask=attention_mask_2d,
1304
- )
1305
- self.rope_deltas = rope_deltas
1306
- # then use the prev pre-calculated rope-deltas to get the correct position ids
1307
- else:
1308
- batch_size, seq_length, _ = inputs_embeds.shape
1309
- # Use prefill token length, not position value
1310
- step = cache_position[0].item() - self._prefill_len
1311
- # Direct lookup - no tensor creation overhead
1312
- position_ids = self._cached_decode_position_ids[:, step : step + seq_length]
1313
- position_ids = position_ids.unsqueeze(1).expand(-1, batch_size, -1)
1314
-
1315
- outputs = self.language_model(
1316
- input_ids=None,
1317
- position_ids=position_ids,
1318
- attention_mask=attention_mask,
1319
- past_key_values=past_key_values,
1320
- inputs_embeds=inputs_embeds,
1321
- cache_position=cache_position,
1322
- **kwargs,
1323
- )
1324
-
1325
- return GlmImageModelOutputWithPast(
1326
- last_hidden_state=outputs.last_hidden_state,
1327
- past_key_values=outputs.past_key_values,
1328
- hidden_states=outputs.hidden_states,
1329
- attentions=outputs.attentions,
1330
- rope_deltas=self.rope_deltas,
1331
- )
1332
-
1333
- def get_image_tokens(
1334
- self,
1335
- hidden_states: torch.FloatTensor,
1336
- image_grid_thw: torch.LongTensor,
1337
- ) -> torch.LongTensor:
1338
- """
1339
- Tokenizes image features into discrete tokens with VQVAE module.
1340
-
1341
- Args:
1342
- hidden_states (`torch.FloatTensor` of shape `(total_patches, hidden_size)`):
1343
- The packed image features from vision encoder.
1344
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
1345
- The temporal, height and width of feature shape of each image.
1346
-
1347
- Returns:
1348
- image_tokens (`torch.LongTensor` of shape `(total_patches,)`):
1349
- Discrete token indices from the VQVAE codebook.
1350
- """
1351
- hidden_size = hidden_states.shape[-1]
1352
- split_sizes = (image_grid_thw.prod(dim=-1)).tolist()
1353
- hidden_states_list = torch.split(hidden_states, split_sizes, dim=0)
1354
-
1355
- all_image_toks = []
1356
- for i, hs in enumerate(hidden_states_list):
1357
- grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
1358
- hs = hs.view(grid_t, grid_h, grid_w, hidden_size)
1359
- hs = hs.permute(0, 3, 1, 2).contiguous()
1360
- vqmodel_outputs: GlmImageVQVAEModelOutput = self.vqmodel.encode(hs)
1361
- all_image_toks.append(vqmodel_outputs.image_tokens)
1362
- return torch.cat(all_image_toks, dim=0)
1363
-
1364
-
1365
- @dataclass
1366
- @auto_docstring(
1367
- custom_intro="""
1368
- Base class for GlmImage causal language model (or autoregressive) outputs.
1369
- """
1370
- )
1371
- class GlmImageCausalLMOutputWithPast(ModelOutput):
1372
- r"""
1373
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1374
- Language modeling loss (for next-token prediction).
1375
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1376
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1377
- past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1378
- It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
1379
-
1380
- Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1381
- `past_key_values` input) to speed up sequential decoding.
1382
- rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1383
- The rope index difference between sequence length and multimodal rope.
1384
- """
1385
-
1386
- loss: torch.FloatTensor | None = None
1387
- logits: torch.FloatTensor | None = None
1388
- past_key_values: Cache | None = None
1389
- hidden_states: tuple[torch.FloatTensor] | None = None
1390
- attentions: tuple[torch.FloatTensor] | None = None
1391
- rope_deltas: torch.LongTensor | None = None
1392
-
1393
-
1394
- class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin):
1395
- _checkpoint_conversion_mapping = {}
1396
- _tied_weights_keys = {}
1397
- # Reference: fix gemma3 grad acc #37208
1398
- accepts_loss_kwargs = False
1399
- base_model_prefix = "model"
1400
- config: GlmImageConfig
1401
-
1402
- def __init__(self, config):
1403
- super().__init__(config)
1404
- self.model = GlmImageModel(config)
1405
- self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vision_vocab_size, bias=False)
1406
-
1407
- # Initialize weights and apply final processing
1408
- self.post_init()
1409
-
1410
- @auto_docstring
1411
- def get_image_features(
1412
- self,
1413
- pixel_values: torch.FloatTensor,
1414
- image_grid_thw: torch.LongTensor | None = None,
1415
- **kwargs: Unpack[TransformersKwargs],
1416
- ) -> tuple | BaseModelOutputWithPooling:
1417
- r"""
1418
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1419
- The tensors corresponding to the input images.
1420
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1421
- The temporal, height and width of feature shape of each image in LLM.
1422
- """
1423
- return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs)
1424
-
1425
- def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
1426
- return self.model.get_image_tokens(hidden_states, image_grid_thw)
1427
-
1428
- def forward(
1429
- self,
1430
- input_ids: torch.LongTensor | None = None,
1431
- attention_mask: torch.Tensor | None = None,
1432
- position_ids: torch.LongTensor | None = None,
1433
- past_key_values: Cache | None = None,
1434
- inputs_embeds: torch.FloatTensor | None = None,
1435
- labels: torch.LongTensor | None = None,
1436
- pixel_values: torch.Tensor | None = None,
1437
- image_grid_thw: torch.LongTensor | None = None,
1438
- cache_position: torch.LongTensor | None = None,
1439
- logits_to_keep: int | torch.Tensor = 0,
1440
- **kwargs: Unpack[TransformersKwargs],
1441
- ) -> tuple | GlmImageCausalLMOutputWithPast:
1442
- r"""
1443
- labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1444
- Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1445
- config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1446
- (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1447
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1448
- The temporal, height and width of feature shape of each image in LLM.
1449
-
1450
- Example:
1451
-
1452
- ```python
1453
- >>> from PIL import Image
1454
- >>> import httpx
1455
- >>> from io import BytesIO
1456
- >>> from transformers import AutoProcessor, GlmImageForConditionalGeneration
1457
-
1458
- >>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image")
1459
- >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-Image")
1460
-
1461
- >>> messages = [
1462
- {
1463
- "role": "user",
1464
- "content": [
1465
- {"type": "image"},
1466
- {"type": "text", "text": "Add a truck of this photo.<sop>28 40<eop>"},
1467
- ],
1468
- },
1469
- ]
1470
- >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1471
- >>> with httpx.stream("GET", url) as response:
1472
- ... image = Image.open(BytesIO(response.read()))
1473
-
1474
- >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1475
- >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
1476
-
1477
- >>> # Generate
1478
- >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1479
- >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1480
- "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
1481
- ```"""
1482
- outputs = self.model(
1483
- input_ids=input_ids,
1484
- pixel_values=pixel_values,
1485
- image_grid_thw=image_grid_thw,
1486
- position_ids=position_ids,
1487
- attention_mask=attention_mask,
1488
- past_key_values=past_key_values,
1489
- inputs_embeds=inputs_embeds,
1490
- cache_position=cache_position,
1491
- **kwargs,
1492
- )
1493
-
1494
- hidden_states = outputs[0]
1495
-
1496
- # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1497
- slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1498
- logits = self.lm_head(hidden_states[:, slice_indices, :])
1499
-
1500
- loss = None
1501
- if labels is not None:
1502
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
1503
-
1504
- return GlmImageCausalLMOutputWithPast(
1505
- loss=loss,
1506
- logits=logits,
1507
- past_key_values=outputs.past_key_values,
1508
- hidden_states=outputs.hidden_states,
1509
- attentions=outputs.attentions,
1510
- rope_deltas=outputs.rope_deltas,
1511
- )
1512
-
1513
- def prepare_inputs_for_generation(
1514
- self,
1515
- input_ids,
1516
- past_key_values=None,
1517
- attention_mask=None,
1518
- inputs_embeds=None,
1519
- cache_position=None,
1520
- position_ids=None,
1521
- use_cache=True,
1522
- pixel_values=None,
1523
- image_grid_thw=None,
1524
- is_first_iteration=False,
1525
- **kwargs,
1526
- ):
1527
- model_inputs = super().prepare_inputs_for_generation(
1528
- input_ids,
1529
- past_key_values=past_key_values,
1530
- attention_mask=attention_mask,
1531
- inputs_embeds=inputs_embeds,
1532
- cache_position=cache_position,
1533
- position_ids=position_ids,
1534
- pixel_values=pixel_values,
1535
- image_grid_thw=image_grid_thw,
1536
- is_first_iteration=is_first_iteration,
1537
- use_cache=use_cache,
1538
- **kwargs,
1539
- )
1540
-
1541
- model_inputs["position_ids"] = None
1542
-
1543
- if not is_first_iteration and use_cache:
1544
- model_inputs["pixel_values"] = None
1545
-
1546
- return model_inputs
1547
-
1548
- def _get_image_nums(
1549
- self,
1550
- input_ids: torch.LongTensor | None,
1551
- ) -> torch.Tensor:
1552
- """
1553
- Get the number of images for each sample.
1554
- For GLM-Image, only input_ids allow us to get the number of images.
1555
-
1556
- Returns:
1557
- image_counts (`torch.LongTensor` of shape `(batch_size,)`)
1558
- """
1559
- is_image = input_ids == self.config.image_start_token_id
1560
-
1561
- return is_image.sum(dim=1)
1562
-
1563
- def _expand_inputs_for_generation(
1564
- self,
1565
- expand_size: int = 1,
1566
- is_encoder_decoder: bool = False,
1567
- input_ids: torch.LongTensor | None = None,
1568
- **model_kwargs,
1569
- ) -> tuple[torch.LongTensor, dict[str, Any]]:
1570
- # Overwritten -- Support for expanding tensors without a batch size dimension
1571
- # e.g., pixel_values, image_grid_thw
1572
- # pixel_values.shape[0] is sum(seqlen_images for samples)
1573
- # image_grid_thw.shape[0] is sum(num_images for samples)
1574
-
1575
- if expand_size == 1:
1576
- return input_ids, model_kwargs
1577
-
1578
- visual_keys = ["pixel_values", "image_grid_thw"]
1579
-
1580
- def _expand_dict_for_generation_visual(dict_to_expand):
1581
- image_grid_thw = model_kwargs.get("image_grid_thw", None)
1582
- image_nums = self._get_image_nums(input_ids)
1583
-
1584
- def _repeat_interleave_samples(x, lengths, repeat_times):
1585
- samples = torch.split(x, lengths)
1586
- repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1587
- result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1588
- return result
1589
-
1590
- for key in dict_to_expand:
1591
- if key == "pixel_values":
1592
- # split images into samples
1593
- samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums))
1594
- # compute the sequence length of images for each sample
1595
- lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1596
- dict_to_expand[key] = _repeat_interleave_samples(
1597
- dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1598
- )
1599
- elif key == "image_grid_thw":
1600
- # get the num of images for each sample and +1 for the image being generated
1601
- lengths = list(image_nums)
1602
- last_image = dict_to_expand[key][:-1]
1603
- dict_to_expand[key] = _repeat_interleave_samples(
1604
- dict_to_expand[key][: sum(image_nums)], lengths=lengths, repeat_times=expand_size
1605
- )
1606
- dict_to_expand[key] = torch.cat([dict_to_expand[key], last_image], dim=0)
1607
- return dict_to_expand
1608
-
1609
- def _expand_dict_for_generation(dict_to_expand):
1610
- for key in dict_to_expand:
1611
- if (
1612
- key != "cache_position"
1613
- and dict_to_expand[key] is not None
1614
- and isinstance(dict_to_expand[key], torch.Tensor)
1615
- and key not in visual_keys
1616
- ):
1617
- dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1618
- return dict_to_expand
1619
-
1620
- model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1621
-
1622
- if input_ids is not None:
1623
- input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1624
-
1625
- model_kwargs = _expand_dict_for_generation(model_kwargs)
1626
-
1627
- if is_encoder_decoder:
1628
- if model_kwargs.get("encoder_outputs") is None:
1629
- raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1630
- model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1631
-
1632
- return input_ids, model_kwargs
1633
-
1634
-
1635
- __all__ = [
1636
- "GlmImagePreTrainedModel",
1637
- "GlmImageVQVAE",
1638
- "GlmImageVisionModel",
1639
- "GlmImageTextModel",
1640
- "GlmImageModel",
1641
- "GlmImageForConditionalGeneration",
1642
- ]