transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1021) hide show
  1. transformers/__init__.py +4 -11
  2. transformers/activations.py +2 -2
  3. transformers/backbone_utils.py +326 -0
  4. transformers/cache_utils.py +11 -2
  5. transformers/cli/serve.py +11 -8
  6. transformers/configuration_utils.py +1 -69
  7. transformers/conversion_mapping.py +146 -26
  8. transformers/convert_slow_tokenizer.py +6 -4
  9. transformers/core_model_loading.py +207 -118
  10. transformers/dependency_versions_check.py +0 -1
  11. transformers/dependency_versions_table.py +7 -8
  12. transformers/file_utils.py +0 -2
  13. transformers/generation/candidate_generator.py +1 -2
  14. transformers/generation/continuous_batching/cache.py +40 -38
  15. transformers/generation/continuous_batching/cache_manager.py +3 -16
  16. transformers/generation/continuous_batching/continuous_api.py +94 -406
  17. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  18. transformers/generation/continuous_batching/requests.py +54 -17
  19. transformers/generation/continuous_batching/scheduler.py +77 -95
  20. transformers/generation/logits_process.py +10 -5
  21. transformers/generation/stopping_criteria.py +1 -2
  22. transformers/generation/utils.py +75 -95
  23. transformers/image_processing_utils.py +0 -3
  24. transformers/image_processing_utils_fast.py +17 -18
  25. transformers/image_transforms.py +44 -13
  26. transformers/image_utils.py +0 -5
  27. transformers/initialization.py +57 -0
  28. transformers/integrations/__init__.py +10 -24
  29. transformers/integrations/accelerate.py +47 -11
  30. transformers/integrations/deepspeed.py +145 -3
  31. transformers/integrations/executorch.py +2 -6
  32. transformers/integrations/finegrained_fp8.py +142 -7
  33. transformers/integrations/flash_attention.py +2 -7
  34. transformers/integrations/hub_kernels.py +18 -7
  35. transformers/integrations/moe.py +226 -106
  36. transformers/integrations/mxfp4.py +47 -34
  37. transformers/integrations/peft.py +488 -176
  38. transformers/integrations/tensor_parallel.py +641 -581
  39. transformers/masking_utils.py +153 -9
  40. transformers/modeling_flash_attention_utils.py +1 -2
  41. transformers/modeling_utils.py +359 -358
  42. transformers/models/__init__.py +6 -0
  43. transformers/models/afmoe/configuration_afmoe.py +14 -4
  44. transformers/models/afmoe/modeling_afmoe.py +8 -8
  45. transformers/models/afmoe/modular_afmoe.py +7 -7
  46. transformers/models/aimv2/configuration_aimv2.py +2 -7
  47. transformers/models/aimv2/modeling_aimv2.py +26 -24
  48. transformers/models/aimv2/modular_aimv2.py +8 -12
  49. transformers/models/albert/configuration_albert.py +8 -1
  50. transformers/models/albert/modeling_albert.py +3 -3
  51. transformers/models/align/configuration_align.py +8 -5
  52. transformers/models/align/modeling_align.py +22 -24
  53. transformers/models/altclip/configuration_altclip.py +4 -6
  54. transformers/models/altclip/modeling_altclip.py +30 -26
  55. transformers/models/apertus/configuration_apertus.py +5 -7
  56. transformers/models/apertus/modeling_apertus.py +4 -4
  57. transformers/models/apertus/modular_apertus.py +8 -10
  58. transformers/models/arcee/configuration_arcee.py +5 -7
  59. transformers/models/arcee/modeling_arcee.py +4 -4
  60. transformers/models/aria/configuration_aria.py +11 -21
  61. transformers/models/aria/modeling_aria.py +39 -36
  62. transformers/models/aria/modular_aria.py +33 -39
  63. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
  64. transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
  65. transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
  66. transformers/models/auto/auto_factory.py +8 -6
  67. transformers/models/auto/configuration_auto.py +22 -0
  68. transformers/models/auto/image_processing_auto.py +17 -13
  69. transformers/models/auto/modeling_auto.py +15 -0
  70. transformers/models/auto/processing_auto.py +9 -18
  71. transformers/models/auto/tokenization_auto.py +17 -15
  72. transformers/models/autoformer/modeling_autoformer.py +2 -1
  73. transformers/models/aya_vision/configuration_aya_vision.py +4 -0
  74. transformers/models/aya_vision/modeling_aya_vision.py +29 -62
  75. transformers/models/aya_vision/modular_aya_vision.py +20 -45
  76. transformers/models/bamba/configuration_bamba.py +17 -7
  77. transformers/models/bamba/modeling_bamba.py +23 -55
  78. transformers/models/bamba/modular_bamba.py +19 -54
  79. transformers/models/bark/configuration_bark.py +2 -1
  80. transformers/models/bark/modeling_bark.py +24 -10
  81. transformers/models/bart/configuration_bart.py +9 -4
  82. transformers/models/bart/modeling_bart.py +9 -12
  83. transformers/models/beit/configuration_beit.py +2 -4
  84. transformers/models/beit/image_processing_beit_fast.py +3 -3
  85. transformers/models/beit/modeling_beit.py +14 -9
  86. transformers/models/bert/configuration_bert.py +12 -1
  87. transformers/models/bert/modeling_bert.py +6 -30
  88. transformers/models/bert_generation/configuration_bert_generation.py +17 -1
  89. transformers/models/bert_generation/modeling_bert_generation.py +6 -6
  90. transformers/models/big_bird/configuration_big_bird.py +12 -8
  91. transformers/models/big_bird/modeling_big_bird.py +0 -15
  92. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
  93. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
  94. transformers/models/biogpt/configuration_biogpt.py +8 -1
  95. transformers/models/biogpt/modeling_biogpt.py +4 -8
  96. transformers/models/biogpt/modular_biogpt.py +1 -5
  97. transformers/models/bit/configuration_bit.py +2 -4
  98. transformers/models/bit/modeling_bit.py +6 -5
  99. transformers/models/bitnet/configuration_bitnet.py +5 -7
  100. transformers/models/bitnet/modeling_bitnet.py +3 -4
  101. transformers/models/bitnet/modular_bitnet.py +3 -4
  102. transformers/models/blenderbot/configuration_blenderbot.py +8 -4
  103. transformers/models/blenderbot/modeling_blenderbot.py +4 -4
  104. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
  105. transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
  106. transformers/models/blip/configuration_blip.py +9 -9
  107. transformers/models/blip/modeling_blip.py +55 -37
  108. transformers/models/blip_2/configuration_blip_2.py +2 -1
  109. transformers/models/blip_2/modeling_blip_2.py +81 -56
  110. transformers/models/bloom/configuration_bloom.py +5 -1
  111. transformers/models/bloom/modeling_bloom.py +2 -1
  112. transformers/models/blt/configuration_blt.py +23 -12
  113. transformers/models/blt/modeling_blt.py +20 -14
  114. transformers/models/blt/modular_blt.py +70 -10
  115. transformers/models/bridgetower/configuration_bridgetower.py +7 -1
  116. transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
  117. transformers/models/bridgetower/modeling_bridgetower.py +29 -15
  118. transformers/models/bros/configuration_bros.py +24 -17
  119. transformers/models/camembert/configuration_camembert.py +8 -1
  120. transformers/models/camembert/modeling_camembert.py +6 -6
  121. transformers/models/canine/configuration_canine.py +4 -1
  122. transformers/models/chameleon/configuration_chameleon.py +5 -7
  123. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
  124. transformers/models/chameleon/modeling_chameleon.py +82 -36
  125. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
  126. transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
  127. transformers/models/clap/configuration_clap.py +4 -8
  128. transformers/models/clap/modeling_clap.py +21 -22
  129. transformers/models/clip/configuration_clip.py +4 -1
  130. transformers/models/clip/image_processing_clip_fast.py +9 -0
  131. transformers/models/clip/modeling_clip.py +25 -22
  132. transformers/models/clipseg/configuration_clipseg.py +4 -1
  133. transformers/models/clipseg/modeling_clipseg.py +27 -25
  134. transformers/models/clipseg/processing_clipseg.py +11 -3
  135. transformers/models/clvp/configuration_clvp.py +14 -2
  136. transformers/models/clvp/modeling_clvp.py +19 -30
  137. transformers/models/codegen/configuration_codegen.py +4 -3
  138. transformers/models/codegen/modeling_codegen.py +2 -1
  139. transformers/models/cohere/configuration_cohere.py +5 -7
  140. transformers/models/cohere/modeling_cohere.py +4 -4
  141. transformers/models/cohere/modular_cohere.py +3 -3
  142. transformers/models/cohere2/configuration_cohere2.py +6 -8
  143. transformers/models/cohere2/modeling_cohere2.py +4 -4
  144. transformers/models/cohere2/modular_cohere2.py +9 -11
  145. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  146. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
  147. transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
  148. transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
  149. transformers/models/colqwen2/modeling_colqwen2.py +7 -6
  150. transformers/models/colqwen2/modular_colqwen2.py +7 -6
  151. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
  152. transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
  153. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
  154. transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
  155. transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
  156. transformers/models/convbert/configuration_convbert.py +11 -7
  157. transformers/models/convnext/configuration_convnext.py +2 -4
  158. transformers/models/convnext/image_processing_convnext_fast.py +2 -2
  159. transformers/models/convnext/modeling_convnext.py +7 -6
  160. transformers/models/convnextv2/configuration_convnextv2.py +2 -4
  161. transformers/models/convnextv2/modeling_convnextv2.py +7 -6
  162. transformers/models/cpmant/configuration_cpmant.py +4 -0
  163. transformers/models/csm/configuration_csm.py +9 -15
  164. transformers/models/csm/modeling_csm.py +3 -3
  165. transformers/models/ctrl/configuration_ctrl.py +16 -0
  166. transformers/models/ctrl/modeling_ctrl.py +13 -25
  167. transformers/models/cwm/configuration_cwm.py +5 -7
  168. transformers/models/cwm/modeling_cwm.py +4 -4
  169. transformers/models/d_fine/configuration_d_fine.py +10 -56
  170. transformers/models/d_fine/modeling_d_fine.py +728 -868
  171. transformers/models/d_fine/modular_d_fine.py +335 -412
  172. transformers/models/dab_detr/configuration_dab_detr.py +22 -48
  173. transformers/models/dab_detr/modeling_dab_detr.py +11 -7
  174. transformers/models/dac/modeling_dac.py +1 -1
  175. transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
  176. transformers/models/data2vec/configuration_data2vec_text.py +11 -2
  177. transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
  178. transformers/models/data2vec/modeling_data2vec_text.py +6 -6
  179. transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
  180. transformers/models/dbrx/configuration_dbrx.py +11 -3
  181. transformers/models/dbrx/modeling_dbrx.py +6 -6
  182. transformers/models/dbrx/modular_dbrx.py +6 -6
  183. transformers/models/deberta/configuration_deberta.py +6 -0
  184. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
  185. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
  186. transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
  187. transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
  188. transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
  189. transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
  190. transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
  191. transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
  192. transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
  193. transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
  194. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
  195. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
  196. transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
  197. transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
  198. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
  199. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
  200. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
  201. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
  202. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
  203. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
  204. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
  205. transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
  206. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
  207. transformers/models/deit/modeling_deit.py +11 -7
  208. transformers/models/depth_anything/configuration_depth_anything.py +12 -42
  209. transformers/models/depth_anything/modeling_depth_anything.py +5 -3
  210. transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
  211. transformers/models/depth_pro/modeling_depth_pro.py +8 -4
  212. transformers/models/detr/configuration_detr.py +18 -49
  213. transformers/models/detr/image_processing_detr_fast.py +11 -11
  214. transformers/models/detr/modeling_detr.py +695 -734
  215. transformers/models/dia/configuration_dia.py +4 -7
  216. transformers/models/dia/generation_dia.py +8 -17
  217. transformers/models/dia/modeling_dia.py +7 -7
  218. transformers/models/dia/modular_dia.py +4 -4
  219. transformers/models/diffllama/configuration_diffllama.py +5 -7
  220. transformers/models/diffllama/modeling_diffllama.py +3 -8
  221. transformers/models/diffllama/modular_diffllama.py +2 -7
  222. transformers/models/dinat/configuration_dinat.py +2 -4
  223. transformers/models/dinat/modeling_dinat.py +7 -6
  224. transformers/models/dinov2/configuration_dinov2.py +2 -4
  225. transformers/models/dinov2/modeling_dinov2.py +9 -8
  226. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
  227. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
  228. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
  229. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
  230. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
  231. transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
  232. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
  233. transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
  234. transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
  235. transformers/models/distilbert/configuration_distilbert.py +8 -1
  236. transformers/models/distilbert/modeling_distilbert.py +3 -3
  237. transformers/models/doge/configuration_doge.py +17 -7
  238. transformers/models/doge/modeling_doge.py +4 -4
  239. transformers/models/doge/modular_doge.py +20 -10
  240. transformers/models/donut/image_processing_donut_fast.py +4 -4
  241. transformers/models/dots1/configuration_dots1.py +16 -7
  242. transformers/models/dots1/modeling_dots1.py +4 -4
  243. transformers/models/dpr/configuration_dpr.py +19 -1
  244. transformers/models/dpt/configuration_dpt.py +23 -65
  245. transformers/models/dpt/image_processing_dpt_fast.py +5 -5
  246. transformers/models/dpt/modeling_dpt.py +19 -15
  247. transformers/models/dpt/modular_dpt.py +4 -4
  248. transformers/models/edgetam/configuration_edgetam.py +1 -1
  249. transformers/models/edgetam/modeling_edgetam.py +53 -53
  250. transformers/models/edgetam/modular_edgetam.py +5 -7
  251. transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
  252. transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
  253. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
  254. transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
  255. transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
  256. transformers/models/electra/configuration_electra.py +13 -2
  257. transformers/models/electra/modeling_electra.py +6 -6
  258. transformers/models/emu3/configuration_emu3.py +12 -10
  259. transformers/models/emu3/modeling_emu3.py +84 -47
  260. transformers/models/emu3/modular_emu3.py +77 -39
  261. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
  262. transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
  263. transformers/models/eomt/configuration_eomt.py +12 -13
  264. transformers/models/eomt/image_processing_eomt_fast.py +3 -3
  265. transformers/models/eomt/modeling_eomt.py +3 -3
  266. transformers/models/eomt/modular_eomt.py +17 -17
  267. transformers/models/eomt_dinov3/__init__.py +28 -0
  268. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  269. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  270. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  271. transformers/models/ernie/configuration_ernie.py +24 -2
  272. transformers/models/ernie/modeling_ernie.py +6 -30
  273. transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
  274. transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
  275. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
  276. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
  277. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
  278. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
  279. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
  280. transformers/models/esm/configuration_esm.py +9 -11
  281. transformers/models/esm/modeling_esm.py +3 -3
  282. transformers/models/esm/modeling_esmfold.py +1 -6
  283. transformers/models/esm/openfold_utils/protein.py +2 -3
  284. transformers/models/evolla/configuration_evolla.py +21 -8
  285. transformers/models/evolla/modeling_evolla.py +11 -7
  286. transformers/models/evolla/modular_evolla.py +5 -1
  287. transformers/models/exaone4/configuration_exaone4.py +8 -5
  288. transformers/models/exaone4/modeling_exaone4.py +4 -4
  289. transformers/models/exaone4/modular_exaone4.py +11 -8
  290. transformers/models/exaone_moe/__init__.py +27 -0
  291. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  292. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  293. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  294. transformers/models/falcon/configuration_falcon.py +9 -1
  295. transformers/models/falcon/modeling_falcon.py +3 -8
  296. transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
  297. transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
  298. transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
  299. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
  300. transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
  301. transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
  302. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
  303. transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
  304. transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
  305. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
  306. transformers/models/flaubert/configuration_flaubert.py +10 -4
  307. transformers/models/flaubert/modeling_flaubert.py +1 -1
  308. transformers/models/flava/configuration_flava.py +4 -3
  309. transformers/models/flava/image_processing_flava_fast.py +4 -4
  310. transformers/models/flava/modeling_flava.py +36 -28
  311. transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
  312. transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
  313. transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
  314. transformers/models/florence2/configuration_florence2.py +4 -0
  315. transformers/models/florence2/modeling_florence2.py +57 -32
  316. transformers/models/florence2/modular_florence2.py +48 -26
  317. transformers/models/fnet/configuration_fnet.py +6 -1
  318. transformers/models/focalnet/configuration_focalnet.py +2 -4
  319. transformers/models/focalnet/modeling_focalnet.py +10 -7
  320. transformers/models/fsmt/configuration_fsmt.py +12 -16
  321. transformers/models/funnel/configuration_funnel.py +8 -0
  322. transformers/models/fuyu/configuration_fuyu.py +5 -8
  323. transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
  324. transformers/models/fuyu/modeling_fuyu.py +24 -23
  325. transformers/models/gemma/configuration_gemma.py +5 -7
  326. transformers/models/gemma/modeling_gemma.py +4 -4
  327. transformers/models/gemma/modular_gemma.py +5 -7
  328. transformers/models/gemma2/configuration_gemma2.py +5 -7
  329. transformers/models/gemma2/modeling_gemma2.py +4 -4
  330. transformers/models/gemma2/modular_gemma2.py +8 -10
  331. transformers/models/gemma3/configuration_gemma3.py +28 -22
  332. transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
  333. transformers/models/gemma3/modeling_gemma3.py +37 -33
  334. transformers/models/gemma3/modular_gemma3.py +46 -42
  335. transformers/models/gemma3n/configuration_gemma3n.py +35 -22
  336. transformers/models/gemma3n/modeling_gemma3n.py +86 -58
  337. transformers/models/gemma3n/modular_gemma3n.py +112 -75
  338. transformers/models/git/configuration_git.py +5 -7
  339. transformers/models/git/modeling_git.py +31 -41
  340. transformers/models/glm/configuration_glm.py +7 -9
  341. transformers/models/glm/modeling_glm.py +4 -4
  342. transformers/models/glm4/configuration_glm4.py +7 -9
  343. transformers/models/glm4/modeling_glm4.py +4 -4
  344. transformers/models/glm46v/configuration_glm46v.py +4 -0
  345. transformers/models/glm46v/image_processing_glm46v.py +5 -2
  346. transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
  347. transformers/models/glm46v/modeling_glm46v.py +91 -46
  348. transformers/models/glm46v/modular_glm46v.py +4 -0
  349. transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
  350. transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
  351. transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
  352. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
  353. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
  354. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
  355. transformers/models/glm4v/configuration_glm4v.py +12 -8
  356. transformers/models/glm4v/image_processing_glm4v.py +5 -2
  357. transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
  358. transformers/models/glm4v/modeling_glm4v.py +120 -63
  359. transformers/models/glm4v/modular_glm4v.py +82 -50
  360. transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
  361. transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
  362. transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
  363. transformers/models/glm_image/configuration_glm_image.py +26 -20
  364. transformers/models/glm_image/image_processing_glm_image.py +1 -1
  365. transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
  366. transformers/models/glm_image/modeling_glm_image.py +337 -236
  367. transformers/models/glm_image/modular_glm_image.py +415 -255
  368. transformers/models/glm_image/processing_glm_image.py +65 -17
  369. transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
  370. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  371. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  372. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  373. transformers/models/glmasr/modeling_glmasr.py +34 -28
  374. transformers/models/glmasr/modular_glmasr.py +23 -11
  375. transformers/models/glpn/image_processing_glpn_fast.py +3 -3
  376. transformers/models/glpn/modeling_glpn.py +4 -2
  377. transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
  378. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
  379. transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
  380. transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
  381. transformers/models/gpt2/configuration_gpt2.py +13 -1
  382. transformers/models/gpt2/modeling_gpt2.py +5 -5
  383. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
  384. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
  385. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
  386. transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
  387. transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
  388. transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
  389. transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
  390. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
  391. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
  392. transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
  393. transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
  394. transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
  395. transformers/models/gptj/configuration_gptj.py +4 -4
  396. transformers/models/gptj/modeling_gptj.py +3 -7
  397. transformers/models/granite/configuration_granite.py +5 -7
  398. transformers/models/granite/modeling_granite.py +4 -4
  399. transformers/models/granite_speech/modeling_granite_speech.py +63 -37
  400. transformers/models/granitemoe/configuration_granitemoe.py +5 -7
  401. transformers/models/granitemoe/modeling_granitemoe.py +4 -4
  402. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
  403. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
  404. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
  405. transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
  406. transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
  407. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
  408. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
  409. transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
  410. transformers/models/groupvit/configuration_groupvit.py +4 -1
  411. transformers/models/groupvit/modeling_groupvit.py +29 -22
  412. transformers/models/helium/configuration_helium.py +5 -7
  413. transformers/models/helium/modeling_helium.py +4 -4
  414. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
  415. transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
  416. transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
  417. transformers/models/hiera/configuration_hiera.py +2 -4
  418. transformers/models/hiera/modeling_hiera.py +11 -8
  419. transformers/models/hubert/configuration_hubert.py +4 -1
  420. transformers/models/hubert/modeling_hubert.py +7 -4
  421. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
  422. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
  423. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
  424. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
  425. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
  426. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
  427. transformers/models/ibert/configuration_ibert.py +4 -1
  428. transformers/models/idefics/configuration_idefics.py +5 -7
  429. transformers/models/idefics/modeling_idefics.py +3 -4
  430. transformers/models/idefics/vision.py +5 -4
  431. transformers/models/idefics2/configuration_idefics2.py +1 -2
  432. transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
  433. transformers/models/idefics2/modeling_idefics2.py +72 -50
  434. transformers/models/idefics3/configuration_idefics3.py +1 -3
  435. transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
  436. transformers/models/idefics3/modeling_idefics3.py +63 -40
  437. transformers/models/ijepa/modeling_ijepa.py +3 -3
  438. transformers/models/imagegpt/configuration_imagegpt.py +9 -1
  439. transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
  440. transformers/models/imagegpt/modeling_imagegpt.py +8 -4
  441. transformers/models/informer/modeling_informer.py +3 -3
  442. transformers/models/instructblip/configuration_instructblip.py +2 -1
  443. transformers/models/instructblip/modeling_instructblip.py +65 -39
  444. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
  445. transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
  446. transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
  447. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
  448. transformers/models/internvl/configuration_internvl.py +5 -0
  449. transformers/models/internvl/modeling_internvl.py +35 -55
  450. transformers/models/internvl/modular_internvl.py +26 -38
  451. transformers/models/internvl/video_processing_internvl.py +2 -2
  452. transformers/models/jais2/configuration_jais2.py +5 -7
  453. transformers/models/jais2/modeling_jais2.py +4 -4
  454. transformers/models/jamba/configuration_jamba.py +5 -7
  455. transformers/models/jamba/modeling_jamba.py +4 -4
  456. transformers/models/jamba/modular_jamba.py +3 -3
  457. transformers/models/janus/image_processing_janus.py +2 -2
  458. transformers/models/janus/image_processing_janus_fast.py +8 -8
  459. transformers/models/janus/modeling_janus.py +63 -146
  460. transformers/models/janus/modular_janus.py +62 -20
  461. transformers/models/jetmoe/configuration_jetmoe.py +6 -4
  462. transformers/models/jetmoe/modeling_jetmoe.py +3 -3
  463. transformers/models/jetmoe/modular_jetmoe.py +3 -3
  464. transformers/models/kosmos2/configuration_kosmos2.py +10 -8
  465. transformers/models/kosmos2/modeling_kosmos2.py +56 -34
  466. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
  467. transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
  468. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
  469. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
  470. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
  471. transformers/models/lasr/configuration_lasr.py +2 -4
  472. transformers/models/lasr/modeling_lasr.py +3 -3
  473. transformers/models/lasr/modular_lasr.py +3 -3
  474. transformers/models/layoutlm/configuration_layoutlm.py +14 -1
  475. transformers/models/layoutlm/modeling_layoutlm.py +3 -3
  476. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
  477. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
  478. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
  479. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
  480. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
  481. transformers/models/led/configuration_led.py +7 -8
  482. transformers/models/levit/image_processing_levit_fast.py +4 -4
  483. transformers/models/lfm2/configuration_lfm2.py +5 -7
  484. transformers/models/lfm2/modeling_lfm2.py +4 -4
  485. transformers/models/lfm2/modular_lfm2.py +3 -3
  486. transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
  487. transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
  488. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
  489. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
  490. transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
  491. transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
  492. transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
  493. transformers/models/lightglue/modeling_lightglue.py +3 -3
  494. transformers/models/lightglue/modular_lightglue.py +3 -3
  495. transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
  496. transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
  497. transformers/models/lilt/configuration_lilt.py +6 -1
  498. transformers/models/llama/configuration_llama.py +5 -7
  499. transformers/models/llama/modeling_llama.py +4 -4
  500. transformers/models/llama4/configuration_llama4.py +67 -47
  501. transformers/models/llama4/image_processing_llama4_fast.py +3 -3
  502. transformers/models/llama4/modeling_llama4.py +46 -44
  503. transformers/models/llava/configuration_llava.py +10 -0
  504. transformers/models/llava/image_processing_llava_fast.py +3 -3
  505. transformers/models/llava/modeling_llava.py +38 -65
  506. transformers/models/llava_next/configuration_llava_next.py +2 -1
  507. transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
  508. transformers/models/llava_next/modeling_llava_next.py +61 -60
  509. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
  510. transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
  511. transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
  512. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
  513. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
  514. transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
  515. transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
  516. transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
  517. transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
  518. transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
  519. transformers/models/longformer/configuration_longformer.py +4 -1
  520. transformers/models/longt5/configuration_longt5.py +9 -6
  521. transformers/models/longt5/modeling_longt5.py +2 -1
  522. transformers/models/luke/configuration_luke.py +8 -1
  523. transformers/models/lw_detr/configuration_lw_detr.py +19 -31
  524. transformers/models/lw_detr/modeling_lw_detr.py +43 -44
  525. transformers/models/lw_detr/modular_lw_detr.py +36 -38
  526. transformers/models/lxmert/configuration_lxmert.py +16 -0
  527. transformers/models/m2m_100/configuration_m2m_100.py +7 -8
  528. transformers/models/m2m_100/modeling_m2m_100.py +3 -3
  529. transformers/models/mamba/configuration_mamba.py +5 -2
  530. transformers/models/mamba/modeling_mamba.py +18 -26
  531. transformers/models/mamba2/configuration_mamba2.py +5 -7
  532. transformers/models/mamba2/modeling_mamba2.py +22 -33
  533. transformers/models/marian/configuration_marian.py +10 -4
  534. transformers/models/marian/modeling_marian.py +4 -4
  535. transformers/models/markuplm/configuration_markuplm.py +4 -6
  536. transformers/models/markuplm/modeling_markuplm.py +3 -3
  537. transformers/models/mask2former/configuration_mask2former.py +12 -47
  538. transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
  539. transformers/models/mask2former/modeling_mask2former.py +18 -12
  540. transformers/models/maskformer/configuration_maskformer.py +14 -45
  541. transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
  542. transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
  543. transformers/models/maskformer/modeling_maskformer.py +15 -9
  544. transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
  545. transformers/models/mbart/configuration_mbart.py +9 -4
  546. transformers/models/mbart/modeling_mbart.py +9 -6
  547. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
  548. transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
  549. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  550. transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
  551. transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
  552. transformers/models/mgp_str/modeling_mgp_str.py +4 -2
  553. transformers/models/mimi/configuration_mimi.py +4 -0
  554. transformers/models/mimi/modeling_mimi.py +40 -36
  555. transformers/models/minimax/configuration_minimax.py +8 -11
  556. transformers/models/minimax/modeling_minimax.py +5 -5
  557. transformers/models/minimax/modular_minimax.py +9 -12
  558. transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
  559. transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
  560. transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
  561. transformers/models/ministral/configuration_ministral.py +5 -7
  562. transformers/models/ministral/modeling_ministral.py +4 -4
  563. transformers/models/ministral/modular_ministral.py +5 -8
  564. transformers/models/ministral3/configuration_ministral3.py +4 -4
  565. transformers/models/ministral3/modeling_ministral3.py +4 -4
  566. transformers/models/ministral3/modular_ministral3.py +3 -3
  567. transformers/models/mistral/configuration_mistral.py +5 -7
  568. transformers/models/mistral/modeling_mistral.py +4 -4
  569. transformers/models/mistral/modular_mistral.py +3 -3
  570. transformers/models/mistral3/configuration_mistral3.py +4 -0
  571. transformers/models/mistral3/modeling_mistral3.py +36 -40
  572. transformers/models/mistral3/modular_mistral3.py +31 -32
  573. transformers/models/mixtral/configuration_mixtral.py +8 -11
  574. transformers/models/mixtral/modeling_mixtral.py +4 -4
  575. transformers/models/mlcd/modeling_mlcd.py +7 -5
  576. transformers/models/mlcd/modular_mlcd.py +7 -5
  577. transformers/models/mllama/configuration_mllama.py +5 -7
  578. transformers/models/mllama/image_processing_mllama_fast.py +6 -5
  579. transformers/models/mllama/modeling_mllama.py +19 -19
  580. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
  581. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
  582. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
  583. transformers/models/mobilebert/configuration_mobilebert.py +4 -1
  584. transformers/models/mobilebert/modeling_mobilebert.py +3 -3
  585. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
  586. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
  587. transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
  588. transformers/models/mobilevit/modeling_mobilevit.py +4 -2
  589. transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
  590. transformers/models/modernbert/configuration_modernbert.py +46 -21
  591. transformers/models/modernbert/modeling_modernbert.py +146 -899
  592. transformers/models/modernbert/modular_modernbert.py +185 -908
  593. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
  594. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
  595. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
  596. transformers/models/moonshine/configuration_moonshine.py +12 -7
  597. transformers/models/moonshine/modeling_moonshine.py +7 -7
  598. transformers/models/moonshine/modular_moonshine.py +19 -13
  599. transformers/models/moshi/configuration_moshi.py +28 -2
  600. transformers/models/moshi/modeling_moshi.py +4 -9
  601. transformers/models/mpnet/configuration_mpnet.py +6 -1
  602. transformers/models/mpt/configuration_mpt.py +16 -0
  603. transformers/models/mra/configuration_mra.py +8 -1
  604. transformers/models/mt5/configuration_mt5.py +9 -5
  605. transformers/models/mt5/modeling_mt5.py +5 -8
  606. transformers/models/musicgen/configuration_musicgen.py +12 -7
  607. transformers/models/musicgen/modeling_musicgen.py +6 -5
  608. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
  609. transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
  610. transformers/models/mvp/configuration_mvp.py +8 -4
  611. transformers/models/mvp/modeling_mvp.py +6 -4
  612. transformers/models/nanochat/configuration_nanochat.py +5 -7
  613. transformers/models/nanochat/modeling_nanochat.py +4 -4
  614. transformers/models/nanochat/modular_nanochat.py +4 -4
  615. transformers/models/nemotron/configuration_nemotron.py +5 -7
  616. transformers/models/nemotron/modeling_nemotron.py +4 -14
  617. transformers/models/nllb/tokenization_nllb.py +7 -5
  618. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
  619. transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
  620. transformers/models/nougat/image_processing_nougat_fast.py +8 -8
  621. transformers/models/nystromformer/configuration_nystromformer.py +8 -1
  622. transformers/models/olmo/configuration_olmo.py +5 -7
  623. transformers/models/olmo/modeling_olmo.py +4 -4
  624. transformers/models/olmo/modular_olmo.py +3 -3
  625. transformers/models/olmo2/configuration_olmo2.py +9 -11
  626. transformers/models/olmo2/modeling_olmo2.py +4 -4
  627. transformers/models/olmo2/modular_olmo2.py +7 -7
  628. transformers/models/olmo3/configuration_olmo3.py +10 -11
  629. transformers/models/olmo3/modeling_olmo3.py +4 -4
  630. transformers/models/olmo3/modular_olmo3.py +13 -14
  631. transformers/models/olmoe/configuration_olmoe.py +5 -7
  632. transformers/models/olmoe/modeling_olmoe.py +4 -4
  633. transformers/models/olmoe/modular_olmoe.py +3 -3
  634. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
  635. transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
  636. transformers/models/oneformer/configuration_oneformer.py +9 -46
  637. transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
  638. transformers/models/oneformer/modeling_oneformer.py +14 -9
  639. transformers/models/openai/configuration_openai.py +16 -0
  640. transformers/models/opt/configuration_opt.py +6 -6
  641. transformers/models/opt/modeling_opt.py +5 -5
  642. transformers/models/ovis2/configuration_ovis2.py +4 -0
  643. transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
  644. transformers/models/ovis2/modeling_ovis2.py +58 -99
  645. transformers/models/ovis2/modular_ovis2.py +52 -13
  646. transformers/models/owlv2/configuration_owlv2.py +4 -1
  647. transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
  648. transformers/models/owlv2/modeling_owlv2.py +40 -27
  649. transformers/models/owlv2/modular_owlv2.py +5 -5
  650. transformers/models/owlvit/configuration_owlvit.py +4 -1
  651. transformers/models/owlvit/modeling_owlvit.py +40 -27
  652. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
  653. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
  654. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
  655. transformers/models/paligemma/configuration_paligemma.py +4 -0
  656. transformers/models/paligemma/modeling_paligemma.py +30 -26
  657. transformers/models/parakeet/configuration_parakeet.py +2 -4
  658. transformers/models/parakeet/modeling_parakeet.py +3 -3
  659. transformers/models/parakeet/modular_parakeet.py +3 -3
  660. transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
  661. transformers/models/patchtst/modeling_patchtst.py +3 -3
  662. transformers/models/pe_audio/modeling_pe_audio.py +4 -4
  663. transformers/models/pe_audio/modular_pe_audio.py +1 -1
  664. transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
  665. transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
  666. transformers/models/pe_video/modeling_pe_video.py +36 -24
  667. transformers/models/pe_video/modular_pe_video.py +36 -23
  668. transformers/models/pegasus/configuration_pegasus.py +8 -5
  669. transformers/models/pegasus/modeling_pegasus.py +4 -4
  670. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
  671. transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
  672. transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
  673. transformers/models/perceiver/modeling_perceiver.py +17 -9
  674. transformers/models/perception_lm/modeling_perception_lm.py +26 -27
  675. transformers/models/perception_lm/modular_perception_lm.py +27 -25
  676. transformers/models/persimmon/configuration_persimmon.py +5 -7
  677. transformers/models/persimmon/modeling_persimmon.py +5 -5
  678. transformers/models/phi/configuration_phi.py +8 -6
  679. transformers/models/phi/modeling_phi.py +4 -4
  680. transformers/models/phi/modular_phi.py +3 -3
  681. transformers/models/phi3/configuration_phi3.py +9 -11
  682. transformers/models/phi3/modeling_phi3.py +4 -4
  683. transformers/models/phi3/modular_phi3.py +3 -3
  684. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
  685. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
  686. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
  687. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
  688. transformers/models/phimoe/configuration_phimoe.py +5 -7
  689. transformers/models/phimoe/modeling_phimoe.py +15 -39
  690. transformers/models/phimoe/modular_phimoe.py +12 -7
  691. transformers/models/pix2struct/configuration_pix2struct.py +12 -9
  692. transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
  693. transformers/models/pix2struct/modeling_pix2struct.py +14 -7
  694. transformers/models/pixio/configuration_pixio.py +2 -4
  695. transformers/models/pixio/modeling_pixio.py +9 -8
  696. transformers/models/pixio/modular_pixio.py +4 -2
  697. transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
  698. transformers/models/pixtral/modeling_pixtral.py +9 -12
  699. transformers/models/plbart/configuration_plbart.py +8 -5
  700. transformers/models/plbart/modeling_plbart.py +9 -7
  701. transformers/models/plbart/modular_plbart.py +1 -1
  702. transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
  703. transformers/models/pop2piano/configuration_pop2piano.py +7 -6
  704. transformers/models/pop2piano/modeling_pop2piano.py +2 -1
  705. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  706. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  707. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  708. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  709. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  710. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
  711. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
  712. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
  713. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
  714. transformers/models/prophetnet/configuration_prophetnet.py +11 -10
  715. transformers/models/prophetnet/modeling_prophetnet.py +12 -23
  716. transformers/models/pvt/image_processing_pvt.py +7 -7
  717. transformers/models/pvt/image_processing_pvt_fast.py +1 -1
  718. transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
  719. transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
  720. transformers/models/qwen2/configuration_qwen2.py +14 -4
  721. transformers/models/qwen2/modeling_qwen2.py +4 -4
  722. transformers/models/qwen2/modular_qwen2.py +3 -3
  723. transformers/models/qwen2/tokenization_qwen2.py +0 -4
  724. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
  725. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
  726. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
  727. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
  728. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
  729. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
  730. transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
  731. transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
  732. transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
  733. transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
  734. transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
  735. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
  736. transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
  737. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
  738. transformers/models/qwen3/configuration_qwen3.py +15 -5
  739. transformers/models/qwen3/modeling_qwen3.py +4 -4
  740. transformers/models/qwen3/modular_qwen3.py +3 -3
  741. transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
  742. transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
  743. transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
  744. transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
  745. transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
  746. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
  747. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
  748. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
  749. transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
  750. transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
  751. transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
  752. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
  753. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
  754. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
  755. transformers/models/rag/configuration_rag.py +6 -6
  756. transformers/models/rag/modeling_rag.py +3 -3
  757. transformers/models/rag/retrieval_rag.py +1 -1
  758. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
  759. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
  760. transformers/models/reformer/configuration_reformer.py +7 -7
  761. transformers/models/rembert/configuration_rembert.py +8 -1
  762. transformers/models/rembert/modeling_rembert.py +0 -22
  763. transformers/models/resnet/configuration_resnet.py +2 -4
  764. transformers/models/resnet/modeling_resnet.py +6 -5
  765. transformers/models/roberta/configuration_roberta.py +11 -2
  766. transformers/models/roberta/modeling_roberta.py +6 -6
  767. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
  768. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
  769. transformers/models/roc_bert/configuration_roc_bert.py +8 -1
  770. transformers/models/roc_bert/modeling_roc_bert.py +6 -41
  771. transformers/models/roformer/configuration_roformer.py +13 -2
  772. transformers/models/roformer/modeling_roformer.py +0 -14
  773. transformers/models/rt_detr/configuration_rt_detr.py +8 -49
  774. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
  775. transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
  776. transformers/models/rt_detr/modeling_rt_detr.py +578 -737
  777. transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
  778. transformers/models/rt_detr/modular_rt_detr.py +1508 -6
  779. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
  780. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
  781. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
  782. transformers/models/rwkv/configuration_rwkv.py +2 -3
  783. transformers/models/rwkv/modeling_rwkv.py +0 -23
  784. transformers/models/sam/configuration_sam.py +2 -0
  785. transformers/models/sam/image_processing_sam_fast.py +4 -4
  786. transformers/models/sam/modeling_sam.py +13 -8
  787. transformers/models/sam/processing_sam.py +3 -3
  788. transformers/models/sam2/configuration_sam2.py +1 -1
  789. transformers/models/sam2/modeling_sam2.py +56 -52
  790. transformers/models/sam2/modular_sam2.py +47 -55
  791. transformers/models/sam2_video/modeling_sam2_video.py +50 -51
  792. transformers/models/sam2_video/modular_sam2_video.py +12 -10
  793. transformers/models/sam3/modeling_sam3.py +43 -47
  794. transformers/models/sam3/processing_sam3.py +8 -4
  795. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
  796. transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
  797. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
  798. transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
  799. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
  800. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
  801. transformers/models/sam3_video/modeling_sam3_video.py +27 -14
  802. transformers/models/sam_hq/configuration_sam_hq.py +2 -0
  803. transformers/models/sam_hq/modeling_sam_hq.py +13 -9
  804. transformers/models/sam_hq/modular_sam_hq.py +6 -6
  805. transformers/models/sam_hq/processing_sam_hq.py +7 -6
  806. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
  807. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
  808. transformers/models/seed_oss/configuration_seed_oss.py +7 -9
  809. transformers/models/seed_oss/modeling_seed_oss.py +4 -4
  810. transformers/models/seed_oss/modular_seed_oss.py +3 -3
  811. transformers/models/segformer/image_processing_segformer_fast.py +4 -4
  812. transformers/models/segformer/modeling_segformer.py +4 -2
  813. transformers/models/segformer/modular_segformer.py +3 -3
  814. transformers/models/seggpt/modeling_seggpt.py +20 -8
  815. transformers/models/sew/configuration_sew.py +4 -1
  816. transformers/models/sew/modeling_sew.py +9 -5
  817. transformers/models/sew/modular_sew.py +2 -1
  818. transformers/models/sew_d/configuration_sew_d.py +4 -1
  819. transformers/models/sew_d/modeling_sew_d.py +4 -1
  820. transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
  821. transformers/models/siglip/configuration_siglip.py +4 -1
  822. transformers/models/siglip/modeling_siglip.py +27 -71
  823. transformers/models/siglip2/__init__.py +1 -0
  824. transformers/models/siglip2/configuration_siglip2.py +4 -2
  825. transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
  826. transformers/models/siglip2/modeling_siglip2.py +37 -78
  827. transformers/models/siglip2/modular_siglip2.py +74 -25
  828. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  829. transformers/models/smollm3/configuration_smollm3.py +6 -6
  830. transformers/models/smollm3/modeling_smollm3.py +4 -4
  831. transformers/models/smollm3/modular_smollm3.py +9 -9
  832. transformers/models/smolvlm/configuration_smolvlm.py +1 -3
  833. transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
  834. transformers/models/smolvlm/modeling_smolvlm.py +75 -46
  835. transformers/models/smolvlm/modular_smolvlm.py +36 -23
  836. transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
  837. transformers/models/solar_open/__init__.py +27 -0
  838. transformers/models/solar_open/configuration_solar_open.py +184 -0
  839. transformers/models/solar_open/modeling_solar_open.py +642 -0
  840. transformers/models/solar_open/modular_solar_open.py +224 -0
  841. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
  842. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
  843. transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
  844. transformers/models/speecht5/configuration_speecht5.py +7 -8
  845. transformers/models/splinter/configuration_splinter.py +6 -6
  846. transformers/models/splinter/modeling_splinter.py +8 -3
  847. transformers/models/squeezebert/configuration_squeezebert.py +14 -1
  848. transformers/models/stablelm/configuration_stablelm.py +8 -6
  849. transformers/models/stablelm/modeling_stablelm.py +5 -5
  850. transformers/models/starcoder2/configuration_starcoder2.py +11 -5
  851. transformers/models/starcoder2/modeling_starcoder2.py +5 -5
  852. transformers/models/starcoder2/modular_starcoder2.py +4 -4
  853. transformers/models/superglue/configuration_superglue.py +4 -0
  854. transformers/models/superglue/image_processing_superglue_fast.py +4 -3
  855. transformers/models/superglue/modeling_superglue.py +9 -4
  856. transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
  857. transformers/models/superpoint/modeling_superpoint.py +4 -2
  858. transformers/models/swin/configuration_swin.py +2 -4
  859. transformers/models/swin/modeling_swin.py +11 -8
  860. transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
  861. transformers/models/swin2sr/modeling_swin2sr.py +4 -2
  862. transformers/models/swinv2/configuration_swinv2.py +2 -4
  863. transformers/models/swinv2/modeling_swinv2.py +10 -7
  864. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
  865. transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
  866. transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
  867. transformers/models/t5/configuration_t5.py +9 -8
  868. transformers/models/t5/modeling_t5.py +5 -8
  869. transformers/models/t5gemma/configuration_t5gemma.py +10 -25
  870. transformers/models/t5gemma/modeling_t5gemma.py +9 -9
  871. transformers/models/t5gemma/modular_t5gemma.py +11 -24
  872. transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
  873. transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
  874. transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
  875. transformers/models/table_transformer/configuration_table_transformer.py +18 -49
  876. transformers/models/table_transformer/modeling_table_transformer.py +27 -53
  877. transformers/models/tapas/configuration_tapas.py +12 -1
  878. transformers/models/tapas/modeling_tapas.py +1 -1
  879. transformers/models/tapas/tokenization_tapas.py +1 -0
  880. transformers/models/textnet/configuration_textnet.py +4 -6
  881. transformers/models/textnet/image_processing_textnet_fast.py +3 -3
  882. transformers/models/textnet/modeling_textnet.py +15 -14
  883. transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
  884. transformers/models/timesfm/modeling_timesfm.py +5 -6
  885. transformers/models/timesfm/modular_timesfm.py +5 -6
  886. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
  887. transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
  888. transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
  889. transformers/models/trocr/configuration_trocr.py +11 -7
  890. transformers/models/trocr/modeling_trocr.py +4 -2
  891. transformers/models/tvp/configuration_tvp.py +10 -35
  892. transformers/models/tvp/image_processing_tvp_fast.py +6 -5
  893. transformers/models/tvp/modeling_tvp.py +1 -1
  894. transformers/models/udop/configuration_udop.py +16 -7
  895. transformers/models/udop/modeling_udop.py +10 -6
  896. transformers/models/umt5/configuration_umt5.py +8 -6
  897. transformers/models/umt5/modeling_umt5.py +7 -3
  898. transformers/models/unispeech/configuration_unispeech.py +4 -1
  899. transformers/models/unispeech/modeling_unispeech.py +7 -4
  900. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
  901. transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
  902. transformers/models/upernet/configuration_upernet.py +8 -35
  903. transformers/models/upernet/modeling_upernet.py +1 -1
  904. transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
  905. transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
  906. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  907. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
  908. transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
  909. transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
  910. transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
  911. transformers/models/video_llava/configuration_video_llava.py +4 -0
  912. transformers/models/video_llava/modeling_video_llava.py +87 -89
  913. transformers/models/videomae/modeling_videomae.py +4 -5
  914. transformers/models/vilt/configuration_vilt.py +4 -1
  915. transformers/models/vilt/image_processing_vilt_fast.py +6 -6
  916. transformers/models/vilt/modeling_vilt.py +27 -12
  917. transformers/models/vipllava/configuration_vipllava.py +4 -0
  918. transformers/models/vipllava/modeling_vipllava.py +57 -31
  919. transformers/models/vipllava/modular_vipllava.py +50 -24
  920. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
  921. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
  922. transformers/models/visual_bert/configuration_visual_bert.py +6 -1
  923. transformers/models/vit/configuration_vit.py +2 -2
  924. transformers/models/vit/modeling_vit.py +7 -5
  925. transformers/models/vit_mae/modeling_vit_mae.py +11 -7
  926. transformers/models/vit_msn/modeling_vit_msn.py +11 -7
  927. transformers/models/vitdet/configuration_vitdet.py +2 -4
  928. transformers/models/vitdet/modeling_vitdet.py +2 -3
  929. transformers/models/vitmatte/configuration_vitmatte.py +6 -35
  930. transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
  931. transformers/models/vitmatte/modeling_vitmatte.py +1 -1
  932. transformers/models/vitpose/configuration_vitpose.py +6 -43
  933. transformers/models/vitpose/modeling_vitpose.py +5 -3
  934. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
  935. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
  936. transformers/models/vits/configuration_vits.py +4 -0
  937. transformers/models/vits/modeling_vits.py +9 -7
  938. transformers/models/vivit/modeling_vivit.py +4 -4
  939. transformers/models/vjepa2/modeling_vjepa2.py +9 -9
  940. transformers/models/voxtral/configuration_voxtral.py +0 -1
  941. transformers/models/voxtral/modeling_voxtral.py +25 -24
  942. transformers/models/voxtral/modular_voxtral.py +26 -20
  943. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
  944. transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
  945. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
  946. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
  947. transformers/models/wavlm/configuration_wavlm.py +4 -1
  948. transformers/models/wavlm/modeling_wavlm.py +4 -1
  949. transformers/models/whisper/configuration_whisper.py +6 -4
  950. transformers/models/whisper/generation_whisper.py +0 -1
  951. transformers/models/whisper/modeling_whisper.py +3 -3
  952. transformers/models/x_clip/configuration_x_clip.py +4 -1
  953. transformers/models/x_clip/modeling_x_clip.py +26 -27
  954. transformers/models/xglm/configuration_xglm.py +9 -7
  955. transformers/models/xlm/configuration_xlm.py +10 -7
  956. transformers/models/xlm/modeling_xlm.py +1 -1
  957. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
  958. transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
  959. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
  960. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
  961. transformers/models/xlnet/configuration_xlnet.py +3 -1
  962. transformers/models/xlstm/configuration_xlstm.py +5 -7
  963. transformers/models/xlstm/modeling_xlstm.py +0 -32
  964. transformers/models/xmod/configuration_xmod.py +11 -2
  965. transformers/models/xmod/modeling_xmod.py +13 -16
  966. transformers/models/yolos/image_processing_yolos_fast.py +25 -28
  967. transformers/models/yolos/modeling_yolos.py +7 -7
  968. transformers/models/yolos/modular_yolos.py +16 -16
  969. transformers/models/yoso/configuration_yoso.py +8 -1
  970. transformers/models/youtu/__init__.py +27 -0
  971. transformers/models/youtu/configuration_youtu.py +194 -0
  972. transformers/models/youtu/modeling_youtu.py +619 -0
  973. transformers/models/youtu/modular_youtu.py +254 -0
  974. transformers/models/zamba/configuration_zamba.py +5 -7
  975. transformers/models/zamba/modeling_zamba.py +25 -56
  976. transformers/models/zamba2/configuration_zamba2.py +8 -13
  977. transformers/models/zamba2/modeling_zamba2.py +53 -78
  978. transformers/models/zamba2/modular_zamba2.py +36 -29
  979. transformers/models/zoedepth/configuration_zoedepth.py +17 -40
  980. transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
  981. transformers/models/zoedepth/modeling_zoedepth.py +5 -3
  982. transformers/pipelines/__init__.py +1 -61
  983. transformers/pipelines/any_to_any.py +1 -1
  984. transformers/pipelines/automatic_speech_recognition.py +0 -2
  985. transformers/pipelines/base.py +1 -1
  986. transformers/pipelines/image_text_to_text.py +1 -1
  987. transformers/pipelines/text_to_audio.py +5 -1
  988. transformers/processing_utils.py +35 -44
  989. transformers/pytorch_utils.py +2 -26
  990. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  991. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  992. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  993. transformers/quantizers/quantizer_mxfp4.py +1 -1
  994. transformers/quantizers/quantizer_torchao.py +0 -16
  995. transformers/safetensors_conversion.py +11 -4
  996. transformers/testing_utils.py +3 -28
  997. transformers/tokenization_mistral_common.py +9 -0
  998. transformers/tokenization_python.py +6 -4
  999. transformers/tokenization_utils_base.py +119 -219
  1000. transformers/tokenization_utils_tokenizers.py +31 -2
  1001. transformers/trainer.py +25 -33
  1002. transformers/trainer_seq2seq.py +1 -1
  1003. transformers/training_args.py +411 -417
  1004. transformers/utils/__init__.py +1 -4
  1005. transformers/utils/auto_docstring.py +15 -18
  1006. transformers/utils/backbone_utils.py +13 -373
  1007. transformers/utils/doc.py +4 -36
  1008. transformers/utils/generic.py +69 -33
  1009. transformers/utils/import_utils.py +72 -75
  1010. transformers/utils/loading_report.py +133 -105
  1011. transformers/utils/quantization_config.py +0 -21
  1012. transformers/video_processing_utils.py +5 -5
  1013. transformers/video_utils.py +3 -1
  1014. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
  1015. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
  1016. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1017. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1018. transformers/pipelines/image_to_text.py +0 -189
  1019. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1020. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1021. {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -32,7 +32,7 @@ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
32
32
  from ...masking_utils import create_causal_mask
33
33
  from ...modeling_flash_attention_utils import FlashAttentionKwargs
34
34
  from ...modeling_layers import GradientCheckpointingLayer
35
- from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
35
+ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
36
36
  from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
37
37
  from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
38
38
  from ...processing_utils import Unpack
@@ -126,9 +126,9 @@ class GlmImageVisionAttention(nn.Module):
126
126
  key_states = key_states.transpose(0, 1).unsqueeze(0)
127
127
  value_states = value_states.transpose(0, 1).unsqueeze(0)
128
128
 
129
- attention_interface: Callable = eager_attention_forward
130
- if self.config._attn_implementation != "eager":
131
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
129
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
130
+ self.config._attn_implementation, eager_attention_forward
131
+ )
132
132
 
