transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1691 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_glm_image.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2025 the HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from collections.abc import Callable
22
+ from dataclasses import dataclass
23
+ from typing import Any, Optional
24
+
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ from ...activations import ACT2FN
29
+ from ...cache_utils import Cache, DynamicCache
30
+ from ...generation import GenerationMixin
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
32
+ from ...masking_utils import create_causal_mask
33
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from ...modeling_layers import GradientCheckpointingLayer
35
+ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
36
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from ...processing_utils import Unpack
39
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available
40
+ from ...utils.generic import check_model_inputs, maybe_autocast
41
+ from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig
42
+
43
+
44
+ if is_torch_available():
45
+ import torch
46
+
47
+
48
+ class GlmImageVisionMLP(nn.Module):
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.config = config
52
+ self.activation_fn = ACT2FN[config.hidden_act]
53
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
54
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
55
+
56
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
57
+ hidden_states = self.fc1(hidden_states)
58
+ hidden_states = self.activation_fn(hidden_states)
59
+ hidden_states = self.fc2(hidden_states)
60
+ return hidden_states
61
+
62
+
63
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
64
+ """
65
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
66
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
67
+ """
68
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
69
+ if n_rep == 1:
70
+ return hidden_states
71
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
72
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
73
+
74
+
75
+ def eager_attention_forward(
76
+ module: nn.Module,
77
+ query: torch.Tensor,
78
+ key: torch.Tensor,
79
+ value: torch.Tensor,
80
+ attention_mask: torch.Tensor | None,
81
+ scaling: float,
82
+ dropout: float = 0.0,
83
+ **kwargs: Unpack[TransformersKwargs],
84
+ ):
85
+ key_states = repeat_kv(key, module.num_key_value_groups)
86
+ value_states = repeat_kv(value, module.num_key_value_groups)
87
+
88
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
89
+ if attention_mask is not None:
90
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
91
+ attn_weights = attn_weights + causal_mask
92
+
93
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
94
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
95
+ attn_output = torch.matmul(attn_weights, value_states)
96
+ attn_output = attn_output.transpose(1, 2).contiguous()
97
+
98
+ return attn_output, attn_weights
99
+
100
+
101
+ class GlmImageVisionAttention(nn.Module):
102
+ def __init__(self, config: GlmImageVisionConfig) -> None:
103
+ super().__init__()
104
+ self.dim = config.hidden_size
105
+ self.num_heads = config.num_heads
106
+ self.head_dim = self.dim // self.num_heads
107
+ self.num_key_value_groups = 1 # needed for eager attention
108
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
109
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
110
+ self.scaling = self.head_dim**-0.5
111
+ self.config = config
112
+ self.attention_dropout = config.attention_dropout
113
+ self.is_causal = False
114
+
115
+ def forward(
116
+ self,
117
+ hidden_states: torch.Tensor,
118
+ cu_seqlens: torch.Tensor,
119
+ **kwargs,
120
+ ) -> torch.Tensor:
121
+ seq_length = hidden_states.shape[0]
122
+ query_states, key_states, value_states = (
123
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
124
+ )
125
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
126
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
127
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
128
+
129
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
130
+ self.config._attn_implementation, eager_attention_forward
131
+ )
132
+
133
+ if "flash" in self.config._attn_implementation:
134
+ # Flash Attention: Use cu_seqlens for variable length attention
135
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
136
+ attn_output, _ = attention_interface(
137
+ self,
138
+ query_states,
139
+ key_states,
140
+ value_states,
141
+ attention_mask=None,
142
+ scaling=self.scaling,
143
+ dropout=0.0 if not self.training else self.attention_dropout,
144
+ cu_seq_lens_q=cu_seqlens,
145
+ cu_seq_lens_k=cu_seqlens,
146
+ max_length_q=max_seqlen,
147
+ max_length_k=max_seqlen,
148
+ is_causal=False,
149
+ **kwargs,
150
+ )
151
+ else:
152
+ # Other implementations: Process each chunk separately
153
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
154
+ splits = [
155
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
156
+ ]
157
+
158
+ attn_outputs = [
159
+ attention_interface(
160
+ self,
161
+ q,
162
+ k,
163
+ v,
164
+ attention_mask=None,
165
+ scaling=self.scaling,
166
+ dropout=0.0 if not self.training else self.attention_dropout,
167
+ is_causal=False,
168
+ **kwargs,
169
+ )[0]
170
+ for q, k, v in zip(*splits)
171
+ ]
172
+ attn_output = torch.cat(attn_outputs, dim=1)
173
+
174
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
175
+ attn_output = self.proj(attn_output)
176
+ return attn_output
177
+
178
+
179
+ class GlmImageVisionPatchEmbed(nn.Module):
180
+ def __init__(self, config: GlmImageVisionConfig) -> None:
181
+ super().__init__()
182
+ self.patch_size = config.patch_size
183
+ self.in_channels = config.in_channels
184
+ self.embed_dim = config.hidden_size
185
+ kernel_size = [self.patch_size, self.patch_size]
186
+ self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
187
+
188
+ def forward(self, hidden_states) -> torch.Tensor:
189
+ target_dtype = self.proj.weight.dtype
190
+ hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size)
191
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
192
+ return hidden_states
193
+
194
+
195
+ class GlmImageVisionEmbeddings(nn.Module):
196
+ def __init__(self, config: GlmImageVisionConfig) -> None:
197
+ super().__init__()
198
+ self.config = config
199
+ self.embed_dim = config.hidden_size
200
+ self.image_size = config.image_size
201
+ self.patch_size = config.patch_size
202
+
203
+ self.num_patches = (self.image_size // self.patch_size) ** 2
204
+ self.num_positions = self.num_patches
205
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
206
+ self.interpolated_method = "bilinear"
207
+
208
+ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
209
+ """
210
+ Forward pass with integrated position encoding adaptation using 2D interpolation.
211
+
212
+ Args:
213
+ embeddings: Input embeddings tensor
214
+ lengths (torch.Tensor): Sequence lengths for each image in the batch.
215
+ image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
216
+ h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
217
+ w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
218
+
219
+ Returns:
220
+ torch.Tensor: Embeddings with adapted position encoding added.
221
+ """
222
+ # Get position embedding parameters
223
+ pos_embed_weight = self.position_embedding.weight
224
+ hidden_size = pos_embed_weight.shape[1]
225
+ device = pos_embed_weight.device
226
+
227
+ # Convert inputs to tensors if needed
228
+ if isinstance(lengths, list):
229
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
230
+
231
+ # Prepare 2D position embedding
232
+ orig_size_sq = pos_embed_weight.shape[0]
233
+ orig_size = int(orig_size_sq**0.5)
234
+ pos_embed_2d = (
235
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
236
+ .permute(2, 0, 1)
237
+ .unsqueeze(0)
238
+ .to(device=device, dtype=torch.float32)
239
+ )
240
+
241
+ # Calculate target dimensions for each patch
242
+ target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
243
+ device=device, dtype=torch.float32
244
+ )
245
+ target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
246
+ device=device, dtype=torch.float32
247
+ )
248
+
249
+ # Normalize coordinates to [-1, 1] range for grid_sample
250
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
251
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
252
+
253
+ # Create sampling grid
254
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
255
+
256
+ # Perform bicubic interpolation
257
+ interpolated_embed_fp32 = F.grid_sample(
258
+ pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border"
259
+ )
260
+
261
+ # Reshape and convert back to original dtype
262
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
263
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
264
+
265
+ # Add adapted position encoding to embeddings
266
+ embeddings = embeddings + adapted_pos_embed
267
+ return embeddings
268
+
269
+
270
+ class GlmImageVisionBlock(GradientCheckpointingLayer):
271
+ def __init__(self, config: GlmImageVisionConfig) -> None:
272
+ super().__init__()
273
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
274
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
275
+ self.attn = GlmImageVisionAttention(config)
276
+ self.mlp = GlmImageVisionMLP(config)
277
+
278
+ def forward(
279
+ self,
280
+ hidden_states: torch.Tensor,
281
+ cu_seqlens: torch.Tensor,
282
+ **kwargs: Unpack[TransformersKwargs],
283
+ ) -> torch.Tensor:
284
+ r"""
285
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
286
+ The cumulative sequence lengths of each image or video feature.
287
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
288
+ The cosine and sine position embeddings for vision attention.
289
+ """
290
+ residual = hidden_states
291
+
292
+ hidden_states = self.norm1(hidden_states)
293
+ hidden_states = self.attn(
294
+ hidden_states,
295
+ cu_seqlens=cu_seqlens,
296
+ **kwargs,
297
+ )
298
+ hidden_states = residual + hidden_states
299
+
300
+ residual = hidden_states
301
+ hidden_states = self.norm2(hidden_states)
302
+ hidden_states = self.mlp(hidden_states)
303
+ hidden_states = residual + hidden_states
304
+
305
+ return hidden_states
306
+
307
+
308
+ def rotate_half(x):
309
+ """Rotates half the hidden dims of the input."""
310
+ x1 = x[..., : x.shape[-1] // 2]
311
+ x2 = x[..., x.shape[-1] // 2 :]
312
+ return torch.cat((-x2, x1), dim=-1)
313
+
314
+
315
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
316
+ """Applies Rotary Position Embedding to the query and key tensors.
317
+
318
+ Args:
319
+ q (`torch.Tensor`): The query tensor.
320
+ k (`torch.Tensor`): The key tensor.
321
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
322
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
323
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
324
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
325
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
326
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
327
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
328
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
329
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
330
+ Returns:
331
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
332
+ """
333
+ cos = cos.unsqueeze(unsqueeze_dim)
334
+ sin = sin.unsqueeze(unsqueeze_dim)
335
+
336
+ # Keep half or full tensor for later concatenation
337
+ rotary_dim = cos.shape[-1]
338
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
339
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
340
+
341
+ # Apply rotary embeddings on the first half or full tensor
342
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
343
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
344
+
345
+ # Concatenate back to full shape
346
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
347
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
348
+ return q_embed, k_embed
349
+
350
+
351
+ @use_kernelized_func(apply_rotary_pos_emb)
352
+ class GlmImageTextAttention(nn.Module):
353
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
354
+
355
+ def __init__(self, config: GlmImageTextConfig, layer_idx: int | None = None):
356
+ super().__init__()
357
+ self.config = config
358
+ self.layer_idx = layer_idx
359
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
360
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
361
+ self.scaling = self.head_dim**-0.5
362
+ self.attention_dropout = config.attention_dropout
363
+ self.is_causal = True
364
+
365
+ self.q_proj = nn.Linear(
366
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
367
+ )
368
+ self.k_proj = nn.Linear(
369
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
370
+ )
371
+ self.v_proj = nn.Linear(
372
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
373
+ )
374
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
375
+ self.rope_parameters = config.rope_parameters
376
+
377
+ def forward(
378
+ self,
379
+ hidden_states: torch.Tensor,
380
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
381
+ attention_mask: torch.Tensor | None,
382
+ past_key_values: Cache | None = None,
383
+ cache_position: torch.LongTensor | None = None,
384
+ **kwargs: Unpack[FlashAttentionKwargs],
385
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
386
+ input_shape = hidden_states.shape[:-1]
387
+ hidden_shape = (*input_shape, -1, self.head_dim)
388
+
389
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
390
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
391
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
392
+
393
+ query_states = query_states.transpose(1, 2)
394
+ key_states = key_states.transpose(1, 2)
395
+ value_states = value_states.transpose(1, 2)
396
+
397
+ cos, sin = position_embeddings
398
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
399
+
400
+ if past_key_values is not None:
401
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
402
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
403
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
404
+
405
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
406
+ self.config._attn_implementation, eager_attention_forward
407
+ )
408
+
409
+ attn_output, attn_weights = attention_interface(
410
+ self,
411
+ query_states,
412
+ key_states,
413
+ value_states,
414
+ attention_mask,
415
+ dropout=0.0 if not self.training else self.attention_dropout,
416
+ scaling=self.scaling,
417
+ **kwargs,
418
+ )
419
+
420
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
421
+ attn_output = self.o_proj(attn_output)
422
+ return attn_output, attn_weights
423
+
424
+
425
+ @use_kernel_forward_from_hub("RMSNorm")
426
+ class GlmImageRMSNorm(nn.Module):
427
+ def __init__(self, hidden_size, eps=1e-6):
428
+ """
429
+ GlmImageRMSNorm is equivalent to T5LayerNorm
430
+ """
431
+ super().__init__()
432
+ self.weight = nn.Parameter(torch.ones(hidden_size))
433
+ self.variance_epsilon = eps
434
+
435
+ def forward(self, hidden_states):
436
+ input_dtype = hidden_states.dtype
437
+ hidden_states = hidden_states.to(torch.float32)
438
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
439
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
440
+ return self.weight * hidden_states.to(input_dtype)
441
+
442
+ def extra_repr(self):
443
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
444
+
445
+
446
+ class GlmImageTextMLP(nn.Module):
447
+ def __init__(self, config):
448
+ super().__init__()
449
+
450
+ self.config = config
451
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
452
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
453
+ self.activation_fn = ACT2FN[config.hidden_act]
454
+
455
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
456
+ up_states = self.gate_up_proj(hidden_states)
457
+
458
+ gate, up_states = up_states.chunk(2, dim=-1)
459
+ up_states = up_states * self.activation_fn(gate)
460
+
461
+ return self.down_proj(up_states)
462
+
463
+
464
+ class GlmImageTextDecoderLayer(GradientCheckpointingLayer):
465
+ def __init__(self, config: GlmImageTextConfig, layer_idx: int):
466
+ super().__init__()
467
+ self.hidden_size = config.hidden_size
468
+ self.self_attn = GlmImageTextAttention(config, layer_idx)
469
+ self.mlp = GlmImageTextMLP(config)
470
+ self.input_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
+ self.post_attention_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
472
+ self.post_self_attn_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
473
+ self.post_mlp_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
474
+
475
+ @auto_docstring
476
+ def forward(
477
+ self,
478
+ hidden_states: torch.Tensor,
479
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
480
+ attention_mask: torch.Tensor | None = None,
481
+ position_ids: torch.LongTensor | None = None,
482
+ past_key_values: Cache | None = None,
483
+ use_cache: bool | None = False,
484
+ cache_position: torch.LongTensor | None = None,
485
+ **kwargs,
486
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
487
+ residual = hidden_states
488
+
489
+ hidden_states = self.input_layernorm(hidden_states)
490
+
491
+ # Self Attention
492
+ hidden_states, _ = self.self_attn(
493
+ hidden_states=hidden_states,
494
+ position_embeddings=position_embeddings,
495
+ attention_mask=attention_mask,
496
+ position_ids=position_ids,
497
+ past_key_values=past_key_values,
498
+ use_cache=use_cache,
499
+ cache_position=cache_position,
500
+ **kwargs,
501
+ )
502
+
503
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
504
+ hidden_states = residual + hidden_states
505
+
506
+ # Fully Connected
507
+ residual = hidden_states
508
+ hidden_states = self.post_attention_layernorm(hidden_states)
509
+ hidden_states = self.mlp(hidden_states)
510
+ hidden_states = self.post_mlp_layernorm(hidden_states)
511
+ hidden_states = residual + hidden_states
512
+
513
+ return hidden_states
514
+
515
+
516
+ @auto_docstring
517
+ class GlmImagePreTrainedModel(PreTrainedModel):
518
+ config: GlmImageConfig
519
+ base_model_prefix = "model"
520
+ input_modalities = ("image", "text")
521
+ supports_gradient_checkpointing = True
522
+ _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]
523
+ _skip_keys_device_placement = "past_key_values"
524
+ _supports_flash_attn = True
525
+ _supports_sdpa = True
526
+
527
+ _can_compile_fullgraph = True
528
+ _supports_attention_backend = True
529
+ _can_record_outputs = {
530
+ "hidden_states": GlmImageTextDecoderLayer,
531
+ "attentions": GlmImageTextAttention,
532
+ }
533
+
534
+ @torch.no_grad()
535
+ def _init_weights(self, module):
536
+ super()._init_weights(module)
537
+
538
+
539
+ @dataclass
540
+ @auto_docstring(
541
+ custom_intro="""
542
+ Base class for Llava outputs, with hidden states and attentions.
543
+ """
544
+ )
545
+ class GlmImageModelOutputWithPast(ModelOutput):
546
+ r"""
547
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
548
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
549
+
550
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
551
+ `past_key_values` input) to speed up sequential decoding.
552
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
553
+ The rope index difference between sequence length and multimodal rope.
554
+ """
555
+
556
+ last_hidden_state: torch.FloatTensor | None = None
557
+ past_key_values: Cache | None = None
558
+ hidden_states: tuple[torch.FloatTensor] | None = None
559
+ attentions: tuple[torch.FloatTensor] | None = None
560
+ rope_deltas: torch.LongTensor | None = None
561
+
562
+
563
+ class GlmImageVQVAEVectorQuantizer(nn.Module):
564
+ """
565
+ A module for vector quantization using learned embedding vectors.
566
+
567
+ This module implements the quantization process similar to te one described in
568
+ the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
569
+ input vectors into discrete codebook vectors, which are learned during training.
570
+ Current implementation improves over previous ones by avoiding costly matrix multiplications
571
+ and allowing for post-hoc remapping of indices.
572
+ """
573
+
574
+ def __init__(self, config: GlmImageVQVAEConfig):
575
+ super().__init__()
576
+ self.num_embeddings = config.num_embeddings
577
+ self.embedding_dim = config.embed_dim
578
+ self.beta = getattr(config, "beta", 0.25)
579
+
580
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
581
+
582
+ def forward(self, hidden_state: torch.Tensor):
583
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
584
+ hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
585
+
586
+ # L2 normalize
587
+ hidden_state = F.normalize(hidden_state, p=2, dim=-1)
588
+ hidden_state_flattened = F.normalize(hidden_state_flattened, p=2, dim=-1)
589
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
590
+
591
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
592
+ distances = (
593
+ torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
594
+ + torch.sum(embedding**2, dim=1)
595
+ - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, embedding.transpose(0, 1))
596
+ )
597
+
598
+ min_encoding_indices = torch.argmin(distances, dim=1)
599
+ hidden_state_quant = embedding[min_encoding_indices].view(hidden_state.shape)
600
+
601
+ # compute loss for embedding
602
+ loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
603
+ (hidden_state_quant - hidden_state.detach()) ** 2
604
+ )
605
+
606
+ # preserve gradients
607
+ hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
608
+
609
+ # reshape back to match original input shape
610
+ hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
611
+
612
+ return hidden_state_quant, loss, min_encoding_indices
613
+
614
+
615
+ @dataclass
616
+ @auto_docstring
617
+ class GlmImageVQVAEModelOutput(BaseModelOutputWithPooling):
618
+ r"""
619
+ quantized_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
620
+ Quantized last hidden state from the VQ-VAE model.
621
+ image_tokens (`torch.FloatTensor` of shape `(batch_size, config.vocab_size`):
622
+ Indices of the image tokens predicted by the VQ-VAE model.
623
+ embedding_loss (`torch.FloatTensor`):
624
+ The embedding loss computed during quantization.
625
+ """
626
+
627
+ quantized_last_hidden_state: torch.FloatTensor | None = None
628
+ image_tokens: torch.FloatTensor | None = None
629
+ embedding_loss: torch.FloatTensor | None = None
630
+
631
+
632
+ @auto_docstring(
633
+ custom_intro="""
634
+ The VQ-VAE model used in GlmImage for encoding/decoding images into discrete tokens.
635
+ This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
636
+ [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
637
+ Taigman](https://huggingface.co/papers/2203.13131).
638
+ """
639
+ )
640
+ class GlmImageVQVAE(GlmImagePreTrainedModel):
641
+ config: GlmImageVQVAEConfig
642
+ _no_split_modules = [
643
+ "GlmImageVQVAEVectorQuantizer",
644
+ ]
645
+ _can_record_outputs = {}
646
+
647
+ def __init__(self, config: GlmImageVQVAEConfig):
648
+ super().__init__(config)
649
+ self.quantize = GlmImageVQVAEVectorQuantizer(config)
650
+ self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
651
+ self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
652
+ self.eval() # GlmImage's VQ model is frozen
653
+ self.post_init()
654
+
655
+ @check_model_inputs
656
+ def encode(self, hidden_states) -> GlmImageVQVAEModelOutput:
657
+ conv_hidden_states = self.quant_conv(hidden_states)
658
+ quantized_last_hidden_state, emb_loss, indices = self.quantize(conv_hidden_states)
659
+ return GlmImageVQVAEModelOutput(
660
+ last_hidden_state=hidden_states,
661
+ quantized_last_hidden_state=quantized_last_hidden_state,
662
+ image_tokens=indices,
663
+ embedding_loss=emb_loss,
664
+ )
665
+
666
+
667
+ class GlmImageVisionModel(GlmImagePreTrainedModel):
668
+ config: GlmImageVisionConfig
669
+ input_modalities = ("image",)
670
+ _no_split_modules = ["GlmImageVisionBlock"]
671
+ _can_record_outputs = {
672
+ "hidden_states": GlmImageVisionBlock,
673
+ "attentions": GlmImageVisionAttention,
674
+ }
675
+ main_input_name = "pixel_values"
676
+
677
+ def __init__(self, config: GlmImageVisionConfig) -> None:
678
+ super().__init__(config)
679
+ self.spatial_merge_size = config.spatial_merge_size
680
+ self.patch_size = config.patch_size
681
+
682
+ self.embeddings = GlmImageVisionEmbeddings(config)
683
+ self.patch_embed = GlmImageVisionPatchEmbed(config)
684
+
685
+ head_dim = config.hidden_size // config.num_heads
686
+
687
+ self.blocks = nn.ModuleList([GlmImageVisionBlock(config) for _ in range(config.depth)])
688
+
689
+ self.gradient_checkpointing = False
690
+ self.head_dim = head_dim
691
+ self.post_init()
692
+
693
+ def rot_pos_emb(self, grid_thw):
694
+ pos_ids = []
695
+ for t, h, w in grid_thw:
696
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
697
+ hpos_ids = hpos_ids.reshape(
698
+ h // self.spatial_merge_size,
699
+ self.spatial_merge_size,
700
+ w // self.spatial_merge_size,
701
+ self.spatial_merge_size,
702
+ )
703
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
704
+ hpos_ids = hpos_ids.flatten()
705
+
706
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
707
+ wpos_ids = wpos_ids.reshape(
708
+ h // self.spatial_merge_size,
709
+ self.spatial_merge_size,
710
+ w // self.spatial_merge_size,
711
+ self.spatial_merge_size,
712
+ )
713
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
714
+ wpos_ids = wpos_ids.flatten()
715
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
716
+ pos_ids = torch.cat(pos_ids, dim=0)
717
+ return pos_ids
718
+
719
+ @check_model_inputs
720
+ @auto_docstring
721
+ def forward(
722
+ self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
723
+ ) -> tuple | BaseModelOutputWithPooling:
724
+ r"""
725
+ pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`):
726
+ Packed pixel values.
727
+ grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
728
+ The temporal, height and width of feature shape of each image.
729
+
730
+ Returns:
731
+ `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states.
732
+ """
733
+
734
+ hidden_states = self.patch_embed(pixel_values)
735
+ image_type_ids = self.rot_pos_emb(grid_thw)
736
+
737
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
738
+ dim=0,
739
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
740
+ )
741
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
742
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
743
+ hidden_states = self.embeddings(
744
+ hidden_states,
745
+ seqlens,
746
+ grid_thw,
747
+ image_type_ids[:, 0].to(hidden_states.device),
748
+ image_type_ids[:, 1].to(hidden_states.device),
749
+ )
750
+
751
+ # Transformer blocks (no position_embeddings needed, already added above)
752
+ for blk in self.blocks:
753
+ hidden_states = blk(
754
+ hidden_states,
755
+ cu_seqlens=cu_seqlens,
756
+ )
757
+
758
+ return BaseModelOutputWithPooling(last_hidden_state=hidden_states)
759
+
760
+
761
+ class GlmImageTextRotaryEmbedding(nn.Module):
762
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
763
+
764
+ def __init__(self, config: GlmImageTextConfig, device=None):
765
+ super().__init__()
766
+ self.max_seq_len_cached = config.max_position_embeddings
767
+ self.original_max_seq_len = config.max_position_embeddings
768
+
769
+ self.config = config
770
+
771
+ self.rope_type = self.config.rope_parameters["rope_type"]
772
+ rope_init_fn: Callable = self.compute_default_rope_parameters
773
+ if self.rope_type != "default":
774
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
775
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
776
+
777
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
778
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
779
+ self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])
780
+
781
+ @staticmethod
782
+ def compute_default_rope_parameters(
783
+ config: GlmImageTextConfig | None = None,
784
+ device: Optional["torch.device"] = None,
785
+ seq_len: int | None = None,
786
+ ) -> tuple["torch.Tensor", float]:
787
+ """
788
+ Computes the inverse frequencies according to the original RoPE implementation
789
+ Args:
790
+ config ([`~transformers.PreTrainedConfig`]):
791
+ The model configuration.
792
+ device (`torch.device`):
793
+ The device to use for initialization of the inverse frequencies.
794
+ seq_len (`int`, *optional*):
795
+ The current sequence length. Unused for this type of RoPE.
796
+ Returns:
797
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
798
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
799
+ """
800
+ base = config.rope_parameters["rope_theta"]
801
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
802
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
803
+ dim = int(head_dim * partial_rotary_factor)
804
+
805
+ attention_factor = 1.0 # Unused in this type of RoPE
806
+
807
+ # Compute the inverse frequencies
808
+ inv_freq = 1.0 / (
809
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
810
+ )
811
+ return inv_freq, attention_factor
812
+
813
+ @torch.no_grad()
814
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
815
+ def forward(self, x, position_ids):
816
+ # In contrast to other models, GLM-V has different position ids for the grids
817
+ # So we expand the inv_freq to shape (3, ...)
818
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
819
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
820
+
821
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
822
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
823
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
824
+ freqs = self.apply_mrope(freqs, self.mrope_section)
825
+ emb = torch.cat((freqs, freqs), dim=-1)
826
+ cos = emb.cos() * self.attention_scaling
827
+ sin = emb.sin() * self.attention_scaling
828
+
829
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
830
+
831
+ def apply_mrope(self, freqs, mrope_section):
832
+ section = mrope_section
833
+ chunks = freqs.split(section, dim=-1)
834
+ result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
835
+ return result
836
+
837
+
838
+ @auto_docstring
839
+ class GlmImageTextModel(GlmImagePreTrainedModel):
840
+ config: GlmImageTextConfig
841
+ input_modalities = ("text",)
842
+
843
+ def __init__(self, config: GlmImageTextConfig):
844
+ super().__init__(config)
845
+ self.padding_idx = config.pad_token_id
846
+ self.vocab_size = config.vocab_size
847
+
848
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
849
+ self.layers = nn.ModuleList(
850
+ [GlmImageTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
851
+ )
852
+ self.norm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
853
+ self.rotary_emb = GlmImageTextRotaryEmbedding(config=config)
854
+
855
+ self.gradient_checkpointing = False
856
+ # Initialize weights and apply final processing
857
+ self.post_init()
858
+
859
+ @auto_docstring
860
+ @check_model_inputs
861
+ def forward(
862
+ self,
863
+ input_ids: torch.LongTensor | None = None,
864
+ attention_mask: torch.Tensor | None = None,
865
+ position_ids: torch.LongTensor | None = None,
866
+ past_key_values: Cache | None = None,
867
+ inputs_embeds: torch.FloatTensor | None = None,
868
+ use_cache: bool | None = None,
869
+ cache_position: torch.LongTensor | None = None,
870
+ **kwargs: Unpack[FlashAttentionKwargs],
871
+ ) -> tuple | BaseModelOutputWithPast:
872
+ if (input_ids is None) ^ (inputs_embeds is not None):
873
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
874
+
875
+ # torch.jit.trace() doesn't support cache objects in the output
876
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
877
+ past_key_values = DynamicCache(config=self.config)
878
+
879
+ if inputs_embeds is None:
880
+ inputs_embeds = self.embed_tokens(input_ids)
881
+
882
+ if cache_position is None:
883
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
884
+ cache_position = torch.arange(
885
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
886
+ )
887
+
888
+ # the hard coded `3` is for temporal, height and width.
889
+ if position_ids is None:
890
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
891
+ elif position_ids.ndim == 2:
892
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
893
+
894
+ # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
895
+ # where each dim indicates visual spatial positions for temporal/height/width grids.
896
+ # There are two scenarios when FA2-like packed masking might be activated.
897
+ # 1. User specifically passed packed `position_ids` and no attention mask.
898
+ # In this case we expect the useer to create correct position ids for all 3 grids
899
+ # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
900
+ # 2. User runs forward with no attention mask and no position ids. In this case, position ids
901
+ # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
902
+ # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
903
+ # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
904
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
905
+ text_position_ids = position_ids[0]
906
+ position_ids = position_ids[1:]
907
+ else:
908
+ # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
909
+ text_position_ids = None
910
+
911
+ mask_kwargs = {
912
+ "config": self.config,
913
+ "input_embeds": inputs_embeds,
914
+ "attention_mask": attention_mask,
915
+ "cache_position": cache_position,
916
+ "past_key_values": past_key_values,
917
+ "position_ids": text_position_ids,
918
+ }
919
+ # Create the masks
920
+ causal_mask = create_causal_mask(**mask_kwargs)
921
+
922
+ hidden_states = inputs_embeds
923
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
924
+
925
+ for decoder_layer in self.layers:
926
+ layer_outputs = decoder_layer(
927
+ hidden_states,
928
+ attention_mask=causal_mask,
929
+ position_ids=text_position_ids,
930
+ past_key_values=past_key_values,
931
+ cache_position=cache_position,
932
+ position_embeddings=position_embeddings,
933
+ **kwargs,
934
+ )
935
+ hidden_states = layer_outputs
936
+
937
+ hidden_states = self.norm(hidden_states)
938
+
939
+ return BaseModelOutputWithPast(
940
+ last_hidden_state=hidden_states,
941
+ past_key_values=past_key_values,
942
+ )
943
+
944
+
945
+ @auto_docstring
946
+ class GlmImageModel(GlmImagePreTrainedModel):
947
+ base_model_prefix = "model"
948
+ _checkpoint_conversion_mapping = {}
949
+ # Reference: fix gemma3 grad acc #37208
950
+ accepts_loss_kwargs = False
951
+ config: GlmImageConfig
952
+ _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]
953
+
954
+ def __init__(self, config):
955
+ super().__init__(config)
956
+ self.visual = GlmImageVisionModel._from_config(config.vision_config)
957
+ self.language_model = GlmImageTextModel._from_config(config.text_config)
958
+
959
+ self.rope_deltas = None # cache rope_deltas here
960
+ self.vqmodel = GlmImageVQVAE._from_config(config.vq_config)
961
+
962
+ # Per-sample caches for batch processing
963
+ self._cached_decode_position_ids = None # shape: [batch_size, 3, max_decode_len]
964
+ self._prefill_len = None # prefill sequence length (same for all samples in batch)
965
+
966
+ # Initialize weights and apply final processing
967
+ self.post_init()
968
+
969
+ def get_input_embeddings(self):
970
+ return self.language_model.get_input_embeddings()
971
+
972
+ def set_input_embeddings(self, value):
973
+ self.language_model.set_input_embeddings(value)
974
+
975
+ def get_rope_index(
976
+ self,
977
+ input_ids: torch.LongTensor | None = None,
978
+ image_grid_thw: torch.LongTensor | None = None,
979
+ images_per_sample: torch.LongTensor | None = None,
980
+ attention_mask: torch.LongTensor | None = None,
981
+ ) -> tuple[torch.Tensor, torch.Tensor]:
982
+ """
983
+ Calculate the 3D rope index for image generation task with full batch support.
984
+
985
+ Args:
986
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
987
+ Indices of input sequence tokens in the vocabulary.
988
+ image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
989
+ The temporal, height and width of feature shape of each image.
990
+ Images are packed across all samples in the batch.
991
+ images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
992
+ Number of images (including target grids) for each sample in the batch.
993
+ Used to split image_grid_thw by sample.
994
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
995
+ Mask to avoid performing attention on padding token indices.
996
+
997
+ Returns:
998
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`):
999
+ Position IDs for temporal, height, and width dimensions.
1000
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`):
1001
+ Position deltas for multi-modal rotary position embedding.
1002
+ """
1003
+ batch_size, seq_len = input_ids.shape
1004
+ device = input_ids.device
1005
+ dtype = input_ids.dtype
1006
+
1007
+ image_start_token_id = self.config.image_start_token_id
1008
+ image_end_token_id = self.config.image_end_token_id
1009
+
1010
+ position_ids = torch.ones(3, batch_size, seq_len, dtype=dtype, device=device)
1011
+ text_positions = torch.arange(seq_len, device=device)[None, :].repeat(3, 1)
1012
+
1013
+ # Split image_grid_thw by sample if images_per_sample is provided
1014
+ if image_grid_thw is not None and images_per_sample is not None:
1015
+ grids_per_sample = torch.split(image_grid_thw, images_per_sample.tolist())
1016
+ elif image_grid_thw is not None:
1017
+ # Fallback: assume all grids belong to first sample (batch_size=1)
1018
+ grids_per_sample = [image_grid_thw] * batch_size
1019
+ else:
1020
+ grids_per_sample = [None] * batch_size
1021
+
1022
+ # Per-sample caches for decode stage
1023
+ all_decode_position_ids = []
1024
+
1025
+ for batch_idx in range(batch_size):
1026
+ curr_input_ids = input_ids[batch_idx]
1027
+ curr_grids = grids_per_sample[batch_idx]
1028
+
1029
+ if attention_mask is not None and attention_mask.shape[1] == seq_len:
1030
+ valid_mask = attention_mask[batch_idx] == 1
1031
+ curr_input_ids_valid = curr_input_ids[valid_mask]
1032
+ else:
1033
+ # attention_mask may have different length during assisted decoding
1034
+ curr_input_ids_valid = curr_input_ids
1035
+ valid_mask = None
1036
+
1037
+ # Find image boundaries in this sample
1038
+ image_end_positions = torch.where(curr_input_ids_valid == image_end_token_id)[0]
1039
+ image_start_positions = torch.where(curr_input_ids_valid == image_start_token_id)[0] + 1
1040
+ num_complete_images = len(image_end_positions)
1041
+
1042
+ current_pos = 0
1043
+ prev_image_end = 0
1044
+ curr_position_ids = []
1045
+
1046
+ # Process complete images (source images in image-to-image task)
1047
+ for img_idx, (start, end) in enumerate(zip(image_start_positions, image_end_positions)):
1048
+ if curr_grids is None or img_idx >= len(curr_grids):
1049
+ break
1050
+ grid = curr_grids[img_idx]
1051
+ # grid format is [temporal, height, width]
1052
+ _, height, width = grid.tolist()
1053
+
1054
+ # Text tokens before this image
1055
+ llm_pos_length = start - prev_image_end
1056
+ llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(device=device)
1057
+ current_pos += llm_position_ids.shape[-1]
1058
+
1059
+ # Image tokens with 2D spatial encoding
1060
+ # For an image with height H and width W:
1061
+ # - position_width cycles [0, 1, ..., W-1] for each row, repeated H times
1062
+ # - position_height stays constant per row, [0]*W, [1]*W, ..., [H-1]*W
1063
+ image_seq_length = height * width
1064
+ position_width = torch.arange(current_pos, current_pos + width, device=device).repeat(height)
1065
+ position_height = torch.arange(current_pos, current_pos + height, device=device).repeat_interleave(
1066
+ width
1067
+ )
1068
+ position_temporal = torch.full((image_seq_length,), current_pos, device=device, dtype=torch.long)
1069
+ vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0)
1070
+ current_pos += max(height, width)
1071
+
1072
+ prev_image_end = end
1073
+ curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1))
1074
+
1075
+ # Remaining text tokens (including the final image_start token for generation)
1076
+ end_position = len(curr_input_ids_valid) - prev_image_end
1077
+ llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=device)
1078
+ current_pos += llm_position_ids.shape[-1]
1079
+ curr_position_ids.append(llm_position_ids)
1080
+
1081
+ # Concatenate all position ids for this sample
1082
+ curr_position_ids = torch.cat(curr_position_ids, dim=-1)
1083
+
1084
+ # Store in the main position_ids tensor
1085
+ if valid_mask is not None:
1086
+ position_ids[:, batch_idx, valid_mask] = curr_position_ids
1087
+ else:
1088
+ position_ids[:, batch_idx, :] = curr_position_ids
1089
+
1090
+ # Build decode position ids for this sample
1091
+ if curr_grids is not None and len(curr_grids) > 0:
1092
+ num_decode_grids = len(curr_grids) - num_complete_images
1093
+ num_decode_grids = max(num_decode_grids, 0)
1094
+ decode_pos = current_pos
1095
+
1096
+ decode_temporal_list = []
1097
+ decode_height_list = []
1098
+ decode_width_list = []
1099
+
1100
+ for i in range(1, num_decode_grids + 1):
1101
+ grid_idx = -i
1102
+ h = curr_grids[grid_idx, 1].item()
1103
+ w = curr_grids[grid_idx, 2].item()
1104
+ total_tokens = h * w
1105
+
1106
+ h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
1107
+ w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()
1108
+
1109
+ decode_temporal_list.append(
1110
+ torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long)
1111
+ )
1112
+ decode_height_list.append(decode_pos + h_indices)
1113
+ decode_width_list.append(decode_pos + w_indices)
1114
+ decode_pos = decode_pos + max(h, w)
1115
+
1116
+ # End marker
1117
+ decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1118
+ decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1119
+ decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1120
+
1121
+ sample_decode_pos_ids = torch.stack(
1122
+ [
1123
+ torch.cat(decode_temporal_list, dim=0),
1124
+ torch.cat(decode_height_list, dim=0),
1125
+ torch.cat(decode_width_list, dim=0),
1126
+ ],
1127
+ dim=0,
1128
+ )
1129
+ all_decode_position_ids.append(sample_decode_pos_ids)
1130
+
1131
+ # Store prefill length (same for all samples since input_ids is padded to same length)
1132
+ self._prefill_len = seq_len
1133
+
1134
+ # Pad decode position ids to same length and stack
1135
+ if all_decode_position_ids:
1136
+ max_decode_len = max(x.shape[1] for x in all_decode_position_ids)
1137
+ padded_decode_pos_ids = [
1138
+ F.pad(pos_ids, (0, max_decode_len - pos_ids.shape[1]), mode="replicate")
1139
+ for pos_ids in all_decode_position_ids
1140
+ ]
1141
+ self._cached_decode_position_ids = torch.stack(padded_decode_pos_ids, dim=0) # [batch, 3, max_decode_len]
1142
+ else:
1143
+ self._cached_decode_position_ids = None
1144
+
1145
+ mrope_position_deltas = torch.zeros([batch_size, 1], dtype=dtype, device=device)
1146
+
1147
+ return position_ids, mrope_position_deltas
1148
+
1149
+ @can_return_tuple
1150
+ @auto_docstring
1151
+ def get_image_features(
1152
+ self,
1153
+ pixel_values: torch.FloatTensor,
1154
+ image_grid_thw: torch.LongTensor | None = None,
1155
+ **kwargs: Unpack[TransformersKwargs],
1156
+ ) -> tuple | BaseModelOutputWithPooling:
1157
+ r"""
1158
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1159
+ The tensors corresponding to the input images.
1160
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1161
+ The temporal, height and width of feature shape of each image in LLM.
1162
+ """
1163
+ pixel_values = pixel_values.type(self.visual.dtype)
1164
+ vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs)
1165
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1166
+ image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes)
1167
+ vision_outputs.pooler_output = image_embeds
1168
+
1169
+ return vision_outputs
1170
+
1171
+ def get_placeholder_mask(
1172
+ self,
1173
+ input_ids: torch.LongTensor,
1174
+ image_ids: torch.LongTensor,
1175
+ ):
1176
+ """
1177
+ Replace image placeholder tokens in input_ids with actual image token ids from VQVAE.
1178
+
1179
+ Args:
1180
+ input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`):
1181
+ Input token ids with image placeholders.
1182
+ image_ids (`torch.LongTensor` of shape `(num_images, num_tokens_per_image)` or flattened):
1183
+ Discrete token indices from the VQVAE codebook.
1184
+
1185
+ Returns:
1186
+ special_image_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
1187
+ Mask indicating positions in input ids that will be replaced by actual image tokens.
1188
+ """
1189
+
1190
+ special_image_mask = input_ids == self.config.image_token_id
1191
+ n_placeholder_tokens = special_image_mask.sum().item()
1192
+ n_image_tokens = image_ids.shape[0]
1193
+
1194
+ if n_placeholder_tokens != n_image_tokens:
1195
+ raise ValueError(
1196
+ f"Number of image placeholder tokens ({n_placeholder_tokens}) does not match "
1197
+ f"number of image tokens from VQVAE ({n_image_tokens})"
1198
+ )
1199
+
1200
+ return special_image_mask
1201
+
1202
+ @auto_docstring
1203
+ @can_return_tuple
1204
+ def forward(
1205
+ self,
1206
+ input_ids: torch.LongTensor | None = None,
1207
+ attention_mask: torch.Tensor | None = None,
1208
+ position_ids: torch.LongTensor | None = None,
1209
+ past_key_values: Cache | None = None,
1210
+ inputs_embeds: torch.FloatTensor | None = None,
1211
+ pixel_values: torch.Tensor | None = None,
1212
+ image_grid_thw: torch.LongTensor | None = None,
1213
+ images_per_sample: torch.LongTensor | None = None,
1214
+ rope_deltas: torch.LongTensor | None = None,
1215
+ cache_position: torch.LongTensor | None = None,
1216
+ **kwargs: Unpack[TransformersKwargs],
1217
+ ) -> tuple | GlmImageModelOutputWithPast:
1218
+ r"""
1219
+ image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
1220
+ The temporal, height and width of feature shape of each image in LLM.
1221
+ Images are packed across all samples in the batch.
1222
+ images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1223
+ Number of images (including target grids) for each sample in the batch.
1224
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1225
+ The rope index difference between sequence length and multimodal rope.
1226
+ """
1227
+ if (input_ids is None) ^ (inputs_embeds is not None):
1228
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1229
+
1230
+ batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
1231
+
1232
+ if pixel_values is not None:
1233
+ # Process source images (image-to-image mode)
1234
+ # Source images are identified by counting image_end_token_id in input_ids
1235
+ # Note: We must exclude padding tokens since pad_token_id == image_end_token_id
1236
+ if images_per_sample is not None:
1237
+ grids_per_sample = torch.split(image_grid_thw, images_per_sample.tolist())
1238
+ # Create mask for non-padding tokens (attention_mask=1 means non-padding)
1239
+ # Handle 4D attention mask (from static cache) by extracting diagonal
1240
+ if attention_mask is not None and attention_mask.ndim == 4:
1241
+ non_pad_mask = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
1242
+ if non_pad_mask.dtype.is_floating_point:
1243
+ non_pad_mask = non_pad_mask / torch.finfo(non_pad_mask.dtype).min
1244
+ non_pad_mask = (1.0 - non_pad_mask).int()
1245
+ # Only keep columns matching input_ids length
1246
+ non_pad_mask = non_pad_mask[:, -input_ids.shape[1] :]
1247
+ else:
1248
+ non_pad_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
1249
+
1250
+ source_grids_list = []
1251
+ for sample_idx in range(batch_size):
1252
+ is_image_end = input_ids[sample_idx] == self.config.image_end_token_id
1253
+ is_non_pad = non_pad_mask[sample_idx] == 1
1254
+ num_source = (is_image_end & is_non_pad).sum().item()
1255
+ if num_source > 0:
1256
+ source_grids_list.append(grids_per_sample[sample_idx][:num_source])
1257
+ if len(source_grids_list) == 0:
1258
+ raise ValueError(
1259
+ "pixel_values provided but no source images found in input_ids. "
1260
+ "Ensure input_ids contains image_end_token_id for each source image."
1261
+ )
1262
+ source_grids = torch.cat(source_grids_list, dim=0)
1263
+ else:
1264
+ # Fallback for batch_size=1: all but last grid are source images
1265
+ source_grids = image_grid_thw[:-1]
1266
+
1267
+ image_features = self.get_image_features(pixel_values, source_grids, return_dict=True)
1268
+ image_embeds = torch.cat(image_features.pooler_output, dim=0)
1269
+ image_ids = self.get_image_tokens(image_embeds, source_grids)
1270
+ image_ids = image_ids.view(-1).to(input_ids.device)
1271
+ special_image_mask = self.get_placeholder_mask(input_ids, image_ids)
1272
+ input_ids = input_ids.masked_scatter(special_image_mask, image_ids)
1273
+
1274
+ if inputs_embeds is None:
1275
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1276
+
1277
+ if position_ids is None:
1278
+ attention_mask_2d = attention_mask
1279
+ if attention_mask is not None and attention_mask.ndim == 4:
1280
+ attention_mask_2d = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
1281
+ # Only apply conversion for floating point tensors (inverted masks)
1282
+ if attention_mask_2d.dtype.is_floating_point:
1283
+ attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
1284
+ attention_mask_2d = (1.0 - attention_mask_2d).int()
1285
+
1286
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1287
+ is_prefill_stage = (input_ids is not None and input_ids.shape[1] != 1) or (
1288
+ inputs_embeds is not None and inputs_embeds.shape[1] != 1
1289
+ )
1290
+ if is_prefill_stage or self.rope_deltas is None:
1291
+ position_ids, rope_deltas = self.get_rope_index(
1292
+ input_ids,
1293
+ image_grid_thw,
1294
+ images_per_sample=images_per_sample,
1295
+ attention_mask=attention_mask_2d,
1296
+ )
1297
+ self.rope_deltas = rope_deltas
1298
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1299
+ else:
1300
+ batch_size, seq_length, _ = inputs_embeds.shape
1301
+ # Per-sample decode position lookup
1302
+ # _cached_decode_position_ids shape: [batch_size, 3, max_decode_len]
1303
+ if self._cached_decode_position_ids is not None:
1304
+ step = cache_position[0].item() - self._prefill_len
1305
+ # Get position ids for all samples at once, then transpose to [3, batch_size, seq_length]
1306
+ position_ids = self._cached_decode_position_ids[:, :, step : step + seq_length].permute(1, 0, 2)
1307
+ else:
1308
+ # Fallback for text-to-image or cases without cached decode positions
1309
+ # Use simple incremental positions
1310
+ start_pos = cache_position[0].item()
1311
+ position_ids = torch.arange(
1312
+ start_pos, start_pos + seq_length, device=inputs_embeds.device, dtype=torch.long
1313
+ )
1314
+ position_ids = position_ids.unsqueeze(0).repeat(3, batch_size, 1)
1315
+
1316
+ outputs = self.language_model(
1317
+ input_ids=None,
1318
+ position_ids=position_ids,
1319
+ attention_mask=attention_mask,
1320
+ past_key_values=past_key_values,
1321
+ inputs_embeds=inputs_embeds,
1322
+ cache_position=cache_position,
1323
+ **kwargs,
1324
+ )
1325
+
1326
+ return GlmImageModelOutputWithPast(
1327
+ last_hidden_state=outputs.last_hidden_state,
1328
+ past_key_values=outputs.past_key_values,
1329
+ hidden_states=outputs.hidden_states,
1330
+ attentions=outputs.attentions,
1331
+ rope_deltas=self.rope_deltas,
1332
+ )
1333
+
1334
+ def get_image_tokens(
1335
+ self,
1336
+ hidden_states: torch.FloatTensor,
1337
+ image_grid_thw: torch.LongTensor,
1338
+ ) -> torch.LongTensor:
1339
+ """
1340
+ Tokenizes image features into discrete tokens with VQVAE module.
1341
+
1342
+ Args:
1343
+ hidden_states (`torch.FloatTensor` of shape `(total_patches, hidden_size)`):
1344
+ The packed image features from vision encoder.
1345
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
1346
+ The temporal, height and width of feature shape of each image.
1347
+
1348
+ Returns:
1349
+ image_tokens (`torch.LongTensor` of shape `(total_patches,)`):
1350
+ Discrete token indices from the VQVAE codebook.
1351
+ """
1352
+ hidden_size = hidden_states.shape[-1]
1353
+ split_sizes = (image_grid_thw.prod(dim=-1)).tolist()
1354
+ hidden_states_list = torch.split(hidden_states, split_sizes, dim=0)
1355
+
1356
+ all_image_toks = []
1357
+ for i, hs in enumerate(hidden_states_list):
1358
+ grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
1359
+ hs = hs.view(grid_t, grid_h, grid_w, hidden_size)
1360
+ hs = hs.permute(0, 3, 1, 2).contiguous()
1361
+ vqmodel_outputs: GlmImageVQVAEModelOutput = self.vqmodel.encode(hs)
1362
+ all_image_toks.append(vqmodel_outputs.image_tokens)
1363
+ return torch.cat(all_image_toks, dim=0)
1364
+
1365
+
1366
+ @dataclass
1367
+ @auto_docstring(
1368
+ custom_intro="""
1369
+ Base class for GlmImage causal language model (or autoregressive) outputs.
1370
+ """
1371
+ )
1372
+ class GlmImageCausalLMOutputWithPast(ModelOutput):
1373
+ r"""
1374
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1375
+ Language modeling loss (for next-token prediction).
1376
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1377
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1378
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1379
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
1380
+
1381
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1382
+ `past_key_values` input) to speed up sequential decoding.
1383
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1384
+ The rope index difference between sequence length and multimodal rope.
1385
+ """
1386
+
1387
+ loss: torch.FloatTensor | None = None
1388
+ logits: torch.FloatTensor | None = None
1389
+ past_key_values: Cache | None = None
1390
+ hidden_states: tuple[torch.FloatTensor] | None = None
1391
+ attentions: tuple[torch.FloatTensor] | None = None
1392
+ rope_deltas: torch.LongTensor | None = None
1393
+
1394
+
1395
+ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin):
1396
+ _checkpoint_conversion_mapping = {}
1397
+ _tied_weights_keys = {}
1398
+ # Reference: fix gemma3 grad acc #37208
1399
+ accepts_loss_kwargs = False
1400
+ base_model_prefix = "model"
1401
+ config: GlmImageConfig
1402
+
1403
+ def __init__(self, config):
1404
+ super().__init__(config)
1405
+ self.model = GlmImageModel(config)
1406
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vision_vocab_size, bias=False)
1407
+
1408
+ # Initialize weights and apply final processing
1409
+ self.post_init()
1410
+
1411
+ @auto_docstring
1412
+ def get_image_features(
1413
+ self,
1414
+ pixel_values: torch.FloatTensor,
1415
+ image_grid_thw: torch.LongTensor | None = None,
1416
+ **kwargs: Unpack[TransformersKwargs],
1417
+ ) -> tuple | BaseModelOutputWithPooling:
1418
+ r"""
1419
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1420
+ The tensors corresponding to the input images.
1421
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1422
+ The temporal, height and width of feature shape of each image in LLM.
1423
+ """
1424
+ return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs)
1425
+
1426
+ def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
1427
+ return self.model.get_image_tokens(hidden_states, image_grid_thw)
1428
+
1429
+ def forward(
1430
+ self,
1431
+ input_ids: torch.LongTensor | None = None,
1432
+ attention_mask: torch.Tensor | None = None,
1433
+ position_ids: torch.LongTensor | None = None,
1434
+ past_key_values: Cache | None = None,
1435
+ inputs_embeds: torch.FloatTensor | None = None,
1436
+ labels: torch.LongTensor | None = None,
1437
+ pixel_values: torch.Tensor | None = None,
1438
+ image_grid_thw: torch.LongTensor | None = None,
1439
+ images_per_sample: torch.LongTensor | None = None,
1440
+ cache_position: torch.LongTensor | None = None,
1441
+ logits_to_keep: int | torch.Tensor = 0,
1442
+ **kwargs: Unpack[TransformersKwargs],
1443
+ ) -> tuple | GlmImageCausalLMOutputWithPast:
1444
+ r"""
1445
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1446
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1447
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1448
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1449
+ image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
1450
+ The temporal, height and width of feature shape of each image in LLM.
1451
+ Images are packed across all samples in the batch.
1452
+ images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1453
+ Number of images (including target grids) for each sample in the batch.
1454
+
1455
+ Example:
1456
+
1457
+ ```python
1458
+ >>> from PIL import Image
1459
+ >>> import httpx
1460
+ >>> from io import BytesIO
1461
+ >>> from transformers import AutoProcessor, GlmImageForConditionalGeneration
1462
+
1463
+ >>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image")
1464
+ >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-Image")
1465
+
1466
+ >>> messages = [
1467
+ {
1468
+ "role": "user",
1469
+ "content": [
1470
+ {"type": "image"},
1471
+ {"type": "text", "text": "Add a truck of this photo.<sop>28 40<eop>"},
1472
+ ],
1473
+ },
1474
+ ]
1475
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1476
+ >>> with httpx.stream("GET", url) as response:
1477
+ ... image = Image.open(BytesIO(response.read()))
1478
+
1479
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1480
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
1481
+
1482
+ >>> # Generate
1483
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1484
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1485
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
1486
+ ```"""
1487
+ outputs = self.model(
1488
+ input_ids=input_ids,
1489
+ pixel_values=pixel_values,
1490
+ image_grid_thw=image_grid_thw,
1491
+ images_per_sample=images_per_sample,
1492
+ position_ids=position_ids,
1493
+ attention_mask=attention_mask,
1494
+ past_key_values=past_key_values,
1495
+ inputs_embeds=inputs_embeds,
1496
+ cache_position=cache_position,
1497
+ **kwargs,
1498
+ )
1499
+
1500
+ hidden_states = outputs[0]
1501
+
1502
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1503
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1504
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1505
+
1506
+ loss = None
1507
+ if labels is not None:
1508
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
1509
+
1510
+ return GlmImageCausalLMOutputWithPast(
1511
+ loss=loss,
1512
+ logits=logits,
1513
+ past_key_values=outputs.past_key_values,
1514
+ hidden_states=outputs.hidden_states,
1515
+ attentions=outputs.attentions,
1516
+ rope_deltas=outputs.rope_deltas,
1517
+ )
1518
+
1519
+ def prepare_inputs_for_generation(
1520
+ self,
1521
+ input_ids,
1522
+ past_key_values=None,
1523
+ attention_mask=None,
1524
+ inputs_embeds=None,
1525
+ cache_position=None,
1526
+ position_ids=None,
1527
+ use_cache=True,
1528
+ pixel_values=None,
1529
+ image_grid_thw=None,
1530
+ images_per_sample=None,
1531
+ is_first_iteration=False,
1532
+ **kwargs,
1533
+ ):
1534
+ model_inputs = super().prepare_inputs_for_generation(
1535
+ input_ids,
1536
+ past_key_values=past_key_values,
1537
+ attention_mask=attention_mask,
1538
+ inputs_embeds=inputs_embeds,
1539
+ cache_position=cache_position,
1540
+ position_ids=position_ids,
1541
+ pixel_values=pixel_values,
1542
+ image_grid_thw=image_grid_thw,
1543
+ is_first_iteration=is_first_iteration,
1544
+ use_cache=use_cache,
1545
+ **kwargs,
1546
+ )
1547
+
1548
+ model_inputs["position_ids"] = None
1549
+ model_inputs["images_per_sample"] = images_per_sample
1550
+
1551
+ if not is_first_iteration and use_cache:
1552
+ model_inputs["pixel_values"] = None
1553
+
1554
+ return model_inputs
1555
+
1556
+ def _get_image_nums(
1557
+ self,
1558
+ input_ids: torch.LongTensor | None,
1559
+ ) -> torch.Tensor:
1560
+ """
1561
+ Get the number of images for each sample.
1562
+ For GLM-Image, only input_ids allow us to get the number of images.
1563
+
1564
+ Returns:
1565
+ image_counts (`torch.LongTensor` of shape `(batch_size,)`)
1566
+ """
1567
+ is_image = input_ids == self.config.image_start_token_id
1568
+
1569
+ return is_image.sum(dim=1)
1570
+
1571
+ def _expand_inputs_for_generation(
1572
+ self,
1573
+ expand_size: int = 1,
1574
+ is_encoder_decoder: bool = False,
1575
+ input_ids: torch.LongTensor | None = None,
1576
+ **model_kwargs,
1577
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1578
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1579
+ # e.g., pixel_values, image_grid_thw
1580
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1581
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1582
+
1583
+ if expand_size == 1:
1584
+ return input_ids, model_kwargs
1585
+
1586
+ visual_keys = ["pixel_values", "image_grid_thw", "images_per_sample"]
1587
+
1588
+ def _expand_dict_for_generation_visual(dict_to_expand):
1589
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1590
+ if image_grid_thw is None:
1591
+ return dict_to_expand
1592
+
1593
+ images_per_sample = model_kwargs.get("images_per_sample", None)
1594
+
1595
+ # Use images_per_sample if available
1596
+ if images_per_sample is not None:
1597
+ image_nums = images_per_sample.tolist()
1598
+ elif input_ids is not None:
1599
+ # Try to infer from image_grid_thw / batch_size
1600
+ batch_size = input_ids.shape[0]
1601
+ total_grids = image_grid_thw.shape[0]
1602
+ if total_grids % batch_size == 0:
1603
+ grids_per_sample = total_grids // batch_size
1604
+ image_nums = [grids_per_sample] * batch_size
1605
+ else:
1606
+ # Cannot evenly distribute grids - fall back to simple repeat_interleave
1607
+ # This handles test cases where image_grid_thw has (batch_size + 1) rows
1608
+ dict_to_expand["image_grid_thw"] = image_grid_thw.repeat_interleave(expand_size, dim=0)
1609
+ if dict_to_expand.get("pixel_values") is not None:
1610
+ dict_to_expand["pixel_values"] = dict_to_expand["pixel_values"].repeat_interleave(
1611
+ expand_size, dim=0
1612
+ )
1613
+ return dict_to_expand
1614
+ else:
1615
+ image_nums = self._get_image_nums(input_ids).tolist()
1616
+
1617
+ # Get source image counts per sample from image_end_token_id count
1618
+ source_image_nums = [
1619
+ (input_ids[batch_idx] == self.config.image_end_token_id).sum().item()
1620
+ for batch_idx in range(len(image_nums))
1621
+ ]
1622
+
1623
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1624
+ samples = torch.split(x, lengths)
1625
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1626
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1627
+ return result
1628
+
1629
+ for key in dict_to_expand:
1630
+ if key == "pixel_values":
1631
+ # Split images into samples based on source image counts
1632
+ if sum(source_image_nums) > 0:
1633
+ # Split grids by sample to compute pixel counts
1634
+ grids_per_sample = torch.split(image_grid_thw, image_nums)
1635
+ lengths = []
1636
+ for batch_idx, sample_grids in enumerate(grids_per_sample):
1637
+ num_source = source_image_nums[batch_idx]
1638
+ if num_source > 0:
1639
+ source_grids = sample_grids[:num_source]
1640
+ lengths.append(torch.prod(source_grids, dim=1).sum().item())
1641
+ else:
1642
+ lengths.append(0)
1643
+
1644
+ dict_to_expand[key] = _repeat_interleave_samples(
1645
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1646
+ )
1647
+ elif key == "image_grid_thw":
1648
+ # Expand all grids (source + target) per sample
1649
+ dict_to_expand[key] = _repeat_interleave_samples(
1650
+ dict_to_expand[key], lengths=image_nums, repeat_times=expand_size
1651
+ )
1652
+ elif key == "images_per_sample":
1653
+ # Simply repeat the counts
1654
+ if dict_to_expand.get(key) is not None:
1655
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1656
+ return dict_to_expand
1657
+
1658
+ def _expand_dict_for_generation(dict_to_expand):
1659
+ for key in dict_to_expand:
1660
+ if (
1661
+ key != "cache_position"
1662
+ and dict_to_expand[key] is not None
1663
+ and isinstance(dict_to_expand[key], torch.Tensor)
1664
+ and key not in visual_keys
1665
+ ):
1666
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1667
+ return dict_to_expand
1668
+
1669
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1670
+
1671
+ if input_ids is not None:
1672
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1673
+
1674
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1675
+
1676
+ if is_encoder_decoder:
1677
+ if model_kwargs.get("encoder_outputs") is None:
1678
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1679
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1680
+
1681
+ return input_ids, model_kwargs
1682
+
1683
+
1684
+ __all__ = [
1685
+ "GlmImagePreTrainedModel",
1686
+ "GlmImageVQVAE",
1687
+ "GlmImageVisionModel",
1688
+ "GlmImageTextModel",
1689
+ "GlmImageModel",
1690
+ "GlmImageForConditionalGeneration",
1691
+ ]