transformers 5.0.0rc2__py3-none-any.whl → 5.0.0rc3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1537) hide show
  1. transformers/__init__.py +9 -28
  2. transformers/audio_utils.py +32 -32
  3. transformers/cache_utils.py +15 -124
  4. transformers/cli/chat.py +3 -3
  5. transformers/cli/serve.py +2 -2
  6. transformers/cli/transformers.py +2 -1
  7. transformers/configuration_utils.py +31 -33
  8. transformers/conversion_mapping.py +5 -1
  9. transformers/convert_slow_tokenizer.py +3 -8
  10. transformers/core_model_loading.py +14 -15
  11. transformers/data/processors/glue.py +0 -1
  12. transformers/data/processors/utils.py +0 -1
  13. transformers/data/processors/xnli.py +0 -1
  14. transformers/dependency_versions_table.py +4 -4
  15. transformers/distributed/configuration_utils.py +1 -2
  16. transformers/dynamic_module_utils.py +23 -23
  17. transformers/feature_extraction_sequence_utils.py +19 -23
  18. transformers/feature_extraction_utils.py +14 -14
  19. transformers/generation/candidate_generator.py +1 -2
  20. transformers/generation/configuration_utils.py +54 -39
  21. transformers/generation/continuous_batching/__init__.py +0 -1
  22. transformers/generation/continuous_batching/cache.py +34 -6
  23. transformers/generation/continuous_batching/cache_manager.py +25 -12
  24. transformers/generation/continuous_batching/continuous_api.py +54 -23
  25. transformers/generation/continuous_batching/requests.py +25 -4
  26. transformers/generation/continuous_batching/scheduler.py +117 -49
  27. transformers/generation/logits_process.py +0 -128
  28. transformers/generation/streamers.py +0 -1
  29. transformers/generation/utils.py +16 -26
  30. transformers/generation/watermarking.py +2 -3
  31. transformers/hf_argparser.py +9 -13
  32. transformers/hyperparameter_search.py +1 -2
  33. transformers/image_processing_base.py +9 -9
  34. transformers/image_processing_utils.py +11 -12
  35. transformers/image_processing_utils_fast.py +53 -53
  36. transformers/image_transforms.py +29 -29
  37. transformers/image_utils.py +30 -32
  38. transformers/integrations/awq.py +1 -3
  39. transformers/integrations/deepspeed.py +1 -1
  40. transformers/integrations/eetq.py +0 -1
  41. transformers/integrations/fbgemm_fp8.py +1 -2
  42. transformers/integrations/finegrained_fp8.py +8 -7
  43. transformers/integrations/flash_attention.py +1 -1
  44. transformers/integrations/flex_attention.py +1 -1
  45. transformers/integrations/fp_quant.py +4 -6
  46. transformers/integrations/ggml.py +0 -1
  47. transformers/integrations/integration_utils.py +2 -3
  48. transformers/integrations/mxfp4.py +5 -6
  49. transformers/integrations/quark.py +2 -4
  50. transformers/integrations/torchao.py +4 -6
  51. transformers/loss/loss_lw_detr.py +356 -0
  52. transformers/loss/loss_utils.py +2 -0
  53. transformers/masking_utils.py +47 -51
  54. transformers/model_debugging_utils.py +4 -5
  55. transformers/modelcard.py +14 -192
  56. transformers/modeling_attn_mask_utils.py +19 -19
  57. transformers/modeling_flash_attention_utils.py +27 -27
  58. transformers/modeling_gguf_pytorch_utils.py +5 -5
  59. transformers/modeling_layers.py +21 -22
  60. transformers/modeling_outputs.py +242 -253
  61. transformers/modeling_rope_utils.py +32 -32
  62. transformers/modeling_utils.py +67 -90
  63. transformers/models/__init__.py +4 -0
  64. transformers/models/afmoe/configuration_afmoe.py +26 -29
  65. transformers/models/afmoe/modeling_afmoe.py +30 -33
  66. transformers/models/afmoe/modular_afmoe.py +16 -18
  67. transformers/models/aimv2/configuration_aimv2.py +2 -5
  68. transformers/models/aimv2/modeling_aimv2.py +20 -21
  69. transformers/models/aimv2/modular_aimv2.py +7 -9
  70. transformers/models/albert/configuration_albert.py +0 -1
  71. transformers/models/albert/modeling_albert.py +67 -69
  72. transformers/models/albert/tokenization_albert.py +1 -4
  73. transformers/models/align/configuration_align.py +0 -1
  74. transformers/models/align/modeling_align.py +61 -62
  75. transformers/models/align/processing_align.py +2 -30
  76. transformers/models/altclip/configuration_altclip.py +0 -1
  77. transformers/models/altclip/modeling_altclip.py +76 -77
  78. transformers/models/altclip/processing_altclip.py +2 -15
  79. transformers/models/apertus/__init__.py +0 -1
  80. transformers/models/apertus/configuration_apertus.py +18 -21
  81. transformers/models/apertus/modeling_apertus.py +31 -34
  82. transformers/models/apertus/modular_apertus.py +28 -30
  83. transformers/models/arcee/configuration_arcee.py +20 -23
  84. transformers/models/arcee/modeling_arcee.py +31 -34
  85. transformers/models/arcee/modular_arcee.py +20 -23
  86. transformers/models/aria/configuration_aria.py +20 -23
  87. transformers/models/aria/image_processing_aria.py +25 -27
  88. transformers/models/aria/modeling_aria.py +63 -66
  89. transformers/models/aria/modular_aria.py +78 -85
  90. transformers/models/aria/processing_aria.py +28 -35
  91. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  92. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  93. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +6 -8
  94. transformers/models/audioflamingo3/__init__.py +0 -1
  95. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  96. transformers/models/audioflamingo3/modeling_audioflamingo3.py +22 -23
  97. transformers/models/audioflamingo3/modular_audioflamingo3.py +12 -17
  98. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  99. transformers/models/auto/auto_factory.py +4 -5
  100. transformers/models/auto/configuration_auto.py +26 -5
  101. transformers/models/auto/feature_extraction_auto.py +5 -7
  102. transformers/models/auto/image_processing_auto.py +13 -26
  103. transformers/models/auto/modeling_auto.py +18 -199
  104. transformers/models/auto/processing_auto.py +2 -1
  105. transformers/models/auto/tokenization_auto.py +21 -22
  106. transformers/models/auto/video_processing_auto.py +7 -8
  107. transformers/models/autoformer/configuration_autoformer.py +4 -7
  108. transformers/models/autoformer/modeling_autoformer.py +98 -100
  109. transformers/models/aya_vision/configuration_aya_vision.py +0 -1
  110. transformers/models/aya_vision/modeling_aya_vision.py +35 -37
  111. transformers/models/aya_vision/modular_aya_vision.py +26 -29
  112. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  113. transformers/models/bamba/configuration_bamba.py +29 -32
  114. transformers/models/bamba/modeling_bamba.py +60 -64
  115. transformers/models/bamba/modular_bamba.py +51 -55
  116. transformers/models/bark/configuration_bark.py +4 -7
  117. transformers/models/bark/generation_configuration_bark.py +3 -5
  118. transformers/models/bark/modeling_bark.py +40 -55
  119. transformers/models/bark/processing_bark.py +19 -41
  120. transformers/models/bart/configuration_bart.py +0 -1
  121. transformers/models/bart/modeling_bart.py +115 -117
  122. transformers/models/barthez/tokenization_barthez.py +1 -4
  123. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  124. transformers/models/beit/configuration_beit.py +0 -11
  125. transformers/models/beit/image_processing_beit.py +53 -56
  126. transformers/models/beit/image_processing_beit_fast.py +8 -9
  127. transformers/models/beit/modeling_beit.py +51 -53
  128. transformers/models/bert/configuration_bert.py +0 -1
  129. transformers/models/bert/modeling_bert.py +111 -122
  130. transformers/models/bert/tokenization_bert.py +2 -4
  131. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  132. transformers/models/bert_generation/configuration_bert_generation.py +0 -1
  133. transformers/models/bert_generation/modeling_bert_generation.py +47 -49
  134. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  135. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  136. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  137. transformers/models/big_bird/configuration_big_bird.py +0 -1
  138. transformers/models/big_bird/modeling_big_bird.py +107 -109
  139. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  140. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +0 -1
  141. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +109 -111
  142. transformers/models/biogpt/configuration_biogpt.py +0 -1
  143. transformers/models/biogpt/modeling_biogpt.py +69 -71
  144. transformers/models/biogpt/modular_biogpt.py +59 -61
  145. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  146. transformers/models/bit/configuration_bit.py +0 -1
  147. transformers/models/bit/image_processing_bit.py +21 -24
  148. transformers/models/bit/image_processing_bit_fast.py +0 -1
  149. transformers/models/bit/modeling_bit.py +9 -11
  150. transformers/models/bitnet/configuration_bitnet.py +18 -21
  151. transformers/models/bitnet/modeling_bitnet.py +31 -34
  152. transformers/models/bitnet/modular_bitnet.py +4 -6
  153. transformers/models/blenderbot/configuration_blenderbot.py +0 -1
  154. transformers/models/blenderbot/modeling_blenderbot.py +64 -95
  155. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  156. transformers/models/blenderbot_small/configuration_blenderbot_small.py +0 -1
  157. transformers/models/blenderbot_small/modeling_blenderbot_small.py +66 -68
  158. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  159. transformers/models/blip/configuration_blip.py +0 -1
  160. transformers/models/blip/image_processing_blip.py +17 -20
  161. transformers/models/blip/image_processing_blip_fast.py +0 -1
  162. transformers/models/blip/modeling_blip.py +60 -71
  163. transformers/models/blip/modeling_blip_text.py +63 -65
  164. transformers/models/blip/processing_blip.py +5 -36
  165. transformers/models/blip_2/configuration_blip_2.py +0 -1
  166. transformers/models/blip_2/modeling_blip_2.py +70 -71
  167. transformers/models/blip_2/processing_blip_2.py +8 -38
  168. transformers/models/bloom/configuration_bloom.py +0 -1
  169. transformers/models/bloom/modeling_bloom.py +58 -59
  170. transformers/models/blt/configuration_blt.py +71 -74
  171. transformers/models/blt/modeling_blt.py +73 -76
  172. transformers/models/blt/modular_blt.py +57 -59
  173. transformers/models/bridgetower/configuration_bridgetower.py +0 -1
  174. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  175. transformers/models/bridgetower/image_processing_bridgetower_fast.py +7 -8
  176. transformers/models/bridgetower/modeling_bridgetower.py +107 -109
  177. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  178. transformers/models/bros/configuration_bros.py +0 -1
  179. transformers/models/bros/modeling_bros.py +78 -80
  180. transformers/models/bros/processing_bros.py +2 -12
  181. transformers/models/byt5/tokenization_byt5.py +4 -6
  182. transformers/models/camembert/configuration_camembert.py +0 -1
  183. transformers/models/camembert/modeling_camembert.py +91 -93
  184. transformers/models/camembert/modular_camembert.py +51 -54
  185. transformers/models/camembert/tokenization_camembert.py +1 -4
  186. transformers/models/canine/configuration_canine.py +0 -1
  187. transformers/models/canine/modeling_canine.py +73 -75
  188. transformers/models/canine/tokenization_canine.py +0 -1
  189. transformers/models/chameleon/configuration_chameleon.py +24 -27
  190. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  191. transformers/models/chameleon/image_processing_chameleon_fast.py +0 -1
  192. transformers/models/chameleon/modeling_chameleon.py +53 -56
  193. transformers/models/chameleon/processing_chameleon.py +16 -41
  194. transformers/models/chinese_clip/configuration_chinese_clip.py +0 -1
  195. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  196. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  197. transformers/models/chinese_clip/modeling_chinese_clip.py +65 -66
  198. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  199. transformers/models/clap/configuration_clap.py +0 -1
  200. transformers/models/clap/feature_extraction_clap.py +9 -10
  201. transformers/models/clap/modeling_clap.py +88 -89
  202. transformers/models/clap/processing_clap.py +2 -15
  203. transformers/models/clip/configuration_clip.py +0 -1
  204. transformers/models/clip/image_processing_clip.py +21 -24
  205. transformers/models/clip/image_processing_clip_fast.py +0 -1
  206. transformers/models/clip/modeling_clip.py +45 -46
  207. transformers/models/clip/processing_clip.py +2 -14
  208. transformers/models/clip/tokenization_clip.py +2 -5
  209. transformers/models/clipseg/configuration_clipseg.py +0 -1
  210. transformers/models/clipseg/modeling_clipseg.py +86 -87
  211. transformers/models/clipseg/processing_clipseg.py +8 -39
  212. transformers/models/clvp/configuration_clvp.py +1 -3
  213. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  214. transformers/models/clvp/modeling_clvp.py +119 -115
  215. transformers/models/clvp/number_normalizer.py +1 -2
  216. transformers/models/clvp/processing_clvp.py +3 -20
  217. transformers/models/clvp/tokenization_clvp.py +0 -1
  218. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  219. transformers/models/codegen/configuration_codegen.py +0 -1
  220. transformers/models/codegen/modeling_codegen.py +48 -48
  221. transformers/models/codegen/tokenization_codegen.py +5 -6
  222. transformers/models/cohere/configuration_cohere.py +20 -23
  223. transformers/models/cohere/modeling_cohere.py +35 -38
  224. transformers/models/cohere/modular_cohere.py +24 -28
  225. transformers/models/cohere/tokenization_cohere.py +5 -6
  226. transformers/models/cohere2/configuration_cohere2.py +21 -24
  227. transformers/models/cohere2/modeling_cohere2.py +34 -37
  228. transformers/models/cohere2/modular_cohere2.py +39 -41
  229. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +6 -7
  230. transformers/models/cohere2_vision/modeling_cohere2_vision.py +28 -30
  231. transformers/models/cohere2_vision/modular_cohere2_vision.py +21 -23
  232. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  233. transformers/models/colpali/configuration_colpali.py +0 -1
  234. transformers/models/colpali/modeling_colpali.py +14 -16
  235. transformers/models/colpali/modular_colpali.py +11 -51
  236. transformers/models/colpali/processing_colpali.py +14 -52
  237. transformers/models/colqwen2/modeling_colqwen2.py +20 -22
  238. transformers/models/colqwen2/modular_colqwen2.py +29 -68
  239. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  240. transformers/models/conditional_detr/configuration_conditional_detr.py +0 -1
  241. transformers/models/conditional_detr/image_processing_conditional_detr.py +64 -66
  242. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +22 -22
  243. transformers/models/conditional_detr/modeling_conditional_detr.py +78 -80
  244. transformers/models/conditional_detr/modular_conditional_detr.py +1 -3
  245. transformers/models/convbert/configuration_convbert.py +0 -1
  246. transformers/models/convbert/modeling_convbert.py +85 -87
  247. transformers/models/convbert/tokenization_convbert.py +0 -1
  248. transformers/models/convnext/configuration_convnext.py +0 -1
  249. transformers/models/convnext/image_processing_convnext.py +18 -21
  250. transformers/models/convnext/image_processing_convnext_fast.py +5 -6
  251. transformers/models/convnext/modeling_convnext.py +5 -8
  252. transformers/models/convnextv2/configuration_convnextv2.py +0 -1
  253. transformers/models/convnextv2/modeling_convnextv2.py +5 -8
  254. transformers/models/cpm/tokenization_cpm.py +6 -7
  255. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  256. transformers/models/cpmant/configuration_cpmant.py +0 -1
  257. transformers/models/cpmant/modeling_cpmant.py +38 -40
  258. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  259. transformers/models/csm/configuration_csm.py +49 -51
  260. transformers/models/csm/generation_csm.py +13 -14
  261. transformers/models/csm/modeling_csm.py +78 -81
  262. transformers/models/csm/modular_csm.py +56 -58
  263. transformers/models/csm/processing_csm.py +25 -68
  264. transformers/models/ctrl/configuration_ctrl.py +0 -1
  265. transformers/models/ctrl/modeling_ctrl.py +38 -41
  266. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  267. transformers/models/cvt/configuration_cvt.py +0 -1
  268. transformers/models/cvt/modeling_cvt.py +13 -15
  269. transformers/models/cwm/__init__.py +0 -1
  270. transformers/models/cwm/configuration_cwm.py +3 -5
  271. transformers/models/cwm/modeling_cwm.py +32 -34
  272. transformers/models/cwm/modular_cwm.py +10 -12
  273. transformers/models/d_fine/configuration_d_fine.py +0 -1
  274. transformers/models/d_fine/modeling_d_fine.py +81 -82
  275. transformers/models/d_fine/modular_d_fine.py +8 -9
  276. transformers/models/dab_detr/configuration_dab_detr.py +0 -1
  277. transformers/models/dab_detr/modeling_dab_detr.py +68 -70
  278. transformers/models/dac/configuration_dac.py +0 -1
  279. transformers/models/dac/feature_extraction_dac.py +6 -9
  280. transformers/models/dac/modeling_dac.py +21 -23
  281. transformers/models/data2vec/configuration_data2vec_audio.py +0 -1
  282. transformers/models/data2vec/configuration_data2vec_text.py +0 -1
  283. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  284. transformers/models/data2vec/modeling_data2vec_audio.py +52 -56
  285. transformers/models/data2vec/modeling_data2vec_text.py +91 -93
  286. transformers/models/data2vec/modeling_data2vec_vision.py +41 -42
  287. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  288. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  289. transformers/models/dbrx/configuration_dbrx.py +18 -19
  290. transformers/models/dbrx/modeling_dbrx.py +39 -42
  291. transformers/models/dbrx/modular_dbrx.py +31 -33
  292. transformers/models/deberta/configuration_deberta.py +0 -1
  293. transformers/models/deberta/modeling_deberta.py +57 -60
  294. transformers/models/deberta/tokenization_deberta.py +2 -5
  295. transformers/models/deberta_v2/configuration_deberta_v2.py +0 -1
  296. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  297. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  298. transformers/models/decision_transformer/configuration_decision_transformer.py +0 -1
  299. transformers/models/decision_transformer/modeling_decision_transformer.py +48 -50
  300. transformers/models/deepseek_v2/configuration_deepseek_v2.py +34 -37
  301. transformers/models/deepseek_v2/modeling_deepseek_v2.py +32 -33
  302. transformers/models/deepseek_v2/modular_deepseek_v2.py +40 -42
  303. transformers/models/deepseek_v3/configuration_deepseek_v3.py +35 -38
  304. transformers/models/deepseek_v3/modeling_deepseek_v3.py +31 -33
  305. transformers/models/deepseek_v3/modular_deepseek_v3.py +4 -5
  306. transformers/models/deepseek_vl/configuration_deepseek_vl.py +2 -3
  307. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +25 -26
  308. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +7 -6
  309. transformers/models/deepseek_vl/modeling_deepseek_vl.py +31 -31
  310. transformers/models/deepseek_vl/modular_deepseek_vl.py +11 -43
  311. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  312. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +3 -5
  313. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +35 -35
  314. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +16 -16
  315. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +33 -33
  316. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +71 -90
  317. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  318. transformers/models/deformable_detr/configuration_deformable_detr.py +0 -1
  319. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  320. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +17 -17
  321. transformers/models/deformable_detr/modeling_deformable_detr.py +66 -67
  322. transformers/models/deformable_detr/modular_deformable_detr.py +1 -3
  323. transformers/models/deit/configuration_deit.py +0 -1
  324. transformers/models/deit/image_processing_deit.py +18 -21
  325. transformers/models/deit/image_processing_deit_fast.py +0 -1
  326. transformers/models/deit/modeling_deit.py +16 -18
  327. transformers/models/depth_anything/configuration_depth_anything.py +0 -1
  328. transformers/models/depth_anything/modeling_depth_anything.py +5 -8
  329. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  330. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  331. transformers/models/depth_pro/image_processing_depth_pro_fast.py +6 -7
  332. transformers/models/depth_pro/modeling_depth_pro.py +21 -23
  333. transformers/models/detr/configuration_detr.py +0 -1
  334. transformers/models/detr/image_processing_detr.py +64 -66
  335. transformers/models/detr/image_processing_detr_fast.py +22 -23
  336. transformers/models/detr/modeling_detr.py +70 -72
  337. transformers/models/dia/configuration_dia.py +5 -8
  338. transformers/models/dia/feature_extraction_dia.py +6 -9
  339. transformers/models/dia/generation_dia.py +40 -36
  340. transformers/models/dia/modeling_dia.py +61 -64
  341. transformers/models/dia/modular_dia.py +52 -54
  342. transformers/models/dia/processing_dia.py +39 -29
  343. transformers/models/dia/tokenization_dia.py +3 -6
  344. transformers/models/diffllama/configuration_diffllama.py +20 -23
  345. transformers/models/diffllama/modeling_diffllama.py +42 -45
  346. transformers/models/diffllama/modular_diffllama.py +16 -18
  347. transformers/models/dinat/configuration_dinat.py +0 -1
  348. transformers/models/dinat/modeling_dinat.py +40 -42
  349. transformers/models/dinov2/configuration_dinov2.py +0 -1
  350. transformers/models/dinov2/modeling_dinov2.py +11 -13
  351. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +1 -1
  352. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +12 -13
  353. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +5 -7
  354. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +4 -7
  355. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +3 -6
  356. transformers/models/dinov3_vit/configuration_dinov3_vit.py +5 -8
  357. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +5 -6
  358. transformers/models/dinov3_vit/modeling_dinov3_vit.py +14 -16
  359. transformers/models/dinov3_vit/modular_dinov3_vit.py +11 -13
  360. transformers/models/distilbert/configuration_distilbert.py +0 -1
  361. transformers/models/distilbert/modeling_distilbert.py +44 -46
  362. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  363. transformers/models/doge/__init__.py +0 -1
  364. transformers/models/doge/configuration_doge.py +25 -28
  365. transformers/models/doge/modeling_doge.py +42 -45
  366. transformers/models/doge/modular_doge.py +57 -58
  367. transformers/models/donut/configuration_donut_swin.py +0 -1
  368. transformers/models/donut/image_processing_donut.py +26 -29
  369. transformers/models/donut/image_processing_donut_fast.py +5 -10
  370. transformers/models/donut/modeling_donut_swin.py +44 -46
  371. transformers/models/donut/processing_donut.py +5 -26
  372. transformers/models/dots1/configuration_dots1.py +27 -29
  373. transformers/models/dots1/modeling_dots1.py +31 -34
  374. transformers/models/dots1/modular_dots1.py +0 -1
  375. transformers/models/dpr/configuration_dpr.py +0 -1
  376. transformers/models/dpr/modeling_dpr.py +37 -39
  377. transformers/models/dpr/tokenization_dpr.py +7 -9
  378. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  379. transformers/models/dpt/configuration_dpt.py +0 -1
  380. transformers/models/dpt/image_processing_dpt.py +65 -66
  381. transformers/models/dpt/image_processing_dpt_fast.py +13 -14
  382. transformers/models/dpt/modeling_dpt.py +19 -21
  383. transformers/models/dpt/modular_dpt.py +10 -11
  384. transformers/models/edgetam/configuration_edgetam.py +0 -1
  385. transformers/models/edgetam/modeling_edgetam.py +39 -41
  386. transformers/models/edgetam/modular_edgetam.py +2 -6
  387. transformers/models/edgetam_video/__init__.py +0 -1
  388. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  389. transformers/models/edgetam_video/modeling_edgetam_video.py +76 -77
  390. transformers/models/edgetam_video/modular_edgetam_video.py +16 -18
  391. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  392. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  393. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -4
  394. transformers/models/efficientloftr/modeling_efficientloftr.py +27 -29
  395. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  396. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  397. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  398. transformers/models/efficientnet/image_processing_efficientnet_fast.py +14 -15
  399. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  400. transformers/models/electra/configuration_electra.py +0 -1
  401. transformers/models/electra/modeling_electra.py +101 -103
  402. transformers/models/emu3/configuration_emu3.py +5 -7
  403. transformers/models/emu3/image_processing_emu3.py +44 -39
  404. transformers/models/emu3/modeling_emu3.py +59 -62
  405. transformers/models/emu3/modular_emu3.py +32 -34
  406. transformers/models/emu3/processing_emu3.py +18 -43
  407. transformers/models/encodec/configuration_encodec.py +2 -4
  408. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  409. transformers/models/encodec/modeling_encodec.py +25 -29
  410. transformers/models/encoder_decoder/configuration_encoder_decoder.py +0 -1
  411. transformers/models/encoder_decoder/modeling_encoder_decoder.py +17 -19
  412. transformers/models/eomt/configuration_eomt.py +0 -1
  413. transformers/models/eomt/image_processing_eomt.py +53 -55
  414. transformers/models/eomt/image_processing_eomt_fast.py +15 -16
  415. transformers/models/eomt/modeling_eomt.py +16 -18
  416. transformers/models/eomt/modular_eomt.py +11 -13
  417. transformers/models/ernie/configuration_ernie.py +0 -1
  418. transformers/models/ernie/modeling_ernie.py +121 -132
  419. transformers/models/ernie/modular_ernie.py +91 -103
  420. transformers/models/ernie4_5/configuration_ernie4_5.py +18 -20
  421. transformers/models/ernie4_5/modeling_ernie4_5.py +31 -33
  422. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  423. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +27 -29
  424. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +36 -38
  425. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  426. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +0 -1
  427. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  428. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  429. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +84 -87
  430. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +86 -89
  431. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  432. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  433. transformers/models/esm/configuration_esm.py +2 -4
  434. transformers/models/esm/modeling_esm.py +32 -34
  435. transformers/models/esm/modeling_esmfold.py +42 -44
  436. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  437. transformers/models/esm/openfold_utils/loss.py +1 -2
  438. transformers/models/esm/openfold_utils/protein.py +13 -13
  439. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  440. transformers/models/esm/tokenization_esm.py +2 -4
  441. transformers/models/evolla/configuration_evolla.py +29 -32
  442. transformers/models/evolla/modeling_evolla.py +58 -61
  443. transformers/models/evolla/modular_evolla.py +45 -47
  444. transformers/models/evolla/processing_evolla.py +23 -35
  445. transformers/models/exaone4/configuration_exaone4.py +19 -22
  446. transformers/models/exaone4/modeling_exaone4.py +32 -35
  447. transformers/models/exaone4/modular_exaone4.py +40 -42
  448. transformers/models/falcon/configuration_falcon.py +22 -25
  449. transformers/models/falcon/modeling_falcon.py +73 -76
  450. transformers/models/falcon_h1/configuration_falcon_h1.py +40 -43
  451. transformers/models/falcon_h1/modeling_falcon_h1.py +52 -55
  452. transformers/models/falcon_h1/modular_falcon_h1.py +47 -48
  453. transformers/models/falcon_mamba/configuration_falcon_mamba.py +0 -1
  454. transformers/models/falcon_mamba/modeling_falcon_mamba.py +46 -47
  455. transformers/models/falcon_mamba/modular_falcon_mamba.py +10 -13
  456. transformers/models/fast_vlm/configuration_fast_vlm.py +1 -0
  457. transformers/models/fast_vlm/modeling_fast_vlm.py +36 -36
  458. transformers/models/fast_vlm/modular_fast_vlm.py +2 -3
  459. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -5
  460. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  461. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  462. transformers/models/flaubert/configuration_flaubert.py +0 -1
  463. transformers/models/flaubert/modeling_flaubert.py +124 -128
  464. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  465. transformers/models/flava/configuration_flava.py +5 -6
  466. transformers/models/flava/image_processing_flava.py +66 -67
  467. transformers/models/flava/image_processing_flava_fast.py +42 -43
  468. transformers/models/flava/modeling_flava.py +108 -107
  469. transformers/models/flava/processing_flava.py +2 -12
  470. transformers/models/flex_olmo/__init__.py +0 -1
  471. transformers/models/flex_olmo/configuration_flex_olmo.py +23 -25
  472. transformers/models/flex_olmo/modeling_flex_olmo.py +37 -39
  473. transformers/models/flex_olmo/modular_flex_olmo.py +35 -37
  474. transformers/models/florence2/configuration_florence2.py +0 -1
  475. transformers/models/florence2/modeling_florence2.py +39 -40
  476. transformers/models/florence2/modular_florence2.py +52 -81
  477. transformers/models/florence2/processing_florence2.py +18 -47
  478. transformers/models/fnet/configuration_fnet.py +0 -1
  479. transformers/models/fnet/modeling_fnet.py +69 -80
  480. transformers/models/fnet/tokenization_fnet.py +0 -1
  481. transformers/models/focalnet/configuration_focalnet.py +0 -1
  482. transformers/models/focalnet/modeling_focalnet.py +39 -41
  483. transformers/models/fsmt/configuration_fsmt.py +0 -1
  484. transformers/models/fsmt/modeling_fsmt.py +47 -48
  485. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  486. transformers/models/funnel/configuration_funnel.py +0 -1
  487. transformers/models/funnel/modeling_funnel.py +91 -93
  488. transformers/models/funnel/tokenization_funnel.py +2 -5
  489. transformers/models/fuyu/configuration_fuyu.py +23 -26
  490. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  491. transformers/models/fuyu/image_processing_fuyu_fast.py +12 -13
  492. transformers/models/fuyu/modeling_fuyu.py +26 -29
  493. transformers/models/fuyu/processing_fuyu.py +9 -36
  494. transformers/models/gemma/configuration_gemma.py +20 -23
  495. transformers/models/gemma/modeling_gemma.py +32 -34
  496. transformers/models/gemma/modular_gemma.py +28 -29
  497. transformers/models/gemma/tokenization_gemma.py +3 -6
  498. transformers/models/gemma2/configuration_gemma2.py +25 -28
  499. transformers/models/gemma2/modeling_gemma2.py +34 -37
  500. transformers/models/gemma2/modular_gemma2.py +55 -57
  501. transformers/models/gemma3/configuration_gemma3.py +28 -29
  502. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  503. transformers/models/gemma3/image_processing_gemma3_fast.py +9 -10
  504. transformers/models/gemma3/modeling_gemma3.py +86 -89
  505. transformers/models/gemma3/modular_gemma3.py +85 -86
  506. transformers/models/gemma3/processing_gemma3.py +5 -5
  507. transformers/models/gemma3n/configuration_gemma3n.py +9 -10
  508. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  509. transformers/models/gemma3n/modeling_gemma3n.py +80 -89
  510. transformers/models/gemma3n/modular_gemma3n.py +66 -75
  511. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  512. transformers/models/git/configuration_git.py +0 -1
  513. transformers/models/git/modeling_git.py +84 -86
  514. transformers/models/git/processing_git.py +2 -14
  515. transformers/models/glm/configuration_glm.py +19 -21
  516. transformers/models/glm/modeling_glm.py +32 -35
  517. transformers/models/glm/modular_glm.py +4 -7
  518. transformers/models/glm4/configuration_glm4.py +19 -21
  519. transformers/models/glm4/modeling_glm4.py +35 -37
  520. transformers/models/glm4/modular_glm4.py +8 -10
  521. transformers/models/glm46v/configuration_glm46v.py +0 -1
  522. transformers/models/glm46v/image_processing_glm46v.py +35 -36
  523. transformers/models/glm46v/image_processing_glm46v_fast.py +7 -7
  524. transformers/models/glm46v/modeling_glm46v.py +51 -51
  525. transformers/models/glm46v/modular_glm46v.py +1 -3
  526. transformers/models/glm46v/processing_glm46v.py +7 -41
  527. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  528. transformers/models/glm4_moe/configuration_glm4_moe.py +25 -28
  529. transformers/models/glm4_moe/modeling_glm4_moe.py +32 -35
  530. transformers/models/glm4_moe/modular_glm4_moe.py +26 -29
  531. transformers/models/glm4_moe_lite/__init__.py +28 -0
  532. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +235 -0
  533. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  534. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +304 -0
  535. transformers/models/glm4v/configuration_glm4v.py +14 -17
  536. transformers/models/glm4v/image_processing_glm4v.py +34 -36
  537. transformers/models/glm4v/image_processing_glm4v_fast.py +6 -7
  538. transformers/models/glm4v/modeling_glm4v.py +133 -151
  539. transformers/models/glm4v/modular_glm4v.py +131 -182
  540. transformers/models/glm4v/processing_glm4v.py +7 -41
  541. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  542. transformers/models/glm4v_moe/configuration_glm4v_moe.py +119 -122
  543. transformers/models/glm4v_moe/modeling_glm4v_moe.py +237 -297
  544. transformers/models/glm4v_moe/modular_glm4v_moe.py +54 -163
  545. transformers/models/glm_image/__init__.py +31 -0
  546. transformers/models/glm_image/configuration_glm_image.py +352 -0
  547. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  548. transformers/models/glm_image/image_processing_glm_image_fast.py +296 -0
  549. transformers/models/glm_image/modeling_glm_image.py +1590 -0
  550. transformers/models/glm_image/modular_glm_image.py +1480 -0
  551. transformers/models/glm_image/processing_glm_image.py +217 -0
  552. transformers/models/glmasr/__init__.py +0 -1
  553. transformers/models/glmasr/configuration_glmasr.py +0 -1
  554. transformers/models/glmasr/modeling_glmasr.py +17 -18
  555. transformers/models/glmasr/modular_glmasr.py +16 -18
  556. transformers/models/glmasr/processing_glmasr.py +7 -8
  557. transformers/models/glpn/configuration_glpn.py +0 -1
  558. transformers/models/glpn/image_processing_glpn.py +11 -12
  559. transformers/models/glpn/image_processing_glpn_fast.py +8 -9
  560. transformers/models/glpn/modeling_glpn.py +10 -12
  561. transformers/models/got_ocr2/configuration_got_ocr2.py +5 -8
  562. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  563. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +6 -7
  564. transformers/models/got_ocr2/modeling_got_ocr2.py +40 -42
  565. transformers/models/got_ocr2/modular_got_ocr2.py +31 -34
  566. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  567. transformers/models/gpt2/configuration_gpt2.py +0 -1
  568. transformers/models/gpt2/modeling_gpt2.py +106 -108
  569. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  570. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +0 -1
  571. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +73 -80
  572. transformers/models/gpt_neo/configuration_gpt_neo.py +0 -1
  573. transformers/models/gpt_neo/modeling_gpt_neo.py +63 -64
  574. transformers/models/gpt_neox/configuration_gpt_neox.py +19 -22
  575. transformers/models/gpt_neox/modeling_gpt_neox.py +70 -72
  576. transformers/models/gpt_neox/modular_gpt_neox.py +64 -66
  577. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  578. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +15 -18
  579. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +41 -44
  580. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  581. transformers/models/gpt_oss/configuration_gpt_oss.py +21 -24
  582. transformers/models/gpt_oss/modeling_gpt_oss.py +34 -35
  583. transformers/models/gpt_oss/modular_gpt_oss.py +17 -19
  584. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  585. transformers/models/gptj/configuration_gptj.py +0 -1
  586. transformers/models/gptj/modeling_gptj.py +82 -81
  587. transformers/models/granite/configuration_granite.py +23 -26
  588. transformers/models/granite/modeling_granite.py +39 -41
  589. transformers/models/granite/modular_granite.py +29 -31
  590. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  591. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  592. transformers/models/granite_speech/modeling_granite_speech.py +21 -23
  593. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  594. transformers/models/granitemoe/configuration_granitemoe.py +26 -29
  595. transformers/models/granitemoe/modeling_granitemoe.py +35 -37
  596. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  597. transformers/models/granitemoehybrid/__init__.py +0 -1
  598. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +38 -41
  599. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +60 -64
  600. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +18 -20
  601. transformers/models/granitemoeshared/configuration_granitemoeshared.py +27 -30
  602. transformers/models/granitemoeshared/modeling_granitemoeshared.py +48 -52
  603. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  604. transformers/models/grounding_dino/configuration_grounding_dino.py +0 -1
  605. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  606. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +17 -18
  607. transformers/models/grounding_dino/modeling_grounding_dino.py +94 -96
  608. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  609. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  610. transformers/models/groupvit/configuration_groupvit.py +0 -1
  611. transformers/models/groupvit/modeling_groupvit.py +69 -70
  612. transformers/models/helium/configuration_helium.py +20 -22
  613. transformers/models/helium/modeling_helium.py +33 -36
  614. transformers/models/helium/modular_helium.py +3 -7
  615. transformers/models/herbert/tokenization_herbert.py +4 -6
  616. transformers/models/hgnet_v2/configuration_hgnet_v2.py +0 -1
  617. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -9
  618. transformers/models/hgnet_v2/modular_hgnet_v2.py +6 -9
  619. transformers/models/hiera/configuration_hiera.py +0 -1
  620. transformers/models/hiera/modeling_hiera.py +60 -62
  621. transformers/models/hubert/configuration_hubert.py +0 -1
  622. transformers/models/hubert/modeling_hubert.py +35 -37
  623. transformers/models/hubert/modular_hubert.py +8 -11
  624. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +21 -24
  625. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +30 -33
  626. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +3 -5
  627. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +25 -28
  628. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +32 -35
  629. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +5 -7
  630. transformers/models/ibert/configuration_ibert.py +0 -1
  631. transformers/models/ibert/modeling_ibert.py +60 -62
  632. transformers/models/ibert/quant_modules.py +0 -1
  633. transformers/models/idefics/configuration_idefics.py +0 -1
  634. transformers/models/idefics/image_processing_idefics.py +13 -15
  635. transformers/models/idefics/modeling_idefics.py +60 -61
  636. transformers/models/idefics/perceiver.py +1 -3
  637. transformers/models/idefics/processing_idefics.py +32 -48
  638. transformers/models/idefics/vision.py +22 -24
  639. transformers/models/idefics2/configuration_idefics2.py +0 -1
  640. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  641. transformers/models/idefics2/image_processing_idefics2_fast.py +7 -8
  642. transformers/models/idefics2/modeling_idefics2.py +56 -58
  643. transformers/models/idefics2/processing_idefics2.py +10 -68
  644. transformers/models/idefics3/configuration_idefics3.py +0 -1
  645. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  646. transformers/models/idefics3/image_processing_idefics3_fast.py +11 -12
  647. transformers/models/idefics3/modeling_idefics3.py +52 -54
  648. transformers/models/idefics3/processing_idefics3.py +15 -69
  649. transformers/models/ijepa/configuration_ijepa.py +0 -1
  650. transformers/models/ijepa/modeling_ijepa.py +10 -11
  651. transformers/models/ijepa/modular_ijepa.py +5 -7
  652. transformers/models/imagegpt/configuration_imagegpt.py +0 -1
  653. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  654. transformers/models/imagegpt/image_processing_imagegpt_fast.py +8 -9
  655. transformers/models/imagegpt/modeling_imagegpt.py +57 -58
  656. transformers/models/informer/configuration_informer.py +6 -9
  657. transformers/models/informer/modeling_informer.py +84 -86
  658. transformers/models/informer/modular_informer.py +13 -16
  659. transformers/models/instructblip/configuration_instructblip.py +0 -1
  660. transformers/models/instructblip/modeling_instructblip.py +43 -44
  661. transformers/models/instructblip/processing_instructblip.py +10 -36
  662. transformers/models/instructblipvideo/configuration_instructblipvideo.py +0 -1
  663. transformers/models/instructblipvideo/modeling_instructblipvideo.py +55 -55
  664. transformers/models/instructblipvideo/modular_instructblipvideo.py +34 -36
  665. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  666. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +4 -5
  667. transformers/models/internvl/configuration_internvl.py +0 -1
  668. transformers/models/internvl/modeling_internvl.py +41 -43
  669. transformers/models/internvl/modular_internvl.py +19 -21
  670. transformers/models/internvl/processing_internvl.py +12 -45
  671. transformers/models/internvl/video_processing_internvl.py +8 -9
  672. transformers/models/jais2/configuration_jais2.py +20 -22
  673. transformers/models/jais2/modeling_jais2.py +32 -34
  674. transformers/models/jais2/modular_jais2.py +20 -22
  675. transformers/models/jamba/configuration_jamba.py +0 -1
  676. transformers/models/jamba/modeling_jamba.py +43 -46
  677. transformers/models/jamba/modular_jamba.py +37 -38
  678. transformers/models/janus/configuration_janus.py +0 -1
  679. transformers/models/janus/image_processing_janus.py +35 -37
  680. transformers/models/janus/image_processing_janus_fast.py +12 -13
  681. transformers/models/janus/modeling_janus.py +41 -43
  682. transformers/models/janus/modular_janus.py +60 -63
  683. transformers/models/janus/processing_janus.py +17 -43
  684. transformers/models/jetmoe/configuration_jetmoe.py +20 -23
  685. transformers/models/jetmoe/modeling_jetmoe.py +39 -42
  686. transformers/models/jetmoe/modular_jetmoe.py +30 -33
  687. transformers/models/kosmos2/configuration_kosmos2.py +0 -1
  688. transformers/models/kosmos2/modeling_kosmos2.py +145 -146
  689. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  690. transformers/models/kosmos2_5/__init__.py +0 -1
  691. transformers/models/kosmos2_5/configuration_kosmos2_5.py +0 -1
  692. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  693. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  694. transformers/models/kosmos2_5/modeling_kosmos2_5.py +108 -109
  695. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  696. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +23 -25
  697. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  698. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +59 -66
  699. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +19 -21
  700. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  701. transformers/models/lasr/configuration_lasr.py +1 -3
  702. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  703. transformers/models/lasr/modeling_lasr.py +18 -21
  704. transformers/models/lasr/modular_lasr.py +8 -10
  705. transformers/models/lasr/processing_lasr.py +12 -6
  706. transformers/models/lasr/tokenization_lasr.py +2 -4
  707. transformers/models/layoutlm/configuration_layoutlm.py +0 -1
  708. transformers/models/layoutlm/modeling_layoutlm.py +67 -69
  709. transformers/models/layoutlmv2/configuration_layoutlmv2.py +0 -1
  710. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  711. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +5 -6
  712. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  713. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  714. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  715. transformers/models/layoutlmv3/configuration_layoutlmv3.py +0 -1
  716. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  717. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +7 -8
  718. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  719. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  720. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  721. transformers/models/layoutxlm/configuration_layoutxlm.py +0 -1
  722. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  723. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  724. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  725. transformers/models/led/configuration_led.py +1 -4
  726. transformers/models/led/modeling_led.py +113 -267
  727. transformers/models/levit/configuration_levit.py +0 -1
  728. transformers/models/levit/image_processing_levit.py +19 -21
  729. transformers/models/levit/image_processing_levit_fast.py +0 -1
  730. transformers/models/levit/modeling_levit.py +17 -19
  731. transformers/models/lfm2/configuration_lfm2.py +22 -23
  732. transformers/models/lfm2/modeling_lfm2.py +42 -44
  733. transformers/models/lfm2/modular_lfm2.py +29 -29
  734. transformers/models/lfm2_moe/__init__.py +0 -1
  735. transformers/models/lfm2_moe/configuration_lfm2_moe.py +1 -2
  736. transformers/models/lfm2_moe/modeling_lfm2_moe.py +44 -45
  737. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  738. transformers/models/lfm2_vl/configuration_lfm2_vl.py +0 -1
  739. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +34 -5
  740. transformers/models/lfm2_vl/modeling_lfm2_vl.py +31 -33
  741. transformers/models/lfm2_vl/modular_lfm2_vl.py +24 -27
  742. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  743. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  744. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -4
  745. transformers/models/lightglue/modeling_lightglue.py +28 -30
  746. transformers/models/lightglue/modular_lightglue.py +28 -28
  747. transformers/models/lighton_ocr/__init__.py +28 -0
  748. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  749. transformers/models/lighton_ocr/modeling_lighton_ocr.py +460 -0
  750. transformers/models/lighton_ocr/modular_lighton_ocr.py +403 -0
  751. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  752. transformers/models/lilt/configuration_lilt.py +0 -1
  753. transformers/models/lilt/modeling_lilt.py +53 -55
  754. transformers/models/llama/configuration_llama.py +21 -24
  755. transformers/models/llama/modeling_llama.py +31 -34
  756. transformers/models/llama/tokenization_llama.py +2 -4
  757. transformers/models/llama4/configuration_llama4.py +20 -22
  758. transformers/models/llama4/image_processing_llama4_fast.py +8 -9
  759. transformers/models/llama4/modeling_llama4.py +70 -71
  760. transformers/models/llama4/processing_llama4.py +33 -57
  761. transformers/models/llava/configuration_llava.py +0 -1
  762. transformers/models/llava/image_processing_llava.py +25 -28
  763. transformers/models/llava/image_processing_llava_fast.py +6 -7
  764. transformers/models/llava/modeling_llava.py +35 -37
  765. transformers/models/llava/processing_llava.py +18 -51
  766. transformers/models/llava_next/configuration_llava_next.py +0 -1
  767. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  768. transformers/models/llava_next/image_processing_llava_next_fast.py +5 -6
  769. transformers/models/llava_next/modeling_llava_next.py +42 -44
  770. transformers/models/llava_next/processing_llava_next.py +18 -47
  771. transformers/models/llava_next_video/configuration_llava_next_video.py +0 -1
  772. transformers/models/llava_next_video/modeling_llava_next_video.py +53 -55
  773. transformers/models/llava_next_video/modular_llava_next_video.py +44 -46
  774. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  775. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  776. transformers/models/llava_onevision/configuration_llava_onevision.py +0 -1
  777. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  778. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +6 -7
  779. transformers/models/llava_onevision/modeling_llava_onevision.py +60 -62
  780. transformers/models/llava_onevision/modular_llava_onevision.py +51 -52
  781. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  782. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  783. transformers/models/longcat_flash/__init__.py +0 -1
  784. transformers/models/longcat_flash/configuration_longcat_flash.py +32 -35
  785. transformers/models/longcat_flash/modeling_longcat_flash.py +30 -31
  786. transformers/models/longcat_flash/modular_longcat_flash.py +17 -19
  787. transformers/models/longformer/configuration_longformer.py +1 -4
  788. transformers/models/longformer/modeling_longformer.py +99 -101
  789. transformers/models/longt5/configuration_longt5.py +0 -1
  790. transformers/models/longt5/modeling_longt5.py +43 -44
  791. transformers/models/luke/configuration_luke.py +0 -1
  792. transformers/models/luke/modeling_luke.py +179 -181
  793. transformers/models/luke/tokenization_luke.py +99 -105
  794. transformers/models/lw_detr/__init__.py +27 -0
  795. transformers/models/lw_detr/configuration_lw_detr.py +374 -0
  796. transformers/models/lw_detr/modeling_lw_detr.py +1698 -0
  797. transformers/models/lw_detr/modular_lw_detr.py +1611 -0
  798. transformers/models/lxmert/configuration_lxmert.py +0 -1
  799. transformers/models/lxmert/modeling_lxmert.py +63 -74
  800. transformers/models/m2m_100/configuration_m2m_100.py +0 -1
  801. transformers/models/m2m_100/modeling_m2m_100.py +69 -71
  802. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  803. transformers/models/mamba/configuration_mamba.py +0 -1
  804. transformers/models/mamba/modeling_mamba.py +43 -44
  805. transformers/models/mamba2/configuration_mamba2.py +0 -1
  806. transformers/models/mamba2/modeling_mamba2.py +44 -46
  807. transformers/models/marian/configuration_marian.py +0 -1
  808. transformers/models/marian/modeling_marian.py +84 -86
  809. transformers/models/marian/tokenization_marian.py +6 -6
  810. transformers/models/markuplm/configuration_markuplm.py +0 -1
  811. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  812. transformers/models/markuplm/modeling_markuplm.py +60 -62
  813. transformers/models/markuplm/processing_markuplm.py +31 -38
  814. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  815. transformers/models/mask2former/configuration_mask2former.py +4 -7
  816. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  817. transformers/models/mask2former/image_processing_mask2former_fast.py +29 -29
  818. transformers/models/mask2former/modeling_mask2former.py +90 -92
  819. transformers/models/mask2former/modular_mask2former.py +6 -8
  820. transformers/models/maskformer/configuration_maskformer.py +5 -8
  821. transformers/models/maskformer/configuration_maskformer_swin.py +0 -1
  822. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  823. transformers/models/maskformer/image_processing_maskformer_fast.py +28 -29
  824. transformers/models/maskformer/modeling_maskformer.py +56 -58
  825. transformers/models/maskformer/modeling_maskformer_swin.py +18 -20
  826. transformers/models/mbart/configuration_mbart.py +0 -1
  827. transformers/models/mbart/modeling_mbart.py +111 -113
  828. transformers/models/mbart/tokenization_mbart.py +2 -4
  829. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  830. transformers/models/megatron_bert/configuration_megatron_bert.py +0 -1
  831. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -150
  832. transformers/models/metaclip_2/modeling_metaclip_2.py +46 -46
  833. transformers/models/metaclip_2/modular_metaclip_2.py +19 -21
  834. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  835. transformers/models/mgp_str/modeling_mgp_str.py +14 -16
  836. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  837. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  838. transformers/models/mimi/configuration_mimi.py +38 -40
  839. transformers/models/mimi/modeling_mimi.py +76 -79
  840. transformers/models/minimax/__init__.py +0 -1
  841. transformers/models/minimax/configuration_minimax.py +32 -36
  842. transformers/models/minimax/modeling_minimax.py +41 -44
  843. transformers/models/minimax/modular_minimax.py +50 -53
  844. transformers/models/minimax_m2/__init__.py +28 -0
  845. transformers/models/minimax_m2/configuration_minimax_m2.py +211 -0
  846. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  847. transformers/models/minimax_m2/modular_minimax_m2.py +369 -0
  848. transformers/models/ministral/configuration_ministral.py +20 -22
  849. transformers/models/ministral/modeling_ministral.py +31 -33
  850. transformers/models/ministral/modular_ministral.py +27 -29
  851. transformers/models/ministral3/configuration_ministral3.py +19 -22
  852. transformers/models/ministral3/modeling_ministral3.py +31 -33
  853. transformers/models/ministral3/modular_ministral3.py +4 -5
  854. transformers/models/mistral/configuration_mistral.py +19 -22
  855. transformers/models/mistral/modeling_mistral.py +31 -33
  856. transformers/models/mistral/modular_mistral.py +11 -12
  857. transformers/models/mistral3/configuration_mistral3.py +0 -1
  858. transformers/models/mistral3/modeling_mistral3.py +43 -42
  859. transformers/models/mistral3/modular_mistral3.py +35 -35
  860. transformers/models/mixtral/configuration_mixtral.py +24 -27
  861. transformers/models/mixtral/modeling_mixtral.py +35 -38
  862. transformers/models/mixtral/modular_mixtral.py +26 -29
  863. transformers/models/mlcd/configuration_mlcd.py +0 -1
  864. transformers/models/mlcd/modeling_mlcd.py +10 -12
  865. transformers/models/mlcd/modular_mlcd.py +9 -11
  866. transformers/models/mllama/configuration_mllama.py +5 -8
  867. transformers/models/mllama/image_processing_mllama.py +23 -25
  868. transformers/models/mllama/image_processing_mllama_fast.py +5 -6
  869. transformers/models/mllama/modeling_mllama.py +81 -84
  870. transformers/models/mllama/processing_mllama.py +6 -55
  871. transformers/models/mluke/tokenization_mluke.py +97 -103
  872. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +0 -1
  873. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +94 -96
  874. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +0 -1
  875. transformers/models/mobilebert/configuration_mobilebert.py +0 -1
  876. transformers/models/mobilebert/modeling_mobilebert.py +75 -85
  877. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  878. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  879. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  880. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  881. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  882. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  883. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  884. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +10 -11
  885. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +17 -20
  886. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  887. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  888. transformers/models/mobilevit/image_processing_mobilevit_fast.py +8 -9
  889. transformers/models/mobilevit/modeling_mobilevit.py +17 -19
  890. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  891. transformers/models/mobilevitv2/modeling_mobilevitv2.py +17 -20
  892. transformers/models/modernbert/configuration_modernbert.py +34 -34
  893. transformers/models/modernbert/modeling_modernbert.py +123 -125
  894. transformers/models/modernbert/modular_modernbert.py +155 -155
  895. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +30 -32
  896. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +45 -47
  897. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +69 -70
  898. transformers/models/moonshine/configuration_moonshine.py +22 -24
  899. transformers/models/moonshine/modeling_moonshine.py +63 -65
  900. transformers/models/moonshine/modular_moonshine.py +72 -73
  901. transformers/models/moshi/configuration_moshi.py +18 -21
  902. transformers/models/moshi/modeling_moshi.py +130 -133
  903. transformers/models/mpnet/configuration_mpnet.py +0 -1
  904. transformers/models/mpnet/modeling_mpnet.py +55 -57
  905. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  906. transformers/models/mpt/configuration_mpt.py +1 -9
  907. transformers/models/mpt/modeling_mpt.py +58 -60
  908. transformers/models/mra/configuration_mra.py +0 -1
  909. transformers/models/mra/modeling_mra.py +54 -56
  910. transformers/models/mt5/configuration_mt5.py +0 -1
  911. transformers/models/mt5/modeling_mt5.py +75 -77
  912. transformers/models/musicgen/configuration_musicgen.py +0 -1
  913. transformers/models/musicgen/modeling_musicgen.py +108 -111
  914. transformers/models/musicgen/processing_musicgen.py +3 -21
  915. transformers/models/musicgen_melody/configuration_musicgen_melody.py +0 -1
  916. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  917. transformers/models/musicgen_melody/modeling_musicgen_melody.py +106 -109
  918. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  919. transformers/models/mvp/configuration_mvp.py +0 -1
  920. transformers/models/mvp/modeling_mvp.py +115 -119
  921. transformers/models/myt5/tokenization_myt5.py +8 -10
  922. transformers/models/nanochat/configuration_nanochat.py +0 -1
  923. transformers/models/nanochat/modeling_nanochat.py +32 -35
  924. transformers/models/nanochat/modular_nanochat.py +12 -14
  925. transformers/models/nemotron/configuration_nemotron.py +20 -23
  926. transformers/models/nemotron/modeling_nemotron.py +49 -52
  927. transformers/models/nllb/tokenization_nllb.py +7 -9
  928. transformers/models/nllb_moe/configuration_nllb_moe.py +0 -1
  929. transformers/models/nllb_moe/modeling_nllb_moe.py +67 -69
  930. transformers/models/nougat/image_processing_nougat.py +29 -32
  931. transformers/models/nougat/image_processing_nougat_fast.py +4 -5
  932. transformers/models/nougat/processing_nougat.py +37 -39
  933. transformers/models/nougat/tokenization_nougat.py +5 -7
  934. transformers/models/nystromformer/configuration_nystromformer.py +0 -1
  935. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  936. transformers/models/olmo/configuration_olmo.py +18 -21
  937. transformers/models/olmo/modeling_olmo.py +31 -34
  938. transformers/models/olmo/modular_olmo.py +5 -9
  939. transformers/models/olmo2/configuration_olmo2.py +18 -21
  940. transformers/models/olmo2/modeling_olmo2.py +32 -35
  941. transformers/models/olmo2/modular_olmo2.py +29 -31
  942. transformers/models/olmo3/__init__.py +0 -1
  943. transformers/models/olmo3/configuration_olmo3.py +20 -23
  944. transformers/models/olmo3/modeling_olmo3.py +31 -34
  945. transformers/models/olmo3/modular_olmo3.py +31 -33
  946. transformers/models/olmoe/configuration_olmoe.py +24 -26
  947. transformers/models/olmoe/modeling_olmoe.py +37 -39
  948. transformers/models/olmoe/modular_olmoe.py +12 -13
  949. transformers/models/omdet_turbo/configuration_omdet_turbo.py +0 -1
  950. transformers/models/omdet_turbo/modeling_omdet_turbo.py +38 -40
  951. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  952. transformers/models/oneformer/configuration_oneformer.py +4 -7
  953. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  954. transformers/models/oneformer/image_processing_oneformer_fast.py +33 -34
  955. transformers/models/oneformer/modeling_oneformer.py +123 -124
  956. transformers/models/oneformer/processing_oneformer.py +28 -43
  957. transformers/models/openai/configuration_openai.py +0 -1
  958. transformers/models/openai/modeling_openai.py +50 -51
  959. transformers/models/openai/tokenization_openai.py +2 -5
  960. transformers/models/opt/configuration_opt.py +0 -1
  961. transformers/models/opt/modeling_opt.py +74 -75
  962. transformers/models/ovis2/__init__.py +0 -1
  963. transformers/models/ovis2/configuration_ovis2.py +0 -1
  964. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  965. transformers/models/ovis2/image_processing_ovis2_fast.py +6 -7
  966. transformers/models/ovis2/modeling_ovis2.py +43 -45
  967. transformers/models/ovis2/modular_ovis2.py +30 -32
  968. transformers/models/ovis2/processing_ovis2.py +12 -40
  969. transformers/models/owlv2/configuration_owlv2.py +0 -1
  970. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  971. transformers/models/owlv2/image_processing_owlv2_fast.py +7 -8
  972. transformers/models/owlv2/modeling_owlv2.py +82 -87
  973. transformers/models/owlv2/modular_owlv2.py +6 -7
  974. transformers/models/owlv2/processing_owlv2.py +20 -49
  975. transformers/models/owlvit/configuration_owlvit.py +0 -1
  976. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  977. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  978. transformers/models/owlvit/modeling_owlvit.py +81 -86
  979. transformers/models/owlvit/processing_owlvit.py +20 -48
  980. transformers/models/paddleocr_vl/__init__.py +0 -1
  981. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +19 -19
  982. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  983. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  984. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +76 -76
  985. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +68 -68
  986. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  987. transformers/models/paligemma/configuration_paligemma.py +0 -1
  988. transformers/models/paligemma/modeling_paligemma.py +51 -53
  989. transformers/models/paligemma/processing_paligemma.py +13 -66
  990. transformers/models/parakeet/configuration_parakeet.py +1 -4
  991. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  992. transformers/models/parakeet/modeling_parakeet.py +18 -22
  993. transformers/models/parakeet/modular_parakeet.py +16 -18
  994. transformers/models/parakeet/processing_parakeet.py +12 -5
  995. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  996. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  997. transformers/models/patchtsmixer/modeling_patchtsmixer.py +60 -62
  998. transformers/models/patchtst/configuration_patchtst.py +6 -9
  999. transformers/models/patchtst/modeling_patchtst.py +72 -74
  1000. transformers/models/pe_audio/__init__.py +0 -1
  1001. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1002. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1003. transformers/models/pe_audio/modeling_pe_audio.py +26 -27
  1004. transformers/models/pe_audio/modular_pe_audio.py +16 -17
  1005. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1006. transformers/models/pe_audio_video/__init__.py +0 -1
  1007. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1008. transformers/models/pe_audio_video/modeling_pe_audio_video.py +60 -61
  1009. transformers/models/pe_audio_video/modular_pe_audio_video.py +52 -53
  1010. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1011. transformers/models/pe_video/__init__.py +0 -1
  1012. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1013. transformers/models/pe_video/modeling_pe_video.py +21 -22
  1014. transformers/models/pe_video/modular_pe_video.py +11 -12
  1015. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1016. transformers/models/pegasus/configuration_pegasus.py +0 -1
  1017. transformers/models/pegasus/modeling_pegasus.py +63 -65
  1018. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1019. transformers/models/pegasus_x/configuration_pegasus_x.py +0 -1
  1020. transformers/models/pegasus_x/modeling_pegasus_x.py +50 -52
  1021. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1022. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1023. transformers/models/perceiver/image_processing_perceiver_fast.py +5 -6
  1024. transformers/models/perceiver/modeling_perceiver.py +135 -136
  1025. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1026. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1027. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1028. transformers/models/perception_lm/modeling_perception_lm.py +38 -40
  1029. transformers/models/perception_lm/modular_perception_lm.py +31 -33
  1030. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1031. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1032. transformers/models/persimmon/configuration_persimmon.py +18 -21
  1033. transformers/models/persimmon/modeling_persimmon.py +39 -42
  1034. transformers/models/phi/configuration_phi.py +19 -22
  1035. transformers/models/phi/modeling_phi.py +35 -37
  1036. transformers/models/phi/modular_phi.py +23 -23
  1037. transformers/models/phi3/configuration_phi3.py +23 -26
  1038. transformers/models/phi3/modeling_phi3.py +33 -36
  1039. transformers/models/phi3/modular_phi3.py +13 -17
  1040. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +25 -26
  1041. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1042. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +7 -7
  1043. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +54 -56
  1044. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +59 -60
  1045. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1046. transformers/models/phimoe/configuration_phimoe.py +26 -29
  1047. transformers/models/phimoe/modeling_phimoe.py +35 -38
  1048. transformers/models/phimoe/modular_phimoe.py +0 -1
  1049. transformers/models/phobert/tokenization_phobert.py +4 -6
  1050. transformers/models/pix2struct/configuration_pix2struct.py +0 -1
  1051. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1052. transformers/models/pix2struct/image_processing_pix2struct_fast.py +7 -10
  1053. transformers/models/pix2struct/modeling_pix2struct.py +42 -45
  1054. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1055. transformers/models/pixio/__init__.py +0 -1
  1056. transformers/models/pixio/configuration_pixio.py +0 -1
  1057. transformers/models/pixio/modeling_pixio.py +7 -9
  1058. transformers/models/pixio/modular_pixio.py +3 -6
  1059. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1060. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1061. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -6
  1062. transformers/models/pixtral/modeling_pixtral.py +22 -25
  1063. transformers/models/pixtral/processing_pixtral.py +18 -52
  1064. transformers/models/plbart/configuration_plbart.py +0 -1
  1065. transformers/models/plbart/modeling_plbart.py +100 -102
  1066. transformers/models/plbart/modular_plbart.py +30 -32
  1067. transformers/models/plbart/tokenization_plbart.py +4 -5
  1068. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1069. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1070. transformers/models/poolformer/image_processing_poolformer_fast.py +6 -7
  1071. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1072. transformers/models/pop2piano/configuration_pop2piano.py +0 -1
  1073. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1074. transformers/models/pop2piano/modeling_pop2piano.py +22 -23
  1075. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1076. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1077. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +1 -0
  1078. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1079. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +14 -15
  1080. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +9 -10
  1081. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +9 -10
  1082. transformers/models/prophetnet/configuration_prophetnet.py +26 -28
  1083. transformers/models/prophetnet/modeling_prophetnet.py +109 -130
  1084. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1085. transformers/models/pvt/configuration_pvt.py +0 -1
  1086. transformers/models/pvt/image_processing_pvt.py +17 -20
  1087. transformers/models/pvt/image_processing_pvt_fast.py +0 -1
  1088. transformers/models/pvt/modeling_pvt.py +19 -21
  1089. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  1090. transformers/models/pvt_v2/modeling_pvt_v2.py +21 -23
  1091. transformers/models/qwen2/configuration_qwen2.py +18 -21
  1092. transformers/models/qwen2/modeling_qwen2.py +31 -33
  1093. transformers/models/qwen2/modular_qwen2.py +11 -12
  1094. transformers/models/qwen2/tokenization_qwen2.py +2 -5
  1095. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +20 -23
  1096. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +135 -128
  1097. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +116 -109
  1098. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1099. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +22 -25
  1100. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +94 -96
  1101. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +46 -85
  1102. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1103. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1104. transformers/models/qwen2_audio/modeling_qwen2_audio.py +27 -29
  1105. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1106. transformers/models/qwen2_moe/configuration_qwen2_moe.py +28 -31
  1107. transformers/models/qwen2_moe/modeling_qwen2_moe.py +36 -39
  1108. transformers/models/qwen2_moe/modular_qwen2_moe.py +7 -10
  1109. transformers/models/qwen2_vl/configuration_qwen2_vl.py +22 -24
  1110. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1111. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +8 -9
  1112. transformers/models/qwen2_vl/modeling_qwen2_vl.py +91 -92
  1113. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1114. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +35 -13
  1115. transformers/models/qwen3/configuration_qwen3.py +20 -23
  1116. transformers/models/qwen3/modeling_qwen3.py +31 -34
  1117. transformers/models/qwen3/modular_qwen3.py +4 -6
  1118. transformers/models/qwen3_moe/configuration_qwen3_moe.py +25 -28
  1119. transformers/models/qwen3_moe/modeling_qwen3_moe.py +36 -39
  1120. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1121. transformers/models/qwen3_next/configuration_qwen3_next.py +31 -34
  1122. transformers/models/qwen3_next/modeling_qwen3_next.py +39 -42
  1123. transformers/models/qwen3_next/modular_qwen3_next.py +33 -34
  1124. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +85 -88
  1125. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +107 -110
  1126. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +122 -148
  1127. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1128. transformers/models/qwen3_vl/configuration_qwen3_vl.py +16 -19
  1129. transformers/models/qwen3_vl/modeling_qwen3_vl.py +74 -77
  1130. transformers/models/qwen3_vl/modular_qwen3_vl.py +68 -105
  1131. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1132. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1133. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +21 -25
  1134. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +80 -83
  1135. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +33 -36
  1136. transformers/models/rag/configuration_rag.py +0 -1
  1137. transformers/models/rag/modeling_rag.py +116 -118
  1138. transformers/models/rag/retrieval_rag.py +2 -4
  1139. transformers/models/rag/tokenization_rag.py +0 -50
  1140. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +21 -24
  1141. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +31 -34
  1142. transformers/models/reformer/configuration_reformer.py +0 -1
  1143. transformers/models/reformer/modeling_reformer.py +67 -68
  1144. transformers/models/reformer/tokenization_reformer.py +3 -6
  1145. transformers/models/regnet/configuration_regnet.py +0 -1
  1146. transformers/models/regnet/modeling_regnet.py +7 -9
  1147. transformers/models/rembert/configuration_rembert.py +0 -1
  1148. transformers/models/rembert/modeling_rembert.py +108 -110
  1149. transformers/models/rembert/tokenization_rembert.py +1 -4
  1150. transformers/models/resnet/configuration_resnet.py +0 -1
  1151. transformers/models/resnet/modeling_resnet.py +8 -10
  1152. transformers/models/roberta/configuration_roberta.py +0 -1
  1153. transformers/models/roberta/modeling_roberta.py +91 -93
  1154. transformers/models/roberta/modular_roberta.py +55 -58
  1155. transformers/models/roberta/tokenization_roberta.py +2 -5
  1156. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1157. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +0 -1
  1158. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +91 -93
  1159. transformers/models/roc_bert/configuration_roc_bert.py +0 -1
  1160. transformers/models/roc_bert/modeling_roc_bert.py +119 -121
  1161. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1162. transformers/models/roformer/configuration_roformer.py +0 -1
  1163. transformers/models/roformer/modeling_roformer.py +79 -81
  1164. transformers/models/roformer/tokenization_roformer.py +3 -6
  1165. transformers/models/roformer/tokenization_utils.py +0 -1
  1166. transformers/models/rt_detr/configuration_rt_detr.py +0 -1
  1167. transformers/models/rt_detr/configuration_rt_detr_resnet.py +0 -1
  1168. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1169. transformers/models/rt_detr/image_processing_rt_detr_fast.py +15 -15
  1170. transformers/models/rt_detr/modeling_rt_detr.py +80 -82
  1171. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -4
  1172. transformers/models/rt_detr/modular_rt_detr.py +14 -14
  1173. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +0 -1
  1174. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +79 -81
  1175. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +2 -4
  1176. transformers/models/rwkv/configuration_rwkv.py +0 -1
  1177. transformers/models/rwkv/modeling_rwkv.py +29 -31
  1178. transformers/models/sam/configuration_sam.py +0 -1
  1179. transformers/models/sam/image_processing_sam.py +59 -60
  1180. transformers/models/sam/image_processing_sam_fast.py +21 -22
  1181. transformers/models/sam/modeling_sam.py +33 -35
  1182. transformers/models/sam/processing_sam.py +39 -27
  1183. transformers/models/sam2/configuration_sam2.py +0 -1
  1184. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1185. transformers/models/sam2/modeling_sam2.py +45 -47
  1186. transformers/models/sam2/modular_sam2.py +43 -44
  1187. transformers/models/sam2/processing_sam2.py +31 -47
  1188. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1189. transformers/models/sam2_video/modeling_sam2_video.py +69 -70
  1190. transformers/models/sam2_video/modular_sam2_video.py +60 -79
  1191. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1192. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1193. transformers/models/sam3/configuration_sam3.py +0 -1
  1194. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1195. transformers/models/sam3/modeling_sam3.py +54 -56
  1196. transformers/models/sam3/modular_sam3.py +3 -8
  1197. transformers/models/sam3/processing_sam3.py +29 -48
  1198. transformers/models/sam3_tracker/__init__.py +0 -1
  1199. transformers/models/sam3_tracker/configuration_sam3_tracker.py +0 -1
  1200. transformers/models/sam3_tracker/modeling_sam3_tracker.py +34 -36
  1201. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  1202. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -47
  1203. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1204. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1205. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +70 -70
  1206. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +2 -4
  1207. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1208. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1209. transformers/models/sam3_video/modeling_sam3_video.py +29 -31
  1210. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1211. transformers/models/sam_hq/__init__.py +1 -1
  1212. transformers/models/sam_hq/configuration_sam_hq.py +0 -1
  1213. transformers/models/sam_hq/modeling_sam_hq.py +39 -41
  1214. transformers/models/sam_hq/modular_sam_hq.py +17 -19
  1215. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +39 -28
  1216. transformers/models/seamless_m4t/configuration_seamless_m4t.py +0 -1
  1217. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1218. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1219. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1220. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1221. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +0 -1
  1222. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1223. transformers/models/seed_oss/configuration_seed_oss.py +23 -25
  1224. transformers/models/seed_oss/modeling_seed_oss.py +30 -32
  1225. transformers/models/seed_oss/modular_seed_oss.py +3 -4
  1226. transformers/models/segformer/configuration_segformer.py +0 -10
  1227. transformers/models/segformer/image_processing_segformer.py +39 -42
  1228. transformers/models/segformer/image_processing_segformer_fast.py +7 -8
  1229. transformers/models/segformer/modeling_segformer.py +24 -26
  1230. transformers/models/segformer/modular_segformer.py +5 -6
  1231. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1232. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1233. transformers/models/seggpt/modeling_seggpt.py +28 -30
  1234. transformers/models/sew/configuration_sew.py +0 -1
  1235. transformers/models/sew/modeling_sew.py +33 -35
  1236. transformers/models/sew/modular_sew.py +10 -12
  1237. transformers/models/sew_d/configuration_sew_d.py +0 -1
  1238. transformers/models/sew_d/modeling_sew_d.py +28 -30
  1239. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1240. transformers/models/shieldgemma2/modeling_shieldgemma2.py +15 -17
  1241. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1242. transformers/models/siglip/configuration_siglip.py +0 -1
  1243. transformers/models/siglip/image_processing_siglip.py +17 -20
  1244. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1245. transformers/models/siglip/modeling_siglip.py +38 -39
  1246. transformers/models/siglip/processing_siglip.py +2 -14
  1247. transformers/models/siglip/tokenization_siglip.py +6 -7
  1248. transformers/models/siglip2/configuration_siglip2.py +1 -1
  1249. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1250. transformers/models/siglip2/image_processing_siglip2_fast.py +4 -5
  1251. transformers/models/siglip2/modeling_siglip2.py +54 -54
  1252. transformers/models/siglip2/modular_siglip2.py +23 -25
  1253. transformers/models/siglip2/processing_siglip2.py +2 -14
  1254. transformers/models/smollm3/configuration_smollm3.py +23 -26
  1255. transformers/models/smollm3/modeling_smollm3.py +31 -34
  1256. transformers/models/smollm3/modular_smollm3.py +27 -29
  1257. transformers/models/smolvlm/configuration_smolvlm.py +1 -1
  1258. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1259. transformers/models/smolvlm/image_processing_smolvlm_fast.py +12 -12
  1260. transformers/models/smolvlm/modeling_smolvlm.py +51 -52
  1261. transformers/models/smolvlm/modular_smolvlm.py +15 -17
  1262. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1263. transformers/models/smolvlm/video_processing_smolvlm.py +7 -8
  1264. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1265. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +20 -23
  1266. transformers/models/speech_to_text/configuration_speech_to_text.py +0 -1
  1267. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1268. transformers/models/speech_to_text/modeling_speech_to_text.py +52 -54
  1269. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1270. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1271. transformers/models/speecht5/configuration_speecht5.py +0 -1
  1272. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1273. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1274. transformers/models/speecht5/number_normalizer.py +0 -1
  1275. transformers/models/speecht5/processing_speecht5.py +3 -37
  1276. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1277. transformers/models/splinter/configuration_splinter.py +0 -1
  1278. transformers/models/splinter/modeling_splinter.py +54 -56
  1279. transformers/models/splinter/tokenization_splinter.py +2 -4
  1280. transformers/models/squeezebert/configuration_squeezebert.py +0 -1
  1281. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1282. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1283. transformers/models/stablelm/configuration_stablelm.py +20 -23
  1284. transformers/models/stablelm/modeling_stablelm.py +39 -42
  1285. transformers/models/starcoder2/configuration_starcoder2.py +19 -22
  1286. transformers/models/starcoder2/modeling_starcoder2.py +33 -36
  1287. transformers/models/starcoder2/modular_starcoder2.py +13 -15
  1288. transformers/models/superglue/configuration_superglue.py +3 -3
  1289. transformers/models/superglue/image_processing_superglue.py +15 -15
  1290. transformers/models/superglue/image_processing_superglue_fast.py +4 -5
  1291. transformers/models/superglue/modeling_superglue.py +32 -33
  1292. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1293. transformers/models/superpoint/image_processing_superpoint_fast.py +4 -5
  1294. transformers/models/superpoint/modeling_superpoint.py +13 -14
  1295. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1296. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1297. transformers/models/swin/configuration_swin.py +0 -1
  1298. transformers/models/swin/modeling_swin.py +58 -70
  1299. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1300. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1301. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -5
  1302. transformers/models/swin2sr/modeling_swin2sr.py +26 -28
  1303. transformers/models/swinv2/configuration_swinv2.py +0 -1
  1304. transformers/models/swinv2/modeling_swinv2.py +55 -67
  1305. transformers/models/switch_transformers/configuration_switch_transformers.py +0 -1
  1306. transformers/models/switch_transformers/modeling_switch_transformers.py +32 -33
  1307. transformers/models/switch_transformers/modular_switch_transformers.py +29 -30
  1308. transformers/models/t5/configuration_t5.py +0 -1
  1309. transformers/models/t5/modeling_t5.py +75 -77
  1310. transformers/models/t5/tokenization_t5.py +1 -3
  1311. transformers/models/t5gemma/configuration_t5gemma.py +33 -34
  1312. transformers/models/t5gemma/modeling_t5gemma.py +96 -99
  1313. transformers/models/t5gemma/modular_t5gemma.py +117 -118
  1314. transformers/models/t5gemma2/configuration_t5gemma2.py +53 -54
  1315. transformers/models/t5gemma2/modeling_t5gemma2.py +96 -99
  1316. transformers/models/t5gemma2/modular_t5gemma2.py +134 -135
  1317. transformers/models/table_transformer/configuration_table_transformer.py +0 -1
  1318. transformers/models/table_transformer/modeling_table_transformer.py +46 -48
  1319. transformers/models/tapas/configuration_tapas.py +0 -1
  1320. transformers/models/tapas/modeling_tapas.py +64 -66
  1321. transformers/models/tapas/tokenization_tapas.py +115 -153
  1322. transformers/models/textnet/configuration_textnet.py +0 -1
  1323. transformers/models/textnet/image_processing_textnet.py +22 -25
  1324. transformers/models/textnet/image_processing_textnet_fast.py +5 -6
  1325. transformers/models/textnet/modeling_textnet.py +13 -14
  1326. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1327. transformers/models/time_series_transformer/modeling_time_series_transformer.py +79 -81
  1328. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1329. transformers/models/timesfm/modeling_timesfm.py +17 -19
  1330. transformers/models/timesfm/modular_timesfm.py +16 -18
  1331. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1332. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1333. transformers/models/timm_backbone/configuration_timm_backbone.py +0 -1
  1334. transformers/models/timm_backbone/modeling_timm_backbone.py +4 -6
  1335. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1336. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1337. transformers/models/timm_wrapper/modeling_timm_wrapper.py +13 -15
  1338. transformers/models/trocr/configuration_trocr.py +0 -1
  1339. transformers/models/trocr/modeling_trocr.py +38 -40
  1340. transformers/models/trocr/processing_trocr.py +5 -25
  1341. transformers/models/tvp/configuration_tvp.py +0 -1
  1342. transformers/models/tvp/image_processing_tvp.py +50 -52
  1343. transformers/models/tvp/image_processing_tvp_fast.py +9 -10
  1344. transformers/models/tvp/modeling_tvp.py +25 -27
  1345. transformers/models/tvp/processing_tvp.py +2 -14
  1346. transformers/models/udop/configuration_udop.py +0 -1
  1347. transformers/models/udop/modeling_udop.py +63 -66
  1348. transformers/models/udop/processing_udop.py +7 -26
  1349. transformers/models/udop/tokenization_udop.py +80 -93
  1350. transformers/models/umt5/configuration_umt5.py +0 -1
  1351. transformers/models/umt5/modeling_umt5.py +80 -81
  1352. transformers/models/unispeech/configuration_unispeech.py +0 -1
  1353. transformers/models/unispeech/modeling_unispeech.py +47 -49
  1354. transformers/models/unispeech/modular_unispeech.py +20 -22
  1355. transformers/models/unispeech_sat/configuration_unispeech_sat.py +0 -1
  1356. transformers/models/unispeech_sat/modeling_unispeech_sat.py +63 -65
  1357. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1358. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1359. transformers/models/univnet/modeling_univnet.py +7 -8
  1360. transformers/models/upernet/configuration_upernet.py +0 -1
  1361. transformers/models/upernet/modeling_upernet.py +10 -13
  1362. transformers/models/vaultgemma/__init__.py +0 -1
  1363. transformers/models/vaultgemma/configuration_vaultgemma.py +24 -26
  1364. transformers/models/vaultgemma/modeling_vaultgemma.py +34 -36
  1365. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1366. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1367. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +8 -8
  1368. transformers/models/video_llama_3/modeling_video_llama_3.py +66 -66
  1369. transformers/models/video_llama_3/modular_video_llama_3.py +101 -112
  1370. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1371. transformers/models/video_llama_3/video_processing_video_llama_3.py +18 -18
  1372. transformers/models/video_llava/configuration_video_llava.py +0 -1
  1373. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1374. transformers/models/video_llava/modeling_video_llava.py +52 -54
  1375. transformers/models/video_llava/processing_video_llava.py +38 -78
  1376. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1377. transformers/models/videomae/configuration_videomae.py +0 -1
  1378. transformers/models/videomae/image_processing_videomae.py +31 -34
  1379. transformers/models/videomae/modeling_videomae.py +13 -15
  1380. transformers/models/videomae/video_processing_videomae.py +0 -1
  1381. transformers/models/vilt/configuration_vilt.py +0 -1
  1382. transformers/models/vilt/image_processing_vilt.py +29 -30
  1383. transformers/models/vilt/image_processing_vilt_fast.py +9 -10
  1384. transformers/models/vilt/modeling_vilt.py +76 -78
  1385. transformers/models/vilt/processing_vilt.py +2 -14
  1386. transformers/models/vipllava/configuration_vipllava.py +0 -1
  1387. transformers/models/vipllava/modeling_vipllava.py +38 -39
  1388. transformers/models/vipllava/modular_vipllava.py +30 -32
  1389. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1390. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +18 -21
  1391. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1392. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +18 -21
  1393. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1394. transformers/models/visual_bert/configuration_visual_bert.py +0 -1
  1395. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1396. transformers/models/vit/configuration_vit.py +0 -1
  1397. transformers/models/vit/image_processing_vit.py +19 -22
  1398. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1399. transformers/models/vit/modeling_vit.py +13 -15
  1400. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1401. transformers/models/vit_mae/modeling_vit_mae.py +21 -23
  1402. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1403. transformers/models/vit_msn/modeling_vit_msn.py +10 -12
  1404. transformers/models/vitdet/configuration_vitdet.py +0 -1
  1405. transformers/models/vitdet/modeling_vitdet.py +12 -14
  1406. transformers/models/vitmatte/configuration_vitmatte.py +1 -4
  1407. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1408. transformers/models/vitmatte/image_processing_vitmatte_fast.py +14 -15
  1409. transformers/models/vitmatte/modeling_vitmatte.py +9 -11
  1410. transformers/models/vitpose/configuration_vitpose.py +3 -6
  1411. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1412. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1413. transformers/models/vitpose/modeling_vitpose.py +10 -12
  1414. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +0 -1
  1415. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +8 -10
  1416. transformers/models/vits/configuration_vits.py +0 -1
  1417. transformers/models/vits/modeling_vits.py +34 -35
  1418. transformers/models/vits/tokenization_vits.py +3 -4
  1419. transformers/models/vivit/configuration_vivit.py +0 -1
  1420. transformers/models/vivit/image_processing_vivit.py +36 -39
  1421. transformers/models/vivit/modeling_vivit.py +5 -7
  1422. transformers/models/vjepa2/__init__.py +0 -1
  1423. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1424. transformers/models/vjepa2/modeling_vjepa2.py +30 -32
  1425. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1426. transformers/models/voxtral/__init__.py +0 -1
  1427. transformers/models/voxtral/configuration_voxtral.py +0 -1
  1428. transformers/models/voxtral/modeling_voxtral.py +17 -25
  1429. transformers/models/voxtral/modular_voxtral.py +10 -19
  1430. transformers/models/voxtral/processing_voxtral.py +25 -48
  1431. transformers/models/wav2vec2/configuration_wav2vec2.py +0 -1
  1432. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1433. transformers/models/wav2vec2/modeling_wav2vec2.py +67 -122
  1434. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1435. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1436. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +0 -1
  1437. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1438. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1439. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1440. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +0 -1
  1441. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1442. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1443. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1444. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1445. transformers/models/wavlm/configuration_wavlm.py +0 -1
  1446. transformers/models/wavlm/modeling_wavlm.py +45 -48
  1447. transformers/models/wavlm/modular_wavlm.py +4 -5
  1448. transformers/models/whisper/configuration_whisper.py +0 -1
  1449. transformers/models/whisper/english_normalizer.py +3 -4
  1450. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1451. transformers/models/whisper/generation_whisper.py +26 -48
  1452. transformers/models/whisper/modeling_whisper.py +68 -70
  1453. transformers/models/whisper/processing_whisper.py +3 -20
  1454. transformers/models/whisper/tokenization_whisper.py +9 -30
  1455. transformers/models/x_clip/configuration_x_clip.py +0 -1
  1456. transformers/models/x_clip/modeling_x_clip.py +68 -69
  1457. transformers/models/x_clip/processing_x_clip.py +2 -14
  1458. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1459. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1460. transformers/models/xglm/configuration_xglm.py +0 -1
  1461. transformers/models/xglm/modeling_xglm.py +49 -55
  1462. transformers/models/xglm/tokenization_xglm.py +1 -4
  1463. transformers/models/xlm/configuration_xlm.py +0 -1
  1464. transformers/models/xlm/modeling_xlm.py +126 -130
  1465. transformers/models/xlm/tokenization_xlm.py +3 -5
  1466. transformers/models/xlm_roberta/configuration_xlm_roberta.py +0 -1
  1467. transformers/models/xlm_roberta/modeling_xlm_roberta.py +90 -92
  1468. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1469. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1470. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +0 -1
  1471. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +91 -93
  1472. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1473. transformers/models/xlnet/configuration_xlnet.py +0 -11
  1474. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1475. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1476. transformers/models/xlstm/configuration_xlstm.py +3 -5
  1477. transformers/models/xlstm/modeling_xlstm.py +62 -65
  1478. transformers/models/xmod/configuration_xmod.py +0 -1
  1479. transformers/models/xmod/modeling_xmod.py +98 -100
  1480. transformers/models/yolos/configuration_yolos.py +0 -1
  1481. transformers/models/yolos/image_processing_yolos.py +60 -62
  1482. transformers/models/yolos/image_processing_yolos_fast.py +18 -18
  1483. transformers/models/yolos/modeling_yolos.py +12 -14
  1484. transformers/models/yolos/modular_yolos.py +2 -4
  1485. transformers/models/yoso/configuration_yoso.py +0 -1
  1486. transformers/models/yoso/modeling_yoso.py +60 -62
  1487. transformers/models/zamba/configuration_zamba.py +0 -1
  1488. transformers/models/zamba/modeling_zamba.py +68 -69
  1489. transformers/models/zamba2/configuration_zamba2.py +36 -37
  1490. transformers/models/zamba2/modeling_zamba2.py +84 -87
  1491. transformers/models/zamba2/modular_zamba2.py +43 -45
  1492. transformers/models/zoedepth/configuration_zoedepth.py +0 -1
  1493. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1494. transformers/models/zoedepth/image_processing_zoedepth_fast.py +11 -12
  1495. transformers/models/zoedepth/modeling_zoedepth.py +14 -16
  1496. transformers/pipelines/__init__.py +50 -49
  1497. transformers/pipelines/any_to_any.py +14 -22
  1498. transformers/pipelines/audio_utils.py +1 -2
  1499. transformers/pipelines/base.py +12 -16
  1500. transformers/pipelines/deprecated/__init__.py +0 -1
  1501. transformers/pipelines/image_text_to_text.py +0 -1
  1502. transformers/pipelines/image_to_text.py +4 -44
  1503. transformers/pipelines/question_answering.py +4 -43
  1504. transformers/pipelines/text_classification.py +1 -14
  1505. transformers/pipelines/token_classification.py +1 -22
  1506. transformers/pipelines/video_classification.py +1 -9
  1507. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1508. transformers/pipelines/zero_shot_classification.py +0 -6
  1509. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1510. transformers/processing_utils.py +95 -95
  1511. transformers/quantizers/base.py +10 -0
  1512. transformers/quantizers/quantizer_quark.py +0 -1
  1513. transformers/quantizers/quantizer_torchao.py +3 -3
  1514. transformers/testing_utils.py +3 -37
  1515. transformers/tokenization_mistral_common.py +554 -903
  1516. transformers/tokenization_utils_base.py +109 -122
  1517. transformers/tokenization_utils_sentencepiece.py +5 -6
  1518. transformers/tokenization_utils_tokenizers.py +5 -5
  1519. transformers/trainer.py +6 -9
  1520. transformers/trainer_jit_checkpoint.py +1 -2
  1521. transformers/training_args.py +3 -3
  1522. transformers/utils/attention_visualizer.py +1 -1
  1523. transformers/utils/auto_docstring.py +564 -12
  1524. transformers/utils/doc.py +1 -1
  1525. transformers/utils/dummy_pt_objects.py +0 -42
  1526. transformers/utils/generic.py +1 -1
  1527. transformers/utils/loading_report.py +3 -3
  1528. transformers/utils/quantization_config.py +8 -10
  1529. transformers/video_processing_utils.py +19 -20
  1530. transformers/video_utils.py +18 -22
  1531. {transformers-5.0.0rc2.dist-info → transformers-5.0.0rc3.dist-info}/METADATA +19 -19
  1532. transformers-5.0.0rc3.dist-info/RECORD +2067 -0
  1533. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1534. {transformers-5.0.0rc2.dist-info → transformers-5.0.0rc3.dist-info}/WHEEL +0 -0
  1535. {transformers-5.0.0rc2.dist-info → transformers-5.0.0rc3.dist-info}/entry_points.txt +0 -0
  1536. {transformers-5.0.0rc2.dist-info → transformers-5.0.0rc3.dist-info}/licenses/LICENSE +0 -0
  1537. {transformers-5.0.0rc2.dist-info → transformers-5.0.0rc3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1590 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/glm_image/modular_glm_image.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_glm_image.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2025 the HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ from collections.abc import Callable
22
+ from dataclasses import dataclass
23
+ from typing import Any, Optional
24
+
25
+ import torch.nn as nn
26
+ import torch.nn.functional as F
27
+
28
+ from ...activations import ACT2FN
29
+ from ...cache_utils import Cache, DynamicCache
30
+ from ...generation import GenerationMixin
31
+ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
32
+ from ...masking_utils import create_causal_mask
33
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
+ from ...modeling_layers import GradientCheckpointingLayer
35
+ from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
36
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
+ from ...processing_utils import Unpack
39
+ from ...utils import TransformersKwargs, auto_docstring, can_return_tuple, is_torch_available
40
+ from ...utils.generic import check_model_inputs, maybe_autocast
41
+ from .configuration_glm_image import GlmImageConfig, GlmImageTextConfig, GlmImageVisionConfig, GlmImageVQVAEConfig
42
+
43
+
44
+ if is_torch_available():
45
+ import torch
46
+
47
+
48
+ class GlmImageVisionMLP(nn.Module):
49
+ def __init__(self, config):
50
+ super().__init__()
51
+ self.config = config
52
+ self.activation_fn = ACT2FN[config.hidden_act]
53
+ self.fc1 = nn.Linear(config.hidden_size, config.intermediate_size)
54
+ self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
55
+
56
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
57
+ hidden_states = self.fc1(hidden_states)
58
+ hidden_states = self.activation_fn(hidden_states)
59
+ hidden_states = self.fc2(hidden_states)
60
+ return hidden_states
61
+
62
+
63
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
64
+ """
65
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
66
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
67
+ """
68
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
69
+ if n_rep == 1:
70
+ return hidden_states
71
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
72
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
73
+
74
+
75
+ def eager_attention_forward(
76
+ module: nn.Module,
77
+ query: torch.Tensor,
78
+ key: torch.Tensor,
79
+ value: torch.Tensor,
80
+ attention_mask: torch.Tensor | None,
81
+ scaling: float,
82
+ dropout: float = 0.0,
83
+ **kwargs: Unpack[TransformersKwargs],
84
+ ):
85
+ key_states = repeat_kv(key, module.num_key_value_groups)
86
+ value_states = repeat_kv(value, module.num_key_value_groups)
87
+
88
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
89
+ if attention_mask is not None:
90
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
91
+ attn_weights = attn_weights + causal_mask
92
+
93
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
94
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
95
+ attn_output = torch.matmul(attn_weights, value_states)
96
+ attn_output = attn_output.transpose(1, 2).contiguous()
97
+
98
+ return attn_output, attn_weights
99
+
100
+
101
+ class GlmImageVisionAttention(nn.Module):
102
+ def __init__(self, config: GlmImageVisionConfig) -> None:
103
+ super().__init__()
104
+ self.dim = config.hidden_size
105
+ self.num_heads = config.num_heads
106
+ self.head_dim = self.dim // self.num_heads
107
+ self.num_key_value_groups = 1 # needed for eager attention
108
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
109
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
110
+ self.scaling = self.head_dim**-0.5
111
+ self.config = config
112
+ self.attention_dropout = config.attention_dropout
113
+ self.is_causal = False
114
+
115
+ def forward(
116
+ self,
117
+ hidden_states: torch.Tensor,
118
+ cu_seqlens: torch.Tensor,
119
+ **kwargs,
120
+ ) -> torch.Tensor:
121
+ seq_length = hidden_states.shape[0]
122
+ query_states, key_states, value_states = (
123
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
124
+ )
125
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
126
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
127
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
128
+
129
+ attention_interface: Callable = eager_attention_forward
130
+ if self.config._attn_implementation != "eager":
131
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
132
+
133
+ if "flash" in self.config._attn_implementation:
134
+ # Flash Attention: Use cu_seqlens for variable length attention
135
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
136
+ attn_output, _ = attention_interface(
137
+ self,
138
+ query_states,
139
+ key_states,
140
+ value_states,
141
+ attention_mask=None,
142
+ scaling=self.scaling,
143
+ dropout=0.0 if not self.training else self.attention_dropout,
144
+ cu_seq_lens_q=cu_seqlens,
145
+ cu_seq_lens_k=cu_seqlens,
146
+ max_length_q=max_seqlen,
147
+ max_length_k=max_seqlen,
148
+ is_causal=False,
149
+ **kwargs,
150
+ )
151
+ else:
152
+ # Other implementations: Process each chunk separately
153
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
154
+ splits = [
155
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
156
+ ]
157
+
158
+ attn_outputs = [
159
+ attention_interface(
160
+ self,
161
+ q,
162
+ k,
163
+ v,
164
+ attention_mask=None,
165
+ scaling=self.scaling,
166
+ dropout=0.0 if not self.training else self.attention_dropout,
167
+ is_causal=False,
168
+ **kwargs,
169
+ )[0]
170
+ for q, k, v in zip(*splits)
171
+ ]
172
+ attn_output = torch.cat(attn_outputs, dim=1)
173
+
174
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
175
+ attn_output = self.proj(attn_output)
176
+ return attn_output
177
+
178
+
179
+ class GlmImageVisionPatchEmbed(nn.Module):
180
+ def __init__(self, config: GlmImageVisionConfig) -> None:
181
+ super().__init__()
182
+ self.patch_size = config.patch_size
183
+ self.in_channels = config.in_channels
184
+ self.embed_dim = config.hidden_size
185
+ kernel_size = [self.patch_size, self.patch_size]
186
+ self.proj = nn.Conv2d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
187
+
188
+ def forward(self, hidden_states) -> torch.Tensor:
189
+ target_dtype = self.proj.weight.dtype
190
+ hidden_states = hidden_states.view(-1, self.in_channels, self.patch_size, self.patch_size)
191
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
192
+ return hidden_states
193
+
194
+
195
+ class GlmImageVisionEmbeddings(nn.Module):
196
+ def __init__(self, config: GlmImageVisionConfig) -> None:
197
+ super().__init__()
198
+ self.config = config
199
+ self.embed_dim = config.hidden_size
200
+ self.image_size = config.image_size
201
+ self.patch_size = config.patch_size
202
+
203
+ self.num_patches = (self.image_size // self.patch_size) ** 2
204
+ self.num_positions = self.num_patches
205
+ self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
206
+ self.interpolated_method = "bilinear"
207
+
208
+ def forward(self, embeddings, lengths, image_shapes, h_coords, w_coords) -> torch.Tensor:
209
+ """
210
+ Forward pass with integrated position encoding adaptation using 2D interpolation.
211
+
212
+ Args:
213
+ embeddings: Input embeddings tensor
214
+ lengths (torch.Tensor): Sequence lengths for each image in the batch.
215
+ image_shapes (torch.Tensor): Tensor of shape [batch_size, 3] representing the image shapes (t, h, w).
216
+ h_coords (torch.Tensor): Tensor of shape [total_seq] representing the h coordinate for each patch.
217
+ w_coords (torch.Tensor): Tensor of shape [total_seq] representing the w coordinate for each patch.
218
+
219
+ Returns:
220
+ torch.Tensor: Embeddings with adapted position encoding added.
221
+ """
222
+ # Get position embedding parameters
223
+ pos_embed_weight = self.position_embedding.weight
224
+ hidden_size = pos_embed_weight.shape[1]
225
+ device = pos_embed_weight.device
226
+
227
+ # Convert inputs to tensors if needed
228
+ if isinstance(lengths, list):
229
+ lengths = torch.tensor(lengths, device=device, dtype=torch.long)
230
+
231
+ # Prepare 2D position embedding
232
+ orig_size_sq = pos_embed_weight.shape[0]
233
+ orig_size = int(orig_size_sq**0.5)
234
+ pos_embed_2d = (
235
+ pos_embed_weight.view(orig_size, orig_size, hidden_size)
236
+ .permute(2, 0, 1)
237
+ .unsqueeze(0)
238
+ .to(device=device, dtype=torch.float32)
239
+ )
240
+
241
+ # Calculate target dimensions for each patch
242
+ target_h = torch.cat([image_shapes[i, 1].repeat(lengths[i]) for i in range(len(lengths))]).to(
243
+ device=device, dtype=torch.float32
244
+ )
245
+ target_w = torch.cat([image_shapes[i, 2].repeat(lengths[i]) for i in range(len(lengths))]).to(
246
+ device=device, dtype=torch.float32
247
+ )
248
+
249
+ # Normalize coordinates to [-1, 1] range for grid_sample
250
+ norm_w = ((w_coords + 0.5) / target_w) * 2 - 1
251
+ norm_h = ((h_coords + 0.5) / target_h) * 2 - 1
252
+
253
+ # Create sampling grid
254
+ grid = torch.stack((norm_w, norm_h), dim=-1).unsqueeze(0).unsqueeze(2)
255
+
256
+ # Perform bicubic interpolation
257
+ interpolated_embed_fp32 = F.grid_sample(
258
+ pos_embed_2d, grid, mode=self.interpolated_method, align_corners=False, padding_mode="border"
259
+ )
260
+
261
+ # Reshape and convert back to original dtype
262
+ adapted_pos_embed_fp32 = interpolated_embed_fp32.squeeze(0).squeeze(-1).permute(1, 0)
263
+ adapted_pos_embed = adapted_pos_embed_fp32.to(pos_embed_weight.dtype).to(embeddings.device)
264
+
265
+ # Add adapted position encoding to embeddings
266
+ embeddings = embeddings + adapted_pos_embed
267
+ return embeddings
268
+
269
+
270
+ class GlmImageVisionBlock(GradientCheckpointingLayer):
271
+ def __init__(self, config: GlmImageVisionConfig) -> None:
272
+ super().__init__()
273
+ self.norm1 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
274
+ self.norm2 = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
275
+ self.attn = GlmImageVisionAttention(config)
276
+ self.mlp = GlmImageVisionMLP(config)
277
+
278
+ def forward(
279
+ self,
280
+ hidden_states: torch.Tensor,
281
+ cu_seqlens: torch.Tensor,
282
+ **kwargs: Unpack[TransformersKwargs],
283
+ ) -> torch.Tensor:
284
+ r"""
285
+ cu_seqlens (`torch.Tensor` of shape `(num_images_or_videos + 1,)`):
286
+ The cumulative sequence lengths of each image or video feature.
287
+ position_embeddings (`tuple(torch.Tensor, torch.Tensor)` of shape `(num_patches, head_dim // 2)`):
288
+ The cosine and sine position embeddings for vision attention.
289
+ """
290
+ residual = hidden_states
291
+
292
+ hidden_states = self.norm1(hidden_states)
293
+ hidden_states = self.attn(
294
+ hidden_states,
295
+ cu_seqlens=cu_seqlens,
296
+ **kwargs,
297
+ )
298
+ hidden_states = residual + hidden_states
299
+
300
+ residual = hidden_states
301
+ hidden_states = self.norm2(hidden_states)
302
+ hidden_states = self.mlp(hidden_states)
303
+ hidden_states = residual + hidden_states
304
+
305
+ return hidden_states
306
+
307
+
308
+ def rotate_half(x):
309
+ """Rotates half the hidden dims of the input."""
310
+ x1 = x[..., : x.shape[-1] // 2]
311
+ x2 = x[..., x.shape[-1] // 2 :]
312
+ return torch.cat((-x2, x1), dim=-1)
313
+
314
+
315
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
316
+ """Applies Rotary Position Embedding to the query and key tensors.
317
+
318
+ Args:
319
+ q (`torch.Tensor`): The query tensor.
320
+ k (`torch.Tensor`): The key tensor.
321
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
322
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
323
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
324
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
325
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
326
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
327
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
328
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
329
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
330
+ Returns:
331
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
332
+ """
333
+ cos = cos.unsqueeze(unsqueeze_dim)
334
+ sin = sin.unsqueeze(unsqueeze_dim)
335
+
336
+ # Keep half or full tensor for later concatenation
337
+ rotary_dim = cos.shape[-1]
338
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
339
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
340
+
341
+ # Apply rotary embeddings on the first half or full tensor
342
+ q_embed = (q_rot * cos) + (rotate_half(q_rot) * sin)
343
+ k_embed = (k_rot * cos) + (rotate_half(k_rot) * sin)
344
+
345
+ # Concatenate back to full shape
346
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
347
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
348
+ return q_embed, k_embed
349
+
350
+
351
+ @use_kernelized_func(apply_rotary_pos_emb)
352
+ class GlmImageTextAttention(nn.Module):
353
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
354
+
355
+ def __init__(self, config: GlmImageTextConfig, layer_idx: int | None = None):
356
+ super().__init__()
357
+ self.config = config
358
+ self.layer_idx = layer_idx
359
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
360
+ self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
361
+ self.scaling = self.head_dim**-0.5
362
+ self.attention_dropout = config.attention_dropout
363
+ self.is_causal = True
364
+
365
+ self.q_proj = nn.Linear(
366
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
367
+ )
368
+ self.k_proj = nn.Linear(
369
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
370
+ )
371
+ self.v_proj = nn.Linear(
372
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
373
+ )
374
+ self.o_proj = nn.Linear(config.num_attention_heads * self.head_dim, config.hidden_size, bias=False)
375
+ self.rope_parameters = config.rope_parameters
376
+
377
+ def forward(
378
+ self,
379
+ hidden_states: torch.Tensor,
380
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
381
+ attention_mask: torch.Tensor | None,
382
+ past_key_values: Cache | None = None,
383
+ cache_position: torch.LongTensor | None = None,
384
+ **kwargs: Unpack[FlashAttentionKwargs],
385
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
386
+ input_shape = hidden_states.shape[:-1]
387
+ hidden_shape = (*input_shape, -1, self.head_dim)
388
+
389
+ query_states = self.q_proj(hidden_states).view(hidden_shape)
390
+ key_states = self.k_proj(hidden_states).view(hidden_shape)
391
+ value_states = self.v_proj(hidden_states).view(hidden_shape)
392
+
393
+ query_states = query_states.transpose(1, 2)
394
+ key_states = key_states.transpose(1, 2)
395
+ value_states = value_states.transpose(1, 2)
396
+
397
+ cos, sin = position_embeddings
398
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
399
+
400
+ if past_key_values is not None:
401
+ # sin and cos are specific to RoPE models; position_ids needed for the static cache
402
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
403
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
404
+
405
+ attention_interface: Callable = eager_attention_forward
406
+ if self.config._attn_implementation != "eager":
407
+ attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
408
+
409
+ attn_output, attn_weights = attention_interface(
410
+ self,
411
+ query_states,
412
+ key_states,
413
+ value_states,
414
+ attention_mask,
415
+ dropout=0.0 if not self.training else self.attention_dropout,
416
+ scaling=self.scaling,
417
+ **kwargs,
418
+ )
419
+
420
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
421
+ attn_output = self.o_proj(attn_output)
422
+ return attn_output, attn_weights
423
+
424
+
425
+ @use_kernel_forward_from_hub("RMSNorm")
426
+ class GlmImageRMSNorm(nn.Module):
427
+ def __init__(self, hidden_size, eps=1e-6):
428
+ """
429
+ GlmImageRMSNorm is equivalent to T5LayerNorm
430
+ """
431
+ super().__init__()
432
+ self.weight = nn.Parameter(torch.ones(hidden_size))
433
+ self.variance_epsilon = eps
434
+
435
+ def forward(self, hidden_states):
436
+ input_dtype = hidden_states.dtype
437
+ hidden_states = hidden_states.to(torch.float32)
438
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
439
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
440
+ return self.weight * hidden_states.to(input_dtype)
441
+
442
+ def extra_repr(self):
443
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
444
+
445
+
446
+ class GlmImageTextMLP(nn.Module):
447
+ def __init__(self, config):
448
+ super().__init__()
449
+
450
+ self.config = config
451
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
452
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
453
+ self.activation_fn = ACT2FN[config.hidden_act]
454
+
455
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
456
+ up_states = self.gate_up_proj(hidden_states)
457
+
458
+ gate, up_states = up_states.chunk(2, dim=-1)
459
+ up_states = up_states * self.activation_fn(gate)
460
+
461
+ return self.down_proj(up_states)
462
+
463
+
464
+ class GlmImageTextDecoderLayer(GradientCheckpointingLayer):
465
+ def __init__(self, config: GlmImageTextConfig, layer_idx: int):
466
+ super().__init__()
467
+ self.hidden_size = config.hidden_size
468
+ self.self_attn = GlmImageTextAttention(config, layer_idx)
469
+ self.mlp = GlmImageTextMLP(config)
470
+ self.input_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
471
+ self.post_attention_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
472
+ self.post_self_attn_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
473
+ self.post_mlp_layernorm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
474
+
475
+ @auto_docstring
476
+ def forward(
477
+ self,
478
+ hidden_states: torch.Tensor,
479
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
480
+ attention_mask: torch.Tensor | None = None,
481
+ position_ids: torch.LongTensor | None = None,
482
+ past_key_values: Cache | None = None,
483
+ use_cache: bool | None = False,
484
+ cache_position: torch.LongTensor | None = None,
485
+ **kwargs,
486
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
487
+ residual = hidden_states
488
+
489
+ hidden_states = self.input_layernorm(hidden_states)
490
+
491
+ # Self Attention
492
+ hidden_states, _ = self.self_attn(
493
+ hidden_states=hidden_states,
494
+ position_embeddings=position_embeddings,
495
+ attention_mask=attention_mask,
496
+ position_ids=position_ids,
497
+ past_key_values=past_key_values,
498
+ use_cache=use_cache,
499
+ cache_position=cache_position,
500
+ **kwargs,
501
+ )
502
+
503
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
504
+ hidden_states = residual + hidden_states
505
+
506
+ # Fully Connected
507
+ residual = hidden_states
508
+ hidden_states = self.post_attention_layernorm(hidden_states)
509
+ hidden_states = self.mlp(hidden_states)
510
+ hidden_states = self.post_mlp_layernorm(hidden_states)
511
+ hidden_states = residual + hidden_states
512
+
513
+ return hidden_states
514
+
515
+
516
+ @auto_docstring
517
+ class GlmImagePreTrainedModel(PreTrainedModel):
518
+ config: GlmImageConfig
519
+ base_model_prefix = "model"
520
+ input_modalities = ("image", "text")
521
+ supports_gradient_checkpointing = True
522
+ _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]
523
+ _skip_keys_device_placement = "past_key_values"
524
+ _supports_flash_attn = True
525
+ _supports_sdpa = True
526
+
527
+ _can_compile_fullgraph = True
528
+ _supports_attention_backend = True
529
+ _can_record_outputs = {
530
+ "hidden_states": GlmImageTextDecoderLayer,
531
+ "attentions": GlmImageTextAttention,
532
+ }
533
+
534
+ @torch.no_grad()
535
+ def _init_weights(self, module):
536
+ super()._init_weights(module)
537
+
538
+
539
+ @dataclass
540
+ @auto_docstring(
541
+ custom_intro="""
542
+ Base class for Llava outputs, with hidden states and attentions.
543
+ """
544
+ )
545
+ class GlmImageModelOutputWithPast(ModelOutput):
546
+ r"""
547
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
548
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
549
+
550
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
551
+ `past_key_values` input) to speed up sequential decoding.
552
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
553
+ The rope index difference between sequence length and multimodal rope.
554
+ """
555
+
556
+ last_hidden_state: torch.FloatTensor | None = None
557
+ past_key_values: Cache | None = None
558
+ hidden_states: tuple[torch.FloatTensor] | None = None
559
+ attentions: tuple[torch.FloatTensor] | None = None
560
+ rope_deltas: torch.LongTensor | None = None
561
+
562
+
563
+ class GlmImageVQVAEVectorQuantizer(nn.Module):
564
+ """
565
+ A module for vector quantization using learned embedding vectors.
566
+
567
+ This module implements the quantization process similar to te one described in
568
+ the VQ-VAE (Vector Quantized Variational AutoEncoder) paper. It quantizes continuous
569
+ input vectors into discrete codebook vectors, which are learned during training.
570
+ Current implementation improves over previous ones by avoiding costly matrix multiplications
571
+ and allowing for post-hoc remapping of indices.
572
+ """
573
+
574
+ def __init__(self, config: GlmImageVQVAEConfig):
575
+ super().__init__()
576
+ self.num_embeddings = config.num_embeddings
577
+ self.embedding_dim = config.embed_dim
578
+ self.beta = getattr(config, "beta", 0.25)
579
+
580
+ self.embedding = nn.Embedding(self.num_embeddings, self.embedding_dim)
581
+
582
+ def forward(self, hidden_state: torch.Tensor):
583
+ hidden_state = hidden_state.permute(0, 2, 3, 1).contiguous()
584
+ hidden_state_flattened = hidden_state.view(-1, self.embedding_dim)
585
+
586
+ # L2 normalize
587
+ hidden_state = F.normalize(hidden_state, p=2, dim=-1)
588
+ hidden_state_flattened = F.normalize(hidden_state_flattened, p=2, dim=-1)
589
+ embedding = F.normalize(self.embedding.weight, p=2, dim=-1)
590
+
591
+ # distances from z to embeddings e_j (z - e)^2 = z^2 + e^2 - 2 e * z
592
+ distances = (
593
+ torch.sum(hidden_state_flattened**2, dim=1, keepdim=True)
594
+ + torch.sum(embedding**2, dim=1)
595
+ - 2 * torch.einsum("bd,dn->bn", hidden_state_flattened, embedding.transpose(0, 1))
596
+ )
597
+
598
+ min_encoding_indices = torch.argmin(distances, dim=1)
599
+ hidden_state_quant = embedding[min_encoding_indices].view(hidden_state.shape)
600
+
601
+ # compute loss for embedding
602
+ loss = torch.mean((hidden_state_quant.detach() - hidden_state) ** 2) + self.beta * torch.mean(
603
+ (hidden_state_quant - hidden_state.detach()) ** 2
604
+ )
605
+
606
+ # preserve gradients
607
+ hidden_state_quant = hidden_state + (hidden_state_quant - hidden_state).detach()
608
+
609
+ # reshape back to match original input shape
610
+ hidden_state_quant = hidden_state_quant.permute(0, 3, 1, 2).contiguous()
611
+
612
+ return hidden_state_quant, loss, min_encoding_indices
613
+
614
+
615
+ @auto_docstring(
616
+ custom_intro="""
617
+ The VQ-VAE model used in GlmImage for encoding/decoding images into discrete tokens.
618
+ This model follows the "Make-a-scene: Scene-based text-to-image generation with human priors" paper from
619
+ [ Oran Gafni, Adam Polyak, Oron Ashual, Shelly Sheynin, Devi Parikh, and Yaniv
620
+ Taigman](https://huggingface.co/papers/2203.13131).
621
+ """
622
+ )
623
+ class GlmImageVQVAE(GlmImagePreTrainedModel):
624
+ config: GlmImageVQVAEConfig
625
+ _no_split_modules = [
626
+ "GlmImageVQVAEVectorQuantizer",
627
+ ]
628
+
629
+ def __init__(self, config: GlmImageVQVAEConfig):
630
+ super().__init__(config)
631
+ self.quantize = GlmImageVQVAEVectorQuantizer(config)
632
+ self.quant_conv = torch.nn.Conv2d(config.latent_channels, config.embed_dim, 1)
633
+ self.post_quant_conv = torch.nn.Conv2d(config.embed_dim, config.latent_channels, 1)
634
+ self.eval() # GlmImage's VQ model is frozen
635
+ self.post_init()
636
+
637
+ def encode(self, hidden_states):
638
+ hidden_states = self.quant_conv(hidden_states)
639
+ quant, emb_loss, indices = self.quantize(hidden_states)
640
+ return quant, emb_loss, indices
641
+
642
+
643
+ class GlmImageVisionModel(GlmImagePreTrainedModel):
644
+ config: GlmImageVisionConfig
645
+ input_modalities = ("image",)
646
+ _no_split_modules = ["GlmImageVisionBlock"]
647
+ main_input_name = "pixel_values"
648
+
649
+ def __init__(self, config: GlmImageVisionConfig) -> None:
650
+ super().__init__(config)
651
+ self.spatial_merge_size = config.spatial_merge_size
652
+ self.patch_size = config.patch_size
653
+
654
+ self.embeddings = GlmImageVisionEmbeddings(config)
655
+ self.patch_embed = GlmImageVisionPatchEmbed(config)
656
+
657
+ head_dim = config.hidden_size // config.num_heads
658
+
659
+ self.blocks = nn.ModuleList([GlmImageVisionBlock(config) for _ in range(config.depth)])
660
+
661
+ self.gradient_checkpointing = False
662
+ self.head_dim = head_dim
663
+ self.post_init()
664
+
665
+ def rot_pos_emb(self, grid_thw):
666
+ pos_ids = []
667
+ for t, h, w in grid_thw:
668
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
669
+ hpos_ids = hpos_ids.reshape(
670
+ h // self.spatial_merge_size,
671
+ self.spatial_merge_size,
672
+ w // self.spatial_merge_size,
673
+ self.spatial_merge_size,
674
+ )
675
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
676
+ hpos_ids = hpos_ids.flatten()
677
+
678
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
679
+ wpos_ids = wpos_ids.reshape(
680
+ h // self.spatial_merge_size,
681
+ self.spatial_merge_size,
682
+ w // self.spatial_merge_size,
683
+ self.spatial_merge_size,
684
+ )
685
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
686
+ wpos_ids = wpos_ids.flatten()
687
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
688
+ pos_ids = torch.cat(pos_ids, dim=0)
689
+ return pos_ids
690
+
691
+ def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
692
+ """
693
+ Args:
694
+ pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`):
695
+ Packed pixel values.
696
+ grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
697
+ The temporal, height and width of feature shape of each image.
698
+
699
+ Returns:
700
+ `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states.
701
+ """
702
+
703
+ hidden_states = self.patch_embed(pixel_values)
704
+ image_type_ids = self.rot_pos_emb(grid_thw)
705
+
706
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
707
+ dim=0,
708
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
709
+ )
710
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
711
+ seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
712
+ hidden_states = self.embeddings(
713
+ hidden_states,
714
+ seqlens,
715
+ grid_thw,
716
+ image_type_ids[:, 0].to(hidden_states.device),
717
+ image_type_ids[:, 1].to(hidden_states.device),
718
+ )
719
+
720
+ # Transformer blocks (no position_embeddings needed, already added above)
721
+ for blk in self.blocks:
722
+ hidden_states = blk(
723
+ hidden_states,
724
+ cu_seqlens=cu_seqlens,
725
+ )
726
+ return hidden_states
727
+
728
+
729
+ class GlmImageTextRotaryEmbedding(nn.Module):
730
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
731
+
732
+ def __init__(self, config: GlmImageTextConfig, device=None):
733
+ super().__init__()
734
+ self.max_seq_len_cached = config.max_position_embeddings
735
+ self.original_max_seq_len = config.max_position_embeddings
736
+
737
+ self.config = config
738
+
739
+ self.rope_type = self.config.rope_parameters["rope_type"]
740
+ rope_init_fn: Callable = self.compute_default_rope_parameters
741
+ if self.rope_type != "default":
742
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
743
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
744
+
745
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
746
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
747
+ self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])
748
+
749
+ @staticmethod
750
+ def compute_default_rope_parameters(
751
+ config: GlmImageTextConfig | None = None,
752
+ device: Optional["torch.device"] = None,
753
+ seq_len: int | None = None,
754
+ ) -> tuple["torch.Tensor", float]:
755
+ """
756
+ Computes the inverse frequencies according to the original RoPE implementation
757
+ Args:
758
+ config ([`~transformers.PreTrainedConfig`]):
759
+ The model configuration.
760
+ device (`torch.device`):
761
+ The device to use for initialization of the inverse frequencies.
762
+ seq_len (`int`, *optional*):
763
+ The current sequence length. Unused for this type of RoPE.
764
+ Returns:
765
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
766
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
767
+ """
768
+ base = config.rope_parameters["rope_theta"]
769
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
770
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
771
+ dim = int(head_dim * partial_rotary_factor)
772
+
773
+ attention_factor = 1.0 # Unused in this type of RoPE
774
+
775
+ # Compute the inverse frequencies
776
+ inv_freq = 1.0 / (
777
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
778
+ )
779
+ return inv_freq, attention_factor
780
+
781
+ @torch.no_grad()
782
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
783
+ def forward(self, x, position_ids):
784
+ # In contrast to other models, GLM-V has different position ids for the grids
785
+ # So we expand the inv_freq to shape (3, ...)
786
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
787
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
788
+
789
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
790
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
791
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
792
+ freqs = self.apply_mrope(freqs, self.mrope_section)
793
+ emb = torch.cat((freqs, freqs), dim=-1)
794
+ cos = emb.cos() * self.attention_scaling
795
+ sin = emb.sin() * self.attention_scaling
796
+
797
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
798
+
799
+ def apply_mrope(self, freqs, mrope_section):
800
+ section = mrope_section
801
+ chunks = freqs.split(section, dim=-1)
802
+ result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
803
+ return result
804
+
805
+
806
+ @auto_docstring
807
+ class GlmImageTextModel(GlmImagePreTrainedModel):
808
+ config: GlmImageTextConfig
809
+ input_modalities = ("text",)
810
+
811
+ def __init__(self, config: GlmImageTextConfig):
812
+ super().__init__(config)
813
+ self.padding_idx = config.pad_token_id
814
+ self.vocab_size = config.vocab_size
815
+
816
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
817
+ self.layers = nn.ModuleList(
818
+ [GlmImageTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
819
+ )
820
+ self.norm = GlmImageRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
821
+ self.rotary_emb = GlmImageTextRotaryEmbedding(config=config)
822
+
823
+ self.gradient_checkpointing = False
824
+ # Initialize weights and apply final processing
825
+ self.post_init()
826
+
827
+ @auto_docstring
828
+ @check_model_inputs
829
+ def forward(
830
+ self,
831
+ input_ids: torch.LongTensor | None = None,
832
+ attention_mask: torch.Tensor | None = None,
833
+ position_ids: torch.LongTensor | None = None,
834
+ past_key_values: Cache | None = None,
835
+ inputs_embeds: torch.FloatTensor | None = None,
836
+ use_cache: bool | None = None,
837
+ cache_position: torch.LongTensor | None = None,
838
+ **kwargs: Unpack[FlashAttentionKwargs],
839
+ ) -> tuple | BaseModelOutputWithPast:
840
+ if (input_ids is None) ^ (inputs_embeds is not None):
841
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
842
+
843
+ # torch.jit.trace() doesn't support cache objects in the output
844
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
845
+ past_key_values = DynamicCache(config=self.config)
846
+
847
+ if inputs_embeds is None:
848
+ inputs_embeds = self.embed_tokens(input_ids)
849
+
850
+ if cache_position is None:
851
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
852
+ cache_position = torch.arange(
853
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
854
+ )
855
+
856
+ # the hard coded `3` is for temporal, height and width.
857
+ if position_ids is None:
858
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
859
+ elif position_ids.ndim == 2:
860
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
861
+
862
+ # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
863
+ # where each dim indicates visual spatial positions for temporal/height/width grids.
864
+ # There are two scenarios when FA2-like packed masking might be activated.
865
+ # 1. User specifically passed packed `position_ids` and no attention mask.
866
+ # In this case we expect the useer to create correct position ids for all 3 grids
867
+ # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
868
+ # 2. User runs forward with no attention mask and no position ids. In this case, position ids
869
+ # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
870
+ # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
871
+ # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
872
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
873
+ text_position_ids = position_ids[0]
874
+ position_ids = position_ids[1:]
875
+ else:
876
+ # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
877
+ text_position_ids = None
878
+
879
+ mask_kwargs = {
880
+ "config": self.config,
881
+ "input_embeds": inputs_embeds,
882
+ "attention_mask": attention_mask,
883
+ "cache_position": cache_position,
884
+ "past_key_values": past_key_values,
885
+ "position_ids": text_position_ids,
886
+ }
887
+ # Create the masks
888
+ causal_mask = create_causal_mask(**mask_kwargs)
889
+
890
+ hidden_states = inputs_embeds
891
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
892
+
893
+ for decoder_layer in self.layers:
894
+ layer_outputs = decoder_layer(
895
+ hidden_states,
896
+ attention_mask=causal_mask,
897
+ position_ids=text_position_ids,
898
+ past_key_values=past_key_values,
899
+ cache_position=cache_position,
900
+ position_embeddings=position_embeddings,
901
+ **kwargs,
902
+ )
903
+ hidden_states = layer_outputs
904
+
905
+ hidden_states = self.norm(hidden_states)
906
+
907
+ return BaseModelOutputWithPast(
908
+ last_hidden_state=hidden_states,
909
+ past_key_values=past_key_values,
910
+ )
911
+
912
+
913
+ @auto_docstring
914
+ class GlmImageModel(GlmImagePreTrainedModel):
915
+ base_model_prefix = "model"
916
+ _checkpoint_conversion_mapping = {}
917
+ # Reference: fix gemma3 grad acc #37208
918
+ accepts_loss_kwargs = False
919
+ config: GlmImageConfig
920
+ _no_split_modules = ["GlmImageTextDecoderLayer", "GlmImageVisionBlock"]
921
+
922
+ def __init__(self, config):
923
+ super().__init__(config)
924
+ self.visual = GlmImageVisionModel._from_config(config.vision_config)
925
+ self.language_model = GlmImageTextModel._from_config(config.text_config)
926
+
927
+ self.rope_deltas = None # cache rope_deltas here
928
+ self.vqmodel = GlmImageVQVAE._from_config(config.vq_config)
929
+
930
+ # Initialize weights and apply final processing
931
+ self.post_init()
932
+
933
+ def get_input_embeddings(self):
934
+ return self.language_model.get_input_embeddings()
935
+
936
+ def set_input_embeddings(self, value):
937
+ self.language_model.set_input_embeddings(value)
938
+
939
+ def get_rope_index(
940
+ self,
941
+ input_ids: torch.LongTensor | None = None,
942
+ image_grid_thw: torch.LongTensor | None = None,
943
+ attention_mask: torch.LongTensor | None = None,
944
+ ) -> tuple[torch.Tensor, torch.Tensor]:
945
+ """
946
+ Calculate the 3D rope index for image generation task.
947
+
948
+ Explanation:
949
+ Each embedding sequence may contain image tokens (for generation) and text tokens,
950
+ or just text tokens.
951
+
952
+ Input format:
953
+ - Text-to-Image: [text tokens] + <|dit_token_16384|>
954
+ - Image-to-Image: <|dit_token_16384|> [image tokens] <|dit_token_16385|> + [text tokens] + <|dit_token_16384|>
955
+
956
+ For pure text embedding sequence, the rotary position embedding is the same across all 3 dimensions.
957
+ Examples:
958
+ input_ids: [T T T T T], here T is for text.
959
+ temporal position_ids: [0, 1, 2, 3, 4]
960
+ height position_ids: [0, 1, 2, 3, 4]
961
+ width position_ids: [0, 1, 2, 3, 4]
962
+
963
+ For sequences with image tokens, we use special markers to denote image regions:
964
+ - <|dit_token_16384|>: image start marker
965
+ - <|dit_token_16385|>: image end marker
966
+ - Image tokens between these markers use 2D spatial position encoding.
967
+
968
+ For image tokens:
969
+ - temporal: stays constant at (image_start_pos + 1)
970
+ - height: increments every w tokens, representing row position
971
+ - width: cycles from 0 to w-1, representing column position
972
+
973
+ After each image region, the next position jumps to: image_start_pos + 1 + max(h, w)
974
+ This ensures sufficient positional separation between images and subsequent tokens.
975
+
976
+ Examples:
977
+ === Case 1: Image-to-Image Generation ===
978
+
979
+ Source image with grid [1, 3, 2], followed by text, then generation.
980
+ input_ids: [<|dit_token_16384|> V V V V V V <|dit_token_16385|> T T T T <|dit_token_16384|>]
981
+ image_grid_thw: [[1, 3, 2], [1, 4, 4]] # first is source, second is target
982
+
983
+ For source image (h=3, w=2, 6 tokens):
984
+ Start marker at position 0
985
+ Image tokens at temporal=1, height=[1,1,2,2,3,3], width=[1,2,1,2,1,2]
986
+ End marker at position 4 (= 0 + 1 + max(3,2))
987
+
988
+ Text tokens and trailing start marker continue from position 5.
989
+
990
+ Full prefill position_ids:
991
+ temporal: [0, 1,1,1,1,1,1, 4, 5,6,7,8, 9]
992
+ height: [0, 1,1,2,2,3,3, 4, 5,6,7,8, 9]
993
+ width: [0, 1,2,1,2,1,2, 4, 5,6,7,8, 9]
994
+
995
+ Decode stage: use image_grid_thw[-1] = [1, 4, 4] to build cached position_ids,
996
+ starting from gen_st_idx = 10.
997
+
998
+ === Case 2: Text-to-Image Generation (multi-resolution) ===
999
+
1000
+ Pure text input with two image_grids for progressive generation.
1001
+ input_ids: [hello<sop>3 3<eop><sop>3 2<eop><|dit_token_16384|>]
1002
+ Assume "hello<sop>3 3<eop><sop>3 2<eop>" = 4 tokens (positions 0-3)
1003
+ <|dit_token_16384|> at position 4
1004
+ image_grid_thw: [[1, 3, 3], [1, 3, 2]]
1005
+ - image_grid_thw[-1] = [1, 3, 2]: first generated image (smaller/draft)
1006
+ - image_grid_thw[-2] = [1, 3, 3]: second generated image (larger/final)
1007
+
1008
+ Prefill position_ids (5 tokens: 4 text + 1 start marker):
1009
+ temporal: [0, 1, 2, 3, 4]
1010
+ height: [0, 1, 2, 3, 4]
1011
+ width: [0, 1, 2, 3, 4]
1012
+
1013
+ Decode stage builds position_ids in reverse order of image_grid_thw:
1014
+
1015
+ First: image_grid_thw[-1] = [1, 3, 2] (6 tokens), starting at position 5:
1016
+ temporal: [5, 5, 5, 5, 5, 5]
1017
+ height: [5, 5, 6, 6, 7, 7]
1018
+ width: [5, 6, 5, 6, 5, 6]
1019
+ next_pos = 5 + max(3, 2) = 8
1020
+
1021
+ Then: image_grid_thw[-2] = [1, 3, 3] (9 tokens), starting at position 8:
1022
+ temporal: [8, 8, 8, 8, 8, 8, 8, 8, 8]
1023
+ height: [8, 8, 8, 9, 9, 9, 10, 10, 10]
1024
+ width: [8, 9, 10, 8, 9, 10, 8, 9, 10]
1025
+ next_pos = 8 + max(3, 3) = 11
1026
+
1027
+ Finally: <|dit_token_16385|> end marker at position 11
1028
+
1029
+ Full sequence position_ids (prefill + decode):
1030
+ temporal: [0,1,2,3, 4, 5,5,5,5,5,5, 8,8,8,8,8,8,8,8,8, 11]
1031
+ height: [0,1,2,3, 4, 5,5,6,6,7,7, 8,8,8,9,9,9,10,10,10, 11]
1032
+ width: [0,1,2,3, 4, 5,6,5,6,5,6, 8,9,10,8,9,10,8,9,10, 11]
1033
+
1034
+ _cached_decode_position_ids shape: [3, 6 + 9 + 1] = [3, 16]
1035
+ (includes all generated image tokens + end marker)
1036
+
1037
+ Args:
1038
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1039
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default
1040
+ should you provide it.
1041
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1042
+ The temporal, height and width of feature shape of each image. For image generation,
1043
+ temporal is typically 1.
1044
+ - For image-to-image: includes source image grids + target image grid(s)
1045
+ - For text-to-image with multi-resolution: includes multiple target grids,
1046
+ processed in reverse order (last grid first, second-to-last grid second, etc.)
1047
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1048
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1049
+ - 1 for tokens that are **not masked**,
1050
+ - 0 for tokens that are **masked**.
1051
+
1052
+ Returns:
1053
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`):
1054
+ Position IDs for temporal, height, and width dimensions.
1055
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`):
1056
+ Position deltas for multi-modal rotary position embedding (zeros for this task).
1057
+ """
1058
+
1059
+ batch_size, seq_len = input_ids.shape
1060
+ device = input_ids.device
1061
+ dtype = input_ids.dtype
1062
+
1063
+ image_start_token_id = self.config.image_start_token_id
1064
+ image_end_token_id = self.config.image_end_token_id
1065
+ num_complete_images = (input_ids == image_end_token_id).sum().item()
1066
+
1067
+ position_ids = torch.ones(
1068
+ 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
1069
+ )
1070
+ text_positions = torch.arange(seq_len)[None, :].repeat(3, 1)
1071
+ for batch_idx in range(batch_size):
1072
+ curr_input_ids = input_ids[batch_idx]
1073
+ if attention_mask is not None:
1074
+ curr_input_ids = curr_input_ids[attention_mask[batch_idx] == 1]
1075
+
1076
+ image_end = torch.where(curr_input_ids == image_end_token_id)[0]
1077
+ image_start = torch.where(curr_input_ids == image_start_token_id)[0] + 1
1078
+ current_pos = 0 # track the current position value
1079
+ prev_image_end = 0
1080
+ curr_position_ids = []
1081
+ for start, end, grid in zip(image_start, image_end, image_grid_thw):
1082
+ _, num_width_grid, num_height_grid = grid
1083
+
1084
+ # Create text position ids first if there are text tokens before image
1085
+ llm_pos_length = start - prev_image_end
1086
+ llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(
1087
+ device=input_ids.device
1088
+ )
1089
+ current_pos += llm_position_ids.shape[-1]
1090
+
1091
+ # Now create image position ids for each grid
1092
+ image_seq_length = num_height_grid * num_width_grid
1093
+ h_grids = image_seq_length // num_height_grid + current_pos
1094
+ w_grids = image_seq_length // num_width_grid + current_pos
1095
+ position_width = torch.arange(current_pos, w_grids, device=input_ids.device).repeat(num_width_grid)
1096
+ position_height = torch.arange(current_pos, h_grids, device=input_ids.device).repeat_interleave(
1097
+ num_height_grid
1098
+ )
1099
+ position_temporal = torch.full(
1100
+ (image_seq_length,), current_pos, device=input_ids.device, dtype=torch.long
1101
+ )
1102
+ vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0)
1103
+ current_pos += max(num_height_grid, num_width_grid)
1104
+
1105
+ prev_image_end = end
1106
+ curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1))
1107
+
1108
+ # Add position ids for the last text tokens if any
1109
+ end_position = len(curr_input_ids) - prev_image_end
1110
+ llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=input_ids.device)
1111
+ current_pos += llm_position_ids.shape[-1]
1112
+ curr_position_ids.append(llm_position_ids)
1113
+ curr_position_ids = torch.cat(curr_position_ids, dim=-1)
1114
+ if attention_mask is not None:
1115
+ position_ids[:, batch_idx, attention_mask[batch_idx] == 1] = curr_position_ids.to(position_ids.device)
1116
+ else:
1117
+ position_ids[:, batch_idx, :] = curr_position_ids.to(position_ids.device)
1118
+
1119
+ # Build and store position ids for tokens that will be generated. Later we will just
1120
+ # slice these instead of computing each decoding step
1121
+ self._prefill_len = seq_len
1122
+ if image_grid_thw is not None and len(image_grid_thw) > 0:
1123
+ num_decode_grids = len(image_grid_thw) - num_complete_images
1124
+ num_decode_grids = max(num_decode_grids, 0)
1125
+ decode_pos = current_pos
1126
+
1127
+ decode_temporal_list = []
1128
+ decode_height_list = []
1129
+ decode_width_list = []
1130
+
1131
+ for i in range(1, num_decode_grids + 1):
1132
+ grid_idx = -i
1133
+ h = image_grid_thw[grid_idx, 1].item()
1134
+ w = image_grid_thw[grid_idx, 2].item()
1135
+ total_tokens = h * w
1136
+
1137
+ h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
1138
+ w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()
1139
+
1140
+ decode_temporal_list.append(torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long))
1141
+ decode_height_list.append(decode_pos + h_indices)
1142
+ decode_width_list.append(decode_pos + w_indices)
1143
+ decode_pos = decode_pos + max(h, w)
1144
+
1145
+ decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1146
+ decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1147
+ decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1148
+
1149
+ self._cached_decode_position_ids = torch.stack(
1150
+ [
1151
+ torch.cat(decode_temporal_list, dim=0),
1152
+ torch.cat(decode_height_list, dim=0),
1153
+ torch.cat(decode_width_list, dim=0),
1154
+ ],
1155
+ dim=0,
1156
+ )
1157
+ else:
1158
+ self._cached_decode_position_ids = None
1159
+
1160
+ mrope_position_deltas = torch.zeros([batch_size, 1], dtype=dtype, device=device)
1161
+
1162
+ return position_ids, mrope_position_deltas
1163
+
1164
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
1165
+ """
1166
+ Encodes images into continuous embeddings that can be forwarded to the language model.
1167
+
1168
+ Args:
1169
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1170
+ The tensors corresponding to the input images.
1171
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1172
+ The temporal, height and width of feature shape of each image in LLM.
1173
+ """
1174
+ pixel_values = pixel_values.type(self.visual.dtype)
1175
+ image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
1176
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1177
+ image_embeds = torch.split(image_embeds, split_sizes)
1178
+ return image_embeds
1179
+
1180
+ def get_placeholder_mask(
1181
+ self,
1182
+ input_ids: torch.LongTensor,
1183
+ image_ids: torch.LongTensor,
1184
+ ):
1185
+ """
1186
+ Replace image placeholder tokens in input_ids with actual image token ids from VQVAE.
1187
+
1188
+ Args:
1189
+ input_ids (`torch.LongTensor` of shape `(batch_size, seq_len)`):
1190
+ Input token ids with image placeholders.
1191
+ image_ids (`torch.LongTensor` of shape `(num_images, num_tokens_per_image)` or flattened):
1192
+ Discrete token indices from the VQVAE codebook.
1193
+
1194
+ Returns:
1195
+ special_image_mask (`torch.LongTensor` of shape `(batch_size, seq_len)`):
1196
+ Mask indicating positions in input ids that will be replaced by actual image tokens.
1197
+ """
1198
+
1199
+ special_image_mask = input_ids == self.config.image_token_id
1200
+ n_placeholder_tokens = special_image_mask.sum().item()
1201
+ n_image_tokens = image_ids.shape[0]
1202
+
1203
+ if n_placeholder_tokens != n_image_tokens:
1204
+ raise ValueError(
1205
+ f"Number of image placeholder tokens ({n_placeholder_tokens}) does not match "
1206
+ f"number of image tokens from VQVAE ({n_image_tokens})"
1207
+ )
1208
+
1209
+ return special_image_mask
1210
+
1211
+ @auto_docstring
1212
+ @can_return_tuple
1213
+ def forward(
1214
+ self,
1215
+ input_ids: torch.LongTensor | None = None,
1216
+ attention_mask: torch.Tensor | None = None,
1217
+ position_ids: torch.LongTensor | None = None,
1218
+ past_key_values: Cache | None = None,
1219
+ inputs_embeds: torch.FloatTensor | None = None,
1220
+ pixel_values: torch.Tensor | None = None,
1221
+ image_grid_thw: torch.LongTensor | None = None,
1222
+ rope_deltas: torch.LongTensor | None = None,
1223
+ cache_position: torch.LongTensor | None = None,
1224
+ **kwargs: Unpack[TransformersKwargs],
1225
+ ) -> tuple | GlmImageModelOutputWithPast:
1226
+ r"""
1227
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1228
+ The temporal, height and width of feature shape of each image in LLM.
1229
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1230
+ The rope index difference between sequence length and multimodal rope.
1231
+ """
1232
+ if (input_ids is None) ^ (inputs_embeds is not None):
1233
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1234
+
1235
+ if pixel_values is not None:
1236
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw[:-1])
1237
+ image_embeds = torch.cat(image_embeds, dim=0)
1238
+ image_ids = self.get_image_tokens(image_embeds, image_grid_thw[:-1])
1239
+ image_ids = image_ids.view(-1).to(input_ids.device)
1240
+ special_image_mask = self.get_placeholder_mask(input_ids, image_ids)
1241
+ input_ids = input_ids.masked_scatter(special_image_mask, image_ids)
1242
+
1243
+ if inputs_embeds is None:
1244
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1245
+
1246
+ if position_ids is None:
1247
+ attention_mask_2d = attention_mask
1248
+ if attention_mask is not None and attention_mask.ndim == 4:
1249
+ attention_mask_2d = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
1250
+ # Only apply conversion for floating point tensors (inverted masks)
1251
+ if attention_mask_2d.dtype.is_floating_point:
1252
+ attention_mask_2d = attention_mask_2d / torch.finfo(attention_mask_2d.dtype).min
1253
+ attention_mask_2d = (1.0 - attention_mask_2d).int()
1254
+
1255
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1256
+ # It is safe to assume that `length!=1` means we're in pre-fill because the
1257
+ # model is used only by DiT pipeline without assisted decoding, etc. techniques
1258
+ is_prefill_stage = (input_ids is not None and input_ids.shape[1] != 1) or (
1259
+ inputs_embeds is not None and inputs_embeds.shape[1] != 1
1260
+ )
1261
+ if is_prefill_stage or self.rope_deltas is None:
1262
+ position_ids, rope_deltas = self.get_rope_index(
1263
+ input_ids,
1264
+ image_grid_thw,
1265
+ attention_mask=attention_mask_2d,
1266
+ )
1267
+ self.rope_deltas = rope_deltas
1268
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1269
+ else:
1270
+ batch_size, seq_length, _ = inputs_embeds.shape
1271
+ # Use prefill token length, not position value
1272
+ step = cache_position[0].item() - self._prefill_len
1273
+ # Direct lookup - no tensor creation overhead
1274
+ position_ids = self._cached_decode_position_ids[:, step : step + seq_length]
1275
+ position_ids = position_ids.unsqueeze(1).expand(-1, batch_size, -1)
1276
+
1277
+ outputs = self.language_model(
1278
+ input_ids=None,
1279
+ position_ids=position_ids,
1280
+ attention_mask=attention_mask,
1281
+ past_key_values=past_key_values,
1282
+ inputs_embeds=inputs_embeds,
1283
+ cache_position=cache_position,
1284
+ **kwargs,
1285
+ )
1286
+
1287
+ return GlmImageModelOutputWithPast(
1288
+ last_hidden_state=outputs.last_hidden_state,
1289
+ past_key_values=outputs.past_key_values,
1290
+ hidden_states=outputs.hidden_states,
1291
+ attentions=outputs.attentions,
1292
+ rope_deltas=self.rope_deltas,
1293
+ )
1294
+
1295
+ def get_image_tokens(
1296
+ self,
1297
+ hidden_states: torch.FloatTensor,
1298
+ image_grid_thw: torch.LongTensor,
1299
+ ) -> torch.LongTensor:
1300
+ """
1301
+ Tokenizes image features into discrete tokens with VQVAE module.
1302
+
1303
+ Args:
1304
+ hidden_states (`torch.FloatTensor` of shape `(total_patches, hidden_size)`):
1305
+ The packed image features from vision encoder.
1306
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`):
1307
+ The temporal, height and width of feature shape of each image.
1308
+
1309
+ Returns:
1310
+ image_tokens (`torch.LongTensor` of shape `(total_patches,)`):
1311
+ Discrete token indices from the VQVAE codebook.
1312
+ """
1313
+ hidden_size = hidden_states.shape[-1]
1314
+ split_sizes = (image_grid_thw.prod(dim=-1)).tolist()
1315
+ hidden_states_list = torch.split(hidden_states, split_sizes, dim=0)
1316
+
1317
+ all_image_toks = []
1318
+ for i, hs in enumerate(hidden_states_list):
1319
+ grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
1320
+ hs = hs.view(grid_t, grid_h, grid_w, hidden_size)
1321
+ hs = hs.permute(0, 3, 1, 2).contiguous()
1322
+ _, _, image_toks = self.vqmodel.encode(hs)
1323
+ all_image_toks.append(image_toks)
1324
+ return torch.cat(all_image_toks, dim=0)
1325
+
1326
+
1327
+ @dataclass
1328
+ @auto_docstring(
1329
+ custom_intro="""
1330
+ Base class for GlmImage causal language model (or autoregressive) outputs.
1331
+ """
1332
+ )
1333
+ class GlmImageCausalLMOutputWithPast(ModelOutput):
1334
+ r"""
1335
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1336
+ Language modeling loss (for next-token prediction).
1337
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1338
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1339
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1340
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
1341
+
1342
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1343
+ `past_key_values` input) to speed up sequential decoding.
1344
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1345
+ The rope index difference between sequence length and multimodal rope.
1346
+ """
1347
+
1348
+ loss: torch.FloatTensor | None = None
1349
+ logits: torch.FloatTensor | None = None
1350
+ past_key_values: Cache | None = None
1351
+ hidden_states: tuple[torch.FloatTensor] | None = None
1352
+ attentions: tuple[torch.FloatTensor] | None = None
1353
+ rope_deltas: torch.LongTensor | None = None
1354
+
1355
+
1356
+ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin):
1357
+ _checkpoint_conversion_mapping = {}
1358
+ _tied_weights_keys = {}
1359
+ # Reference: fix gemma3 grad acc #37208
1360
+ accepts_loss_kwargs = False
1361
+ base_model_prefix = "model"
1362
+ config: GlmImageConfig
1363
+
1364
+ def __init__(self, config):
1365
+ super().__init__(config)
1366
+ self.model = GlmImageModel(config)
1367
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vision_vocab_size, bias=False)
1368
+
1369
+ # Initialize weights and apply final processing
1370
+ self.post_init()
1371
+
1372
+ def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
1373
+ return self.model.get_image_features(pixel_values, image_grid_thw)
1374
+
1375
+ def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
1376
+ return self.model.get_image_tokens(hidden_states, image_grid_thw)
1377
+
1378
+ def forward(
1379
+ self,
1380
+ input_ids: torch.LongTensor | None = None,
1381
+ attention_mask: torch.Tensor | None = None,
1382
+ position_ids: torch.LongTensor | None = None,
1383
+ past_key_values: Cache | None = None,
1384
+ inputs_embeds: torch.FloatTensor | None = None,
1385
+ labels: torch.LongTensor | None = None,
1386
+ pixel_values: torch.Tensor | None = None,
1387
+ image_grid_thw: torch.LongTensor | None = None,
1388
+ cache_position: torch.LongTensor | None = None,
1389
+ logits_to_keep: int | torch.Tensor = 0,
1390
+ **kwargs: Unpack[TransformersKwargs],
1391
+ ) -> tuple | GlmImageCausalLMOutputWithPast:
1392
+ r"""
1393
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1394
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1395
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1396
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1397
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1398
+ The temporal, height and width of feature shape of each image in LLM.
1399
+
1400
+ Example:
1401
+
1402
+ ```python
1403
+ >>> from PIL import Image
1404
+ >>> import requests
1405
+ >>> from transformers import AutoProcessor, GlmImageForConditionalGeneration
1406
+
1407
+ >>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image")
1408
+ >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-Image")
1409
+
1410
+ >>> messages = [
1411
+ {
1412
+ "role": "user",
1413
+ "content": [
1414
+ {"type": "image"},
1415
+ {"type": "text", "text": "Add a truck of this photo.<sop>28 40<eop>"},
1416
+ ],
1417
+ },
1418
+ ]
1419
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1420
+ >>> image = Image.open(requests.get(url, stream=True).raw)
1421
+
1422
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1423
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
1424
+
1425
+ >>> # Generate
1426
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1427
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1428
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
1429
+ ```"""
1430
+ outputs = self.model(
1431
+ input_ids=input_ids,
1432
+ pixel_values=pixel_values,
1433
+ image_grid_thw=image_grid_thw,
1434
+ position_ids=position_ids,
1435
+ attention_mask=attention_mask,
1436
+ past_key_values=past_key_values,
1437
+ inputs_embeds=inputs_embeds,
1438
+ cache_position=cache_position,
1439
+ **kwargs,
1440
+ )
1441
+
1442
+ hidden_states = outputs[0]
1443
+
1444
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1445
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1446
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1447
+
1448
+ loss = None
1449
+ if labels is not None:
1450
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
1451
+
1452
+ return GlmImageCausalLMOutputWithPast(
1453
+ loss=loss,
1454
+ logits=logits,
1455
+ past_key_values=outputs.past_key_values,
1456
+ hidden_states=outputs.hidden_states,
1457
+ attentions=outputs.attentions,
1458
+ rope_deltas=outputs.rope_deltas,
1459
+ )
1460
+
1461
+ def prepare_inputs_for_generation(
1462
+ self,
1463
+ input_ids,
1464
+ past_key_values=None,
1465
+ attention_mask=None,
1466
+ inputs_embeds=None,
1467
+ cache_position=None,
1468
+ position_ids=None,
1469
+ use_cache=True,
1470
+ pixel_values=None,
1471
+ image_grid_thw=None,
1472
+ is_first_iteration=False,
1473
+ **kwargs,
1474
+ ):
1475
+ model_inputs = super().prepare_inputs_for_generation(
1476
+ input_ids,
1477
+ past_key_values=past_key_values,
1478
+ attention_mask=attention_mask,
1479
+ inputs_embeds=inputs_embeds,
1480
+ cache_position=cache_position,
1481
+ position_ids=position_ids,
1482
+ pixel_values=pixel_values,
1483
+ image_grid_thw=image_grid_thw,
1484
+ is_first_iteration=is_first_iteration,
1485
+ use_cache=use_cache,
1486
+ **kwargs,
1487
+ )
1488
+
1489
+ model_inputs["position_ids"] = None
1490
+
1491
+ if not is_first_iteration and use_cache:
1492
+ model_inputs["pixel_values"] = None
1493
+
1494
+ return model_inputs
1495
+
1496
+ def _get_image_nums(
1497
+ self,
1498
+ input_ids: torch.LongTensor | None,
1499
+ ) -> torch.Tensor:
1500
+ """
1501
+ Get the number of images for each sample.
1502
+ For GLM-Image, only input_ids allow us to get the number of images.
1503
+
1504
+ Returns:
1505
+ image_counts (`torch.LongTensor` of shape `(batch_size,)`)
1506
+ """
1507
+ is_image = input_ids == self.config.image_start_token_id
1508
+
1509
+ return is_image.sum(dim=1)
1510
+
1511
+ def _expand_inputs_for_generation(
1512
+ self,
1513
+ expand_size: int = 1,
1514
+ is_encoder_decoder: bool = False,
1515
+ input_ids: torch.LongTensor | None = None,
1516
+ **model_kwargs,
1517
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1518
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1519
+ # e.g., pixel_values, image_grid_thw
1520
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1521
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1522
+
1523
+ if expand_size == 1:
1524
+ return input_ids, model_kwargs
1525
+
1526
+ visual_keys = ["pixel_values", "image_grid_thw"]
1527
+
1528
+ def _expand_dict_for_generation_visual(dict_to_expand):
1529
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1530
+ image_nums = self._get_image_nums(input_ids)
1531
+
1532
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1533
+ samples = torch.split(x, lengths)
1534
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1535
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1536
+ return result
1537
+
1538
+ for key in dict_to_expand:
1539
+ if key == "pixel_values":
1540
+ # split images into samples
1541
+ samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums))
1542
+ # compute the sequence length of images for each sample
1543
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1544
+ dict_to_expand[key] = _repeat_interleave_samples(
1545
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1546
+ )
1547
+ elif key == "image_grid_thw":
1548
+ # get the num of images for each sample and +1 for the image being generated
1549
+ lengths = list(image_nums)
1550
+ last_image = dict_to_expand[key][:-1]
1551
+ dict_to_expand[key] = _repeat_interleave_samples(
1552
+ dict_to_expand[key][: sum(image_nums)], lengths=lengths, repeat_times=expand_size
1553
+ )
1554
+ dict_to_expand[key] = torch.cat([dict_to_expand[key], last_image], dim=0)
1555
+ return dict_to_expand
1556
+
1557
+ def _expand_dict_for_generation(dict_to_expand):
1558
+ for key in dict_to_expand:
1559
+ if (
1560
+ key != "cache_position"
1561
+ and dict_to_expand[key] is not None
1562
+ and isinstance(dict_to_expand[key], torch.Tensor)
1563
+ and key not in visual_keys
1564
+ ):
1565
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1566
+ return dict_to_expand
1567
+
1568
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1569
+
1570
+ if input_ids is not None:
1571
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1572
+
1573
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1574
+
1575
+ if is_encoder_decoder:
1576
+ if model_kwargs.get("encoder_outputs") is None:
1577
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1578
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1579
+
1580
+ return input_ids, model_kwargs
1581
+
1582
+
1583
+ __all__ = [
1584
+ "GlmImagePreTrainedModel",
1585
+ "GlmImageVQVAE",
1586
+ "GlmImageVisionModel",
1587
+ "GlmImageTextModel",
1588
+ "GlmImageModel",
1589
+ "GlmImageForConditionalGeneration",
1590
+ ]