133
133
  if "flash" in self.config._attn_implementation:
134
134
  # Flash Attention: Use cu_seqlens for variable length attention
@@ -402,9 +402,9 @@ class GlmImageTextAttention(nn.Module):
402
402
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
403
403
  key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
404
404
 
405
- attention_interface: Callable = eager_attention_forward
406
- if self.config._attn_implementation != "eager":
407
- attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
405
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
406
+ self.config._attn_implementation, eager_attention_forward
407
+ )
408
408
 
409
409
  attn_output, attn_weights = attention_interface(
410
410
  self,
@@ -612,6 +612,23 @@ class GlmImageVQVAEVectorQuantizer(nn.Module):
612
612
  return hidden_state_quant, loss, min_encoding_indices
613
613
 
614
614
 
615
+ @dataclass
616
+ @auto_docstring
617
+ class GlmImageVQVAEModelOutput(BaseModelOutputWithPooling):
618
+ r"""
619
+ quantized_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
620
+ Quantized last hidden state from the VQ-VAE model.
621
+ image_tokens (`torch.FloatTensor` of shape `(batch_size, config.vocab_size`):
622
+ Indices of the image tokens predicted by the VQ-VAE model.
623
+ embedding_loss (`torch.FloatTensor`):
624
+ The embedding loss computed during quantization.
625
+ """
626
+
627
+ quantized_last_hidden_state: torch.FloatTensor | None = None
628
+ image_tokens: torch.FloatTensor | None = None
629
+ embedding_loss: torch.FloatTensor | None = None
630
+
631
+
615
632
  @auto_docstring(
616
633
  custom_intro="""
617
634
  The VQ-VAE model used in GlmImage for encoding/decoding images into discrete tokens.
@@ -625,6 +642,7 @@ class GlmImageVQVAE(GlmImagePreTrainedModel):
625
642
  _no_split_modules = [
626
643
  "GlmImageVQVAEVectorQuantizer",
627
644
  ]
645
+ _can_record_outputs = {}
628
646
 
629
647
  def __init__(self, config: GlmImageVQVAEConfig):
630
648
  super().__init__(config)
@@ -634,16 +652,26 @@ class GlmImageVQVAE(GlmImagePreTrainedModel):
634
652
  self.eval() # GlmImage's VQ model is frozen
635
653
  self.post_init()
636
654
 
637
- def encode(self, hidden_states):
638
- hidden_states = self.quant_conv(hidden_states)
639
- quant, emb_loss, indices = self.quantize(hidden_states)
640
- return quant, emb_loss, indices
655
+ @check_model_inputs
656
+ def encode(self, hidden_states) -> GlmImageVQVAEModelOutput:
657
+ conv_hidden_states = self.quant_conv(hidden_states)
658
+ quantized_last_hidden_state, emb_loss, indices = self.quantize(conv_hidden_states)
659
+ return GlmImageVQVAEModelOutput(
660
+ last_hidden_state=hidden_states,
661
+ quantized_last_hidden_state=quantized_last_hidden_state,
662
+ image_tokens=indices,
663
+ embedding_loss=emb_loss,
664
+ )
641
665
 
642
666
 
643
667
  class GlmImageVisionModel(GlmImagePreTrainedModel):
644
668
  config: GlmImageVisionConfig
645
669
  input_modalities = ("image",)
646
670
  _no_split_modules = ["GlmImageVisionBlock"]
671
+ _can_record_outputs = {
672
+ "hidden_states": GlmImageVisionBlock,
673
+ "attentions": GlmImageVisionAttention,
674
+ }
647
675
  main_input_name = "pixel_values"
648
676
 
649
677
  def __init__(self, config: GlmImageVisionConfig) -> None:
@@ -688,13 +716,16 @@ class GlmImageVisionModel(GlmImagePreTrainedModel):
688
716
  pos_ids = torch.cat(pos_ids, dim=0)
689
717
  return pos_ids
690
718
 
691
- def forward(self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
692
- """
693
- Args:
694
- pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`):
695
- Packed pixel values.
696
- grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
697
- The temporal, height and width of feature shape of each image.
719
+ @check_model_inputs
720
+ @auto_docstring
721
+ def forward(
722
+ self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
723
+ ) -> tuple | BaseModelOutputWithPooling:
724
+ r"""
725
+ pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`):
726
+ Packed pixel values.
727
+ grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
728
+ The temporal, height and width of feature shape of each image.
698
729
 
699
730
  Returns:
700
731
  `torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states.
@@ -723,7 +754,8 @@ class GlmImageVisionModel(GlmImagePreTrainedModel):
723
754
  hidden_states,
724
755
  cu_seqlens=cu_seqlens,
725
756
  )
726
- return hidden_states
757
+
758
+ return BaseModelOutputWithPooling(last_hidden_state=hidden_states)
727
759
 
728
760
 
729
761
  class GlmImageTextRotaryEmbedding(nn.Module):
@@ -927,6 +959,10 @@ class GlmImageModel(GlmImagePreTrainedModel):
927
959
  self.rope_deltas = None # cache rope_deltas here
928
960
  self.vqmodel = GlmImageVQVAE._from_config(config.vq_config)
929
961
 
962
+ # Per-sample caches for batch processing
963
+ self._cached_decode_position_ids = None # shape: [batch_size, 3, max_decode_len]
964
+ self._prefill_len = None # prefill sequence length (same for all samples in batch)
965
+
930
966
  # Initialize weights and apply final processing
931
967
  self.post_init()
932
968
 
@@ -940,220 +976,169 @@ class GlmImageModel(GlmImagePreTrainedModel):
940
976
  self,
941
977
  input_ids: torch.LongTensor | None = None,
942
978
  image_grid_thw: torch.LongTensor | None = None,
979
+ images_per_sample: torch.LongTensor | None = None,
943
980
  attention_mask: torch.LongTensor | None = None,
944
981
  ) -> tuple[torch.Tensor, torch.Tensor]:
945
982
  """
946
- Calculate the 3D rope index for image generation task.
947
-
948
- Explanation:
949
- Each embedding sequence may contain image tokens (for generation) and text tokens,
950
- or just text tokens.
951
-
952
- Input format:
953
- - Text-to-Image: [text tokens] + <|dit_token_16384|>
954
- - Image-to-Image: <|dit_token_16384|> [image tokens] <|dit_token_16385|> + [text tokens] + <|dit_token_16384|>
955
-
956
- For pure text embedding sequence, the rotary position embedding is the same across all 3 dimensions.
957
- Examples:
958
- input_ids: [T T T T T], here T is for text.
959
- temporal position_ids: [0, 1, 2, 3, 4]
960
- height position_ids: [0, 1, 2, 3, 4]
961
- width position_ids: [0, 1, 2, 3, 4]
962
-
963
- For sequences with image tokens, we use special markers to denote image regions:
964
- - <|dit_token_16384|>: image start marker
965
- - <|dit_token_16385|>: image end marker
966
- - Image tokens between these markers use 2D spatial position encoding.
967
-
968
- For image tokens:
969
- - temporal: stays constant at (image_start_pos + 1)
970
- - height: increments every w tokens, representing row position
971
- - width: cycles from 0 to w-1, representing column position
972
-
973
- After each image region, the next position jumps to: image_start_pos + 1 + max(h, w)
974
- This ensures sufficient positional separation between images and subsequent tokens.
975
-
976
- Examples:
977
- === Case 1: Image-to-Image Generation ===
978
-
979
- Source image with grid [1, 3, 2], followed by text, then generation.
980
- input_ids: [<|dit_token_16384|> V V V V V V <|dit_token_16385|> T T T T <|dit_token_16384|>]
981
- image_grid_thw: [[1, 3, 2], [1, 4, 4]] # first is source, second is target
982
-
983
- For source image (h=3, w=2, 6 tokens):
984
- Start marker at position 0
985
- Image tokens at temporal=1, height=[1,1,2,2,3,3], width=[1,2,1,2,1,2]
986
- End marker at position 4 (= 0 + 1 + max(3,2))
987
-
988
- Text tokens and trailing start marker continue from position 5.
989
-
990
- Full prefill position_ids:
991
- temporal: [0, 1,1,1,1,1,1, 4, 5,6,7,8, 9]
992
- height: [0, 1,1,2,2,3,3, 4, 5,6,7,8, 9]
993
- width: [0, 1,2,1,2,1,2, 4, 5,6,7,8, 9]
994
-
995
- Decode stage: use image_grid_thw[-1] = [1, 4, 4] to build cached position_ids,
996
- starting from gen_st_idx = 10.
997
-
998
- === Case 2: Text-to-Image Generation (multi-resolution) ===
999
-
1000
- Pure text input with two image_grids for progressive generation.
1001
- input_ids: [hello<sop>3 3<eop><sop>3 2<eop><|dit_token_16384|>]
1002
- Assume "hello<sop>3 3<eop><sop>3 2<eop>" = 4 tokens (positions 0-3)
1003
- <|dit_token_16384|> at position 4
1004
- image_grid_thw: [[1, 3, 3], [1, 3, 2]]
1005
- - image_grid_thw[-1] = [1, 3, 2]: first generated image (smaller/draft)
1006
- - image_grid_thw[-2] = [1, 3, 3]: second generated image (larger/final)
1007
-
1008
- Prefill position_ids (5 tokens: 4 text + 1 start marker):
1009
- temporal: [0, 1, 2, 3, 4]
1010
- height: [0, 1, 2, 3, 4]
1011
- width: [0, 1, 2, 3, 4]
1012
-
1013
- Decode stage builds position_ids in reverse order of image_grid_thw:
1014
-
1015
- First: image_grid_thw[-1] = [1, 3, 2] (6 tokens), starting at position 5:
1016
- temporal: [5, 5, 5, 5, 5, 5]
1017
- height: [5, 5, 6, 6, 7, 7]
1018
- width: [5, 6, 5, 6, 5, 6]
1019
- next_pos = 5 + max(3, 2) = 8
1020
-
1021
- Then: image_grid_thw[-2] = [1, 3, 3] (9 tokens), starting at position 8:
1022
- temporal: [8, 8, 8, 8, 8, 8, 8, 8, 8]
1023
- height: [8, 8, 8, 9, 9, 9, 10, 10, 10]
1024
- width: [8, 9, 10, 8, 9, 10, 8, 9, 10]
1025
- next_pos = 8 + max(3, 3) = 11
1026
-
1027
- Finally: <|dit_token_16385|> end marker at position 11
1028
-
1029
- Full sequence position_ids (prefill + decode):
1030
- temporal: [0,1,2,3, 4, 5,5,5,5,5,5, 8,8,8,8,8,8,8,8,8, 11]
1031
- height: [0,1,2,3, 4, 5,5,6,6,7,7, 8,8,8,9,9,9,10,10,10, 11]
1032
- width: [0,1,2,3, 4, 5,6,5,6,5,6, 8,9,10,8,9,10,8,9,10, 11]
1033
-
1034
- _cached_decode_position_ids shape: [3, 6 + 9 + 1] = [3, 16]
1035
- (includes all generated image tokens + end marker)
983
+ Calculate the 3D rope index for image generation task with full batch support.
1036
984
 
1037
985
  Args:
1038
986
  input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1039
- Indices of input sequence tokens in the vocabulary. Padding will be ignored by default
1040
- should you provide it.
1041
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1042
- The temporal, height and width of feature shape of each image. For image generation,
1043
- temporal is typically 1.
1044
- - For image-to-image: includes source image grids + target image grid(s)
1045
- - For text-to-image with multi-resolution: includes multiple target grids,
1046
- processed in reverse order (last grid first, second-to-last grid second, etc.)
987
+ Indices of input sequence tokens in the vocabulary.
988
+ image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
989
+ The temporal, height and width of feature shape of each image.
990
+ Images are packed across all samples in the batch.
991
+ images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
992
+ Number of images (including target grids) for each sample in the batch.
993
+ Used to split image_grid_thw by sample.
1047
994
  attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
1048
- Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
1049
- - 1 for tokens that are **not masked**,
1050
- - 0 for tokens that are **masked**.
995
+ Mask to avoid performing attention on padding token indices.
1051
996
 
1052
997
  Returns:
1053
998
  position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`):
1054
999
  Position IDs for temporal, height, and width dimensions.
1055
1000
  mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`):
1056
- Position deltas for multi-modal rotary position embedding (zeros for this task).
1001
+ Position deltas for multi-modal rotary position embedding.
1057
1002
  """
1058
-
1059
1003
  batch_size, seq_len = input_ids.shape
1060
1004
  device = input_ids.device
1061
1005
  dtype = input_ids.dtype
1062
1006
 
1063
1007
  image_start_token_id = self.config.image_start_token_id
1064
1008
  image_end_token_id = self.config.image_end_token_id
1065
- num_complete_images = (input_ids == image_end_token_id).sum().item()
1066
1009
 
1067
- position_ids = torch.ones(
1068
- 3, input_ids.shape[0], input_ids.shape[1], dtype=input_ids.dtype, device=input_ids.device
1069
- )
1070
- text_positions = torch.arange(seq_len)[None, :].repeat(3, 1)
1010
+ position_ids = torch.ones(3, batch_size, seq_len, dtype=dtype, device=device)
1011
+ text_positions = torch.arange(seq_len, device=device)[None, :].repeat(3, 1)
1012
+
1013
+ # Split image_grid_thw by sample if images_per_sample is provided
1014
+ if image_grid_thw is not None and images_per_sample is not None:
1015
+ grids_per_sample = torch.split(image_grid_thw, images_per_sample.tolist())
1016
+ elif image_grid_thw is not None:
1017
+ # Fallback: assume all grids belong to first sample (batch_size=1)
1018
+ grids_per_sample = [image_grid_thw] * batch_size
1019
+ else:
1020
+ grids_per_sample = [None] * batch_size
1021
+
1022
+ # Per-sample caches for decode stage
1023
+ all_decode_position_ids = []
1024
+
1071
1025
  for batch_idx in range(batch_size):
1072
1026
  curr_input_ids = input_ids[batch_idx]
1073
- if attention_mask is not None:
1074
- curr_input_ids = curr_input_ids[attention_mask[batch_idx] == 1]
1027
+ curr_grids = grids_per_sample[batch_idx]
1075
1028
 
1076
- image_end = torch.where(curr_input_ids == image_end_token_id)[0]
1077
- image_start = torch.where(curr_input_ids == image_start_token_id)[0] + 1
1078
- current_pos = 0 # track the current position value
1029
+ if attention_mask is not None and attention_mask.shape[1] == seq_len:
1030
+ valid_mask = attention_mask[batch_idx] == 1
1031
+ curr_input_ids_valid = curr_input_ids[valid_mask]
1032
+ else:
1033
+ # attention_mask may have different length during assisted decoding
1034
+ curr_input_ids_valid = curr_input_ids
1035
+ valid_mask = None
1036
+
1037
+ # Find image boundaries in this sample
1038
+ image_end_positions = torch.where(curr_input_ids_valid == image_end_token_id)[0]
1039
+ image_start_positions = torch.where(curr_input_ids_valid == image_start_token_id)[0] + 1
1040
+ num_complete_images = len(image_end_positions)
1041
+
1042
+ current_pos = 0
1079
1043
  prev_image_end = 0
1080
1044
  curr_position_ids = []
1081
- for start, end, grid in zip(image_start, image_end, image_grid_thw):
1082
- _, num_width_grid, num_height_grid = grid
1083
1045
 
1084
- # Create text position ids first if there are text tokens before image
1046
+ # Process complete images (source images in image-to-image task)
1047
+ for img_idx, (start, end) in enumerate(zip(image_start_positions, image_end_positions)):
1048
+ if curr_grids is None or img_idx >= len(curr_grids):
1049
+ break
1050
+ grid = curr_grids[img_idx]
1051
+ # grid format is [temporal, height, width]
1052
+ _, height, width = grid.tolist()
1053
+
1054
+ # Text tokens before this image
1085
1055
  llm_pos_length = start - prev_image_end
1086
- llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(
1087
- device=input_ids.device
1088
- )
1056
+ llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(device=device)
1089
1057
  current_pos += llm_position_ids.shape[-1]
1090
1058
 
1091
- # Now create image position ids for each grid
1092
- image_seq_length = num_height_grid * num_width_grid
1093
- h_grids = image_seq_length // num_height_grid + current_pos
1094
- w_grids = image_seq_length // num_width_grid + current_pos
1095
- position_width = torch.arange(current_pos, w_grids, device=input_ids.device).repeat(num_width_grid)
1096
- position_height = torch.arange(current_pos, h_grids, device=input_ids.device).repeat_interleave(
1097
- num_height_grid
1098
- )
1099
- position_temporal = torch.full(
1100
- (image_seq_length,), current_pos, device=input_ids.device, dtype=torch.long
1059
+ # Image tokens with 2D spatial encoding
1060
+ # For an image with height H and width W:
1061
+ # - position_width cycles [0, 1, ..., W-1] for each row, repeated H times
1062
+ # - position_height stays constant per row, [0]*W, [1]*W, ..., [H-1]*W
1063
+ image_seq_length = height * width
1064
+ position_width = torch.arange(current_pos, current_pos + width, device=device).repeat(height)
1065
+ position_height = torch.arange(current_pos, current_pos + height, device=device).repeat_interleave(
1066
+ width
1101
1067
  )
1068
+ position_temporal = torch.full((image_seq_length,), current_pos, device=device, dtype=torch.long)
1102
1069
  vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0)
1103
- current_pos += max(num_height_grid, num_width_grid)
1070
+ current_pos += max(height, width)
1104
1071
 
1105
1072
  prev_image_end = end
1106
1073
  curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1))
1107
1074
 
1108
- # Add position ids for the last text tokens if any
1109
- end_position = len(curr_input_ids) - prev_image_end
1110
- llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=input_ids.device)
1075
+ # Remaining text tokens (including the final image_start token for generation)
1076
+ end_position = len(curr_input_ids_valid) - prev_image_end
1077
+ llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=device)
1111
1078
  current_pos += llm_position_ids.shape[-1]
1112
1079
  curr_position_ids.append(llm_position_ids)
1080
+
1081
+ # Concatenate all position ids for this sample
1113
1082
  curr_position_ids = torch.cat(curr_position_ids, dim=-1)
1114
- if attention_mask is not None:
1115
- position_ids[:, batch_idx, attention_mask[batch_idx] == 1] = curr_position_ids.to(position_ids.device)
1083
+
1084
+ # Store in the main position_ids tensor
1085
+ if valid_mask is not None:
1086
+ position_ids[:, batch_idx, valid_mask] = curr_position_ids
1116
1087
  else:
1117
- position_ids[:, batch_idx, :] = curr_position_ids.to(position_ids.device)
1088
+ position_ids[:, batch_idx, :] = curr_position_ids
1089
+
1090
+ # Build decode position ids for this sample
1091
+ if curr_grids is not None and len(curr_grids) > 0:
1092
+ num_decode_grids = len(curr_grids) - num_complete_images
1093
+ num_decode_grids = max(num_decode_grids, 0)
1094
+ decode_pos = current_pos
1095
+
1096
+ decode_temporal_list = []
1097
+ decode_height_list = []
1098
+ decode_width_list = []
1099
+
1100
+ for i in range(1, num_decode_grids + 1):
1101
+ grid_idx = -i
1102
+ h = curr_grids[grid_idx, 1].item()
1103
+ w = curr_grids[grid_idx, 2].item()
1104
+ total_tokens = h * w
1105
+
1106
+ h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
1107
+ w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()
1108
+
1109
+ decode_temporal_list.append(
1110
+ torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long)
1111
+ )
1112
+ decode_height_list.append(decode_pos + h_indices)
1113
+ decode_width_list.append(decode_pos + w_indices)
1114
+ decode_pos = decode_pos + max(h, w)
1115
+
1116
+ # End marker
1117
+ decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1118
+ decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1119
+ decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1120
+
1121
+ sample_decode_pos_ids = torch.stack(
1122
+ [
1123
+ torch.cat(decode_temporal_list, dim=0),
1124
+ torch.cat(decode_height_list, dim=0),
1125
+ torch.cat(decode_width_list, dim=0),
1126
+ ],
1127
+ dim=0,
1128
+ )
1129
+ all_decode_position_ids.append(sample_decode_pos_ids)
1118
1130
 
1119
- # Build and store position ids for tokens that will be generated. Later we will just
1120
- # slice these instead of computing each decoding step
1131
+ # Store prefill length (same for all samples since input_ids is padded to same length)
1121
1132
  self._prefill_len = seq_len
1122
- if image_grid_thw is not None and len(image_grid_thw) > 0:
1123
- num_decode_grids = len(image_grid_thw) - num_complete_images
1124
- num_decode_grids = max(num_decode_grids, 0)
1125
- decode_pos = current_pos
1126
-
1127
- decode_temporal_list = []
1128
- decode_height_list = []
1129
- decode_width_list = []
1130
-
1131
- for i in range(1, num_decode_grids + 1):
1132
- grid_idx = -i
1133
- h = image_grid_thw[grid_idx, 1].item()
1134
- w = image_grid_thw[grid_idx, 2].item()
1135
- total_tokens = h * w
1136
-
1137
- h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
1138
- w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()
1139
-
1140
- decode_temporal_list.append(torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long))
1141
- decode_height_list.append(decode_pos + h_indices)
1142
- decode_width_list.append(decode_pos + w_indices)
1143
- decode_pos = decode_pos + max(h, w)
1144
-
1145
- decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1146
- decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1147
- decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
1148
-
1149
- self._cached_decode_position_ids = torch.stack(
1150
- [
1151
- torch.cat(decode_temporal_list, dim=0),
1152
- torch.cat(decode_height_list, dim=0),
1153
- torch.cat(decode_width_list, dim=0),
1154
- ],
1155
- dim=0,
1156
- )
1133
+
1134
+ # Pad decode position ids to same length and stack
1135
+ if all_decode_position_ids:
1136
+ max_decode_len = max(x.shape[1] for x in all_decode_position_ids)
1137
+ padded_decode_pos_ids = [
1138
+ F.pad(pos_ids, (0, max_decode_len - pos_ids.shape[1]), mode="replicate")
1139
+ for pos_ids in all_decode_position_ids
1140
+ ]
1141
+ self._cached_decode_position_ids = torch.stack(padded_decode_pos_ids, dim=0) # [batch, 3, max_decode_len]
1157
1142
  else:
1158
1143
  self._cached_decode_position_ids = None
1159
1144
 
@@ -1161,21 +1146,27 @@ class GlmImageModel(GlmImagePreTrainedModel):
1161
1146
 
1162
1147
  return position_ids, mrope_position_deltas
1163
1148
 
1164
- def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
1165
- """
1166
- Encodes images into continuous embeddings that can be forwarded to the language model.
1167
-
1168
- Args:
1169
- pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1170
- The tensors corresponding to the input images.
1171
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1172
- The temporal, height and width of feature shape of each image in LLM.
1149
+ @can_return_tuple
1150
+ @auto_docstring
1151
+ def get_image_features(
1152
+ self,
1153
+ pixel_values: torch.FloatTensor,
1154
+ image_grid_thw: torch.LongTensor | None = None,
1155
+ **kwargs: Unpack[TransformersKwargs],
1156
+ ) -> tuple | BaseModelOutputWithPooling:
1157
+ r"""
1158
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1159
+ The tensors corresponding to the input images.
1160
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1161
+ The temporal, height and width of feature shape of each image in LLM.
1173
1162
  """
1174
1163
  pixel_values = pixel_values.type(self.visual.dtype)
1175
- image_embeds = self.visual(pixel_values, grid_thw=image_grid_thw)
1164
+ vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs)
1176
1165
  split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1177
- image_embeds = torch.split(image_embeds, split_sizes)
1178
- return image_embeds
1166
+ image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes)
1167
+ vision_outputs.pooler_output = image_embeds
1168
+
1169
+ return vision_outputs
1179
1170
 
1180
1171
  def get_placeholder_mask(
1181
1172
  self,
@@ -1219,23 +1210,63 @@ class GlmImageModel(GlmImagePreTrainedModel):
1219
1210
  inputs_embeds: torch.FloatTensor | None = None,
1220
1211
  pixel_values: torch.Tensor | None = None,
1221
1212
  image_grid_thw: torch.LongTensor | None = None,
1213
+ images_per_sample: torch.LongTensor | None = None,
1222
1214
  rope_deltas: torch.LongTensor | None = None,
1223
1215
  cache_position: torch.LongTensor | None = None,
1224
1216
  **kwargs: Unpack[TransformersKwargs],
1225
1217
  ) -> tuple | GlmImageModelOutputWithPast:
1226
1218
  r"""
1227
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1219
+ image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
1228
1220
  The temporal, height and width of feature shape of each image in LLM.
1221
+ Images are packed across all samples in the batch.
1222
+ images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1223
+ Number of images (including target grids) for each sample in the batch.
1229
1224
  rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1230
1225
  The rope index difference between sequence length and multimodal rope.
1231
1226
  """
1232
1227
  if (input_ids is None) ^ (inputs_embeds is not None):
1233
1228
  raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1234
1229
 
1230
+ batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
1231
+
1235
1232
  if pixel_values is not None:
1236
- image_embeds = self.get_image_features(pixel_values, image_grid_thw[:-1])
1237
- image_embeds = torch.cat(image_embeds, dim=0)
1238
- image_ids = self.get_image_tokens(image_embeds, image_grid_thw[:-1])
1233
+ # Process source images (image-to-image mode)
1234
+ # Source images are identified by counting image_end_token_id in input_ids
1235
+ # Note: We must exclude padding tokens since pad_token_id == image_end_token_id
1236
+ if images_per_sample is not None:
1237
+ grids_per_sample = torch.split(image_grid_thw, images_per_sample.tolist())
1238
+ # Create mask for non-padding tokens (attention_mask=1 means non-padding)
1239
+ # Handle 4D attention mask (from static cache) by extracting diagonal
1240
+ if attention_mask is not None and attention_mask.ndim == 4:
1241
+ non_pad_mask = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
1242
+ if non_pad_mask.dtype.is_floating_point:
1243
+ non_pad_mask = non_pad_mask / torch.finfo(non_pad_mask.dtype).min
1244
+ non_pad_mask = (1.0 - non_pad_mask).int()
1245
+ # Only keep columns matching input_ids length
1246
+ non_pad_mask = non_pad_mask[:, -input_ids.shape[1] :]
1247
+ else:
1248
+ non_pad_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
1249
+
1250
+ source_grids_list = []
1251
+ for sample_idx in range(batch_size):
1252
+ is_image_end = input_ids[sample_idx] == self.config.image_end_token_id
1253
+ is_non_pad = non_pad_mask[sample_idx] == 1
1254
+ num_source = (is_image_end & is_non_pad).sum().item()
1255
+ if num_source > 0:
1256
+ source_grids_list.append(grids_per_sample[sample_idx][:num_source])
1257
+ if len(source_grids_list) == 0:
1258
+ raise ValueError(
1259
+ "pixel_values provided but no source images found in input_ids. "
1260
+ "Ensure input_ids contains image_end_token_id for each source image."
1261
+ )
1262
+ source_grids = torch.cat(source_grids_list, dim=0)
1263
+ else:
1264
+ # Fallback for batch_size=1: all but last grid are source images
1265
+ source_grids = image_grid_thw[:-1]
1266
+
1267
+ image_features = self.get_image_features(pixel_values, source_grids, return_dict=True)
1268
+ image_embeds = torch.cat(image_features.pooler_output, dim=0)
1269
+ image_ids = self.get_image_tokens(image_embeds, source_grids)
1239
1270
  image_ids = image_ids.view(-1).to(input_ids.device)
1240
1271
  special_image_mask = self.get_placeholder_mask(input_ids, image_ids)
1241
1272
  input_ids = input_ids.masked_scatter(special_image_mask, image_ids)
@@ -1253,8 +1284,6 @@ class GlmImageModel(GlmImagePreTrainedModel):
1253
1284
  attention_mask_2d = (1.0 - attention_mask_2d).int()
1254
1285
 
1255
1286
  # Calculate RoPE index once per generation in the pre-fill stage only.
1256
- # It is safe to assume that `length!=1` means we're in pre-fill because the
1257
- # model is used only by DiT pipeline without assisted decoding, etc. techniques
1258
1287
  is_prefill_stage = (input_ids is not None and input_ids.shape[1] != 1) or (
1259
1288
  inputs_embeds is not None and inputs_embeds.shape[1] != 1
1260
1289
  )
@@ -1262,17 +1291,27 @@ class GlmImageModel(GlmImagePreTrainedModel):
1262
1291
  position_ids, rope_deltas = self.get_rope_index(
1263
1292
  input_ids,
1264
1293
  image_grid_thw,
1294
+ images_per_sample=images_per_sample,
1265
1295
  attention_mask=attention_mask_2d,
1266
1296
  )
1267
1297
  self.rope_deltas = rope_deltas
1268
1298
  # then use the prev pre-calculated rope-deltas to get the correct position ids
1269
1299
  else:
1270
1300
  batch_size, seq_length, _ = inputs_embeds.shape
1271
- # Use prefill token length, not position value
1272
- step = cache_position[0].item() - self._prefill_len
1273
- # Direct lookup - no tensor creation overhead
1274
- position_ids = self._cached_decode_position_ids[:, step : step + seq_length]
1275
- position_ids = position_ids.unsqueeze(1).expand(-1, batch_size, -1)
1301
+ # Per-sample decode position lookup
1302
+ # _cached_decode_position_ids shape: [batch_size, 3, max_decode_len]
1303
+ if self._cached_decode_position_ids is not None:
1304
+ step = cache_position[0].item() - self._prefill_len
1305
+ # Get position ids for all samples at once, then transpose to [3, batch_size, seq_length]
1306
+ position_ids = self._cached_decode_position_ids[:, :, step : step + seq_length].permute(1, 0, 2)
1307
+ else:
1308
+ # Fallback for text-to-image or cases without cached decode positions
1309
+ # Use simple incremental positions
1310
+ start_pos = cache_position[0].item()
1311
+ position_ids = torch.arange(
1312
+ start_pos, start_pos + seq_length, device=inputs_embeds.device, dtype=torch.long
1313
+ )
1314
+ position_ids = position_ids.unsqueeze(0).repeat(3, batch_size, 1)
1276
1315
 
1277
1316
  outputs = self.language_model(
1278
1317
  input_ids=None,
@@ -1319,8 +1358,8 @@ class GlmImageModel(GlmImagePreTrainedModel):
1319
1358
  grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
1320
1359
  hs = hs.view(grid_t, grid_h, grid_w, hidden_size)
1321
1360
  hs = hs.permute(0, 3, 1, 2).contiguous()
1322
- _, _, image_toks = self.vqmodel.encode(hs)
1323
- all_image_toks.append(image_toks)
1361
+ vqmodel_outputs: GlmImageVQVAEModelOutput = self.vqmodel.encode(hs)
1362
+ all_image_toks.append(vqmodel_outputs.image_tokens)
1324
1363
  return torch.cat(all_image_toks, dim=0)
1325
1364
 
1326
1365
 
@@ -1369,8 +1408,20 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
1369
1408
  # Initialize weights and apply final processing
1370
1409
  self.post_init()
1371
1410
 
1372
- def get_image_features(self, pixel_values: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
1373
- return self.model.get_image_features(pixel_values, image_grid_thw)
1411
+ @auto_docstring
1412
+ def get_image_features(
1413
+ self,
1414
+ pixel_values: torch.FloatTensor,
1415
+ image_grid_thw: torch.LongTensor | None = None,
1416
+ **kwargs: Unpack[TransformersKwargs],
1417
+ ) -> tuple | BaseModelOutputWithPooling:
1418
+ r"""
1419
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1420
+ The tensors corresponding to the input images.
1421
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1422
+ The temporal, height and width of feature shape of each image in LLM.
1423
+ """
1424
+ return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs)
1374
1425
 
1375
1426
  def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
1376
1427
  return self.model.get_image_tokens(hidden_states, image_grid_thw)
@@ -1385,6 +1436,7 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
1385
1436
  labels: torch.LongTensor | None = None,
1386
1437
  pixel_values: torch.Tensor | None = None,
1387
1438
  image_grid_thw: torch.LongTensor | None = None,
1439
+ images_per_sample: torch.LongTensor | None = None,
1388
1440
  cache_position: torch.LongTensor | None = None,
1389
1441
  logits_to_keep: int | torch.Tensor = 0,
1390
1442
  **kwargs: Unpack[TransformersKwargs],
@@ -1394,14 +1446,18 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
1394
1446
  Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1395
1447
  config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1396
1448
  (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1397
- image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1449
+ image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
1398
1450
  The temporal, height and width of feature shape of each image in LLM.
1451
+ Images are packed across all samples in the batch.
1452
+ images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
1453
+ Number of images (including target grids) for each sample in the batch.
1399
1454
 
1400
1455
  Example:
1401
1456
 
1402
1457
  ```python
1403
1458
  >>> from PIL import Image
1404
- >>> import requests
1459
+ >>> import httpx
1460
+ >>> from io import BytesIO
1405
1461
  >>> from transformers import AutoProcessor, GlmImageForConditionalGeneration
1406
1462
 
1407
1463
  >>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image")
@@ -1417,7 +1473,8 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
1417
1473
  },
1418
1474
  ]
1419
1475
  >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1420
- >>> image = Image.open(requests.get(url, stream=True).raw)
1476
+ >>> with httpx.stream("GET", url) as response:
1477
+ ... image = Image.open(BytesIO(response.read()))
1421
1478
 
1422
1479
  >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1423
1480
  >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
@@ -1431,6 +1488,7 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
1431
1488
  input_ids=input_ids,
1432
1489
  pixel_values=pixel_values,
1433
1490
  image_grid_thw=image_grid_thw,
1491
+ images_per_sample=images_per_sample,
1434
1492
  position_ids=position_ids,
1435
1493
  attention_mask=attention_mask,
1436
1494
  past_key_values=past_key_values,
@@ -1469,6 +1527,7 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
1469
1527
  use_cache=True,
1470
1528
  pixel_values=None,
1471
1529
  image_grid_thw=None,
1530
+ images_per_sample=None,
1472
1531
  is_first_iteration=False,
1473
1532
  **kwargs,
1474
1533
  ):
@@ -1487,6 +1546,7 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
1487
1546
  )
1488
1547
 
1489
1548
  model_inputs["position_ids"] = None
1549
+ model_inputs["images_per_sample"] = images_per_sample
1490
1550
 
1491
1551
  if not is_first_iteration and use_cache:
1492
1552
  model_inputs["pixel_values"] = None
@@ -1523,11 +1583,42 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
1523
1583
  if expand_size == 1:
1524
1584
  return input_ids, model_kwargs
1525
1585
 
1526
- visual_keys = ["pixel_values", "image_grid_thw"]
1586
+ visual_keys = ["pixel_values", "image_grid_thw", "images_per_sample"]
1527
1587
 
1528
1588
  def _expand_dict_for_generation_visual(dict_to_expand):
1529
1589
  image_grid_thw = model_kwargs.get("image_grid_thw", None)
1530
- image_nums = self._get_image_nums(input_ids)
1590
+ if image_grid_thw is None:
1591
+ return dict_to_expand
1592
+
1593
+ images_per_sample = model_kwargs.get("images_per_sample", None)
1594
+
1595
+ # Use images_per_sample if available
1596
+ if images_per_sample is not None:
1597
+ image_nums = images_per_sample.tolist()
1598
+ elif input_ids is not None:
1599
+ # Try to infer from image_grid_thw / batch_size
1600
+ batch_size = input_ids.shape[0]
1601
+ total_grids = image_grid_thw.shape[0]
1602
+ if total_grids % batch_size == 0:
1603
+ grids_per_sample = total_grids // batch_size
1604
+ image_nums = [grids_per_sample] * batch_size
1605
+ else:
1606
+ # Cannot evenly distribute grids - fall back to simple repeat_interleave
1607
+ # This handles test cases where image_grid_thw has (batch_size + 1) rows
1608
+ dict_to_expand["image_grid_thw"] = image_grid_thw.repeat_interleave(expand_size, dim=0)
1609
+ if dict_to_expand.get("pixel_values") is not None:
1610
+ dict_to_expand["pixel_values"] = dict_to_expand["pixel_values"].repeat_interleave(
1611
+ expand_size, dim=0
1612
+ )
1613
+ return dict_to_expand
1614
+ else:
1615
+ image_nums = self._get_image_nums(input_ids).tolist()
1616
+
1617
+ # Get source image counts per sample from image_end_token_id count
1618
+ source_image_nums = [
1619
+ (input_ids[batch_idx] == self.config.image_end_token_id).sum().item()
1620
+ for batch_idx in range(len(image_nums))
1621
+ ]
1531
1622
 
1532
1623
  def _repeat_interleave_samples(x, lengths, repeat_times):
1533
1624
  samples = torch.split(x, lengths)
@@ -1537,21 +1628,31 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
1537
1628
 
1538
1629
  for key in dict_to_expand:
1539
1630
  if key == "pixel_values":
1540
- # split images into samples
1541
- samples = torch.split(image_grid_thw[: sum(image_nums)], list(image_nums))
1542
- # compute the sequence length of images for each sample
1543
- lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1544
- dict_to_expand[key] = _repeat_interleave_samples(
1545
- dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1546
- )
1631
+ # Split images into samples based on source image counts
1632
+ if sum(source_image_nums) > 0:
1633
+ # Split grids by sample to compute pixel counts
1634
+ grids_per_sample = torch.split(image_grid_thw, image_nums)
1635
+ lengths = []
1636
+ for batch_idx, sample_grids in enumerate(grids_per_sample):
1637
+ num_source = source_image_nums[batch_idx]
1638
+ if num_source > 0:
1639
+ source_grids = sample_grids[:num_source]
1640
+ lengths.append(torch.prod(source_grids, dim=1).sum().item())
1641
+ else:
1642
+ lengths.append(0)
1643
+
1644
+ dict_to_expand[key] = _repeat_interleave_samples(
1645
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1646
+ )
1547
1647
  elif key == "image_grid_thw":
1548
- # get the num of images for each sample and +1 for the image being generated
1549
- lengths = list(image_nums)
1550
- last_image = dict_to_expand[key][:-1]
1648
+ # Expand all grids (source + target) per sample
1551
1649
  dict_to_expand[key] = _repeat_interleave_samples(
1552
- dict_to_expand[key][: sum(image_nums)], lengths=lengths, repeat_times=expand_size
1650
+ dict_to_expand[key], lengths=image_nums, repeat_times=expand_size
1553
1651
  )
1554
- dict_to_expand[key] = torch.cat([dict_to_expand[key], last_image], dim=0)
1652
+ elif key == "images_per_sample":
1653
+ # Simply repeat the counts
1654
+ if dict_to_expand.get(key) is not None:
1655
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1555
1656
  return dict_to_expand
1556
1657
 
1557
1658
  def _expand_dict_for_generation(dict_to_expand):