transformers 5.0.0rc2__py3-none-any.whl → 5.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1594) hide show
  1. transformers/__init__.py +11 -37
  2. transformers/activations.py +2 -2
  3. transformers/audio_utils.py +32 -32
  4. transformers/backbone_utils.py +326 -0
  5. transformers/cache_utils.py +26 -126
  6. transformers/cli/chat.py +3 -3
  7. transformers/cli/serve.py +13 -10
  8. transformers/cli/transformers.py +2 -1
  9. transformers/configuration_utils.py +22 -92
  10. transformers/conversion_mapping.py +150 -26
  11. transformers/convert_slow_tokenizer.py +9 -12
  12. transformers/core_model_loading.py +217 -129
  13. transformers/data/processors/glue.py +0 -1
  14. transformers/data/processors/utils.py +0 -1
  15. transformers/data/processors/xnli.py +0 -1
  16. transformers/dependency_versions_check.py +0 -1
  17. transformers/dependency_versions_table.py +10 -11
  18. transformers/distributed/configuration_utils.py +1 -2
  19. transformers/dynamic_module_utils.py +23 -23
  20. transformers/feature_extraction_sequence_utils.py +19 -23
  21. transformers/feature_extraction_utils.py +14 -14
  22. transformers/file_utils.py +0 -2
  23. transformers/generation/candidate_generator.py +2 -4
  24. transformers/generation/configuration_utils.py +54 -39
  25. transformers/generation/continuous_batching/__init__.py +0 -1
  26. transformers/generation/continuous_batching/cache.py +74 -44
  27. transformers/generation/continuous_batching/cache_manager.py +28 -28
  28. transformers/generation/continuous_batching/continuous_api.py +133 -414
  29. transformers/generation/continuous_batching/input_ouputs.py +464 -0
  30. transformers/generation/continuous_batching/requests.py +77 -19
  31. transformers/generation/continuous_batching/scheduler.py +154 -104
  32. transformers/generation/logits_process.py +10 -133
  33. transformers/generation/stopping_criteria.py +1 -2
  34. transformers/generation/streamers.py +0 -1
  35. transformers/generation/utils.py +91 -121
  36. transformers/generation/watermarking.py +2 -3
  37. transformers/hf_argparser.py +9 -13
  38. transformers/hyperparameter_search.py +1 -2
  39. transformers/image_processing_base.py +9 -9
  40. transformers/image_processing_utils.py +11 -15
  41. transformers/image_processing_utils_fast.py +70 -71
  42. transformers/image_transforms.py +73 -42
  43. transformers/image_utils.py +30 -37
  44. transformers/initialization.py +57 -0
  45. transformers/integrations/__init__.py +10 -24
  46. transformers/integrations/accelerate.py +47 -11
  47. transformers/integrations/awq.py +1 -3
  48. transformers/integrations/deepspeed.py +146 -4
  49. transformers/integrations/eetq.py +0 -1
  50. transformers/integrations/executorch.py +2 -6
  51. transformers/integrations/fbgemm_fp8.py +1 -2
  52. transformers/integrations/finegrained_fp8.py +149 -13
  53. transformers/integrations/flash_attention.py +3 -8
  54. transformers/integrations/flex_attention.py +1 -1
  55. transformers/integrations/fp_quant.py +4 -6
  56. transformers/integrations/ggml.py +0 -1
  57. transformers/integrations/hub_kernels.py +18 -7
  58. transformers/integrations/integration_utils.py +2 -3
  59. transformers/integrations/moe.py +226 -106
  60. transformers/integrations/mxfp4.py +52 -40
  61. transformers/integrations/peft.py +488 -176
  62. transformers/integrations/quark.py +2 -4
  63. transformers/integrations/tensor_parallel.py +641 -581
  64. transformers/integrations/torchao.py +4 -6
  65. transformers/loss/loss_lw_detr.py +356 -0
  66. transformers/loss/loss_utils.py +2 -0
  67. transformers/masking_utils.py +199 -59
  68. transformers/model_debugging_utils.py +4 -5
  69. transformers/modelcard.py +14 -192
  70. transformers/modeling_attn_mask_utils.py +19 -19
  71. transformers/modeling_flash_attention_utils.py +28 -29
  72. transformers/modeling_gguf_pytorch_utils.py +5 -5
  73. transformers/modeling_layers.py +21 -22
  74. transformers/modeling_outputs.py +242 -253
  75. transformers/modeling_rope_utils.py +32 -32
  76. transformers/modeling_utils.py +416 -438
  77. transformers/models/__init__.py +10 -0
  78. transformers/models/afmoe/configuration_afmoe.py +40 -33
  79. transformers/models/afmoe/modeling_afmoe.py +38 -41
  80. transformers/models/afmoe/modular_afmoe.py +23 -25
  81. transformers/models/aimv2/configuration_aimv2.py +2 -10
  82. transformers/models/aimv2/modeling_aimv2.py +46 -45
  83. transformers/models/aimv2/modular_aimv2.py +13 -19
  84. transformers/models/albert/configuration_albert.py +8 -2
  85. transformers/models/albert/modeling_albert.py +70 -72
  86. transformers/models/albert/tokenization_albert.py +1 -4
  87. transformers/models/align/configuration_align.py +8 -6
  88. transformers/models/align/modeling_align.py +83 -86
  89. transformers/models/align/processing_align.py +2 -30
  90. transformers/models/altclip/configuration_altclip.py +4 -7
  91. transformers/models/altclip/modeling_altclip.py +106 -103
  92. transformers/models/altclip/processing_altclip.py +2 -15
  93. transformers/models/apertus/__init__.py +0 -1
  94. transformers/models/apertus/configuration_apertus.py +23 -28
  95. transformers/models/apertus/modeling_apertus.py +35 -38
  96. transformers/models/apertus/modular_apertus.py +36 -40
  97. transformers/models/arcee/configuration_arcee.py +25 -30
  98. transformers/models/arcee/modeling_arcee.py +35 -38
  99. transformers/models/arcee/modular_arcee.py +20 -23
  100. transformers/models/aria/configuration_aria.py +31 -44
  101. transformers/models/aria/image_processing_aria.py +25 -27
  102. transformers/models/aria/modeling_aria.py +102 -102
  103. transformers/models/aria/modular_aria.py +111 -124
  104. transformers/models/aria/processing_aria.py +28 -35
  105. transformers/models/audio_spectrogram_transformer/configuration_audio_spectrogram_transformer.py +0 -1
  106. transformers/models/audio_spectrogram_transformer/feature_extraction_audio_spectrogram_transformer.py +3 -6
  107. transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +9 -11
  108. transformers/models/audioflamingo3/__init__.py +0 -1
  109. transformers/models/audioflamingo3/configuration_audioflamingo3.py +0 -1
  110. transformers/models/audioflamingo3/modeling_audioflamingo3.py +60 -52
  111. transformers/models/audioflamingo3/modular_audioflamingo3.py +52 -43
  112. transformers/models/audioflamingo3/processing_audioflamingo3.py +6 -8
  113. transformers/models/auto/auto_factory.py +12 -11
  114. transformers/models/auto/configuration_auto.py +48 -5
  115. transformers/models/auto/feature_extraction_auto.py +5 -7
  116. transformers/models/auto/image_processing_auto.py +30 -39
  117. transformers/models/auto/modeling_auto.py +33 -199
  118. transformers/models/auto/processing_auto.py +11 -19
  119. transformers/models/auto/tokenization_auto.py +38 -37
  120. transformers/models/auto/video_processing_auto.py +7 -8
  121. transformers/models/autoformer/configuration_autoformer.py +4 -7
  122. transformers/models/autoformer/modeling_autoformer.py +100 -101
  123. transformers/models/aya_vision/configuration_aya_vision.py +4 -1
  124. transformers/models/aya_vision/modeling_aya_vision.py +64 -99
  125. transformers/models/aya_vision/modular_aya_vision.py +46 -74
  126. transformers/models/aya_vision/processing_aya_vision.py +25 -53
  127. transformers/models/bamba/configuration_bamba.py +46 -39
  128. transformers/models/bamba/modeling_bamba.py +83 -119
  129. transformers/models/bamba/modular_bamba.py +70 -109
  130. transformers/models/bark/configuration_bark.py +6 -8
  131. transformers/models/bark/generation_configuration_bark.py +3 -5
  132. transformers/models/bark/modeling_bark.py +64 -65
  133. transformers/models/bark/processing_bark.py +19 -41
  134. transformers/models/bart/configuration_bart.py +9 -5
  135. transformers/models/bart/modeling_bart.py +124 -129
  136. transformers/models/barthez/tokenization_barthez.py +1 -4
  137. transformers/models/bartpho/tokenization_bartpho.py +6 -7
  138. transformers/models/beit/configuration_beit.py +2 -15
  139. transformers/models/beit/image_processing_beit.py +53 -56
  140. transformers/models/beit/image_processing_beit_fast.py +11 -12
  141. transformers/models/beit/modeling_beit.py +65 -62
  142. transformers/models/bert/configuration_bert.py +12 -2
  143. transformers/models/bert/modeling_bert.py +117 -152
  144. transformers/models/bert/tokenization_bert.py +2 -4
  145. transformers/models/bert/tokenization_bert_legacy.py +3 -5
  146. transformers/models/bert_generation/configuration_bert_generation.py +17 -2
  147. transformers/models/bert_generation/modeling_bert_generation.py +53 -55
  148. transformers/models/bert_generation/tokenization_bert_generation.py +2 -3
  149. transformers/models/bert_japanese/tokenization_bert_japanese.py +5 -6
  150. transformers/models/bertweet/tokenization_bertweet.py +1 -3
  151. transformers/models/big_bird/configuration_big_bird.py +12 -9
  152. transformers/models/big_bird/modeling_big_bird.py +107 -124
  153. transformers/models/big_bird/tokenization_big_bird.py +1 -4
  154. transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -9
  155. transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +118 -118
  156. transformers/models/biogpt/configuration_biogpt.py +8 -2
  157. transformers/models/biogpt/modeling_biogpt.py +73 -79
  158. transformers/models/biogpt/modular_biogpt.py +60 -66
  159. transformers/models/biogpt/tokenization_biogpt.py +3 -5
  160. transformers/models/bit/configuration_bit.py +2 -5
  161. transformers/models/bit/image_processing_bit.py +21 -24
  162. transformers/models/bit/image_processing_bit_fast.py +0 -1
  163. transformers/models/bit/modeling_bit.py +15 -16
  164. transformers/models/bitnet/configuration_bitnet.py +23 -28
  165. transformers/models/bitnet/modeling_bitnet.py +34 -38
  166. transformers/models/bitnet/modular_bitnet.py +7 -10
  167. transformers/models/blenderbot/configuration_blenderbot.py +8 -5
  168. transformers/models/blenderbot/modeling_blenderbot.py +68 -99
  169. transformers/models/blenderbot/tokenization_blenderbot.py +0 -1
  170. transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -5
  171. transformers/models/blenderbot_small/modeling_blenderbot_small.py +70 -72
  172. transformers/models/blenderbot_small/tokenization_blenderbot_small.py +1 -3
  173. transformers/models/blip/configuration_blip.py +9 -10
  174. transformers/models/blip/image_processing_blip.py +17 -20
  175. transformers/models/blip/image_processing_blip_fast.py +0 -1
  176. transformers/models/blip/modeling_blip.py +115 -108
  177. transformers/models/blip/modeling_blip_text.py +63 -65
  178. transformers/models/blip/processing_blip.py +5 -36
  179. transformers/models/blip_2/configuration_blip_2.py +2 -2
  180. transformers/models/blip_2/modeling_blip_2.py +145 -121
  181. transformers/models/blip_2/processing_blip_2.py +8 -38
  182. transformers/models/bloom/configuration_bloom.py +5 -2
  183. transformers/models/bloom/modeling_bloom.py +60 -60
  184. transformers/models/blt/configuration_blt.py +94 -86
  185. transformers/models/blt/modeling_blt.py +93 -90
  186. transformers/models/blt/modular_blt.py +127 -69
  187. transformers/models/bridgetower/configuration_bridgetower.py +7 -2
  188. transformers/models/bridgetower/image_processing_bridgetower.py +34 -35
  189. transformers/models/bridgetower/image_processing_bridgetower_fast.py +13 -14
  190. transformers/models/bridgetower/modeling_bridgetower.py +136 -124
  191. transformers/models/bridgetower/processing_bridgetower.py +2 -16
  192. transformers/models/bros/configuration_bros.py +24 -18
  193. transformers/models/bros/modeling_bros.py +78 -80
  194. transformers/models/bros/processing_bros.py +2 -12
  195. transformers/models/byt5/tokenization_byt5.py +4 -6
  196. transformers/models/camembert/configuration_camembert.py +8 -2
  197. transformers/models/camembert/modeling_camembert.py +97 -99
  198. transformers/models/camembert/modular_camembert.py +51 -54
  199. transformers/models/camembert/tokenization_camembert.py +1 -4
  200. transformers/models/canine/configuration_canine.py +4 -2
  201. transformers/models/canine/modeling_canine.py +73 -75
  202. transformers/models/canine/tokenization_canine.py +0 -1
  203. transformers/models/chameleon/configuration_chameleon.py +29 -34
  204. transformers/models/chameleon/image_processing_chameleon.py +21 -24
  205. transformers/models/chameleon/image_processing_chameleon_fast.py +5 -6
  206. transformers/models/chameleon/modeling_chameleon.py +135 -92
  207. transformers/models/chameleon/processing_chameleon.py +16 -41
  208. transformers/models/chinese_clip/configuration_chinese_clip.py +10 -8
  209. transformers/models/chinese_clip/image_processing_chinese_clip.py +21 -24
  210. transformers/models/chinese_clip/image_processing_chinese_clip_fast.py +0 -1
  211. transformers/models/chinese_clip/modeling_chinese_clip.py +93 -95
  212. transformers/models/chinese_clip/processing_chinese_clip.py +2 -15
  213. transformers/models/clap/configuration_clap.py +4 -9
  214. transformers/models/clap/feature_extraction_clap.py +9 -10
  215. transformers/models/clap/modeling_clap.py +109 -111
  216. transformers/models/clap/processing_clap.py +2 -15
  217. transformers/models/clip/configuration_clip.py +4 -2
  218. transformers/models/clip/image_processing_clip.py +21 -24
  219. transformers/models/clip/image_processing_clip_fast.py +9 -1
  220. transformers/models/clip/modeling_clip.py +70 -68
  221. transformers/models/clip/processing_clip.py +2 -14
  222. transformers/models/clip/tokenization_clip.py +2 -5
  223. transformers/models/clipseg/configuration_clipseg.py +4 -2
  224. transformers/models/clipseg/modeling_clipseg.py +113 -112
  225. transformers/models/clipseg/processing_clipseg.py +19 -42
  226. transformers/models/clvp/configuration_clvp.py +15 -5
  227. transformers/models/clvp/feature_extraction_clvp.py +7 -10
  228. transformers/models/clvp/modeling_clvp.py +138 -145
  229. transformers/models/clvp/number_normalizer.py +1 -2
  230. transformers/models/clvp/processing_clvp.py +3 -20
  231. transformers/models/clvp/tokenization_clvp.py +0 -1
  232. transformers/models/code_llama/tokenization_code_llama.py +3 -6
  233. transformers/models/codegen/configuration_codegen.py +4 -4
  234. transformers/models/codegen/modeling_codegen.py +50 -49
  235. transformers/models/codegen/tokenization_codegen.py +5 -6
  236. transformers/models/cohere/configuration_cohere.py +25 -30
  237. transformers/models/cohere/modeling_cohere.py +39 -42
  238. transformers/models/cohere/modular_cohere.py +27 -31
  239. transformers/models/cohere/tokenization_cohere.py +5 -6
  240. transformers/models/cohere2/configuration_cohere2.py +27 -32
  241. transformers/models/cohere2/modeling_cohere2.py +38 -41
  242. transformers/models/cohere2/modular_cohere2.py +48 -52
  243. transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
  244. transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +9 -10
  245. transformers/models/cohere2_vision/modeling_cohere2_vision.py +52 -55
  246. transformers/models/cohere2_vision/modular_cohere2_vision.py +41 -43
  247. transformers/models/cohere2_vision/processing_cohere2_vision.py +6 -36
  248. transformers/models/colpali/configuration_colpali.py +0 -1
  249. transformers/models/colpali/modeling_colpali.py +14 -16
  250. transformers/models/colpali/modular_colpali.py +11 -51
  251. transformers/models/colpali/processing_colpali.py +14 -52
  252. transformers/models/colqwen2/modeling_colqwen2.py +27 -28
  253. transformers/models/colqwen2/modular_colqwen2.py +36 -74
  254. transformers/models/colqwen2/processing_colqwen2.py +16 -52
  255. transformers/models/conditional_detr/configuration_conditional_detr.py +19 -47
  256. transformers/models/conditional_detr/image_processing_conditional_detr.py +67 -70
  257. transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +50 -36
  258. transformers/models/conditional_detr/modeling_conditional_detr.py +851 -1001
  259. transformers/models/conditional_detr/modular_conditional_detr.py +901 -5
  260. transformers/models/convbert/configuration_convbert.py +11 -8
  261. transformers/models/convbert/modeling_convbert.py +85 -87
  262. transformers/models/convbert/tokenization_convbert.py +0 -1
  263. transformers/models/convnext/configuration_convnext.py +2 -5
  264. transformers/models/convnext/image_processing_convnext.py +18 -21
  265. transformers/models/convnext/image_processing_convnext_fast.py +7 -8
  266. transformers/models/convnext/modeling_convnext.py +12 -14
  267. transformers/models/convnextv2/configuration_convnextv2.py +2 -5
  268. transformers/models/convnextv2/modeling_convnextv2.py +12 -14
  269. transformers/models/cpm/tokenization_cpm.py +6 -7
  270. transformers/models/cpm/tokenization_cpm_fast.py +3 -5
  271. transformers/models/cpmant/configuration_cpmant.py +4 -1
  272. transformers/models/cpmant/modeling_cpmant.py +38 -40
  273. transformers/models/cpmant/tokenization_cpmant.py +1 -3
  274. transformers/models/csm/configuration_csm.py +58 -66
  275. transformers/models/csm/generation_csm.py +13 -14
  276. transformers/models/csm/modeling_csm.py +81 -84
  277. transformers/models/csm/modular_csm.py +56 -58
  278. transformers/models/csm/processing_csm.py +25 -68
  279. transformers/models/ctrl/configuration_ctrl.py +16 -1
  280. transformers/models/ctrl/modeling_ctrl.py +51 -66
  281. transformers/models/ctrl/tokenization_ctrl.py +0 -1
  282. transformers/models/cvt/configuration_cvt.py +0 -1
  283. transformers/models/cvt/modeling_cvt.py +13 -15
  284. transformers/models/cwm/__init__.py +0 -1
  285. transformers/models/cwm/configuration_cwm.py +8 -12
  286. transformers/models/cwm/modeling_cwm.py +36 -38
  287. transformers/models/cwm/modular_cwm.py +10 -12
  288. transformers/models/d_fine/configuration_d_fine.py +10 -57
  289. transformers/models/d_fine/modeling_d_fine.py +786 -927
  290. transformers/models/d_fine/modular_d_fine.py +339 -417
  291. transformers/models/dab_detr/configuration_dab_detr.py +22 -49
  292. transformers/models/dab_detr/modeling_dab_detr.py +79 -77
  293. transformers/models/dac/configuration_dac.py +0 -1
  294. transformers/models/dac/feature_extraction_dac.py +6 -9
  295. transformers/models/dac/modeling_dac.py +22 -24
  296. transformers/models/data2vec/configuration_data2vec_audio.py +4 -2
  297. transformers/models/data2vec/configuration_data2vec_text.py +11 -3
  298. transformers/models/data2vec/configuration_data2vec_vision.py +0 -1
  299. transformers/models/data2vec/modeling_data2vec_audio.py +55 -59
  300. transformers/models/data2vec/modeling_data2vec_text.py +97 -99
  301. transformers/models/data2vec/modeling_data2vec_vision.py +45 -44
  302. transformers/models/data2vec/modular_data2vec_audio.py +6 -1
  303. transformers/models/data2vec/modular_data2vec_text.py +51 -54
  304. transformers/models/dbrx/configuration_dbrx.py +29 -22
  305. transformers/models/dbrx/modeling_dbrx.py +45 -48
  306. transformers/models/dbrx/modular_dbrx.py +37 -39
  307. transformers/models/deberta/configuration_deberta.py +6 -1
  308. transformers/models/deberta/modeling_deberta.py +57 -60
  309. transformers/models/deberta/tokenization_deberta.py +2 -5
  310. transformers/models/deberta_v2/configuration_deberta_v2.py +6 -1
  311. transformers/models/deberta_v2/modeling_deberta_v2.py +63 -65
  312. transformers/models/deberta_v2/tokenization_deberta_v2.py +1 -4
  313. transformers/models/decision_transformer/configuration_decision_transformer.py +3 -2
  314. transformers/models/decision_transformer/modeling_decision_transformer.py +51 -53
  315. transformers/models/deepseek_v2/configuration_deepseek_v2.py +41 -47
  316. transformers/models/deepseek_v2/modeling_deepseek_v2.py +39 -41
  317. transformers/models/deepseek_v2/modular_deepseek_v2.py +48 -52
  318. transformers/models/deepseek_v3/configuration_deepseek_v3.py +42 -48
  319. transformers/models/deepseek_v3/modeling_deepseek_v3.py +38 -40
  320. transformers/models/deepseek_v3/modular_deepseek_v3.py +10 -10
  321. transformers/models/deepseek_vl/configuration_deepseek_vl.py +6 -3
  322. transformers/models/deepseek_vl/image_processing_deepseek_vl.py +27 -28
  323. transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +12 -11
  324. transformers/models/deepseek_vl/modeling_deepseek_vl.py +48 -43
  325. transformers/models/deepseek_vl/modular_deepseek_vl.py +15 -43
  326. transformers/models/deepseek_vl/processing_deepseek_vl.py +10 -41
  327. transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +7 -5
  328. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +37 -37
  329. transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +22 -22
  330. transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +100 -56
  331. transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +141 -109
  332. transformers/models/deepseek_vl_hybrid/processing_deepseek_vl_hybrid.py +12 -44
  333. transformers/models/deformable_detr/configuration_deformable_detr.py +22 -46
  334. transformers/models/deformable_detr/image_processing_deformable_detr.py +59 -61
  335. transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +42 -28
  336. transformers/models/deformable_detr/modeling_deformable_detr.py +454 -652
  337. transformers/models/deformable_detr/modular_deformable_detr.py +1385 -5
  338. transformers/models/deit/configuration_deit.py +0 -1
  339. transformers/models/deit/image_processing_deit.py +18 -21
  340. transformers/models/deit/image_processing_deit_fast.py +0 -1
  341. transformers/models/deit/modeling_deit.py +27 -25
  342. transformers/models/depth_anything/configuration_depth_anything.py +12 -43
  343. transformers/models/depth_anything/modeling_depth_anything.py +10 -11
  344. transformers/models/depth_pro/configuration_depth_pro.py +0 -1
  345. transformers/models/depth_pro/image_processing_depth_pro.py +22 -23
  346. transformers/models/depth_pro/image_processing_depth_pro_fast.py +8 -9
  347. transformers/models/depth_pro/modeling_depth_pro.py +29 -27
  348. transformers/models/detr/configuration_detr.py +18 -50
  349. transformers/models/detr/image_processing_detr.py +64 -66
  350. transformers/models/detr/image_processing_detr_fast.py +33 -34
  351. transformers/models/detr/modeling_detr.py +748 -789
  352. transformers/models/dia/configuration_dia.py +9 -15
  353. transformers/models/dia/feature_extraction_dia.py +6 -9
  354. transformers/models/dia/generation_dia.py +48 -53
  355. transformers/models/dia/modeling_dia.py +68 -71
  356. transformers/models/dia/modular_dia.py +56 -58
  357. transformers/models/dia/processing_dia.py +39 -29
  358. transformers/models/dia/tokenization_dia.py +3 -6
  359. transformers/models/diffllama/configuration_diffllama.py +25 -30
  360. transformers/models/diffllama/modeling_diffllama.py +45 -53
  361. transformers/models/diffllama/modular_diffllama.py +18 -25
  362. transformers/models/dinat/configuration_dinat.py +2 -5
  363. transformers/models/dinat/modeling_dinat.py +47 -48
  364. transformers/models/dinov2/configuration_dinov2.py +2 -5
  365. transformers/models/dinov2/modeling_dinov2.py +20 -21
  366. transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +3 -5
  367. transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +21 -21
  368. transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +11 -14
  369. transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +6 -11
  370. transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +5 -9
  371. transformers/models/dinov3_vit/configuration_dinov3_vit.py +7 -12
  372. transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +7 -8
  373. transformers/models/dinov3_vit/modeling_dinov3_vit.py +19 -22
  374. transformers/models/dinov3_vit/modular_dinov3_vit.py +16 -19
  375. transformers/models/distilbert/configuration_distilbert.py +8 -2
  376. transformers/models/distilbert/modeling_distilbert.py +47 -49
  377. transformers/models/distilbert/tokenization_distilbert.py +0 -1
  378. transformers/models/doge/__init__.py +0 -1
  379. transformers/models/doge/configuration_doge.py +42 -35
  380. transformers/models/doge/modeling_doge.py +46 -49
  381. transformers/models/doge/modular_doge.py +77 -68
  382. transformers/models/donut/configuration_donut_swin.py +0 -1
  383. transformers/models/donut/image_processing_donut.py +26 -29
  384. transformers/models/donut/image_processing_donut_fast.py +9 -14
  385. transformers/models/donut/modeling_donut_swin.py +44 -46
  386. transformers/models/donut/processing_donut.py +5 -26
  387. transformers/models/dots1/configuration_dots1.py +43 -36
  388. transformers/models/dots1/modeling_dots1.py +35 -38
  389. transformers/models/dots1/modular_dots1.py +0 -1
  390. transformers/models/dpr/configuration_dpr.py +19 -2
  391. transformers/models/dpr/modeling_dpr.py +37 -39
  392. transformers/models/dpr/tokenization_dpr.py +7 -9
  393. transformers/models/dpr/tokenization_dpr_fast.py +7 -9
  394. transformers/models/dpt/configuration_dpt.py +23 -66
  395. transformers/models/dpt/image_processing_dpt.py +65 -66
  396. transformers/models/dpt/image_processing_dpt_fast.py +18 -19
  397. transformers/models/dpt/modeling_dpt.py +38 -36
  398. transformers/models/dpt/modular_dpt.py +14 -15
  399. transformers/models/edgetam/configuration_edgetam.py +1 -2
  400. transformers/models/edgetam/modeling_edgetam.py +87 -89
  401. transformers/models/edgetam/modular_edgetam.py +7 -13
  402. transformers/models/edgetam_video/__init__.py +0 -1
  403. transformers/models/edgetam_video/configuration_edgetam_video.py +0 -1
  404. transformers/models/edgetam_video/modeling_edgetam_video.py +126 -128
  405. transformers/models/edgetam_video/modular_edgetam_video.py +25 -27
  406. transformers/models/efficientloftr/configuration_efficientloftr.py +4 -5
  407. transformers/models/efficientloftr/image_processing_efficientloftr.py +14 -16
  408. transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +8 -7
  409. transformers/models/efficientloftr/modeling_efficientloftr.py +46 -38
  410. transformers/models/efficientloftr/modular_efficientloftr.py +1 -3
  411. transformers/models/efficientnet/configuration_efficientnet.py +0 -1
  412. transformers/models/efficientnet/image_processing_efficientnet.py +23 -26
  413. transformers/models/efficientnet/image_processing_efficientnet_fast.py +16 -17
  414. transformers/models/efficientnet/modeling_efficientnet.py +12 -14
  415. transformers/models/electra/configuration_electra.py +13 -3
  416. transformers/models/electra/modeling_electra.py +107 -109
  417. transformers/models/emu3/configuration_emu3.py +17 -17
  418. transformers/models/emu3/image_processing_emu3.py +44 -39
  419. transformers/models/emu3/modeling_emu3.py +143 -109
  420. transformers/models/emu3/modular_emu3.py +109 -73
  421. transformers/models/emu3/processing_emu3.py +18 -43
  422. transformers/models/encodec/configuration_encodec.py +2 -4
  423. transformers/models/encodec/feature_extraction_encodec.py +10 -13
  424. transformers/models/encodec/modeling_encodec.py +25 -29
  425. transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -2
  426. transformers/models/encoder_decoder/modeling_encoder_decoder.py +37 -43
  427. transformers/models/eomt/configuration_eomt.py +12 -14
  428. transformers/models/eomt/image_processing_eomt.py +53 -55
  429. transformers/models/eomt/image_processing_eomt_fast.py +18 -19
  430. transformers/models/eomt/modeling_eomt.py +19 -21
  431. transformers/models/eomt/modular_eomt.py +28 -30
  432. transformers/models/eomt_dinov3/__init__.py +28 -0
  433. transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
  434. transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
  435. transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
  436. transformers/models/ernie/configuration_ernie.py +24 -3
  437. transformers/models/ernie/modeling_ernie.py +127 -162
  438. transformers/models/ernie/modular_ernie.py +91 -103
  439. transformers/models/ernie4_5/configuration_ernie4_5.py +23 -27
  440. transformers/models/ernie4_5/modeling_ernie4_5.py +35 -37
  441. transformers/models/ernie4_5/modular_ernie4_5.py +1 -3
  442. transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +34 -39
  443. transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +40 -42
  444. transformers/models/ernie4_5_moe/modular_ernie4_5_moe.py +7 -9
  445. transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -7
  446. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe.py +34 -35
  447. transformers/models/ernie4_5_vl_moe/image_processing_ernie4_5_vl_moe_fast.py +6 -7
  448. transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +305 -267
  449. transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +163 -142
  450. transformers/models/ernie4_5_vl_moe/processing_ernie4_5_vl_moe.py +3 -5
  451. transformers/models/ernie4_5_vl_moe/video_processing_ernie4_5_vl_moe.py +17 -18
  452. transformers/models/esm/configuration_esm.py +11 -15
  453. transformers/models/esm/modeling_esm.py +35 -37
  454. transformers/models/esm/modeling_esmfold.py +43 -50
  455. transformers/models/esm/openfold_utils/chunk_utils.py +6 -6
  456. transformers/models/esm/openfold_utils/loss.py +1 -2
  457. transformers/models/esm/openfold_utils/protein.py +15 -16
  458. transformers/models/esm/openfold_utils/tensor_utils.py +6 -6
  459. transformers/models/esm/tokenization_esm.py +2 -4
  460. transformers/models/evolla/configuration_evolla.py +50 -40
  461. transformers/models/evolla/modeling_evolla.py +69 -68
  462. transformers/models/evolla/modular_evolla.py +50 -48
  463. transformers/models/evolla/processing_evolla.py +23 -35
  464. transformers/models/exaone4/configuration_exaone4.py +27 -27
  465. transformers/models/exaone4/modeling_exaone4.py +36 -39
  466. transformers/models/exaone4/modular_exaone4.py +51 -50
  467. transformers/models/exaone_moe/__init__.py +27 -0
  468. transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
  469. transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
  470. transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
  471. transformers/models/falcon/configuration_falcon.py +31 -26
  472. transformers/models/falcon/modeling_falcon.py +76 -84
  473. transformers/models/falcon_h1/configuration_falcon_h1.py +57 -51
  474. transformers/models/falcon_h1/modeling_falcon_h1.py +74 -109
  475. transformers/models/falcon_h1/modular_falcon_h1.py +68 -100
  476. transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -2
  477. transformers/models/falcon_mamba/modeling_falcon_mamba.py +64 -73
  478. transformers/models/falcon_mamba/modular_falcon_mamba.py +14 -13
  479. transformers/models/fast_vlm/configuration_fast_vlm.py +10 -0
  480. transformers/models/fast_vlm/modeling_fast_vlm.py +70 -97
  481. transformers/models/fast_vlm/modular_fast_vlm.py +148 -38
  482. transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +2 -6
  483. transformers/models/fastspeech2_conformer/modeling_fastspeech2_conformer.py +45 -47
  484. transformers/models/fastspeech2_conformer/tokenization_fastspeech2_conformer.py +1 -3
  485. transformers/models/flaubert/configuration_flaubert.py +10 -5
  486. transformers/models/flaubert/modeling_flaubert.py +125 -129
  487. transformers/models/flaubert/tokenization_flaubert.py +3 -5
  488. transformers/models/flava/configuration_flava.py +9 -9
  489. transformers/models/flava/image_processing_flava.py +66 -67
  490. transformers/models/flava/image_processing_flava_fast.py +46 -47
  491. transformers/models/flava/modeling_flava.py +144 -135
  492. transformers/models/flava/processing_flava.py +2 -12
  493. transformers/models/flex_olmo/__init__.py +0 -1
  494. transformers/models/flex_olmo/configuration_flex_olmo.py +34 -39
  495. transformers/models/flex_olmo/modeling_flex_olmo.py +41 -43
  496. transformers/models/flex_olmo/modular_flex_olmo.py +46 -51
  497. transformers/models/florence2/configuration_florence2.py +4 -1
  498. transformers/models/florence2/modeling_florence2.py +96 -72
  499. transformers/models/florence2/modular_florence2.py +100 -107
  500. transformers/models/florence2/processing_florence2.py +18 -47
  501. transformers/models/fnet/configuration_fnet.py +6 -2
  502. transformers/models/fnet/modeling_fnet.py +69 -80
  503. transformers/models/fnet/tokenization_fnet.py +0 -1
  504. transformers/models/focalnet/configuration_focalnet.py +2 -5
  505. transformers/models/focalnet/modeling_focalnet.py +49 -48
  506. transformers/models/fsmt/configuration_fsmt.py +12 -17
  507. transformers/models/fsmt/modeling_fsmt.py +47 -48
  508. transformers/models/fsmt/tokenization_fsmt.py +3 -5
  509. transformers/models/funnel/configuration_funnel.py +8 -1
  510. transformers/models/funnel/modeling_funnel.py +91 -93
  511. transformers/models/funnel/tokenization_funnel.py +2 -5
  512. transformers/models/fuyu/configuration_fuyu.py +28 -34
  513. transformers/models/fuyu/image_processing_fuyu.py +29 -31
  514. transformers/models/fuyu/image_processing_fuyu_fast.py +17 -17
  515. transformers/models/fuyu/modeling_fuyu.py +50 -52
  516. transformers/models/fuyu/processing_fuyu.py +9 -36
  517. transformers/models/gemma/configuration_gemma.py +25 -30
  518. transformers/models/gemma/modeling_gemma.py +36 -38
  519. transformers/models/gemma/modular_gemma.py +33 -36
  520. transformers/models/gemma/tokenization_gemma.py +3 -6
  521. transformers/models/gemma2/configuration_gemma2.py +30 -35
  522. transformers/models/gemma2/modeling_gemma2.py +38 -41
  523. transformers/models/gemma2/modular_gemma2.py +63 -67
  524. transformers/models/gemma3/configuration_gemma3.py +53 -48
  525. transformers/models/gemma3/image_processing_gemma3.py +29 -31
  526. transformers/models/gemma3/image_processing_gemma3_fast.py +11 -12
  527. transformers/models/gemma3/modeling_gemma3.py +123 -122
  528. transformers/models/gemma3/modular_gemma3.py +128 -125
  529. transformers/models/gemma3/processing_gemma3.py +5 -5
  530. transformers/models/gemma3n/configuration_gemma3n.py +42 -30
  531. transformers/models/gemma3n/feature_extraction_gemma3n.py +9 -11
  532. transformers/models/gemma3n/modeling_gemma3n.py +166 -147
  533. transformers/models/gemma3n/modular_gemma3n.py +176 -148
  534. transformers/models/gemma3n/processing_gemma3n.py +12 -26
  535. transformers/models/git/configuration_git.py +5 -8
  536. transformers/models/git/modeling_git.py +115 -127
  537. transformers/models/git/processing_git.py +2 -14
  538. transformers/models/glm/configuration_glm.py +26 -30
  539. transformers/models/glm/modeling_glm.py +36 -39
  540. transformers/models/glm/modular_glm.py +4 -7
  541. transformers/models/glm4/configuration_glm4.py +26 -30
  542. transformers/models/glm4/modeling_glm4.py +39 -41
  543. transformers/models/glm4/modular_glm4.py +8 -10
  544. transformers/models/glm46v/configuration_glm46v.py +4 -1
  545. transformers/models/glm46v/image_processing_glm46v.py +40 -38
  546. transformers/models/glm46v/image_processing_glm46v_fast.py +9 -9
  547. transformers/models/glm46v/modeling_glm46v.py +138 -93
  548. transformers/models/glm46v/modular_glm46v.py +5 -3
  549. transformers/models/glm46v/processing_glm46v.py +7 -41
  550. transformers/models/glm46v/video_processing_glm46v.py +9 -11
  551. transformers/models/glm4_moe/configuration_glm4_moe.py +42 -35
  552. transformers/models/glm4_moe/modeling_glm4_moe.py +36 -39
  553. transformers/models/glm4_moe/modular_glm4_moe.py +43 -36
  554. transformers/models/glm4_moe_lite/__init__.py +28 -0
  555. transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +233 -0
  556. transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +740 -0
  557. transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +302 -0
  558. transformers/models/glm4v/configuration_glm4v.py +25 -24
  559. transformers/models/glm4v/image_processing_glm4v.py +39 -38
  560. transformers/models/glm4v/image_processing_glm4v_fast.py +8 -9
  561. transformers/models/glm4v/modeling_glm4v.py +249 -210
  562. transformers/models/glm4v/modular_glm4v.py +211 -230
  563. transformers/models/glm4v/processing_glm4v.py +7 -41
  564. transformers/models/glm4v/video_processing_glm4v.py +9 -11
  565. transformers/models/glm4v_moe/configuration_glm4v_moe.py +136 -127
  566. transformers/models/glm4v_moe/modeling_glm4v_moe.py +348 -356
  567. transformers/models/glm4v_moe/modular_glm4v_moe.py +76 -174
  568. transformers/models/glm_image/__init__.py +31 -0
  569. transformers/models/glm_image/configuration_glm_image.py +358 -0
  570. transformers/models/glm_image/image_processing_glm_image.py +503 -0
  571. transformers/models/glm_image/image_processing_glm_image_fast.py +294 -0
  572. transformers/models/glm_image/modeling_glm_image.py +1691 -0
  573. transformers/models/glm_image/modular_glm_image.py +1640 -0
  574. transformers/models/glm_image/processing_glm_image.py +265 -0
  575. transformers/models/glm_ocr/__init__.py +28 -0
  576. transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
  577. transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
  578. transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
  579. transformers/models/glmasr/__init__.py +0 -1
  580. transformers/models/glmasr/configuration_glmasr.py +0 -1
  581. transformers/models/glmasr/modeling_glmasr.py +51 -46
  582. transformers/models/glmasr/modular_glmasr.py +39 -29
  583. transformers/models/glmasr/processing_glmasr.py +7 -8
  584. transformers/models/glpn/configuration_glpn.py +0 -1
  585. transformers/models/glpn/image_processing_glpn.py +11 -12
  586. transformers/models/glpn/image_processing_glpn_fast.py +11 -12
  587. transformers/models/glpn/modeling_glpn.py +14 -14
  588. transformers/models/got_ocr2/configuration_got_ocr2.py +10 -13
  589. transformers/models/got_ocr2/image_processing_got_ocr2.py +22 -24
  590. transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +9 -10
  591. transformers/models/got_ocr2/modeling_got_ocr2.py +69 -77
  592. transformers/models/got_ocr2/modular_got_ocr2.py +60 -52
  593. transformers/models/got_ocr2/processing_got_ocr2.py +42 -63
  594. transformers/models/gpt2/configuration_gpt2.py +13 -2
  595. transformers/models/gpt2/modeling_gpt2.py +111 -113
  596. transformers/models/gpt2/tokenization_gpt2.py +6 -9
  597. transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -2
  598. transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +78 -84
  599. transformers/models/gpt_neo/configuration_gpt_neo.py +9 -2
  600. transformers/models/gpt_neo/modeling_gpt_neo.py +66 -71
  601. transformers/models/gpt_neox/configuration_gpt_neox.py +27 -25
  602. transformers/models/gpt_neox/modeling_gpt_neox.py +74 -76
  603. transformers/models/gpt_neox/modular_gpt_neox.py +68 -70
  604. transformers/models/gpt_neox/tokenization_gpt_neox.py +2 -5
  605. transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +24 -19
  606. transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +43 -46
  607. transformers/models/gpt_neox_japanese/tokenization_gpt_neox_japanese.py +1 -3
  608. transformers/models/gpt_oss/configuration_gpt_oss.py +31 -30
  609. transformers/models/gpt_oss/modeling_gpt_oss.py +80 -114
  610. transformers/models/gpt_oss/modular_gpt_oss.py +62 -97
  611. transformers/models/gpt_sw3/tokenization_gpt_sw3.py +4 -4
  612. transformers/models/gptj/configuration_gptj.py +4 -5
  613. transformers/models/gptj/modeling_gptj.py +85 -88
  614. transformers/models/granite/configuration_granite.py +28 -33
  615. transformers/models/granite/modeling_granite.py +43 -45
  616. transformers/models/granite/modular_granite.py +29 -31
  617. transformers/models/granite_speech/configuration_granite_speech.py +0 -1
  618. transformers/models/granite_speech/feature_extraction_granite_speech.py +1 -3
  619. transformers/models/granite_speech/modeling_granite_speech.py +84 -60
  620. transformers/models/granite_speech/processing_granite_speech.py +11 -4
  621. transformers/models/granitemoe/configuration_granitemoe.py +31 -36
  622. transformers/models/granitemoe/modeling_granitemoe.py +39 -41
  623. transformers/models/granitemoe/modular_granitemoe.py +21 -23
  624. transformers/models/granitemoehybrid/__init__.py +0 -1
  625. transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +55 -48
  626. transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +82 -118
  627. transformers/models/granitemoehybrid/modular_granitemoehybrid.py +57 -65
  628. transformers/models/granitemoeshared/configuration_granitemoeshared.py +33 -37
  629. transformers/models/granitemoeshared/modeling_granitemoeshared.py +52 -56
  630. transformers/models/granitemoeshared/modular_granitemoeshared.py +19 -21
  631. transformers/models/grounding_dino/configuration_grounding_dino.py +10 -46
  632. transformers/models/grounding_dino/image_processing_grounding_dino.py +60 -62
  633. transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +28 -29
  634. transformers/models/grounding_dino/modeling_grounding_dino.py +161 -181
  635. transformers/models/grounding_dino/modular_grounding_dino.py +2 -3
  636. transformers/models/grounding_dino/processing_grounding_dino.py +10 -38
  637. transformers/models/groupvit/configuration_groupvit.py +4 -2
  638. transformers/models/groupvit/modeling_groupvit.py +98 -92
  639. transformers/models/helium/configuration_helium.py +25 -29
  640. transformers/models/helium/modeling_helium.py +37 -40
  641. transformers/models/helium/modular_helium.py +3 -7
  642. transformers/models/herbert/tokenization_herbert.py +4 -6
  643. transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -5
  644. transformers/models/hgnet_v2/modeling_hgnet_v2.py +12 -14
  645. transformers/models/hgnet_v2/modular_hgnet_v2.py +13 -17
  646. transformers/models/hiera/configuration_hiera.py +2 -5
  647. transformers/models/hiera/modeling_hiera.py +71 -70
  648. transformers/models/hubert/configuration_hubert.py +4 -2
  649. transformers/models/hubert/modeling_hubert.py +42 -41
  650. transformers/models/hubert/modular_hubert.py +8 -11
  651. transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +26 -31
  652. transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +58 -37
  653. transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +31 -11
  654. transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +31 -36
  655. transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +54 -44
  656. transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +27 -15
  657. transformers/models/ibert/configuration_ibert.py +4 -2
  658. transformers/models/ibert/modeling_ibert.py +60 -62
  659. transformers/models/ibert/quant_modules.py +0 -1
  660. transformers/models/idefics/configuration_idefics.py +5 -8
  661. transformers/models/idefics/image_processing_idefics.py +13 -15
  662. transformers/models/idefics/modeling_idefics.py +63 -65
  663. transformers/models/idefics/perceiver.py +1 -3
  664. transformers/models/idefics/processing_idefics.py +32 -48
  665. transformers/models/idefics/vision.py +27 -28
  666. transformers/models/idefics2/configuration_idefics2.py +1 -3
  667. transformers/models/idefics2/image_processing_idefics2.py +31 -32
  668. transformers/models/idefics2/image_processing_idefics2_fast.py +8 -8
  669. transformers/models/idefics2/modeling_idefics2.py +126 -106
  670. transformers/models/idefics2/processing_idefics2.py +10 -68
  671. transformers/models/idefics3/configuration_idefics3.py +1 -4
  672. transformers/models/idefics3/image_processing_idefics3.py +42 -43
  673. transformers/models/idefics3/image_processing_idefics3_fast.py +40 -15
  674. transformers/models/idefics3/modeling_idefics3.py +113 -92
  675. transformers/models/idefics3/processing_idefics3.py +15 -69
  676. transformers/models/ijepa/configuration_ijepa.py +0 -1
  677. transformers/models/ijepa/modeling_ijepa.py +13 -14
  678. transformers/models/ijepa/modular_ijepa.py +5 -7
  679. transformers/models/imagegpt/configuration_imagegpt.py +9 -2
  680. transformers/models/imagegpt/image_processing_imagegpt.py +17 -18
  681. transformers/models/imagegpt/image_processing_imagegpt_fast.py +10 -11
  682. transformers/models/imagegpt/modeling_imagegpt.py +65 -62
  683. transformers/models/informer/configuration_informer.py +6 -9
  684. transformers/models/informer/modeling_informer.py +87 -89
  685. transformers/models/informer/modular_informer.py +13 -16
  686. transformers/models/instructblip/configuration_instructblip.py +2 -2
  687. transformers/models/instructblip/modeling_instructblip.py +104 -79
  688. transformers/models/instructblip/processing_instructblip.py +10 -36
  689. transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -2
  690. transformers/models/instructblipvideo/modeling_instructblipvideo.py +108 -105
  691. transformers/models/instructblipvideo/modular_instructblipvideo.py +73 -64
  692. transformers/models/instructblipvideo/processing_instructblipvideo.py +14 -33
  693. transformers/models/instructblipvideo/video_processing_instructblipvideo.py +6 -7
  694. transformers/models/internvl/configuration_internvl.py +5 -1
  695. transformers/models/internvl/modeling_internvl.py +76 -98
  696. transformers/models/internvl/modular_internvl.py +45 -59
  697. transformers/models/internvl/processing_internvl.py +12 -45
  698. transformers/models/internvl/video_processing_internvl.py +10 -11
  699. transformers/models/jais2/configuration_jais2.py +25 -29
  700. transformers/models/jais2/modeling_jais2.py +36 -38
  701. transformers/models/jais2/modular_jais2.py +20 -22
  702. transformers/models/jamba/configuration_jamba.py +5 -8
  703. transformers/models/jamba/modeling_jamba.py +47 -50
  704. transformers/models/jamba/modular_jamba.py +40 -41
  705. transformers/models/janus/configuration_janus.py +0 -1
  706. transformers/models/janus/image_processing_janus.py +37 -39
  707. transformers/models/janus/image_processing_janus_fast.py +20 -21
  708. transformers/models/janus/modeling_janus.py +103 -188
  709. transformers/models/janus/modular_janus.py +122 -83
  710. transformers/models/janus/processing_janus.py +17 -43
  711. transformers/models/jetmoe/configuration_jetmoe.py +26 -27
  712. transformers/models/jetmoe/modeling_jetmoe.py +42 -45
  713. transformers/models/jetmoe/modular_jetmoe.py +33 -36
  714. transformers/models/kosmos2/configuration_kosmos2.py +10 -9
  715. transformers/models/kosmos2/modeling_kosmos2.py +199 -178
  716. transformers/models/kosmos2/processing_kosmos2.py +40 -55
  717. transformers/models/kosmos2_5/__init__.py +0 -1
  718. transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -9
  719. transformers/models/kosmos2_5/image_processing_kosmos2_5.py +10 -12
  720. transformers/models/kosmos2_5/image_processing_kosmos2_5_fast.py +2 -11
  721. transformers/models/kosmos2_5/modeling_kosmos2_5.py +162 -172
  722. transformers/models/kosmos2_5/processing_kosmos2_5.py +8 -29
  723. transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +31 -28
  724. transformers/models/kyutai_speech_to_text/feature_extraction_kyutai_speech_to_text.py +12 -14
  725. transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +103 -106
  726. transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +20 -22
  727. transformers/models/kyutai_speech_to_text/processing_kyutai_speech_to_text.py +2 -8
  728. transformers/models/lasr/configuration_lasr.py +3 -7
  729. transformers/models/lasr/feature_extraction_lasr.py +10 -12
  730. transformers/models/lasr/modeling_lasr.py +21 -24
  731. transformers/models/lasr/modular_lasr.py +11 -13
  732. transformers/models/lasr/processing_lasr.py +12 -6
  733. transformers/models/lasr/tokenization_lasr.py +2 -4
  734. transformers/models/layoutlm/configuration_layoutlm.py +14 -2
  735. transformers/models/layoutlm/modeling_layoutlm.py +70 -72
  736. transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -17
  737. transformers/models/layoutlmv2/image_processing_layoutlmv2.py +18 -21
  738. transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +7 -8
  739. transformers/models/layoutlmv2/modeling_layoutlmv2.py +48 -50
  740. transformers/models/layoutlmv2/processing_layoutlmv2.py +14 -44
  741. transformers/models/layoutlmv2/tokenization_layoutlmv2.py +63 -74
  742. transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -19
  743. transformers/models/layoutlmv3/image_processing_layoutlmv3.py +24 -26
  744. transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +9 -10
  745. transformers/models/layoutlmv3/modeling_layoutlmv3.py +49 -51
  746. transformers/models/layoutlmv3/processing_layoutlmv3.py +14 -46
  747. transformers/models/layoutlmv3/tokenization_layoutlmv3.py +64 -75
  748. transformers/models/layoutxlm/configuration_layoutxlm.py +14 -17
  749. transformers/models/layoutxlm/modular_layoutxlm.py +0 -1
  750. transformers/models/layoutxlm/processing_layoutxlm.py +14 -44
  751. transformers/models/layoutxlm/tokenization_layoutxlm.py +65 -76
  752. transformers/models/led/configuration_led.py +8 -12
  753. transformers/models/led/modeling_led.py +113 -267
  754. transformers/models/levit/configuration_levit.py +0 -1
  755. transformers/models/levit/image_processing_levit.py +19 -21
  756. transformers/models/levit/image_processing_levit_fast.py +4 -5
  757. transformers/models/levit/modeling_levit.py +17 -19
  758. transformers/models/lfm2/configuration_lfm2.py +27 -30
  759. transformers/models/lfm2/modeling_lfm2.py +46 -48
  760. transformers/models/lfm2/modular_lfm2.py +32 -32
  761. transformers/models/lfm2_moe/__init__.py +0 -1
  762. transformers/models/lfm2_moe/configuration_lfm2_moe.py +6 -9
  763. transformers/models/lfm2_moe/modeling_lfm2_moe.py +48 -49
  764. transformers/models/lfm2_moe/modular_lfm2_moe.py +8 -9
  765. transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -1
  766. transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +43 -20
  767. transformers/models/lfm2_vl/modeling_lfm2_vl.py +73 -61
  768. transformers/models/lfm2_vl/modular_lfm2_vl.py +66 -54
  769. transformers/models/lfm2_vl/processing_lfm2_vl.py +14 -34
  770. transformers/models/lightglue/image_processing_lightglue.py +16 -15
  771. transformers/models/lightglue/image_processing_lightglue_fast.py +8 -7
  772. transformers/models/lightglue/modeling_lightglue.py +31 -33
  773. transformers/models/lightglue/modular_lightglue.py +31 -31
  774. transformers/models/lighton_ocr/__init__.py +28 -0
  775. transformers/models/lighton_ocr/configuration_lighton_ocr.py +128 -0
  776. transformers/models/lighton_ocr/modeling_lighton_ocr.py +463 -0
  777. transformers/models/lighton_ocr/modular_lighton_ocr.py +404 -0
  778. transformers/models/lighton_ocr/processing_lighton_ocr.py +229 -0
  779. transformers/models/lilt/configuration_lilt.py +6 -2
  780. transformers/models/lilt/modeling_lilt.py +53 -55
  781. transformers/models/llama/configuration_llama.py +26 -31
  782. transformers/models/llama/modeling_llama.py +35 -38
  783. transformers/models/llama/tokenization_llama.py +2 -4
  784. transformers/models/llama4/configuration_llama4.py +87 -69
  785. transformers/models/llama4/image_processing_llama4_fast.py +11 -12
  786. transformers/models/llama4/modeling_llama4.py +116 -115
  787. transformers/models/llama4/processing_llama4.py +33 -57
  788. transformers/models/llava/configuration_llava.py +10 -1
  789. transformers/models/llava/image_processing_llava.py +25 -28
  790. transformers/models/llava/image_processing_llava_fast.py +9 -10
  791. transformers/models/llava/modeling_llava.py +73 -102
  792. transformers/models/llava/processing_llava.py +18 -51
  793. transformers/models/llava_next/configuration_llava_next.py +2 -2
  794. transformers/models/llava_next/image_processing_llava_next.py +43 -45
  795. transformers/models/llava_next/image_processing_llava_next_fast.py +11 -12
  796. transformers/models/llava_next/modeling_llava_next.py +103 -104
  797. transformers/models/llava_next/processing_llava_next.py +18 -47
  798. transformers/models/llava_next_video/configuration_llava_next_video.py +10 -7
  799. transformers/models/llava_next_video/modeling_llava_next_video.py +168 -155
  800. transformers/models/llava_next_video/modular_llava_next_video.py +154 -147
  801. transformers/models/llava_next_video/processing_llava_next_video.py +21 -63
  802. transformers/models/llava_next_video/video_processing_llava_next_video.py +0 -1
  803. transformers/models/llava_onevision/configuration_llava_onevision.py +10 -7
  804. transformers/models/llava_onevision/image_processing_llava_onevision.py +40 -42
  805. transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +14 -14
  806. transformers/models/llava_onevision/modeling_llava_onevision.py +170 -166
  807. transformers/models/llava_onevision/modular_llava_onevision.py +156 -152
  808. transformers/models/llava_onevision/processing_llava_onevision.py +21 -53
  809. transformers/models/llava_onevision/video_processing_llava_onevision.py +0 -1
  810. transformers/models/longcat_flash/__init__.py +0 -1
  811. transformers/models/longcat_flash/configuration_longcat_flash.py +39 -45
  812. transformers/models/longcat_flash/modeling_longcat_flash.py +37 -38
  813. transformers/models/longcat_flash/modular_longcat_flash.py +23 -24
  814. transformers/models/longformer/configuration_longformer.py +5 -5
  815. transformers/models/longformer/modeling_longformer.py +99 -101
  816. transformers/models/longt5/configuration_longt5.py +9 -7
  817. transformers/models/longt5/modeling_longt5.py +45 -45
  818. transformers/models/luke/configuration_luke.py +8 -2
  819. transformers/models/luke/modeling_luke.py +179 -181
  820. transformers/models/luke/tokenization_luke.py +99 -105
  821. transformers/{pipelines/deprecated → models/lw_detr}/__init__.py +14 -3
  822. transformers/models/lw_detr/configuration_lw_detr.py +362 -0
  823. transformers/models/lw_detr/modeling_lw_detr.py +1697 -0
  824. transformers/models/lw_detr/modular_lw_detr.py +1609 -0
  825. transformers/models/lxmert/configuration_lxmert.py +16 -1
  826. transformers/models/lxmert/modeling_lxmert.py +63 -74
  827. transformers/models/m2m_100/configuration_m2m_100.py +7 -9
  828. transformers/models/m2m_100/modeling_m2m_100.py +72 -74
  829. transformers/models/m2m_100/tokenization_m2m_100.py +8 -8
  830. transformers/models/mamba/configuration_mamba.py +5 -3
  831. transformers/models/mamba/modeling_mamba.py +61 -70
  832. transformers/models/mamba2/configuration_mamba2.py +5 -8
  833. transformers/models/mamba2/modeling_mamba2.py +66 -79
  834. transformers/models/marian/configuration_marian.py +10 -5
  835. transformers/models/marian/modeling_marian.py +88 -90
  836. transformers/models/marian/tokenization_marian.py +6 -6
  837. transformers/models/markuplm/configuration_markuplm.py +4 -7
  838. transformers/models/markuplm/feature_extraction_markuplm.py +1 -2
  839. transformers/models/markuplm/modeling_markuplm.py +63 -65
  840. transformers/models/markuplm/processing_markuplm.py +31 -38
  841. transformers/models/markuplm/tokenization_markuplm.py +67 -77
  842. transformers/models/mask2former/configuration_mask2former.py +14 -52
  843. transformers/models/mask2former/image_processing_mask2former.py +84 -85
  844. transformers/models/mask2former/image_processing_mask2former_fast.py +36 -36
  845. transformers/models/mask2former/modeling_mask2former.py +108 -104
  846. transformers/models/mask2former/modular_mask2former.py +6 -8
  847. transformers/models/maskformer/configuration_maskformer.py +17 -51
  848. transformers/models/maskformer/configuration_maskformer_swin.py +2 -5
  849. transformers/models/maskformer/image_processing_maskformer.py +84 -85
  850. transformers/models/maskformer/image_processing_maskformer_fast.py +35 -36
  851. transformers/models/maskformer/modeling_maskformer.py +71 -67
  852. transformers/models/maskformer/modeling_maskformer_swin.py +20 -23
  853. transformers/models/mbart/configuration_mbart.py +9 -5
  854. transformers/models/mbart/modeling_mbart.py +120 -119
  855. transformers/models/mbart/tokenization_mbart.py +2 -4
  856. transformers/models/mbart50/tokenization_mbart50.py +3 -5
  857. transformers/models/megatron_bert/configuration_megatron_bert.py +13 -3
  858. transformers/models/megatron_bert/modeling_megatron_bert.py +139 -165
  859. transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
  860. transformers/models/metaclip_2/modeling_metaclip_2.py +94 -87
  861. transformers/models/metaclip_2/modular_metaclip_2.py +59 -45
  862. transformers/models/mgp_str/configuration_mgp_str.py +0 -1
  863. transformers/models/mgp_str/modeling_mgp_str.py +18 -18
  864. transformers/models/mgp_str/processing_mgp_str.py +3 -20
  865. transformers/models/mgp_str/tokenization_mgp_str.py +1 -3
  866. transformers/models/mimi/configuration_mimi.py +42 -40
  867. transformers/models/mimi/modeling_mimi.py +116 -115
  868. transformers/models/minimax/__init__.py +0 -1
  869. transformers/models/minimax/configuration_minimax.py +40 -47
  870. transformers/models/minimax/modeling_minimax.py +46 -49
  871. transformers/models/minimax/modular_minimax.py +59 -65
  872. transformers/models/minimax_m2/__init__.py +28 -0
  873. transformers/models/minimax_m2/configuration_minimax_m2.py +188 -0
  874. transformers/models/minimax_m2/modeling_minimax_m2.py +704 -0
  875. transformers/models/minimax_m2/modular_minimax_m2.py +346 -0
  876. transformers/models/ministral/configuration_ministral.py +25 -29
  877. transformers/models/ministral/modeling_ministral.py +35 -37
  878. transformers/models/ministral/modular_ministral.py +32 -37
  879. transformers/models/ministral3/configuration_ministral3.py +23 -26
  880. transformers/models/ministral3/modeling_ministral3.py +35 -37
  881. transformers/models/ministral3/modular_ministral3.py +7 -8
  882. transformers/models/mistral/configuration_mistral.py +24 -29
  883. transformers/models/mistral/modeling_mistral.py +35 -37
  884. transformers/models/mistral/modular_mistral.py +14 -15
  885. transformers/models/mistral3/configuration_mistral3.py +4 -1
  886. transformers/models/mistral3/modeling_mistral3.py +79 -82
  887. transformers/models/mistral3/modular_mistral3.py +66 -67
  888. transformers/models/mixtral/configuration_mixtral.py +32 -38
  889. transformers/models/mixtral/modeling_mixtral.py +39 -42
  890. transformers/models/mixtral/modular_mixtral.py +26 -29
  891. transformers/models/mlcd/configuration_mlcd.py +0 -1
  892. transformers/models/mlcd/modeling_mlcd.py +17 -17
  893. transformers/models/mlcd/modular_mlcd.py +16 -16
  894. transformers/models/mllama/configuration_mllama.py +10 -15
  895. transformers/models/mllama/image_processing_mllama.py +23 -25
  896. transformers/models/mllama/image_processing_mllama_fast.py +11 -11
  897. transformers/models/mllama/modeling_mllama.py +100 -103
  898. transformers/models/mllama/processing_mllama.py +6 -55
  899. transformers/models/mluke/tokenization_mluke.py +97 -103
  900. transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -46
  901. transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +159 -179
  902. transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -46
  903. transformers/models/mobilebert/configuration_mobilebert.py +4 -2
  904. transformers/models/mobilebert/modeling_mobilebert.py +78 -88
  905. transformers/models/mobilebert/tokenization_mobilebert.py +0 -1
  906. transformers/models/mobilenet_v1/configuration_mobilenet_v1.py +0 -1
  907. transformers/models/mobilenet_v1/image_processing_mobilenet_v1.py +20 -23
  908. transformers/models/mobilenet_v1/image_processing_mobilenet_v1_fast.py +0 -1
  909. transformers/models/mobilenet_v1/modeling_mobilenet_v1.py +13 -16
  910. transformers/models/mobilenet_v2/configuration_mobilenet_v2.py +0 -1
  911. transformers/models/mobilenet_v2/image_processing_mobilenet_v2.py +48 -51
  912. transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +14 -15
  913. transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +21 -22
  914. transformers/models/mobilevit/configuration_mobilevit.py +0 -1
  915. transformers/models/mobilevit/image_processing_mobilevit.py +41 -44
  916. transformers/models/mobilevit/image_processing_mobilevit_fast.py +12 -13
  917. transformers/models/mobilevit/modeling_mobilevit.py +21 -21
  918. transformers/models/mobilevitv2/configuration_mobilevitv2.py +0 -1
  919. transformers/models/mobilevitv2/modeling_mobilevitv2.py +21 -22
  920. transformers/models/modernbert/configuration_modernbert.py +76 -51
  921. transformers/models/modernbert/modeling_modernbert.py +188 -943
  922. transformers/models/modernbert/modular_modernbert.py +255 -978
  923. transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +50 -44
  924. transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +54 -64
  925. transformers/models/modernbert_decoder/modular_modernbert_decoder.py +92 -92
  926. transformers/models/moonshine/configuration_moonshine.py +34 -31
  927. transformers/models/moonshine/modeling_moonshine.py +70 -72
  928. transformers/models/moonshine/modular_moonshine.py +91 -86
  929. transformers/models/moshi/configuration_moshi.py +46 -23
  930. transformers/models/moshi/modeling_moshi.py +134 -142
  931. transformers/models/mpnet/configuration_mpnet.py +6 -2
  932. transformers/models/mpnet/modeling_mpnet.py +55 -57
  933. transformers/models/mpnet/tokenization_mpnet.py +1 -4
  934. transformers/models/mpt/configuration_mpt.py +17 -9
  935. transformers/models/mpt/modeling_mpt.py +58 -60
  936. transformers/models/mra/configuration_mra.py +8 -2
  937. transformers/models/mra/modeling_mra.py +54 -56
  938. transformers/models/mt5/configuration_mt5.py +9 -6
  939. transformers/models/mt5/modeling_mt5.py +80 -85
  940. transformers/models/musicgen/configuration_musicgen.py +12 -8
  941. transformers/models/musicgen/modeling_musicgen.py +114 -116
  942. transformers/models/musicgen/processing_musicgen.py +3 -21
  943. transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -8
  944. transformers/models/musicgen_melody/feature_extraction_musicgen_melody.py +8 -9
  945. transformers/models/musicgen_melody/modeling_musicgen_melody.py +113 -126
  946. transformers/models/musicgen_melody/processing_musicgen_melody.py +3 -22
  947. transformers/models/mvp/configuration_mvp.py +8 -5
  948. transformers/models/mvp/modeling_mvp.py +121 -123
  949. transformers/models/myt5/tokenization_myt5.py +8 -10
  950. transformers/models/nanochat/configuration_nanochat.py +5 -8
  951. transformers/models/nanochat/modeling_nanochat.py +36 -39
  952. transformers/models/nanochat/modular_nanochat.py +16 -18
  953. transformers/models/nemotron/configuration_nemotron.py +25 -30
  954. transformers/models/nemotron/modeling_nemotron.py +53 -66
  955. transformers/models/nllb/tokenization_nllb.py +14 -14
  956. transformers/models/nllb_moe/configuration_nllb_moe.py +7 -10
  957. transformers/models/nllb_moe/modeling_nllb_moe.py +70 -72
  958. transformers/models/nougat/image_processing_nougat.py +29 -32
  959. transformers/models/nougat/image_processing_nougat_fast.py +12 -13
  960. transformers/models/nougat/processing_nougat.py +37 -39
  961. transformers/models/nougat/tokenization_nougat.py +5 -7
  962. transformers/models/nystromformer/configuration_nystromformer.py +8 -2
  963. transformers/models/nystromformer/modeling_nystromformer.py +61 -63
  964. transformers/models/olmo/configuration_olmo.py +23 -28
  965. transformers/models/olmo/modeling_olmo.py +35 -38
  966. transformers/models/olmo/modular_olmo.py +8 -12
  967. transformers/models/olmo2/configuration_olmo2.py +27 -32
  968. transformers/models/olmo2/modeling_olmo2.py +36 -39
  969. transformers/models/olmo2/modular_olmo2.py +36 -38
  970. transformers/models/olmo3/__init__.py +0 -1
  971. transformers/models/olmo3/configuration_olmo3.py +30 -34
  972. transformers/models/olmo3/modeling_olmo3.py +35 -38
  973. transformers/models/olmo3/modular_olmo3.py +44 -47
  974. transformers/models/olmoe/configuration_olmoe.py +29 -33
  975. transformers/models/olmoe/modeling_olmoe.py +41 -43
  976. transformers/models/olmoe/modular_olmoe.py +15 -16
  977. transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -50
  978. transformers/models/omdet_turbo/modeling_omdet_turbo.py +59 -57
  979. transformers/models/omdet_turbo/processing_omdet_turbo.py +19 -67
  980. transformers/models/oneformer/configuration_oneformer.py +11 -51
  981. transformers/models/oneformer/image_processing_oneformer.py +83 -84
  982. transformers/models/oneformer/image_processing_oneformer_fast.py +41 -42
  983. transformers/models/oneformer/modeling_oneformer.py +137 -133
  984. transformers/models/oneformer/processing_oneformer.py +28 -43
  985. transformers/models/openai/configuration_openai.py +16 -1
  986. transformers/models/openai/modeling_openai.py +50 -51
  987. transformers/models/openai/tokenization_openai.py +2 -5
  988. transformers/models/opt/configuration_opt.py +6 -7
  989. transformers/models/opt/modeling_opt.py +79 -80
  990. transformers/models/ovis2/__init__.py +0 -1
  991. transformers/models/ovis2/configuration_ovis2.py +4 -1
  992. transformers/models/ovis2/image_processing_ovis2.py +22 -24
  993. transformers/models/ovis2/image_processing_ovis2_fast.py +9 -10
  994. transformers/models/ovis2/modeling_ovis2.py +99 -142
  995. transformers/models/ovis2/modular_ovis2.py +82 -45
  996. transformers/models/ovis2/processing_ovis2.py +12 -40
  997. transformers/models/owlv2/configuration_owlv2.py +4 -2
  998. transformers/models/owlv2/image_processing_owlv2.py +20 -21
  999. transformers/models/owlv2/image_processing_owlv2_fast.py +12 -13
  1000. transformers/models/owlv2/modeling_owlv2.py +122 -114
  1001. transformers/models/owlv2/modular_owlv2.py +11 -12
  1002. transformers/models/owlv2/processing_owlv2.py +20 -49
  1003. transformers/models/owlvit/configuration_owlvit.py +4 -2
  1004. transformers/models/owlvit/image_processing_owlvit.py +21 -22
  1005. transformers/models/owlvit/image_processing_owlvit_fast.py +2 -3
  1006. transformers/models/owlvit/modeling_owlvit.py +121 -113
  1007. transformers/models/owlvit/processing_owlvit.py +20 -48
  1008. transformers/models/paddleocr_vl/__init__.py +0 -1
  1009. transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +28 -29
  1010. transformers/models/paddleocr_vl/image_processing_paddleocr_vl.py +34 -35
  1011. transformers/models/paddleocr_vl/image_processing_paddleocr_vl_fast.py +12 -12
  1012. transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +159 -158
  1013. transformers/models/paddleocr_vl/modular_paddleocr_vl.py +148 -119
  1014. transformers/models/paddleocr_vl/processing_paddleocr_vl.py +1 -3
  1015. transformers/models/paligemma/configuration_paligemma.py +4 -1
  1016. transformers/models/paligemma/modeling_paligemma.py +81 -79
  1017. transformers/models/paligemma/processing_paligemma.py +13 -66
  1018. transformers/models/parakeet/configuration_parakeet.py +3 -8
  1019. transformers/models/parakeet/feature_extraction_parakeet.py +10 -12
  1020. transformers/models/parakeet/modeling_parakeet.py +21 -25
  1021. transformers/models/parakeet/modular_parakeet.py +19 -21
  1022. transformers/models/parakeet/processing_parakeet.py +12 -5
  1023. transformers/models/parakeet/tokenization_parakeet.py +2 -4
  1024. transformers/models/patchtsmixer/configuration_patchtsmixer.py +5 -8
  1025. transformers/models/patchtsmixer/modeling_patchtsmixer.py +63 -65
  1026. transformers/models/patchtst/configuration_patchtst.py +6 -9
  1027. transformers/models/patchtst/modeling_patchtst.py +75 -77
  1028. transformers/models/pe_audio/__init__.py +0 -1
  1029. transformers/models/pe_audio/configuration_pe_audio.py +14 -16
  1030. transformers/models/pe_audio/feature_extraction_pe_audio.py +6 -8
  1031. transformers/models/pe_audio/modeling_pe_audio.py +30 -31
  1032. transformers/models/pe_audio/modular_pe_audio.py +17 -18
  1033. transformers/models/pe_audio/processing_pe_audio.py +0 -1
  1034. transformers/models/pe_audio_video/__init__.py +0 -1
  1035. transformers/models/pe_audio_video/configuration_pe_audio_video.py +15 -17
  1036. transformers/models/pe_audio_video/modeling_pe_audio_video.py +64 -65
  1037. transformers/models/pe_audio_video/modular_pe_audio_video.py +56 -57
  1038. transformers/models/pe_audio_video/processing_pe_audio_video.py +0 -1
  1039. transformers/models/pe_video/__init__.py +0 -1
  1040. transformers/models/pe_video/configuration_pe_video.py +14 -16
  1041. transformers/models/pe_video/modeling_pe_video.py +57 -46
  1042. transformers/models/pe_video/modular_pe_video.py +47 -35
  1043. transformers/models/pe_video/video_processing_pe_video.py +2 -4
  1044. transformers/models/pegasus/configuration_pegasus.py +8 -6
  1045. transformers/models/pegasus/modeling_pegasus.py +67 -69
  1046. transformers/models/pegasus/tokenization_pegasus.py +1 -4
  1047. transformers/models/pegasus_x/configuration_pegasus_x.py +5 -4
  1048. transformers/models/pegasus_x/modeling_pegasus_x.py +53 -55
  1049. transformers/models/perceiver/configuration_perceiver.py +0 -1
  1050. transformers/models/perceiver/image_processing_perceiver.py +22 -25
  1051. transformers/models/perceiver/image_processing_perceiver_fast.py +7 -8
  1052. transformers/models/perceiver/modeling_perceiver.py +152 -145
  1053. transformers/models/perceiver/tokenization_perceiver.py +3 -6
  1054. transformers/models/perception_lm/configuration_perception_lm.py +0 -1
  1055. transformers/models/perception_lm/image_processing_perception_lm_fast.py +8 -9
  1056. transformers/models/perception_lm/modeling_perception_lm.py +64 -67
  1057. transformers/models/perception_lm/modular_perception_lm.py +58 -58
  1058. transformers/models/perception_lm/processing_perception_lm.py +13 -47
  1059. transformers/models/perception_lm/video_processing_perception_lm.py +0 -1
  1060. transformers/models/persimmon/configuration_persimmon.py +23 -28
  1061. transformers/models/persimmon/modeling_persimmon.py +44 -47
  1062. transformers/models/phi/configuration_phi.py +27 -28
  1063. transformers/models/phi/modeling_phi.py +39 -41
  1064. transformers/models/phi/modular_phi.py +26 -26
  1065. transformers/models/phi3/configuration_phi3.py +32 -37
  1066. transformers/models/phi3/modeling_phi3.py +37 -40
  1067. transformers/models/phi3/modular_phi3.py +16 -20
  1068. transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +36 -39
  1069. transformers/models/phi4_multimodal/feature_extraction_phi4_multimodal.py +7 -9
  1070. transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +11 -11
  1071. transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +100 -117
  1072. transformers/models/phi4_multimodal/modular_phi4_multimodal.py +103 -90
  1073. transformers/models/phi4_multimodal/processing_phi4_multimodal.py +7 -42
  1074. transformers/models/phimoe/configuration_phimoe.py +31 -36
  1075. transformers/models/phimoe/modeling_phimoe.py +50 -77
  1076. transformers/models/phimoe/modular_phimoe.py +12 -8
  1077. transformers/models/phobert/tokenization_phobert.py +4 -6
  1078. transformers/models/pix2struct/configuration_pix2struct.py +12 -10
  1079. transformers/models/pix2struct/image_processing_pix2struct.py +15 -19
  1080. transformers/models/pix2struct/image_processing_pix2struct_fast.py +12 -15
  1081. transformers/models/pix2struct/modeling_pix2struct.py +56 -52
  1082. transformers/models/pix2struct/processing_pix2struct.py +5 -26
  1083. transformers/models/pixio/__init__.py +0 -1
  1084. transformers/models/pixio/configuration_pixio.py +2 -5
  1085. transformers/models/pixio/modeling_pixio.py +16 -17
  1086. transformers/models/pixio/modular_pixio.py +7 -8
  1087. transformers/models/pixtral/configuration_pixtral.py +11 -14
  1088. transformers/models/pixtral/image_processing_pixtral.py +26 -28
  1089. transformers/models/pixtral/image_processing_pixtral_fast.py +10 -11
  1090. transformers/models/pixtral/modeling_pixtral.py +31 -37
  1091. transformers/models/pixtral/processing_pixtral.py +18 -52
  1092. transformers/models/plbart/configuration_plbart.py +8 -6
  1093. transformers/models/plbart/modeling_plbart.py +109 -109
  1094. transformers/models/plbart/modular_plbart.py +31 -33
  1095. transformers/models/plbart/tokenization_plbart.py +4 -5
  1096. transformers/models/poolformer/configuration_poolformer.py +0 -1
  1097. transformers/models/poolformer/image_processing_poolformer.py +21 -24
  1098. transformers/models/poolformer/image_processing_poolformer_fast.py +13 -14
  1099. transformers/models/poolformer/modeling_poolformer.py +10 -12
  1100. transformers/models/pop2piano/configuration_pop2piano.py +7 -7
  1101. transformers/models/pop2piano/feature_extraction_pop2piano.py +6 -9
  1102. transformers/models/pop2piano/modeling_pop2piano.py +24 -24
  1103. transformers/models/pop2piano/processing_pop2piano.py +25 -33
  1104. transformers/models/pop2piano/tokenization_pop2piano.py +15 -23
  1105. transformers/models/pp_doclayout_v3/__init__.py +30 -0
  1106. transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
  1107. transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
  1108. transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
  1109. transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
  1110. transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +13 -46
  1111. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything.py +28 -28
  1112. transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +20 -21
  1113. transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +17 -16
  1114. transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +21 -20
  1115. transformers/models/prophetnet/configuration_prophetnet.py +37 -38
  1116. transformers/models/prophetnet/modeling_prophetnet.py +121 -153
  1117. transformers/models/prophetnet/tokenization_prophetnet.py +14 -16
  1118. transformers/models/pvt/configuration_pvt.py +0 -1
  1119. transformers/models/pvt/image_processing_pvt.py +24 -27
  1120. transformers/models/pvt/image_processing_pvt_fast.py +1 -2
  1121. transformers/models/pvt/modeling_pvt.py +19 -21
  1122. transformers/models/pvt_v2/configuration_pvt_v2.py +4 -8
  1123. transformers/models/pvt_v2/modeling_pvt_v2.py +27 -28
  1124. transformers/models/qwen2/configuration_qwen2.py +32 -25
  1125. transformers/models/qwen2/modeling_qwen2.py +35 -37
  1126. transformers/models/qwen2/modular_qwen2.py +14 -15
  1127. transformers/models/qwen2/tokenization_qwen2.py +2 -9
  1128. transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +36 -27
  1129. transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +241 -214
  1130. transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +228 -193
  1131. transformers/models/qwen2_5_omni/processing_qwen2_5_omni.py +41 -49
  1132. transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +28 -34
  1133. transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +188 -145
  1134. transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +64 -91
  1135. transformers/models/qwen2_5_vl/processing_qwen2_5_vl.py +7 -43
  1136. transformers/models/qwen2_audio/configuration_qwen2_audio.py +0 -1
  1137. transformers/models/qwen2_audio/modeling_qwen2_audio.py +39 -41
  1138. transformers/models/qwen2_audio/processing_qwen2_audio.py +13 -42
  1139. transformers/models/qwen2_moe/configuration_qwen2_moe.py +42 -35
  1140. transformers/models/qwen2_moe/modeling_qwen2_moe.py +40 -43
  1141. transformers/models/qwen2_moe/modular_qwen2_moe.py +10 -13
  1142. transformers/models/qwen2_vl/configuration_qwen2_vl.py +28 -33
  1143. transformers/models/qwen2_vl/image_processing_qwen2_vl.py +38 -40
  1144. transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +12 -15
  1145. transformers/models/qwen2_vl/modeling_qwen2_vl.py +184 -141
  1146. transformers/models/qwen2_vl/processing_qwen2_vl.py +7 -44
  1147. transformers/models/qwen2_vl/video_processing_qwen2_vl.py +38 -18
  1148. transformers/models/qwen3/configuration_qwen3.py +34 -27
  1149. transformers/models/qwen3/modeling_qwen3.py +35 -38
  1150. transformers/models/qwen3/modular_qwen3.py +7 -9
  1151. transformers/models/qwen3_moe/configuration_qwen3_moe.py +45 -35
  1152. transformers/models/qwen3_moe/modeling_qwen3_moe.py +40 -43
  1153. transformers/models/qwen3_moe/modular_qwen3_moe.py +10 -13
  1154. transformers/models/qwen3_next/configuration_qwen3_next.py +47 -38
  1155. transformers/models/qwen3_next/modeling_qwen3_next.py +44 -47
  1156. transformers/models/qwen3_next/modular_qwen3_next.py +37 -38
  1157. transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +139 -106
  1158. transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +266 -206
  1159. transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +228 -181
  1160. transformers/models/qwen3_omni_moe/processing_qwen3_omni_moe.py +40 -48
  1161. transformers/models/qwen3_vl/configuration_qwen3_vl.py +22 -24
  1162. transformers/models/qwen3_vl/modeling_qwen3_vl.py +185 -122
  1163. transformers/models/qwen3_vl/modular_qwen3_vl.py +153 -139
  1164. transformers/models/qwen3_vl/processing_qwen3_vl.py +6 -42
  1165. transformers/models/qwen3_vl/video_processing_qwen3_vl.py +10 -12
  1166. transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +27 -30
  1167. transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +249 -178
  1168. transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +55 -42
  1169. transformers/models/rag/configuration_rag.py +6 -7
  1170. transformers/models/rag/modeling_rag.py +119 -121
  1171. transformers/models/rag/retrieval_rag.py +3 -5
  1172. transformers/models/rag/tokenization_rag.py +0 -50
  1173. transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +29 -30
  1174. transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +35 -39
  1175. transformers/models/reformer/configuration_reformer.py +7 -8
  1176. transformers/models/reformer/modeling_reformer.py +67 -68
  1177. transformers/models/reformer/tokenization_reformer.py +3 -6
  1178. transformers/models/regnet/configuration_regnet.py +0 -1
  1179. transformers/models/regnet/modeling_regnet.py +7 -9
  1180. transformers/models/rembert/configuration_rembert.py +8 -2
  1181. transformers/models/rembert/modeling_rembert.py +108 -132
  1182. transformers/models/rembert/tokenization_rembert.py +1 -4
  1183. transformers/models/resnet/configuration_resnet.py +2 -5
  1184. transformers/models/resnet/modeling_resnet.py +14 -15
  1185. transformers/models/roberta/configuration_roberta.py +11 -3
  1186. transformers/models/roberta/modeling_roberta.py +97 -99
  1187. transformers/models/roberta/modular_roberta.py +55 -58
  1188. transformers/models/roberta/tokenization_roberta.py +2 -5
  1189. transformers/models/roberta/tokenization_roberta_old.py +2 -4
  1190. transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -3
  1191. transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +97 -99
  1192. transformers/models/roc_bert/configuration_roc_bert.py +8 -2
  1193. transformers/models/roc_bert/modeling_roc_bert.py +125 -162
  1194. transformers/models/roc_bert/tokenization_roc_bert.py +88 -94
  1195. transformers/models/roformer/configuration_roformer.py +13 -3
  1196. transformers/models/roformer/modeling_roformer.py +79 -95
  1197. transformers/models/roformer/tokenization_roformer.py +3 -6
  1198. transformers/models/roformer/tokenization_utils.py +0 -1
  1199. transformers/models/rt_detr/configuration_rt_detr.py +8 -50
  1200. transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -5
  1201. transformers/models/rt_detr/image_processing_rt_detr.py +54 -55
  1202. transformers/models/rt_detr/image_processing_rt_detr_fast.py +39 -26
  1203. transformers/models/rt_detr/modeling_rt_detr.py +643 -804
  1204. transformers/models/rt_detr/modeling_rt_detr_resnet.py +4 -7
  1205. transformers/models/rt_detr/modular_rt_detr.py +1522 -20
  1206. transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -58
  1207. transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +384 -521
  1208. transformers/models/rt_detr_v2/modular_rt_detr_v2.py +27 -70
  1209. transformers/models/rwkv/configuration_rwkv.py +2 -4
  1210. transformers/models/rwkv/modeling_rwkv.py +29 -54
  1211. transformers/models/sam/configuration_sam.py +2 -1
  1212. transformers/models/sam/image_processing_sam.py +59 -60
  1213. transformers/models/sam/image_processing_sam_fast.py +25 -26
  1214. transformers/models/sam/modeling_sam.py +46 -43
  1215. transformers/models/sam/processing_sam.py +39 -27
  1216. transformers/models/sam2/configuration_sam2.py +1 -2
  1217. transformers/models/sam2/image_processing_sam2_fast.py +14 -15
  1218. transformers/models/sam2/modeling_sam2.py +96 -94
  1219. transformers/models/sam2/modular_sam2.py +85 -94
  1220. transformers/models/sam2/processing_sam2.py +31 -47
  1221. transformers/models/sam2_video/configuration_sam2_video.py +0 -1
  1222. transformers/models/sam2_video/modeling_sam2_video.py +114 -116
  1223. transformers/models/sam2_video/modular_sam2_video.py +72 -89
  1224. transformers/models/sam2_video/processing_sam2_video.py +49 -66
  1225. transformers/models/sam2_video/video_processing_sam2_video.py +1 -4
  1226. transformers/models/sam3/configuration_sam3.py +0 -1
  1227. transformers/models/sam3/image_processing_sam3_fast.py +17 -20
  1228. transformers/models/sam3/modeling_sam3.py +94 -100
  1229. transformers/models/sam3/modular_sam3.py +3 -8
  1230. transformers/models/sam3/processing_sam3.py +37 -52
  1231. transformers/models/sam3_tracker/__init__.py +0 -1
  1232. transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -3
  1233. transformers/models/sam3_tracker/modeling_sam3_tracker.py +79 -80
  1234. transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -2
  1235. transformers/models/sam3_tracker/processing_sam3_tracker.py +31 -48
  1236. transformers/models/sam3_tracker_video/__init__.py +0 -1
  1237. transformers/models/sam3_tracker_video/configuration_sam3_tracker_video.py +0 -1
  1238. transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +115 -114
  1239. transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -24
  1240. transformers/models/sam3_tracker_video/processing_sam3_tracker_video.py +50 -66
  1241. transformers/models/sam3_video/configuration_sam3_video.py +0 -1
  1242. transformers/models/sam3_video/modeling_sam3_video.py +56 -45
  1243. transformers/models/sam3_video/processing_sam3_video.py +25 -45
  1244. transformers/models/sam_hq/__init__.py +1 -1
  1245. transformers/models/sam_hq/configuration_sam_hq.py +2 -1
  1246. transformers/models/sam_hq/modeling_sam_hq.py +52 -50
  1247. transformers/models/sam_hq/modular_sam_hq.py +23 -25
  1248. transformers/models/sam_hq/{processing_samhq.py → processing_sam_hq.py} +41 -29
  1249. transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -10
  1250. transformers/models/seamless_m4t/feature_extraction_seamless_m4t.py +8 -11
  1251. transformers/models/seamless_m4t/modeling_seamless_m4t.py +180 -182
  1252. transformers/models/seamless_m4t/processing_seamless_m4t.py +18 -39
  1253. transformers/models/seamless_m4t/tokenization_seamless_m4t.py +15 -20
  1254. transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -10
  1255. transformers/models/seamless_m4t_v2/modeling_seamless_m4t_v2.py +193 -195
  1256. transformers/models/seed_oss/configuration_seed_oss.py +30 -34
  1257. transformers/models/seed_oss/modeling_seed_oss.py +34 -36
  1258. transformers/models/seed_oss/modular_seed_oss.py +6 -7
  1259. transformers/models/segformer/configuration_segformer.py +0 -10
  1260. transformers/models/segformer/image_processing_segformer.py +39 -42
  1261. transformers/models/segformer/image_processing_segformer_fast.py +11 -12
  1262. transformers/models/segformer/modeling_segformer.py +28 -28
  1263. transformers/models/segformer/modular_segformer.py +8 -9
  1264. transformers/models/seggpt/configuration_seggpt.py +0 -1
  1265. transformers/models/seggpt/image_processing_seggpt.py +38 -41
  1266. transformers/models/seggpt/modeling_seggpt.py +48 -38
  1267. transformers/models/sew/configuration_sew.py +4 -2
  1268. transformers/models/sew/modeling_sew.py +42 -40
  1269. transformers/models/sew/modular_sew.py +12 -13
  1270. transformers/models/sew_d/configuration_sew_d.py +4 -2
  1271. transformers/models/sew_d/modeling_sew_d.py +32 -31
  1272. transformers/models/shieldgemma2/configuration_shieldgemma2.py +0 -1
  1273. transformers/models/shieldgemma2/modeling_shieldgemma2.py +19 -21
  1274. transformers/models/shieldgemma2/processing_shieldgemma2.py +3 -5
  1275. transformers/models/siglip/configuration_siglip.py +4 -2
  1276. transformers/models/siglip/image_processing_siglip.py +17 -20
  1277. transformers/models/siglip/image_processing_siglip_fast.py +0 -1
  1278. transformers/models/siglip/modeling_siglip.py +65 -110
  1279. transformers/models/siglip/processing_siglip.py +2 -14
  1280. transformers/models/siglip/tokenization_siglip.py +6 -7
  1281. transformers/models/siglip2/__init__.py +1 -0
  1282. transformers/models/siglip2/configuration_siglip2.py +4 -2
  1283. transformers/models/siglip2/image_processing_siglip2.py +15 -16
  1284. transformers/models/siglip2/image_processing_siglip2_fast.py +6 -7
  1285. transformers/models/siglip2/modeling_siglip2.py +89 -130
  1286. transformers/models/siglip2/modular_siglip2.py +95 -48
  1287. transformers/models/siglip2/processing_siglip2.py +2 -14
  1288. transformers/models/siglip2/tokenization_siglip2.py +95 -0
  1289. transformers/models/smollm3/configuration_smollm3.py +29 -32
  1290. transformers/models/smollm3/modeling_smollm3.py +35 -38
  1291. transformers/models/smollm3/modular_smollm3.py +36 -38
  1292. transformers/models/smolvlm/configuration_smolvlm.py +2 -4
  1293. transformers/models/smolvlm/image_processing_smolvlm.py +42 -43
  1294. transformers/models/smolvlm/image_processing_smolvlm_fast.py +41 -15
  1295. transformers/models/smolvlm/modeling_smolvlm.py +124 -96
  1296. transformers/models/smolvlm/modular_smolvlm.py +50 -39
  1297. transformers/models/smolvlm/processing_smolvlm.py +15 -76
  1298. transformers/models/smolvlm/video_processing_smolvlm.py +16 -17
  1299. transformers/models/solar_open/__init__.py +27 -0
  1300. transformers/models/solar_open/configuration_solar_open.py +184 -0
  1301. transformers/models/solar_open/modeling_solar_open.py +642 -0
  1302. transformers/models/solar_open/modular_solar_open.py +224 -0
  1303. transformers/models/speech_encoder_decoder/configuration_speech_encoder_decoder.py +0 -1
  1304. transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +26 -27
  1305. transformers/models/speech_to_text/configuration_speech_to_text.py +9 -9
  1306. transformers/models/speech_to_text/feature_extraction_speech_to_text.py +10 -13
  1307. transformers/models/speech_to_text/modeling_speech_to_text.py +55 -57
  1308. transformers/models/speech_to_text/processing_speech_to_text.py +4 -30
  1309. transformers/models/speech_to_text/tokenization_speech_to_text.py +5 -6
  1310. transformers/models/speecht5/configuration_speecht5.py +7 -9
  1311. transformers/models/speecht5/feature_extraction_speecht5.py +16 -37
  1312. transformers/models/speecht5/modeling_speecht5.py +172 -174
  1313. transformers/models/speecht5/number_normalizer.py +0 -1
  1314. transformers/models/speecht5/processing_speecht5.py +3 -37
  1315. transformers/models/speecht5/tokenization_speecht5.py +4 -5
  1316. transformers/models/splinter/configuration_splinter.py +6 -7
  1317. transformers/models/splinter/modeling_splinter.py +62 -59
  1318. transformers/models/splinter/tokenization_splinter.py +2 -4
  1319. transformers/models/squeezebert/configuration_squeezebert.py +14 -2
  1320. transformers/models/squeezebert/modeling_squeezebert.py +60 -62
  1321. transformers/models/squeezebert/tokenization_squeezebert.py +0 -1
  1322. transformers/models/stablelm/configuration_stablelm.py +28 -29
  1323. transformers/models/stablelm/modeling_stablelm.py +44 -47
  1324. transformers/models/starcoder2/configuration_starcoder2.py +30 -27
  1325. transformers/models/starcoder2/modeling_starcoder2.py +38 -41
  1326. transformers/models/starcoder2/modular_starcoder2.py +17 -19
  1327. transformers/models/superglue/configuration_superglue.py +7 -3
  1328. transformers/models/superglue/image_processing_superglue.py +15 -15
  1329. transformers/models/superglue/image_processing_superglue_fast.py +8 -8
  1330. transformers/models/superglue/modeling_superglue.py +41 -37
  1331. transformers/models/superpoint/image_processing_superpoint.py +15 -15
  1332. transformers/models/superpoint/image_processing_superpoint_fast.py +7 -9
  1333. transformers/models/superpoint/modeling_superpoint.py +17 -16
  1334. transformers/models/swiftformer/configuration_swiftformer.py +0 -1
  1335. transformers/models/swiftformer/modeling_swiftformer.py +12 -14
  1336. transformers/models/swin/configuration_swin.py +2 -5
  1337. transformers/models/swin/modeling_swin.py +69 -78
  1338. transformers/models/swin2sr/configuration_swin2sr.py +0 -1
  1339. transformers/models/swin2sr/image_processing_swin2sr.py +10 -13
  1340. transformers/models/swin2sr/image_processing_swin2sr_fast.py +4 -7
  1341. transformers/models/swin2sr/modeling_swin2sr.py +30 -30
  1342. transformers/models/swinv2/configuration_swinv2.py +2 -5
  1343. transformers/models/swinv2/modeling_swinv2.py +65 -74
  1344. transformers/models/switch_transformers/configuration_switch_transformers.py +11 -7
  1345. transformers/models/switch_transformers/modeling_switch_transformers.py +35 -36
  1346. transformers/models/switch_transformers/modular_switch_transformers.py +32 -33
  1347. transformers/models/t5/configuration_t5.py +9 -9
  1348. transformers/models/t5/modeling_t5.py +80 -85
  1349. transformers/models/t5/tokenization_t5.py +1 -3
  1350. transformers/models/t5gemma/configuration_t5gemma.py +43 -59
  1351. transformers/models/t5gemma/modeling_t5gemma.py +105 -108
  1352. transformers/models/t5gemma/modular_t5gemma.py +128 -142
  1353. transformers/models/t5gemma2/configuration_t5gemma2.py +86 -100
  1354. transformers/models/t5gemma2/modeling_t5gemma2.py +234 -194
  1355. transformers/models/t5gemma2/modular_t5gemma2.py +279 -264
  1356. transformers/models/table_transformer/configuration_table_transformer.py +18 -50
  1357. transformers/models/table_transformer/modeling_table_transformer.py +73 -101
  1358. transformers/models/tapas/configuration_tapas.py +12 -2
  1359. transformers/models/tapas/modeling_tapas.py +65 -67
  1360. transformers/models/tapas/tokenization_tapas.py +116 -153
  1361. transformers/models/textnet/configuration_textnet.py +4 -7
  1362. transformers/models/textnet/image_processing_textnet.py +22 -25
  1363. transformers/models/textnet/image_processing_textnet_fast.py +8 -9
  1364. transformers/models/textnet/modeling_textnet.py +28 -28
  1365. transformers/models/time_series_transformer/configuration_time_series_transformer.py +5 -8
  1366. transformers/models/time_series_transformer/modeling_time_series_transformer.py +82 -84
  1367. transformers/models/timesfm/configuration_timesfm.py +0 -1
  1368. transformers/models/timesfm/modeling_timesfm.py +22 -25
  1369. transformers/models/timesfm/modular_timesfm.py +21 -24
  1370. transformers/models/timesformer/configuration_timesformer.py +0 -1
  1371. transformers/models/timesformer/modeling_timesformer.py +13 -16
  1372. transformers/models/timm_backbone/configuration_timm_backbone.py +33 -8
  1373. transformers/models/timm_backbone/modeling_timm_backbone.py +25 -30
  1374. transformers/models/timm_wrapper/configuration_timm_wrapper.py +2 -3
  1375. transformers/models/timm_wrapper/image_processing_timm_wrapper.py +4 -5
  1376. transformers/models/timm_wrapper/modeling_timm_wrapper.py +22 -19
  1377. transformers/models/trocr/configuration_trocr.py +11 -8
  1378. transformers/models/trocr/modeling_trocr.py +42 -42
  1379. transformers/models/trocr/processing_trocr.py +5 -25
  1380. transformers/models/tvp/configuration_tvp.py +10 -36
  1381. transformers/models/tvp/image_processing_tvp.py +50 -52
  1382. transformers/models/tvp/image_processing_tvp_fast.py +15 -15
  1383. transformers/models/tvp/modeling_tvp.py +26 -28
  1384. transformers/models/tvp/processing_tvp.py +2 -14
  1385. transformers/models/udop/configuration_udop.py +16 -8
  1386. transformers/models/udop/modeling_udop.py +73 -72
  1387. transformers/models/udop/processing_udop.py +7 -26
  1388. transformers/models/udop/tokenization_udop.py +80 -93
  1389. transformers/models/umt5/configuration_umt5.py +8 -7
  1390. transformers/models/umt5/modeling_umt5.py +87 -84
  1391. transformers/models/unispeech/configuration_unispeech.py +4 -2
  1392. transformers/models/unispeech/modeling_unispeech.py +54 -53
  1393. transformers/models/unispeech/modular_unispeech.py +20 -22
  1394. transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -2
  1395. transformers/models/unispeech_sat/modeling_unispeech_sat.py +70 -69
  1396. transformers/models/unispeech_sat/modular_unispeech_sat.py +21 -23
  1397. transformers/models/univnet/feature_extraction_univnet.py +14 -14
  1398. transformers/models/univnet/modeling_univnet.py +7 -8
  1399. transformers/models/upernet/configuration_upernet.py +8 -36
  1400. transformers/models/upernet/modeling_upernet.py +11 -14
  1401. transformers/models/vaultgemma/__init__.py +0 -1
  1402. transformers/models/vaultgemma/configuration_vaultgemma.py +29 -33
  1403. transformers/models/vaultgemma/modeling_vaultgemma.py +38 -40
  1404. transformers/models/vaultgemma/modular_vaultgemma.py +29 -31
  1405. transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
  1406. transformers/models/video_llama_3/image_processing_video_llama_3.py +40 -40
  1407. transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +12 -14
  1408. transformers/models/video_llama_3/modeling_video_llama_3.py +149 -112
  1409. transformers/models/video_llama_3/modular_video_llama_3.py +152 -150
  1410. transformers/models/video_llama_3/processing_video_llama_3.py +5 -39
  1411. transformers/models/video_llama_3/video_processing_video_llama_3.py +45 -24
  1412. transformers/models/video_llava/configuration_video_llava.py +4 -1
  1413. transformers/models/video_llava/image_processing_video_llava.py +35 -38
  1414. transformers/models/video_llava/modeling_video_llava.py +139 -143
  1415. transformers/models/video_llava/processing_video_llava.py +38 -78
  1416. transformers/models/video_llava/video_processing_video_llava.py +0 -1
  1417. transformers/models/videomae/configuration_videomae.py +0 -1
  1418. transformers/models/videomae/image_processing_videomae.py +31 -34
  1419. transformers/models/videomae/modeling_videomae.py +17 -20
  1420. transformers/models/videomae/video_processing_videomae.py +0 -1
  1421. transformers/models/vilt/configuration_vilt.py +4 -2
  1422. transformers/models/vilt/image_processing_vilt.py +29 -30
  1423. transformers/models/vilt/image_processing_vilt_fast.py +15 -16
  1424. transformers/models/vilt/modeling_vilt.py +103 -90
  1425. transformers/models/vilt/processing_vilt.py +2 -14
  1426. transformers/models/vipllava/configuration_vipllava.py +4 -1
  1427. transformers/models/vipllava/modeling_vipllava.py +92 -67
  1428. transformers/models/vipllava/modular_vipllava.py +78 -54
  1429. transformers/models/vision_encoder_decoder/configuration_vision_encoder_decoder.py +0 -1
  1430. transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +28 -27
  1431. transformers/models/vision_text_dual_encoder/configuration_vision_text_dual_encoder.py +0 -1
  1432. transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +45 -41
  1433. transformers/models/vision_text_dual_encoder/processing_vision_text_dual_encoder.py +2 -16
  1434. transformers/models/visual_bert/configuration_visual_bert.py +6 -2
  1435. transformers/models/visual_bert/modeling_visual_bert.py +90 -92
  1436. transformers/models/vit/configuration_vit.py +2 -3
  1437. transformers/models/vit/image_processing_vit.py +19 -22
  1438. transformers/models/vit/image_processing_vit_fast.py +0 -1
  1439. transformers/models/vit/modeling_vit.py +20 -20
  1440. transformers/models/vit_mae/configuration_vit_mae.py +0 -1
  1441. transformers/models/vit_mae/modeling_vit_mae.py +32 -30
  1442. transformers/models/vit_msn/configuration_vit_msn.py +0 -1
  1443. transformers/models/vit_msn/modeling_vit_msn.py +21 -19
  1444. transformers/models/vitdet/configuration_vitdet.py +2 -5
  1445. transformers/models/vitdet/modeling_vitdet.py +14 -17
  1446. transformers/models/vitmatte/configuration_vitmatte.py +7 -39
  1447. transformers/models/vitmatte/image_processing_vitmatte.py +15 -18
  1448. transformers/models/vitmatte/image_processing_vitmatte_fast.py +16 -17
  1449. transformers/models/vitmatte/modeling_vitmatte.py +10 -12
  1450. transformers/models/vitpose/configuration_vitpose.py +7 -47
  1451. transformers/models/vitpose/image_processing_vitpose.py +24 -25
  1452. transformers/models/vitpose/image_processing_vitpose_fast.py +9 -10
  1453. transformers/models/vitpose/modeling_vitpose.py +15 -15
  1454. transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -5
  1455. transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +13 -16
  1456. transformers/models/vits/configuration_vits.py +4 -1
  1457. transformers/models/vits/modeling_vits.py +43 -42
  1458. transformers/models/vits/tokenization_vits.py +3 -4
  1459. transformers/models/vivit/configuration_vivit.py +0 -1
  1460. transformers/models/vivit/image_processing_vivit.py +36 -39
  1461. transformers/models/vivit/modeling_vivit.py +9 -11
  1462. transformers/models/vjepa2/__init__.py +0 -1
  1463. transformers/models/vjepa2/configuration_vjepa2.py +0 -1
  1464. transformers/models/vjepa2/modeling_vjepa2.py +39 -41
  1465. transformers/models/vjepa2/video_processing_vjepa2.py +0 -1
  1466. transformers/models/voxtral/__init__.py +0 -1
  1467. transformers/models/voxtral/configuration_voxtral.py +0 -2
  1468. transformers/models/voxtral/modeling_voxtral.py +41 -48
  1469. transformers/models/voxtral/modular_voxtral.py +35 -38
  1470. transformers/models/voxtral/processing_voxtral.py +25 -48
  1471. transformers/models/wav2vec2/configuration_wav2vec2.py +4 -2
  1472. transformers/models/wav2vec2/feature_extraction_wav2vec2.py +7 -10
  1473. transformers/models/wav2vec2/modeling_wav2vec2.py +74 -126
  1474. transformers/models/wav2vec2/processing_wav2vec2.py +6 -35
  1475. transformers/models/wav2vec2/tokenization_wav2vec2.py +20 -332
  1476. transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -2
  1477. transformers/models/wav2vec2_bert/modeling_wav2vec2_bert.py +49 -52
  1478. transformers/models/wav2vec2_bert/modular_wav2vec2_bert.py +45 -48
  1479. transformers/models/wav2vec2_bert/processing_wav2vec2_bert.py +6 -35
  1480. transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -2
  1481. transformers/models/wav2vec2_conformer/modeling_wav2vec2_conformer.py +62 -65
  1482. transformers/models/wav2vec2_conformer/modular_wav2vec2_conformer.py +15 -18
  1483. transformers/models/wav2vec2_phoneme/tokenization_wav2vec2_phoneme.py +16 -17
  1484. transformers/models/wav2vec2_with_lm/processing_wav2vec2_with_lm.py +36 -55
  1485. transformers/models/wavlm/configuration_wavlm.py +4 -2
  1486. transformers/models/wavlm/modeling_wavlm.py +49 -49
  1487. transformers/models/wavlm/modular_wavlm.py +4 -5
  1488. transformers/models/whisper/configuration_whisper.py +6 -5
  1489. transformers/models/whisper/english_normalizer.py +3 -4
  1490. transformers/models/whisper/feature_extraction_whisper.py +9 -24
  1491. transformers/models/whisper/generation_whisper.py +26 -49
  1492. transformers/models/whisper/modeling_whisper.py +71 -73
  1493. transformers/models/whisper/processing_whisper.py +3 -20
  1494. transformers/models/whisper/tokenization_whisper.py +9 -30
  1495. transformers/models/x_clip/configuration_x_clip.py +4 -2
  1496. transformers/models/x_clip/modeling_x_clip.py +94 -96
  1497. transformers/models/x_clip/processing_x_clip.py +2 -14
  1498. transformers/models/xcodec/configuration_xcodec.py +4 -6
  1499. transformers/models/xcodec/modeling_xcodec.py +15 -17
  1500. transformers/models/xglm/configuration_xglm.py +9 -8
  1501. transformers/models/xglm/modeling_xglm.py +49 -55
  1502. transformers/models/xglm/tokenization_xglm.py +1 -4
  1503. transformers/models/xlm/configuration_xlm.py +10 -8
  1504. transformers/models/xlm/modeling_xlm.py +127 -131
  1505. transformers/models/xlm/tokenization_xlm.py +3 -5
  1506. transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -3
  1507. transformers/models/xlm_roberta/modeling_xlm_roberta.py +96 -98
  1508. transformers/models/xlm_roberta/modular_xlm_roberta.py +50 -53
  1509. transformers/models/xlm_roberta/tokenization_xlm_roberta.py +1 -4
  1510. transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -2
  1511. transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +97 -99
  1512. transformers/models/xlm_roberta_xl/modular_xlm_roberta_xl.py +67 -70
  1513. transformers/models/xlnet/configuration_xlnet.py +3 -12
  1514. transformers/models/xlnet/modeling_xlnet.py +149 -162
  1515. transformers/models/xlnet/tokenization_xlnet.py +1 -4
  1516. transformers/models/xlstm/configuration_xlstm.py +8 -12
  1517. transformers/models/xlstm/modeling_xlstm.py +61 -96
  1518. transformers/models/xmod/configuration_xmod.py +11 -3
  1519. transformers/models/xmod/modeling_xmod.py +111 -116
  1520. transformers/models/yolos/configuration_yolos.py +0 -1
  1521. transformers/models/yolos/image_processing_yolos.py +60 -62
  1522. transformers/models/yolos/image_processing_yolos_fast.py +42 -45
  1523. transformers/models/yolos/modeling_yolos.py +19 -21
  1524. transformers/models/yolos/modular_yolos.py +17 -19
  1525. transformers/models/yoso/configuration_yoso.py +8 -2
  1526. transformers/models/yoso/modeling_yoso.py +60 -62
  1527. transformers/models/youtu/__init__.py +27 -0
  1528. transformers/models/youtu/configuration_youtu.py +194 -0
  1529. transformers/models/youtu/modeling_youtu.py +619 -0
  1530. transformers/models/youtu/modular_youtu.py +254 -0
  1531. transformers/models/zamba/configuration_zamba.py +5 -8
  1532. transformers/models/zamba/modeling_zamba.py +93 -125
  1533. transformers/models/zamba2/configuration_zamba2.py +44 -50
  1534. transformers/models/zamba2/modeling_zamba2.py +137 -165
  1535. transformers/models/zamba2/modular_zamba2.py +79 -74
  1536. transformers/models/zoedepth/configuration_zoedepth.py +17 -41
  1537. transformers/models/zoedepth/image_processing_zoedepth.py +28 -29
  1538. transformers/models/zoedepth/image_processing_zoedepth_fast.py +20 -21
  1539. transformers/models/zoedepth/modeling_zoedepth.py +19 -19
  1540. transformers/pipelines/__init__.py +47 -106
  1541. transformers/pipelines/any_to_any.py +15 -23
  1542. transformers/pipelines/audio_utils.py +1 -2
  1543. transformers/pipelines/automatic_speech_recognition.py +0 -2
  1544. transformers/pipelines/base.py +13 -17
  1545. transformers/pipelines/image_text_to_text.py +1 -2
  1546. transformers/pipelines/question_answering.py +4 -43
  1547. transformers/pipelines/text_classification.py +1 -14
  1548. transformers/pipelines/text_to_audio.py +5 -1
  1549. transformers/pipelines/token_classification.py +1 -22
  1550. transformers/pipelines/video_classification.py +1 -9
  1551. transformers/pipelines/zero_shot_audio_classification.py +0 -1
  1552. transformers/pipelines/zero_shot_classification.py +0 -6
  1553. transformers/pipelines/zero_shot_image_classification.py +0 -7
  1554. transformers/processing_utils.py +128 -137
  1555. transformers/pytorch_utils.py +2 -26
  1556. transformers/quantizers/base.py +10 -0
  1557. transformers/quantizers/quantizer_compressed_tensors.py +7 -5
  1558. transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
  1559. transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
  1560. transformers/quantizers/quantizer_mxfp4.py +1 -1
  1561. transformers/quantizers/quantizer_quark.py +0 -1
  1562. transformers/quantizers/quantizer_torchao.py +3 -19
  1563. transformers/safetensors_conversion.py +11 -4
  1564. transformers/testing_utils.py +6 -65
  1565. transformers/tokenization_mistral_common.py +563 -903
  1566. transformers/tokenization_python.py +6 -4
  1567. transformers/tokenization_utils_base.py +228 -341
  1568. transformers/tokenization_utils_sentencepiece.py +5 -6
  1569. transformers/tokenization_utils_tokenizers.py +36 -7
  1570. transformers/trainer.py +30 -41
  1571. transformers/trainer_jit_checkpoint.py +1 -2
  1572. transformers/trainer_seq2seq.py +1 -1
  1573. transformers/training_args.py +414 -420
  1574. transformers/utils/__init__.py +1 -4
  1575. transformers/utils/attention_visualizer.py +1 -1
  1576. transformers/utils/auto_docstring.py +567 -18
  1577. transformers/utils/backbone_utils.py +13 -373
  1578. transformers/utils/doc.py +4 -36
  1579. transformers/utils/dummy_pt_objects.py +0 -42
  1580. transformers/utils/generic.py +70 -34
  1581. transformers/utils/import_utils.py +72 -75
  1582. transformers/utils/loading_report.py +135 -107
  1583. transformers/utils/quantization_config.py +8 -31
  1584. transformers/video_processing_utils.py +24 -25
  1585. transformers/video_utils.py +21 -23
  1586. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/METADATA +120 -239
  1587. transformers-5.1.0.dist-info/RECORD +2092 -0
  1588. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
  1589. transformers/pipelines/deprecated/text2text_generation.py +0 -408
  1590. transformers/pipelines/image_to_text.py +0 -229
  1591. transformers-5.0.0rc2.dist-info/RECORD +0 -2042
  1592. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
  1593. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
  1594. {transformers-5.0.0rc2.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1633 @@
1
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
2
+ # This file was automatically generated from src/transformers/models/glm_ocr/modular_glm_ocr.py.
3
+ # Do NOT edit this file manually as any edits will be overwritten by the generation of
4
+ # the file from the modular. If any change should be done, please apply the change to the
5
+ # modular_glm_ocr.py file directly. One of our CI enforces this.
6
+ # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ # Copyright 2026 the HuggingFace Team. All rights reserved.
8
+ #
9
+ # Licensed under the Apache License, Version 2.0 (the "License");
10
+ # you may not use this file except in compliance with the License.
11
+ # You may obtain a copy of the License at
12
+ #
13
+ # http://www.apache.org/licenses/LICENSE-2.0
14
+ #
15
+ # Unless required by applicable law or agreed to in writing, software
16
+ # distributed under the License is distributed on an "AS IS" BASIS,
17
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
18
+ # See the License for the specific language governing permissions and
19
+ # limitations under the License.
20
+
21
+ import itertools
22
+ from collections.abc import Callable
23
+ from dataclasses import dataclass
24
+ from typing import Any, Optional
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import torch.nn.functional as F
29
+ from torch.nn import LayerNorm
30
+
31
+ from ... import initialization as init
32
+ from ...activations import ACT2FN
33
+ from ...cache_utils import Cache, DynamicCache
34
+ from ...generation import GenerationMixin
35
+ from ...integrations import use_kernel_forward_from_hub
36
+ from ...masking_utils import create_causal_mask
37
+ from ...modeling_flash_attention_utils import FlashAttentionKwargs
38
+ from ...modeling_layers import GradientCheckpointingLayer
39
+ from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
40
+ from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
41
+ from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
42
+ from ...processing_utils import Unpack
43
+ from ...utils import (
44
+ TransformersKwargs,
45
+ auto_docstring,
46
+ can_return_tuple,
47
+ is_torchdynamo_compiling,
48
+ torch_compilable_check,
49
+ )
50
+ from ...utils.generic import check_model_inputs, is_flash_attention_requested, maybe_autocast
51
+ from .configuration_glm_ocr import GlmOcrConfig, GlmOcrTextConfig, GlmOcrVisionConfig
52
+
53
+
54
+ @use_kernel_forward_from_hub("RMSNorm")
55
+ class GlmOcrRMSNorm(nn.Module):
56
+ def __init__(self, hidden_size, eps=1e-6):
57
+ """
58
+ GlmOcrRMSNorm is equivalent to T5LayerNorm
59
+ """
60
+ super().__init__()
61
+ self.weight = nn.Parameter(torch.ones(hidden_size))
62
+ self.variance_epsilon = eps
63
+
64
+ def forward(self, hidden_states):
65
+ input_dtype = hidden_states.dtype
66
+ hidden_states = hidden_states.to(torch.float32)
67
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
68
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
69
+ return self.weight * hidden_states.to(input_dtype)
70
+
71
+ def extra_repr(self):
72
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
73
+
74
+
75
+ class GlmOcrVisionMlp(nn.Module):
76
+ def __init__(self, config, bias: bool = True):
77
+ super().__init__()
78
+ self.hidden_size = config.hidden_size
79
+ self.intermediate_size = config.intermediate_size
80
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
81
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=bias)
82
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=bias)
83
+ self.act_fn = ACT2FN[config.hidden_act]
84
+
85
+ def forward(self, hidden_state):
86
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
87
+
88
+
89
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
90
+ """
91
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
92
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
93
+ """
94
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
95
+ if n_rep == 1:
96
+ return hidden_states
97
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
98
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
99
+
100
+
101
+ def eager_attention_forward(
102
+ module: nn.Module,
103
+ query: torch.Tensor,
104
+ key: torch.Tensor,
105
+ value: torch.Tensor,
106
+ attention_mask: torch.Tensor | None,
107
+ scaling: float,
108
+ dropout: float = 0.0,
109
+ **kwargs: Unpack[TransformersKwargs],
110
+ ):
111
+ key_states = repeat_kv(key, module.num_key_value_groups)
112
+ value_states = repeat_kv(value, module.num_key_value_groups)
113
+
114
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
115
+ if attention_mask is not None:
116
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
117
+ attn_weights = attn_weights + causal_mask
118
+
119
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
120
+ attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
121
+ attn_output = torch.matmul(attn_weights, value_states)
122
+ attn_output = attn_output.transpose(1, 2).contiguous()
123
+
124
+ return attn_output, attn_weights
125
+
126
+
127
+ def rotate_half_llm(x):
128
+ """Rotates half the hidden dims of the input."""
129
+ x1 = x[..., 0::2]
130
+ x2 = x[..., 1::2]
131
+ return torch.stack((-x2, x1), dim=-1).flatten(-2)
132
+
133
+
134
+ def apply_rotary_pos_emb(q, k, cos, sin, unsqueeze_dim=1):
135
+ """Applies Rotary Position Embedding to the query and key tensors.
136
+
137
+ Args:
138
+ q (`torch.Tensor`): The query tensor.
139
+ k (`torch.Tensor`): The key tensor.
140
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
141
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
142
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
143
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
144
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
145
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
146
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
147
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
148
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
149
+ Returns:
150
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
151
+ """
152
+ cos = cos.unsqueeze(unsqueeze_dim)
153
+ sin = sin.unsqueeze(unsqueeze_dim)
154
+
155
+ # Interleave them instead of usual shape
156
+ cos = cos[..., : cos.shape[-1] // 2].repeat_interleave(2, dim=-1)
157
+ sin = sin[..., : sin.shape[-1] // 2].repeat_interleave(2, dim=-1)
158
+
159
+ # Keep half or full tensor for later concatenation
160
+ rotary_dim = cos.shape[-1]
161
+ q_rot, q_pass = q[..., :rotary_dim], q[..., rotary_dim:]
162
+ k_rot, k_pass = k[..., :rotary_dim], k[..., rotary_dim:]
163
+
164
+ # Apply rotary embeddings on the first half or full tensor
165
+ q_embed = (q_rot * cos) + (rotate_half_llm(q_rot) * sin)
166
+ k_embed = (k_rot * cos) + (rotate_half_llm(k_rot) * sin)
167
+
168
+ # Concatenate back to full shape
169
+ q_embed = torch.cat([q_embed, q_pass], dim=-1)
170
+ k_embed = torch.cat([k_embed, k_pass], dim=-1)
171
+ return q_embed, k_embed
172
+
173
+
174
+ class GlmOcrTextAttention(nn.Module):
175
+ """
176
+ Multi-headed attention from 'Attention Is All You Need' paper.
177
+ and "Generating Long Sequences with Sparse Transformers".
178
+ """
179
+
180
+ def __init__(self, config: GlmOcrTextConfig, layer_idx: int | None = None):
181
+ super().__init__()
182
+ self.config = config
183
+ self.layer_idx = layer_idx
184
+
185
+ self.hidden_size = config.hidden_size
186
+ self.num_heads = config.num_attention_heads
187
+ self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
188
+ self.num_key_value_heads = config.num_key_value_heads
189
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
190
+ self.is_causal = True
191
+ self.attention_dropout = config.attention_dropout
192
+ self.rope_parameters = config.rope_parameters
193
+ self.scaling = self.head_dim**-0.5
194
+ self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
195
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
196
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
197
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
198
+
199
+ def forward(
200
+ self,
201
+ hidden_states: torch.Tensor,
202
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
203
+ attention_mask: torch.Tensor | None = None,
204
+ past_key_values: Cache | None = None,
205
+ cache_position: torch.LongTensor | None = None,
206
+ **kwargs: Unpack[FlashAttentionKwargs],
207
+ ) -> tuple[torch.Tensor, torch.Tensor | None, tuple[torch.Tensor] | None]:
208
+ bsz, q_len, _ = hidden_states.size()
209
+
210
+ query_states = self.q_proj(hidden_states)
211
+ key_states = self.k_proj(hidden_states)
212
+ value_states = self.v_proj(hidden_states)
213
+
214
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
215
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
216
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
217
+
218
+ cos, sin = position_embeddings
219
+ query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
220
+
221
+ if past_key_values is not None:
222
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} # Specific to RoPE models
223
+ key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
224
+
225
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
226
+ self.config._attn_implementation, eager_attention_forward
227
+ )
228
+
229
+ attn_output, attn_weights = attention_interface(
230
+ self,
231
+ query_states,
232
+ key_states,
233
+ value_states,
234
+ attention_mask,
235
+ dropout=0.0 if not self.training else self.attention_dropout,
236
+ scaling=self.scaling,
237
+ **kwargs,
238
+ )
239
+
240
+ attn_output = attn_output.reshape(bsz, q_len, -1).contiguous()
241
+ attn_output = self.o_proj(attn_output)
242
+ return attn_output, attn_weights
243
+
244
+
245
+ class GlmOcrVisionRotaryEmbedding(nn.Module):
246
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
247
+
248
+ def __init__(self, dim: int, theta: float = 10000.0) -> None:
249
+ super().__init__()
250
+ self.dim = dim
251
+ self.theta = theta
252
+ inv_freq = 1.0 / (theta ** (torch.arange(0, dim, 2, dtype=torch.float) / dim))
253
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
254
+
255
+ def forward(self, seqlen: int) -> torch.Tensor:
256
+ seq = torch.arange(seqlen, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
257
+ freqs = torch.outer(seq, self.inv_freq)
258
+ return freqs
259
+
260
+
261
+ class GlmOcrTextMLP(nn.Module):
262
+ def __init__(self, config):
263
+ super().__init__()
264
+
265
+ self.config = config
266
+ self.gate_up_proj = nn.Linear(config.hidden_size, 2 * config.intermediate_size, bias=False)
267
+ self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
268
+ self.activation_fn = ACT2FN[config.hidden_act]
269
+
270
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
271
+ up_states = self.gate_up_proj(hidden_states)
272
+
273
+ gate, up_states = up_states.chunk(2, dim=-1)
274
+ up_states = up_states * self.activation_fn(gate)
275
+
276
+ return self.down_proj(up_states)
277
+
278
+
279
+ class GlmOcrTextDecoderLayer(GradientCheckpointingLayer):
280
+ def __init__(self, config: GlmOcrTextConfig, layer_idx: int):
281
+ super().__init__()
282
+ self.hidden_size = config.hidden_size
283
+ self.self_attn = GlmOcrTextAttention(config, layer_idx)
284
+ self.mlp = GlmOcrTextMLP(config)
285
+ self.input_layernorm = GlmOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
286
+ self.post_attention_layernorm = GlmOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
287
+ self.post_self_attn_layernorm = GlmOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
288
+ self.post_mlp_layernorm = GlmOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
289
+
290
+ @auto_docstring
291
+ def forward(
292
+ self,
293
+ hidden_states: torch.Tensor,
294
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
295
+ attention_mask: torch.Tensor | None = None,
296
+ position_ids: torch.LongTensor | None = None,
297
+ past_key_values: Cache | None = None,
298
+ use_cache: bool | None = False,
299
+ cache_position: torch.LongTensor | None = None,
300
+ **kwargs,
301
+ ) -> tuple[torch.FloatTensor, tuple[torch.FloatTensor, torch.FloatTensor] | None]:
302
+ residual = hidden_states
303
+
304
+ hidden_states = self.input_layernorm(hidden_states)
305
+
306
+ # Self Attention
307
+ hidden_states, _ = self.self_attn(
308
+ hidden_states=hidden_states,
309
+ position_embeddings=position_embeddings,
310
+ attention_mask=attention_mask,
311
+ position_ids=position_ids,
312
+ past_key_values=past_key_values,
313
+ use_cache=use_cache,
314
+ cache_position=cache_position,
315
+ **kwargs,
316
+ )
317
+
318
+ hidden_states = self.post_self_attn_layernorm(hidden_states)
319
+ hidden_states = residual + hidden_states
320
+
321
+ # Fully Connected
322
+ residual = hidden_states
323
+ hidden_states = self.post_attention_layernorm(hidden_states)
324
+ hidden_states = self.mlp(hidden_states)
325
+ hidden_states = self.post_mlp_layernorm(hidden_states)
326
+ hidden_states = residual + hidden_states
327
+
328
+ return hidden_states
329
+
330
+
331
+ @auto_docstring
332
+ class GlmOcrPreTrainedModel(PreTrainedModel):
333
+ config: GlmOcrConfig
334
+ base_model_prefix = "model"
335
+ input_modalities = ("image", "video", "text")
336
+ supports_gradient_checkpointing = True
337
+ _no_split_modules = ["GlmOcrTextDecoderLayer", "GlmOcrVisionBlock"]
338
+ _skip_keys_device_placement = "past_key_values"
339
+ _supports_flash_attn = True
340
+ _supports_sdpa = True
341
+
342
+ _can_compile_fullgraph = True
343
+ _supports_attention_backend = True
344
+ _can_record_outputs = {
345
+ "hidden_states": GlmOcrTextDecoderLayer,
346
+ "attentions": GlmOcrTextAttention,
347
+ }
348
+ _keys_to_ignore_on_load_unexpected = [r"model\.language_model\.layers\.16.*"]
349
+
350
+ def _init_weights(self, module):
351
+ super()._init_weights(module)
352
+ if isinstance(module, GlmOcrVisionRotaryEmbedding):
353
+ inv_freq = 1.0 / (module.theta ** (torch.arange(0, module.dim, 2, dtype=torch.float) / module.dim))
354
+ init.copy_(module.inv_freq, inv_freq)
355
+
356
+
357
+ @dataclass
358
+ @auto_docstring(
359
+ custom_intro="""
360
+ Base class for Llava outputs, with hidden states and attentions.
361
+ """
362
+ )
363
+ class GlmOcrModelOutputWithPast(ModelOutput):
364
+ r"""
365
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
366
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
367
+
368
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
369
+ `past_key_values` input) to speed up sequential decoding.
370
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
371
+ The rope index difference between sequence length and multimodal rope.
372
+ """
373
+
374
+ last_hidden_state: torch.FloatTensor | None = None
375
+ past_key_values: Cache | None = None
376
+ hidden_states: tuple[torch.FloatTensor] | None = None
377
+ attentions: tuple[torch.FloatTensor] | None = None
378
+ rope_deltas: torch.LongTensor | None = None
379
+
380
+
381
+ def rotate_half(x):
382
+ """Rotates half the hidden dims of the input."""
383
+ x1 = x[..., : x.shape[-1] // 2]
384
+ x2 = x[..., x.shape[-1] // 2 :]
385
+ return torch.cat((-x2, x1), dim=-1)
386
+
387
+
388
+ def apply_rotary_pos_emb_vision(
389
+ q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor
390
+ ) -> tuple[torch.Tensor, torch.Tensor]:
391
+ orig_q_dtype = q.dtype
392
+ orig_k_dtype = k.dtype
393
+ q, k = q.float(), k.float()
394
+ cos, sin = cos.unsqueeze(-2).float(), sin.unsqueeze(-2).float()
395
+ q_embed = (q * cos) + (rotate_half(q) * sin)
396
+ k_embed = (k * cos) + (rotate_half(k) * sin)
397
+ q_embed = q_embed.to(orig_q_dtype)
398
+ k_embed = k_embed.to(orig_k_dtype)
399
+ return q_embed, k_embed
400
+
401
+
402
+ class GlmOcrVisionAttention(nn.Module):
403
+ def __init__(self, config: GlmOcrVisionConfig) -> None:
404
+ super().__init__()
405
+ self.dim = config.hidden_size
406
+ self.num_heads = config.num_heads
407
+ self.head_dim = self.dim // self.num_heads
408
+ self.num_key_value_groups = 1 # needed for eager attention
409
+ self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3, bias=config.attention_bias)
410
+ self.proj = nn.Linear(config.hidden_size, config.hidden_size, bias=config.attention_bias)
411
+ self.scaling = self.head_dim**-0.5
412
+ self.config = config
413
+ self.attention_dropout = config.attention_dropout
414
+ self.is_causal = False
415
+ self.q_norm = GlmOcrRMSNorm(self.head_dim, eps=config.rms_norm_eps)
416
+ self.k_norm = GlmOcrRMSNorm(self.head_dim, eps=config.rms_norm_eps)
417
+
418
+ def forward(
419
+ self,
420
+ hidden_states: torch.Tensor,
421
+ cu_seqlens: torch.Tensor,
422
+ rotary_pos_emb: torch.Tensor | None = None,
423
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
424
+ **kwargs,
425
+ ) -> torch.Tensor:
426
+ seq_length = hidden_states.shape[0]
427
+ query_states, key_states, value_states = (
428
+ self.qkv(hidden_states).reshape(seq_length, 3, self.num_heads, -1).permute(1, 0, 2, 3).unbind(0)
429
+ )
430
+
431
+ query_states = self.q_norm(query_states)
432
+ key_states = self.k_norm(key_states)
433
+
434
+ cos, sin = position_embeddings
435
+ query_states, key_states = apply_rotary_pos_emb_vision(query_states, key_states, cos, sin)
436
+ query_states = query_states.transpose(0, 1).unsqueeze(0)
437
+ key_states = key_states.transpose(0, 1).unsqueeze(0)
438
+ value_states = value_states.transpose(0, 1).unsqueeze(0)
439
+
440
+ attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
441
+ self.config._attn_implementation, eager_attention_forward
442
+ )
443
+
444
+ if is_flash_attention_requested(self.config):
445
+ # Flash Attention: Use cu_seqlens for variable length attention
446
+ max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max()
447
+ attn_output, _ = attention_interface(
448
+ self,
449
+ query_states,
450
+ key_states,
451
+ value_states,
452
+ attention_mask=None,
453
+ scaling=self.scaling,
454
+ dropout=0.0 if not self.training else self.attention_dropout,
455
+ cu_seq_lens_q=cu_seqlens,
456
+ cu_seq_lens_k=cu_seqlens,
457
+ max_length_q=max_seqlen,
458
+ max_length_k=max_seqlen,
459
+ is_causal=False,
460
+ **kwargs,
461
+ )
462
+ else:
463
+ # Other implementations: Process each chunk separately
464
+ lengths = cu_seqlens[1:] - cu_seqlens[:-1]
465
+ splits = [
466
+ torch.split(tensor, lengths.tolist(), dim=2) for tensor in (query_states, key_states, value_states)
467
+ ]
468
+
469
+ attn_outputs = [
470
+ attention_interface(
471
+ self,
472
+ q,
473
+ k,
474
+ v,
475
+ attention_mask=None,
476
+ scaling=self.scaling,
477
+ dropout=0.0 if not self.training else self.attention_dropout,
478
+ is_causal=False,
479
+ **kwargs,
480
+ )[0]
481
+ for q, k, v in zip(*splits)
482
+ ]
483
+ attn_output = torch.cat(attn_outputs, dim=1)
484
+
485
+ attn_output = attn_output.reshape(seq_length, -1).contiguous()
486
+ attn_output = self.proj(attn_output)
487
+ return attn_output
488
+
489
+
490
+ class GlmOcrVisionBlock(GradientCheckpointingLayer):
491
+ def __init__(self, config) -> None:
492
+ super().__init__()
493
+ self.norm1 = GlmOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
494
+ self.norm2 = GlmOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
495
+ self.attn = GlmOcrVisionAttention(config)
496
+ self.mlp = GlmOcrVisionMlp(config, bias=config.attention_bias)
497
+
498
+ def forward(
499
+ self,
500
+ hidden_states: torch.Tensor,
501
+ cu_seqlens: torch.Tensor,
502
+ rotary_pos_emb: torch.Tensor | None = None,
503
+ position_embeddings: tuple[torch.Tensor, torch.Tensor] | None = None,
504
+ **kwargs,
505
+ ) -> torch.Tensor:
506
+ hidden_states = hidden_states + self.attn(
507
+ self.norm1(hidden_states),
508
+ cu_seqlens=cu_seqlens,
509
+ rotary_pos_emb=rotary_pos_emb,
510
+ position_embeddings=position_embeddings,
511
+ **kwargs,
512
+ )
513
+ hidden_states = hidden_states + self.mlp(self.norm2(hidden_states))
514
+ return hidden_states
515
+
516
+
517
+ class GlmOcrVisionPatchMerger(nn.Module):
518
+ def __init__(self, dim: int, context_dim: int, hidden_act: str, bias: bool = False) -> None:
519
+ super().__init__()
520
+ self.proj = nn.Linear(dim, dim, bias=bias)
521
+ self.post_projection_norm = LayerNorm(dim)
522
+ self.gate_proj = nn.Linear(dim, context_dim, bias=bias)
523
+ self.up_proj = nn.Linear(dim, context_dim, bias=bias)
524
+ self.down_proj = nn.Linear(context_dim, dim, bias=bias)
525
+ self.act1 = nn.GELU()
526
+ self.act_fn = ACT2FN[hidden_act]
527
+
528
+ def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
529
+ hidden_state = self.proj(hidden_state)
530
+ hidden_state = self.act1(self.post_projection_norm(hidden_state))
531
+ return self.down_proj(self.act_fn(self.gate_proj(hidden_state)) * self.up_proj(hidden_state))
532
+
533
+
534
+ class GlmOcrVisionPatchEmbed(nn.Module):
535
+ def __init__(self, config: GlmOcrVisionConfig) -> None:
536
+ super().__init__()
537
+ self.patch_size = config.patch_size
538
+ self.temporal_patch_size = config.temporal_patch_size
539
+ self.in_channels = config.in_channels
540
+ self.embed_dim = config.hidden_size
541
+
542
+ kernel_size = [self.temporal_patch_size, self.patch_size, self.patch_size]
543
+ self.proj = nn.Conv3d(self.in_channels, self.embed_dim, kernel_size=kernel_size, stride=kernel_size)
544
+
545
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
546
+ target_dtype = self.proj.weight.dtype
547
+ hidden_states = hidden_states.view(
548
+ -1, self.in_channels, self.temporal_patch_size, self.patch_size, self.patch_size
549
+ )
550
+ hidden_states = self.proj(hidden_states.to(dtype=target_dtype)).view(-1, self.embed_dim)
551
+ return hidden_states
552
+
553
+
554
+ class GlmOcrVisionModel(GlmOcrPreTrainedModel):
555
+ config: GlmOcrVisionConfig
556
+ input_modalities = ("image", "video")
557
+ _no_split_modules = ["GlmOcrVisionBlock"]
558
+ _can_record_outputs = {
559
+ "hidden_states": GlmOcrVisionBlock,
560
+ "attentions": GlmOcrVisionAttention,
561
+ }
562
+
563
+ def __init__(self, config) -> None:
564
+ super().__init__(config)
565
+ self.spatial_merge_size = config.spatial_merge_size
566
+ self.patch_size = config.patch_size
567
+ self.patch_embed = GlmOcrVisionPatchEmbed(config)
568
+
569
+ head_dim = config.hidden_size // config.num_heads
570
+ self.rotary_pos_emb = GlmOcrVisionRotaryEmbedding(head_dim // 2)
571
+
572
+ self.blocks = nn.ModuleList([GlmOcrVisionBlock(config) for _ in range(config.depth)])
573
+ self.merger = GlmOcrVisionPatchMerger(
574
+ dim=config.out_hidden_size,
575
+ context_dim=config.out_hidden_size * config.in_channels,
576
+ hidden_act=config.hidden_act,
577
+ )
578
+ self.downsample = nn.Conv2d(
579
+ in_channels=config.hidden_size,
580
+ out_channels=config.out_hidden_size,
581
+ kernel_size=config.spatial_merge_size,
582
+ stride=config.spatial_merge_size,
583
+ )
584
+ self.post_layernorm = GlmOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
585
+
586
+ self.gradient_checkpointing = False
587
+ self.post_init()
588
+
589
+ def rot_pos_emb(self, grid_thw):
590
+ pos_ids = []
591
+ for t, h, w in grid_thw:
592
+ hpos_ids = torch.arange(h).unsqueeze(1).expand(-1, w)
593
+ hpos_ids = hpos_ids.reshape(
594
+ h // self.spatial_merge_size,
595
+ self.spatial_merge_size,
596
+ w // self.spatial_merge_size,
597
+ self.spatial_merge_size,
598
+ )
599
+ hpos_ids = hpos_ids.permute(0, 2, 1, 3)
600
+ hpos_ids = hpos_ids.flatten()
601
+
602
+ wpos_ids = torch.arange(w).unsqueeze(0).expand(h, -1)
603
+ wpos_ids = wpos_ids.reshape(
604
+ h // self.spatial_merge_size,
605
+ self.spatial_merge_size,
606
+ w // self.spatial_merge_size,
607
+ self.spatial_merge_size,
608
+ )
609
+ wpos_ids = wpos_ids.permute(0, 2, 1, 3)
610
+ wpos_ids = wpos_ids.flatten()
611
+ pos_ids.append(torch.stack([hpos_ids, wpos_ids], dim=-1).repeat(t, 1))
612
+ pos_ids = torch.cat(pos_ids, dim=0)
613
+ max_grid_size = grid_thw[:, 1:].max()
614
+ rotary_pos_emb_full = self.rotary_pos_emb(max_grid_size)
615
+ rotary_pos_emb = rotary_pos_emb_full[pos_ids].flatten(1)
616
+ return rotary_pos_emb, pos_ids
617
+
618
+ @check_model_inputs
619
+ @auto_docstring
620
+ def forward(self, hidden_states: torch.Tensor, grid_thw: torch.Tensor, **kwargs) -> torch.Tensor:
621
+ r"""
622
+ hidden_states (`torch.Tensor` of shape `(seq_len, hidden_size)`):
623
+ The final hidden states of the model.
624
+ grid_thw (`torch.Tensor` of shape `(num_images_or_videos, 3)`):
625
+ The temporal, height and width of feature shape of each image in LLM.
626
+
627
+ Returns:
628
+ `torch.Tensor`: hidden_states.
629
+ """
630
+ hidden_states = self.patch_embed(hidden_states)
631
+ rotary_pos_emb, image_type_ids = self.rot_pos_emb(grid_thw)
632
+ emb = torch.cat((rotary_pos_emb, rotary_pos_emb), dim=-1)
633
+ position_embeddings = (emb.cos(), emb.sin())
634
+
635
+ cu_seqlens = torch.repeat_interleave(grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0]).cumsum(
636
+ dim=0,
637
+ # Select dtype based on the following factors:
638
+ # - FA2 requires that cu_seqlens_q must have dtype int32
639
+ # - torch.onnx.export requires that cu_seqlens_q must have same dtype as grid_thw
640
+ # See https://github.com/huggingface/transformers/pull/34852 for more information
641
+ dtype=grid_thw.dtype if torch.jit.is_tracing() else torch.int32,
642
+ )
643
+ cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0)
644
+
645
+ for blk in self.blocks:
646
+ hidden_states = blk(
647
+ hidden_states,
648
+ cu_seqlens=cu_seqlens,
649
+ position_embeddings=position_embeddings,
650
+ )
651
+
652
+ hidden_states = self.post_layernorm(hidden_states)
653
+
654
+ hidden_states = hidden_states.view(
655
+ -1, self.spatial_merge_size, self.spatial_merge_size, hidden_states.shape[-1]
656
+ )
657
+ hidden_states = hidden_states.permute(0, 3, 1, 2)
658
+ hidden_states = self.downsample(hidden_states).view(-1, self.config.out_hidden_size)
659
+
660
+ merged_hidden_states = self.merger(hidden_states)
661
+ return BaseModelOutputWithPooling(
662
+ last_hidden_state=hidden_states,
663
+ pooler_output=merged_hidden_states,
664
+ )
665
+
666
+
667
+ class GlmOcrTextRotaryEmbedding(nn.Module):
668
+ inv_freq: torch.Tensor # fix linting for `register_buffer`
669
+
670
+ def __init__(self, config: GlmOcrTextConfig, device=None):
671
+ super().__init__()
672
+ self.max_seq_len_cached = config.max_position_embeddings
673
+ self.original_max_seq_len = config.max_position_embeddings
674
+
675
+ self.config = config
676
+
677
+ self.rope_type = self.config.rope_parameters["rope_type"]
678
+ rope_init_fn: Callable = self.compute_default_rope_parameters
679
+ if self.rope_type != "default":
680
+ rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
681
+ inv_freq, self.attention_scaling = rope_init_fn(self.config, device)
682
+
683
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
684
+ self.register_buffer("original_inv_freq", inv_freq.clone(), persistent=False)
685
+ self.mrope_section = config.rope_parameters.get("mrope_section", [8, 12, 12])
686
+
687
+ @staticmethod
688
+ def compute_default_rope_parameters(
689
+ config: GlmOcrTextConfig | None = None,
690
+ device: Optional["torch.device"] = None,
691
+ seq_len: int | None = None,
692
+ ) -> tuple["torch.Tensor", float]:
693
+ """
694
+ Computes the inverse frequencies according to the original RoPE implementation
695
+ Args:
696
+ config ([`~transformers.PreTrainedConfig`]):
697
+ The model configuration.
698
+ device (`torch.device`):
699
+ The device to use for initialization of the inverse frequencies.
700
+ seq_len (`int`, *optional*):
701
+ The current sequence length. Unused for this type of RoPE.
702
+ Returns:
703
+ Tuple of (`torch.Tensor`, `float`), containing the inverse frequencies for the RoPE embeddings and the
704
+ post-processing scaling factor applied to the computed cos/sin (unused in this type of RoPE).
705
+ """
706
+ base = config.rope_parameters["rope_theta"]
707
+ partial_rotary_factor = config.rope_parameters.get("partial_rotary_factor", 1.0)
708
+ head_dim = getattr(config, "head_dim", None) or config.hidden_size // config.num_attention_heads
709
+ dim = int(head_dim * partial_rotary_factor)
710
+
711
+ attention_factor = 1.0 # Unused in this type of RoPE
712
+
713
+ # Compute the inverse frequencies
714
+ inv_freq = 1.0 / (
715
+ base ** (torch.arange(0, dim, 2, dtype=torch.int64).to(device=device, dtype=torch.float) / dim)
716
+ )
717
+ return inv_freq, attention_factor
718
+
719
+ @torch.no_grad()
720
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
721
+ def forward(self, x, position_ids):
722
+ # In contrast to other models, GLM-V has different position ids for the grids
723
+ # So we expand the inv_freq to shape (3, ...)
724
+ inv_freq_expanded = self.inv_freq[None, None, :, None].float().expand(3, position_ids.shape[1], -1, 1)
725
+ position_ids_expanded = position_ids[:, :, None, :].float() # shape (3, bs, 1, positions)
726
+
727
+ device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
728
+ with maybe_autocast(device_type=device_type, enabled=False): # Force float32
729
+ freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(2, 3)
730
+ freqs = self.apply_mrope(freqs, self.mrope_section)
731
+ emb = torch.cat((freqs, freqs), dim=-1)
732
+ cos = emb.cos() * self.attention_scaling
733
+ sin = emb.sin() * self.attention_scaling
734
+
735
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
736
+
737
+ def apply_mrope(self, freqs, mrope_section):
738
+ section = mrope_section
739
+ chunks = freqs.split(section, dim=-1)
740
+ result = torch.cat([chunk[i % 3] for i, chunk in enumerate(chunks)], dim=-1)
741
+ return result
742
+
743
+
744
+ @auto_docstring
745
+ class GlmOcrTextModel(GlmOcrPreTrainedModel):
746
+ config: GlmOcrTextConfig
747
+ input_modalities = ("text",)
748
+
749
+ def __init__(self, config: GlmOcrTextConfig):
750
+ super().__init__(config)
751
+ self.padding_idx = config.pad_token_id
752
+ self.vocab_size = config.vocab_size
753
+
754
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
755
+ self.layers = nn.ModuleList(
756
+ [GlmOcrTextDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
757
+ )
758
+ self.norm = GlmOcrRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
759
+ self.rotary_emb = GlmOcrTextRotaryEmbedding(config=config)
760
+
761
+ self.gradient_checkpointing = False
762
+ # Initialize weights and apply final processing
763
+ self.post_init()
764
+
765
+ @auto_docstring
766
+ @check_model_inputs
767
+ def forward(
768
+ self,
769
+ input_ids: torch.LongTensor | None = None,
770
+ attention_mask: torch.Tensor | None = None,
771
+ position_ids: torch.LongTensor | None = None,
772
+ past_key_values: Cache | None = None,
773
+ inputs_embeds: torch.FloatTensor | None = None,
774
+ use_cache: bool | None = None,
775
+ cache_position: torch.LongTensor | None = None,
776
+ **kwargs: Unpack[FlashAttentionKwargs],
777
+ ) -> tuple | BaseModelOutputWithPast:
778
+ if (input_ids is None) ^ (inputs_embeds is not None):
779
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
780
+
781
+ # torch.jit.trace() doesn't support cache objects in the output
782
+ if use_cache and past_key_values is None and not torch.jit.is_tracing():
783
+ past_key_values = DynamicCache(config=self.config)
784
+
785
+ if inputs_embeds is None:
786
+ inputs_embeds = self.embed_tokens(input_ids)
787
+
788
+ if cache_position is None:
789
+ past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
790
+ cache_position = torch.arange(
791
+ past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device
792
+ )
793
+
794
+ # the hard coded `3` is for temporal, height and width.
795
+ if position_ids is None:
796
+ position_ids = cache_position.view(1, 1, -1).expand(3, inputs_embeds.shape[0], -1)
797
+ elif position_ids.ndim == 2:
798
+ position_ids = position_ids[None, ...].expand(3, position_ids.shape[0], -1)
799
+
800
+ # NOTE: we need to pass text position ids for packing. Qwen2-VL uses 3D positions
801
+ # where each dim indicates visual spatial positions for temporal/height/width grids.
802
+ # There are two scenarios when FA2-like packed masking might be activated.
803
+ # 1. User specifically passed packed `position_ids` and no attention mask.
804
+ # In this case we expect the useer to create correct position ids for all 3 grids
805
+ # and prepend text-only position ids to it. The final tensor will be [4, bs, seq-len]
806
+ # 2. User runs forward with no attention mask and no position ids. In this case, position ids
807
+ # are prepared by the model (`get_rope_index`) as `[4, bs, seq-len]` tensor. Text-only positions are
808
+ # prepended by us when creating positions so that the mask is constructed correctly. NOTE: failing to pass
809
+ # text-only positions will cause incorrect mask construction, do not change `prepare_input_for_generation`
810
+ if position_ids.ndim == 3 and position_ids.shape[0] == 4:
811
+ text_position_ids = position_ids[0]
812
+ position_ids = position_ids[1:]
813
+ else:
814
+ # If inputs are not packed (usual 3D positions), do not prepare mask from position_ids
815
+ text_position_ids = None
816
+
817
+ mask_kwargs = {
818
+ "config": self.config,
819
+ "input_embeds": inputs_embeds,
820
+ "attention_mask": attention_mask,
821
+ "cache_position": cache_position,
822
+ "past_key_values": past_key_values,
823
+ "position_ids": text_position_ids,
824
+ }
825
+ # Create the masks
826
+ causal_mask = create_causal_mask(**mask_kwargs)
827
+
828
+ hidden_states = inputs_embeds
829
+ position_embeddings = self.rotary_emb(hidden_states, position_ids=position_ids)
830
+
831
+ for decoder_layer in self.layers:
832
+ layer_outputs = decoder_layer(
833
+ hidden_states,
834
+ attention_mask=causal_mask,
835
+ position_ids=text_position_ids,
836
+ past_key_values=past_key_values,
837
+ cache_position=cache_position,
838
+ position_embeddings=position_embeddings,
839
+ **kwargs,
840
+ )
841
+ hidden_states = layer_outputs
842
+
843
+ hidden_states = self.norm(hidden_states)
844
+
845
+ return BaseModelOutputWithPast(
846
+ last_hidden_state=hidden_states,
847
+ past_key_values=past_key_values,
848
+ )
849
+
850
+
851
+ @auto_docstring
852
+ class GlmOcrModel(GlmOcrPreTrainedModel):
853
+ base_model_prefix = "model"
854
+ _checkpoint_conversion_mapping = {}
855
+ # Reference: fix gemma3 grad acc #37208
856
+ accepts_loss_kwargs = False
857
+ config: GlmOcrConfig
858
+ _no_split_modules = ["GlmOcrTextDecoderLayer", "GlmOcrVisionBlock"]
859
+
860
+ def __init__(self, config):
861
+ super().__init__(config)
862
+ self.visual = GlmOcrVisionModel._from_config(config.vision_config)
863
+ self.language_model = GlmOcrTextModel._from_config(config.text_config)
864
+ self.rope_deltas = None # cache rope_deltas here
865
+
866
+ # Initialize weights and apply final processing
867
+ self.post_init()
868
+
869
+ def get_input_embeddings(self):
870
+ return self.language_model.get_input_embeddings()
871
+
872
+ def set_input_embeddings(self, value):
873
+ self.language_model.set_input_embeddings(value)
874
+
875
+ def get_rope_index(
876
+ self,
877
+ input_ids: torch.LongTensor | None = None,
878
+ image_grid_thw: torch.LongTensor | None = None,
879
+ video_grid_thw: torch.LongTensor | None = None,
880
+ attention_mask: torch.Tensor | None = None,
881
+ ) -> tuple[torch.Tensor, torch.Tensor]:
882
+ """
883
+ Calculate the 3D rope index based on image and video's temporal, height and width in LLM.
884
+
885
+ Explanation:
886
+ Each embedding sequence contains vision embedding and text embedding or just contains text embedding.
887
+
888
+ For pure text embedding sequence, the rotary position embedding has no difference with modern LLMs.
889
+ Examples:
890
+ input_ids: [T T T T T], here T is for text.
891
+ temporal position_ids: [0, 1, 2, 3, 4]
892
+ height position_ids: [0, 1, 2, 3, 4]
893
+ width position_ids: [0, 1, 2, 3, 4]
894
+
895
+ For vision and text embedding sequence, we calculate 3D rotary position embedding for vision part
896
+ and 1D rotary position embedding for text part.
897
+ Examples:
898
+ Temporal (Time): 3 patches, representing different segments of the video in time.
899
+ Height: 2 patches, dividing each frame vertically.
900
+ Width: 2 patches, dividing each frame horizontally.
901
+ We also have some important parameters:
902
+ fps (Frames Per Second): The video's frame rate, set to 1. This means one frame is processed each second.
903
+ tokens_per_second: This is a crucial parameter. It dictates how many "time-steps" or "temporal tokens" are conceptually packed into a one-second interval of the video. In this case, we have 25 tokens per second. So each second of the video will be represented with 25 separate time points. It essentially defines the temporal granularity.
904
+ temporal_patch_size: The number of frames that compose one temporal patch. Here, it's 2 frames.
905
+ interval: The step size for the temporal position IDs, calculated as tokens_per_second * temporal_patch_size / fps. In this case, 25 * 2 / 1 = 50. This means that each temporal patch will be have a difference of 50 in the temporal position IDs.
906
+ input_ids: [V V V V V V V V V V V V T T T T T], here V is for vision.
907
+ vision temporal position_ids: [0, 0, 0, 0, 50, 50, 50, 50, 100, 100, 100, 100]
908
+ vision height position_ids: [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 1]
909
+ vision width position_ids: [0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1]
910
+ text temporal position_ids: [101, 102, 103, 104, 105]
911
+ text height position_ids: [101, 102, 103, 104, 105]
912
+ text width position_ids: [101, 102, 103, 104, 105]
913
+ Here we calculate the text start position_ids as the max vision position_ids plus 1.
914
+
915
+ Args:
916
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
917
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
918
+ it.
919
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
920
+ The temporal, height and width of feature shape of each image in LLM.
921
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
922
+ The temporal, height and width of feature shape of each video in LLM.
923
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
924
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
925
+
926
+ - 1 for tokens that are **not masked**,
927
+ - 0 for tokens that are **masked**.
928
+
929
+ Returns:
930
+ position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`)
931
+ mrope_position_deltas (`torch.Tensor` of shape `(batch_size)`)
932
+ """
933
+
934
+ spatial_merge_size = self.config.vision_config.spatial_merge_size
935
+ image_token_id = self.config.image_token_id
936
+ video_start_token_id = self.config.video_start_token_id
937
+ video_end_token_id = self.config.video_end_token_id
938
+
939
+ mrope_position_deltas = []
940
+ if input_ids is not None and (image_grid_thw is not None or video_grid_thw is not None):
941
+ total_input_ids = input_ids
942
+ if attention_mask is None:
943
+ attention_mask = torch.ones_like(total_input_ids)
944
+ position_ids = torch.ones(
945
+ 3,
946
+ input_ids.shape[0],
947
+ input_ids.shape[1],
948
+ dtype=input_ids.dtype,
949
+ device=input_ids.device,
950
+ )
951
+ image_index, video_index = 0, 0
952
+ video_group_index = 0
953
+ attention_mask = attention_mask.to(total_input_ids.device)
954
+ for i, input_ids in enumerate(total_input_ids):
955
+ input_ids = input_ids[attention_mask[i] == 1]
956
+ input_tokens = input_ids.tolist()
957
+
958
+ input_token_type = []
959
+ video_check_flg = False
960
+ for token in input_tokens:
961
+ if token == video_start_token_id:
962
+ video_check_flg = True
963
+ elif token == video_end_token_id:
964
+ video_check_flg = False
965
+
966
+ if token == image_token_id and not video_check_flg:
967
+ input_token_type.append("image")
968
+ elif token == image_token_id and video_check_flg:
969
+ input_token_type.append("video")
970
+ else:
971
+ input_token_type.append("text")
972
+
973
+ input_type_group = []
974
+ for key, group in itertools.groupby(enumerate(input_token_type), lambda x: x[1]):
975
+ group = list(group)
976
+ start_index = group[0][0]
977
+ end_index = group[-1][0] + 1
978
+ input_type_group.append((key, start_index, end_index))
979
+
980
+ llm_pos_ids_list = []
981
+ video_frame_num = 1
982
+ for modality_type, start_idx, end_idx in input_type_group:
983
+ st_idx = llm_pos_ids_list[-1].max() + 1 if len(llm_pos_ids_list) > 0 else 0
984
+
985
+ if modality_type == "image":
986
+ t, h, w = (
987
+ image_grid_thw[image_index][0],
988
+ image_grid_thw[image_index][1],
989
+ image_grid_thw[image_index][2],
990
+ )
991
+ llm_grid_t, llm_grid_h, llm_grid_w = (
992
+ t.item(),
993
+ h.item() // spatial_merge_size,
994
+ w.item() // spatial_merge_size,
995
+ )
996
+
997
+ t_index = torch.arange(llm_grid_t).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
998
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(llm_grid_t, -1, llm_grid_w).flatten()
999
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(llm_grid_t, llm_grid_h, -1).flatten()
1000
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
1001
+
1002
+ image_index += 1
1003
+ video_frame_num = 1
1004
+
1005
+ elif modality_type == "video":
1006
+ t, h, w = (
1007
+ video_frame_num,
1008
+ video_grid_thw[video_index][1],
1009
+ video_grid_thw[video_index][2],
1010
+ )
1011
+
1012
+ llm_grid_t, llm_grid_h, llm_grid_w = (
1013
+ t,
1014
+ h.item() // spatial_merge_size,
1015
+ w.item() // spatial_merge_size,
1016
+ )
1017
+
1018
+ for t_idx in range(llm_grid_t):
1019
+ t_index = torch.tensor(t_idx).view(-1, 1).expand(-1, llm_grid_h * llm_grid_w).flatten()
1020
+
1021
+ h_index = torch.arange(llm_grid_h).view(1, -1, 1).expand(1, -1, llm_grid_w).flatten()
1022
+ w_index = torch.arange(llm_grid_w).view(1, 1, -1).expand(1, llm_grid_h, -1).flatten()
1023
+ llm_pos_ids_list.append(torch.stack([t_index, h_index, w_index]) + st_idx)
1024
+
1025
+ video_group_index += 1
1026
+
1027
+ if video_group_index >= video_grid_thw[video_index][0]:
1028
+ video_index += 1
1029
+ video_group_index = 0
1030
+
1031
+ video_frame_num += 1
1032
+
1033
+ else:
1034
+ text_len = end_idx - start_idx
1035
+ llm_pos_ids_list.append(torch.arange(text_len).view(1, -1).expand(3, -1) + st_idx)
1036
+
1037
+ video_frame_num = 1
1038
+
1039
+ llm_positions = torch.cat(llm_pos_ids_list, dim=1).reshape(3, -1)
1040
+ position_ids[..., i, attention_mask[i] == 1] = llm_positions.to(position_ids.device)
1041
+ mrope_position_deltas.append(llm_positions.max() + 1 - len(total_input_ids[i]))
1042
+ mrope_position_deltas = torch.tensor(mrope_position_deltas, device=input_ids.device).unsqueeze(1)
1043
+ return position_ids, mrope_position_deltas
1044
+ else:
1045
+ if attention_mask is not None:
1046
+ position_ids = attention_mask.long().cumsum(-1) - 1
1047
+ position_ids.masked_fill_(attention_mask == 0, 1)
1048
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1).to(attention_mask.device)
1049
+ max_position_ids = position_ids.max(0, keepdim=False)[0].max(-1, keepdim=True)[0]
1050
+ mrope_position_deltas = max_position_ids + 1 - attention_mask.shape[-1]
1051
+ else:
1052
+ position_ids = (
1053
+ torch.arange(input_ids.shape[1], device=input_ids.device)
1054
+ .view(1, 1, -1)
1055
+ .expand(3, input_ids.shape[0], -1)
1056
+ )
1057
+ mrope_position_deltas = torch.zeros(
1058
+ [input_ids.shape[0], 1],
1059
+ device=input_ids.device,
1060
+ dtype=input_ids.dtype,
1061
+ )
1062
+
1063
+ return position_ids, mrope_position_deltas
1064
+
1065
+ @can_return_tuple
1066
+ @auto_docstring
1067
+ def get_video_features(
1068
+ self,
1069
+ pixel_values_videos: torch.FloatTensor,
1070
+ video_grid_thw: torch.LongTensor | None = None,
1071
+ **kwargs: Unpack[TransformersKwargs],
1072
+ ) -> tuple | BaseModelOutputWithPooling:
1073
+ r"""
1074
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1075
+ The tensors corresponding to the input videos.
1076
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1077
+ The temporal, height and width of feature shape of each video in LLM.
1078
+ """
1079
+ pixel_values_videos = pixel_values_videos.type(self.visual.dtype)
1080
+ # reshape video_grid_thw -> [b, 3] -> [1, h, w] * frames
1081
+ temp_frames_hw = []
1082
+ for t, h, w in video_grid_thw:
1083
+ repeated_row = torch.tensor([1, h.item(), w.item()]).unsqueeze(0).repeat(t, 1)
1084
+ temp_frames_hw.append(repeated_row)
1085
+ flattened_video_grid_thw = torch.cat(temp_frames_hw, dim=0)
1086
+ vision_outputs = self.visual(
1087
+ pixel_values_videos, grid_thw=flattened_video_grid_thw, return_dict=True, **kwargs
1088
+ )
1089
+ split_sizes = (video_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1090
+ video_embeds = torch.split(vision_outputs.pooler_output, split_sizes)
1091
+ vision_outputs.pooler_output = video_embeds
1092
+
1093
+ return vision_outputs
1094
+
1095
+ @can_return_tuple
1096
+ @auto_docstring
1097
+ def get_image_features(
1098
+ self,
1099
+ pixel_values: torch.FloatTensor,
1100
+ image_grid_thw: torch.LongTensor | None = None,
1101
+ **kwargs: Unpack[TransformersKwargs],
1102
+ ) -> tuple | BaseModelOutputWithPooling:
1103
+ r"""
1104
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1105
+ The tensors corresponding to the input images.
1106
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1107
+ The temporal, height and width of feature shape of each image in LLM.
1108
+ """
1109
+ pixel_values = pixel_values.type(self.visual.dtype)
1110
+ vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs)
1111
+ split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
1112
+ image_embeds = torch.split(vision_outputs.pooler_output, split_sizes)
1113
+ vision_outputs.pooler_output = image_embeds
1114
+
1115
+ return vision_outputs
1116
+
1117
+ def get_placeholder_mask(
1118
+ self,
1119
+ input_ids: torch.LongTensor,
1120
+ inputs_embeds: torch.FloatTensor,
1121
+ image_features: torch.FloatTensor | None = None,
1122
+ video_features: torch.FloatTensor | None = None,
1123
+ ):
1124
+ """
1125
+ Obtains multimodal placeholder mask from `input_ids` or `inputs_embeds`, and checks that the placeholder token count is
1126
+ equal to the length of multimodal features. If the lengths are different, an error is raised.
1127
+ """
1128
+ if input_ids is None:
1129
+ special_image_mask = inputs_embeds == self.get_input_embeddings()(
1130
+ torch.tensor(self.config.image_token_id, dtype=torch.long, device=inputs_embeds.device)
1131
+ )
1132
+ special_image_mask = special_image_mask.all(-1)
1133
+ special_video_mask = inputs_embeds == self.get_input_embeddings()(
1134
+ torch.tensor(self.config.video_token_id, dtype=torch.long, device=inputs_embeds.device)
1135
+ )
1136
+ special_video_mask = special_video_mask.all(-1)
1137
+ else:
1138
+ # GLM-4.1V and GLM-4.5V special_video_mask is special_image_mask
1139
+ special_image_mask = input_ids == self.config.image_token_id
1140
+ special_video_mask = input_ids == self.config.image_token_id
1141
+
1142
+ n_image_tokens = special_image_mask.sum()
1143
+ special_image_mask = special_image_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1144
+ if image_features is not None:
1145
+ torch_compilable_check(
1146
+ inputs_embeds[special_image_mask].numel() == image_features.numel(),
1147
+ f"Image features and image tokens do not match, tokens: {n_image_tokens}, features: {image_features.shape[0]}",
1148
+ )
1149
+
1150
+ n_video_tokens = special_video_mask.sum()
1151
+ special_video_mask = special_video_mask.unsqueeze(-1).expand_as(inputs_embeds).to(inputs_embeds.device)
1152
+ if video_features is not None:
1153
+ torch_compilable_check(
1154
+ inputs_embeds[special_video_mask].numel() == video_features.numel(),
1155
+ f"Video features and video tokens do not match, tokens: {n_video_tokens}, features: {video_features.shape[0]}",
1156
+ )
1157
+ return special_image_mask, special_video_mask
1158
+
1159
+ @auto_docstring
1160
+ @can_return_tuple
1161
+ def forward(
1162
+ self,
1163
+ input_ids: torch.LongTensor | None = None,
1164
+ attention_mask: torch.Tensor | None = None,
1165
+ position_ids: torch.LongTensor | None = None,
1166
+ past_key_values: Cache | None = None,
1167
+ inputs_embeds: torch.FloatTensor | None = None,
1168
+ pixel_values: torch.Tensor | None = None,
1169
+ pixel_values_videos: torch.FloatTensor | None = None,
1170
+ image_grid_thw: torch.LongTensor | None = None,
1171
+ video_grid_thw: torch.LongTensor | None = None,
1172
+ rope_deltas: torch.LongTensor | None = None,
1173
+ cache_position: torch.LongTensor | None = None,
1174
+ **kwargs: Unpack[TransformersKwargs],
1175
+ ) -> tuple | GlmOcrModelOutputWithPast:
1176
+ r"""
1177
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1178
+ The temporal, height and width of feature shape of each image in LLM.
1179
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1180
+ The temporal, height and width of feature shape of each video in LLM.
1181
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1182
+ The rope index difference between sequence length and multimodal rope.
1183
+ """
1184
+ if (input_ids is None) ^ (inputs_embeds is not None):
1185
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
1186
+
1187
+ if inputs_embeds is None:
1188
+ inputs_embeds = self.get_input_embeddings()(input_ids)
1189
+
1190
+ if pixel_values is not None:
1191
+ image_embeds = self.get_image_features(pixel_values, image_grid_thw, return_dict=True).pooler_output
1192
+ image_embeds = torch.cat(image_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1193
+ image_mask, _ = self.get_placeholder_mask(input_ids, inputs_embeds, image_features=image_embeds)
1194
+ inputs_embeds = inputs_embeds.masked_scatter(image_mask, image_embeds)
1195
+
1196
+ if pixel_values_videos is not None:
1197
+ video_embeds = self.get_video_features(pixel_values_videos, video_grid_thw, return_dict=True).pooler_output
1198
+ video_embeds = torch.cat(video_embeds, dim=0).to(inputs_embeds.device, inputs_embeds.dtype)
1199
+ _, video_mask = self.get_placeholder_mask(input_ids, inputs_embeds, video_features=video_embeds)
1200
+ inputs_embeds = inputs_embeds.masked_scatter(video_mask, video_embeds)
1201
+
1202
+ if position_ids is None:
1203
+ attention_mask_tensor = (
1204
+ attention_mask if not isinstance(attention_mask, dict) else attention_mask["full_attention"]
1205
+ )
1206
+ if attention_mask_tensor is not None and attention_mask_tensor.ndim == 4:
1207
+ attention_mask_tensor = torch.diagonal(attention_mask_tensor[:, 0], dim1=1, dim2=2)
1208
+ # Only apply conversion for floating point tensors (inverted masks)
1209
+ if attention_mask_tensor.dtype.is_floating_point:
1210
+ attention_mask_tensor = attention_mask_tensor / torch.finfo(attention_mask_tensor.dtype).min
1211
+ attention_mask_tensor = (1.0 - attention_mask_tensor).int()
1212
+
1213
+ # Calculate RoPE index once per generation in the pre-fill stage only.
1214
+ # When compiling, we can't check tensor values thus we check only input length
1215
+ # It is safe to assume that `length!=1` means we're in pre-fill because compiled
1216
+ # models currently cannot do asssisted decoding
1217
+ prefill_compiled_stage = is_torchdynamo_compiling() and (
1218
+ (input_ids is not None and input_ids.shape[1] != 1)
1219
+ or (inputs_embeds is not None and inputs_embeds.shape[1] != 1)
1220
+ )
1221
+ prefill_noncompiled_stage = not is_torchdynamo_compiling() and (
1222
+ (cache_position is not None and cache_position[0] == 0)
1223
+ or (past_key_values is None or past_key_values.get_seq_length() == 0)
1224
+ )
1225
+ if (prefill_compiled_stage or prefill_noncompiled_stage) or self.rope_deltas is None:
1226
+ position_ids, rope_deltas = self.get_rope_index(
1227
+ input_ids,
1228
+ image_grid_thw,
1229
+ video_grid_thw,
1230
+ attention_mask=attention_mask_tensor,
1231
+ )
1232
+ self.rope_deltas = rope_deltas
1233
+ # then use the prev pre-calculated rope-deltas to get the correct position ids
1234
+ else:
1235
+ batch_size, seq_length, _ = inputs_embeds.shape
1236
+ delta = (
1237
+ (cache_position[0] + self.rope_deltas).to(inputs_embeds.device)
1238
+ if cache_position is not None
1239
+ else 0
1240
+ )
1241
+ position_ids = torch.arange(seq_length, device=inputs_embeds.device)
1242
+ position_ids = position_ids.view(1, -1).expand(batch_size, -1)
1243
+ if cache_position is not None: # otherwise `deltas` is an int `0`
1244
+ delta = delta.repeat_interleave(batch_size // delta.shape[0], dim=0)
1245
+ position_ids = position_ids.add(delta)
1246
+ position_ids = position_ids.unsqueeze(0).expand(3, -1, -1)
1247
+
1248
+ outputs = self.language_model(
1249
+ input_ids=None,
1250
+ position_ids=position_ids,
1251
+ attention_mask=attention_mask,
1252
+ past_key_values=past_key_values,
1253
+ inputs_embeds=inputs_embeds,
1254
+ cache_position=cache_position,
1255
+ **kwargs,
1256
+ )
1257
+
1258
+ return GlmOcrModelOutputWithPast(
1259
+ last_hidden_state=outputs.last_hidden_state,
1260
+ past_key_values=outputs.past_key_values,
1261
+ hidden_states=outputs.hidden_states,
1262
+ attentions=outputs.attentions,
1263
+ rope_deltas=self.rope_deltas,
1264
+ )
1265
+
1266
+
1267
+ @dataclass
1268
+ @auto_docstring(
1269
+ custom_intro="""
1270
+ Base class for GlmOcr causal language model (or autoregressive) outputs.
1271
+ """
1272
+ )
1273
+ class GlmOcrCausalLMOutputWithPast(ModelOutput):
1274
+ r"""
1275
+ loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
1276
+ Language modeling loss (for next-token prediction).
1277
+ logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
1278
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
1279
+ past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
1280
+ It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
1281
+
1282
+ Contains pre-computed hidden-states (key and values in the self-attention blocks) that can be used (see
1283
+ `past_key_values` input) to speed up sequential decoding.
1284
+ rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
1285
+ The rope index difference between sequence length and multimodal rope.
1286
+ """
1287
+
1288
+ loss: torch.FloatTensor | None = None
1289
+ logits: torch.FloatTensor | None = None
1290
+ past_key_values: Cache | None = None
1291
+ hidden_states: tuple[torch.FloatTensor] | None = None
1292
+ attentions: tuple[torch.FloatTensor] | None = None
1293
+ rope_deltas: torch.LongTensor | None = None
1294
+
1295
+
1296
+ class GlmOcrForConditionalGeneration(GlmOcrPreTrainedModel, GenerationMixin):
1297
+ _checkpoint_conversion_mapping = {}
1298
+ _tied_weights_keys = {"lm_head.weight": "model.language_model.embed_tokens.weight"}
1299
+ # Reference: fix gemma3 grad acc #37208
1300
+ accepts_loss_kwargs = False
1301
+
1302
+ def __init__(self, config):
1303
+ super().__init__(config)
1304
+ self.model = GlmOcrModel(config)
1305
+ self.lm_head = nn.Linear(config.text_config.hidden_size, config.text_config.vocab_size, bias=False)
1306
+
1307
+ self.post_init()
1308
+
1309
+ def get_input_embeddings(self):
1310
+ return self.model.get_input_embeddings()
1311
+
1312
+ def set_input_embeddings(self, value):
1313
+ self.model.set_input_embeddings(value)
1314
+
1315
+ @auto_docstring
1316
+ def get_video_features(
1317
+ self,
1318
+ pixel_values_videos: torch.FloatTensor,
1319
+ video_grid_thw: torch.LongTensor | None = None,
1320
+ **kwargs: Unpack[TransformersKwargs],
1321
+ ) -> tuple | BaseModelOutputWithPooling:
1322
+ r"""
1323
+ pixel_values_videos (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1324
+ The tensors corresponding to the input videos.
1325
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1326
+ The temporal, height and width of feature shape of each video in LLM.
1327
+ """
1328
+ return self.model.get_video_features(
1329
+ pixel_values_videos=pixel_values_videos, video_grid_thw=video_grid_thw, **kwargs
1330
+ )
1331
+
1332
+ @auto_docstring
1333
+ def get_image_features(
1334
+ self,
1335
+ pixel_values: torch.FloatTensor,
1336
+ image_grid_thw: torch.LongTensor | None = None,
1337
+ **kwargs: Unpack[TransformersKwargs],
1338
+ ) -> tuple | BaseModelOutputWithPooling:
1339
+ r"""
1340
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
1341
+ The tensors corresponding to the input images.
1342
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1343
+ The temporal, height and width of feature shape of each image in LLM.
1344
+ """
1345
+ return self.model.get_image_features(pixel_values=pixel_values, image_grid_thw=image_grid_thw, **kwargs)
1346
+
1347
+ @can_return_tuple
1348
+ @auto_docstring
1349
+ def forward(
1350
+ self,
1351
+ input_ids: torch.LongTensor | None = None,
1352
+ attention_mask: torch.Tensor | None = None,
1353
+ position_ids: torch.LongTensor | None = None,
1354
+ past_key_values: Cache | None = None,
1355
+ inputs_embeds: torch.FloatTensor | None = None,
1356
+ labels: torch.LongTensor | None = None,
1357
+ pixel_values: torch.Tensor | None = None,
1358
+ pixel_values_videos: torch.FloatTensor | None = None,
1359
+ image_grid_thw: torch.LongTensor | None = None,
1360
+ video_grid_thw: torch.LongTensor | None = None,
1361
+ cache_position: torch.LongTensor | None = None,
1362
+ logits_to_keep: int | torch.Tensor = 0,
1363
+ **kwargs: Unpack[TransformersKwargs],
1364
+ ) -> tuple | GlmOcrCausalLMOutputWithPast:
1365
+ r"""
1366
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
1367
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
1368
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
1369
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
1370
+ image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
1371
+ The temporal, height and width of feature shape of each image in LLM.
1372
+ video_grid_thw (`torch.LongTensor` of shape `(num_videos, 3)`, *optional*):
1373
+ The temporal, height and width of feature shape of each video in LLM.
1374
+
1375
+ Example:
1376
+
1377
+ ```python
1378
+ >>> from PIL import Image
1379
+ >>> import httpx
1380
+ >>> from io import BytesIO
1381
+ >>> from transformers import AutoProcessor, GlmOcrForConditionalGeneration
1382
+
1383
+ >>> model = GlmOcrForConditionalGeneration.from_pretrained("zai-org/GLM-4.1V-9B-Thinking")
1384
+ >>> processor = AutoProcessor.from_pretrained("zai-org/GLM-4.1V-9B-Thinking")
1385
+
1386
+ >>> messages = [
1387
+ {
1388
+ "role": "user",
1389
+ "content": [
1390
+ {"type": "image", "url": "https://www.ilankelman.org/stopsigns/australia.jpg"},
1391
+ {"type": "text", "text": "What is shown in this image?"},
1392
+ ],
1393
+ },
1394
+ ]
1395
+ >>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
1396
+ >>> with httpx.stream("GET", url) as response:
1397
+ ... image = Image.open(BytesIO(response.read()))
1398
+
1399
+ >>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
1400
+ >>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
1401
+
1402
+ >>> # Generate
1403
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
1404
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
1405
+ "The image shows a street scene with a red stop sign in the foreground. In the background, there is a large red gate with Chinese characters ..."
1406
+ ```"""
1407
+ outputs = self.model(
1408
+ input_ids=input_ids,
1409
+ pixel_values=pixel_values,
1410
+ pixel_values_videos=pixel_values_videos,
1411
+ image_grid_thw=image_grid_thw,
1412
+ video_grid_thw=video_grid_thw,
1413
+ position_ids=position_ids,
1414
+ attention_mask=attention_mask,
1415
+ past_key_values=past_key_values,
1416
+ inputs_embeds=inputs_embeds,
1417
+ cache_position=cache_position,
1418
+ **kwargs,
1419
+ )
1420
+
1421
+ hidden_states = outputs[0]
1422
+
1423
+ # Only compute necessary logits, and do not upcast them to float if we are not computing the loss
1424
+ slice_indices = slice(-logits_to_keep, None) if isinstance(logits_to_keep, int) else logits_to_keep
1425
+ logits = self.lm_head(hidden_states[:, slice_indices, :])
1426
+
1427
+ loss = None
1428
+ if labels is not None:
1429
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.config.text_config.vocab_size)
1430
+
1431
+ return GlmOcrCausalLMOutputWithPast(
1432
+ loss=loss,
1433
+ logits=logits,
1434
+ past_key_values=outputs.past_key_values,
1435
+ hidden_states=outputs.hidden_states,
1436
+ attentions=outputs.attentions,
1437
+ rope_deltas=outputs.rope_deltas,
1438
+ )
1439
+
1440
+ def prepare_inputs_for_generation(
1441
+ self,
1442
+ input_ids,
1443
+ past_key_values=None,
1444
+ attention_mask=None,
1445
+ inputs_embeds=None,
1446
+ cache_position=None,
1447
+ position_ids=None,
1448
+ use_cache=True,
1449
+ pixel_values=None,
1450
+ pixel_values_videos=None,
1451
+ image_grid_thw=None,
1452
+ video_grid_thw=None,
1453
+ is_first_iteration=False,
1454
+ **kwargs,
1455
+ ):
1456
+ # Overwritten -- in specific circumstances we don't want to forward image inputs to the model
1457
+
1458
+ model_inputs = super().prepare_inputs_for_generation(
1459
+ input_ids,
1460
+ past_key_values=past_key_values,
1461
+ attention_mask=attention_mask,
1462
+ inputs_embeds=inputs_embeds,
1463
+ cache_position=cache_position,
1464
+ position_ids=position_ids,
1465
+ pixel_values=pixel_values,
1466
+ pixel_values_videos=pixel_values_videos,
1467
+ image_grid_thw=image_grid_thw,
1468
+ video_grid_thw=video_grid_thw,
1469
+ use_cache=use_cache,
1470
+ is_first_iteration=is_first_iteration,
1471
+ **kwargs,
1472
+ )
1473
+
1474
+ # GLM-V position_ids are prepared with rope_deltas in forward
1475
+ model_inputs["position_ids"] = None
1476
+
1477
+ if not is_first_iteration and use_cache:
1478
+ model_inputs["pixel_values"] = None
1479
+ model_inputs["pixel_values_videos"] = None
1480
+
1481
+ return model_inputs
1482
+
1483
+ def _get_image_nums_and_video_nums(
1484
+ self,
1485
+ input_ids: torch.LongTensor | None,
1486
+ inputs_embeds: torch.Tensor | None = None,
1487
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1488
+ """
1489
+ Get the number of images and videos for each sample to calculate the separation length of the sample tensor.
1490
+ These parameters are not passed through the processor to avoid unpredictable impacts from interface modifications.
1491
+
1492
+ Args:
1493
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
1494
+ Indices of input sequence tokens in the vocabulary.
1495
+
1496
+ Returns:
1497
+ image_nums (`torch.LongTensor` of shape `(batch_size, num_images_sample)`)
1498
+ video_nums (`torch.LongTensor` of shape `(batch_size, num_videos_sample)`)
1499
+ """
1500
+
1501
+ if inputs_embeds is not None:
1502
+ is_image = (
1503
+ inputs_embeds
1504
+ == self.get_input_embeddings()(
1505
+ torch.tensor(self.config.image_start_token_id, dtype=torch.long, device=inputs_embeds.device)
1506
+ )
1507
+ )[..., 0]
1508
+ is_video_start = (
1509
+ inputs_embeds
1510
+ == self.get_input_embeddings()(
1511
+ torch.tensor(self.config.video_start_token_id, dtype=torch.long, device=inputs_embeds.device)
1512
+ )
1513
+ )[..., 0]
1514
+ is_video_end = (
1515
+ inputs_embeds
1516
+ == self.get_input_embeddings()(
1517
+ torch.tensor(self.config.video_end_token_id, dtype=torch.long, device=inputs_embeds.device)
1518
+ )
1519
+ )[..., 0]
1520
+ else:
1521
+ is_image = input_ids == self.config.image_start_token_id
1522
+ is_video_start = input_ids == self.config.video_start_token_id
1523
+ is_video_end = input_ids == self.config.video_end_token_id
1524
+
1525
+ # Cumulative sum to track if we're inside a video span
1526
+ # We'll assume well-formed video tags (i.e. matching starts and ends)
1527
+ video_level = torch.cumsum(is_video_start.int() - is_video_end.int(), dim=1)
1528
+ inside_video = video_level > 0 # shape (batch_size, seq_length)
1529
+
1530
+ # Mask out image tokens that are inside video spans
1531
+ standalone_images = is_image & (~inside_video)
1532
+
1533
+ # Count per batch
1534
+ image_counts = standalone_images.sum(dim=1)
1535
+ video_counts = is_video_start.sum(dim=1)
1536
+
1537
+ return image_counts, video_counts
1538
+
1539
+ def _expand_inputs_for_generation(
1540
+ self,
1541
+ expand_size: int = 1,
1542
+ is_encoder_decoder: bool = False,
1543
+ input_ids: torch.LongTensor | None = None,
1544
+ **model_kwargs,
1545
+ ) -> tuple[torch.LongTensor, dict[str, Any]]:
1546
+ # Overwritten -- Support for expanding tensors without a batch size dimension
1547
+ # e.g., pixel_values, image_grid_thw, pixel_values_videos, video_grid_thw, second_per_grid_t
1548
+ # pixel_values.shape[0] is sum(seqlen_images for samples)
1549
+ # image_grid_thw.shape[0] is sum(num_images for samples)
1550
+
1551
+ if expand_size == 1:
1552
+ return input_ids, model_kwargs
1553
+
1554
+ visual_keys = ["pixel_values", "image_grid_thw", "pixel_values_videos", "video_grid_thw", "second_per_grid_ts"]
1555
+
1556
+ def _expand_dict_for_generation_visual(dict_to_expand):
1557
+ image_grid_thw = model_kwargs.get("image_grid_thw", None)
1558
+ video_grid_thw = model_kwargs.get("video_grid_thw", None)
1559
+ image_nums, video_nums = self._get_image_nums_and_video_nums(
1560
+ input_ids, inputs_embeds=model_kwargs.get("inputs_embeds", None)
1561
+ )
1562
+
1563
+ def _repeat_interleave_samples(x, lengths, repeat_times):
1564
+ samples = torch.split(x, lengths)
1565
+ repeat_args = [repeat_times] + [1] * (x.dim() - 1)
1566
+ result = torch.cat([sample.repeat(*repeat_args) for sample in samples], dim=0)
1567
+ return result
1568
+
1569
+ for key in dict_to_expand:
1570
+ if key == "pixel_values":
1571
+ # split images into samples
1572
+ samples = torch.split(image_grid_thw, list(image_nums))
1573
+ # compute the sequence length of images for each sample
1574
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1575
+ dict_to_expand[key] = _repeat_interleave_samples(
1576
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1577
+ )
1578
+ elif key == "image_grid_thw":
1579
+ # get the num of images for each sample
1580
+ lengths = list(image_nums)
1581
+ dict_to_expand[key] = _repeat_interleave_samples(
1582
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1583
+ )
1584
+ elif key == "pixel_values_videos":
1585
+ samples = torch.split(video_grid_thw, list(video_nums))
1586
+ lengths = [torch.prod(sample, dim=1).sum() for sample in samples]
1587
+ dict_to_expand[key] = _repeat_interleave_samples(
1588
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1589
+ )
1590
+ elif key == "video_grid_thw":
1591
+ lengths = list(video_nums)
1592
+ dict_to_expand[key] = _repeat_interleave_samples(
1593
+ dict_to_expand[key], lengths=lengths, repeat_times=expand_size
1594
+ )
1595
+ elif key == "second_per_grid_ts":
1596
+ dict_to_expand[key] = _repeat_interleave_samples(
1597
+ dict_to_expand[key], lengths=list(video_nums), repeat_times=expand_size
1598
+ )
1599
+ return dict_to_expand
1600
+
1601
+ def _expand_dict_for_generation(dict_to_expand):
1602
+ for key in dict_to_expand:
1603
+ if (
1604
+ key != "cache_position"
1605
+ and dict_to_expand[key] is not None
1606
+ and isinstance(dict_to_expand[key], torch.Tensor)
1607
+ and key not in visual_keys
1608
+ ):
1609
+ dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
1610
+ return dict_to_expand
1611
+
1612
+ model_kwargs = _expand_dict_for_generation_visual(model_kwargs)
1613
+
1614
+ if input_ids is not None:
1615
+ input_ids = input_ids.repeat_interleave(expand_size, dim=0)
1616
+
1617
+ model_kwargs = _expand_dict_for_generation(model_kwargs)
1618
+
1619
+ if is_encoder_decoder:
1620
+ if model_kwargs.get("encoder_outputs") is None:
1621
+ raise ValueError("If `is_encoder_decoder` is True, make sure that `encoder_outputs` is defined.")
1622
+ model_kwargs["encoder_outputs"] = _expand_dict_for_generation(model_kwargs["encoder_outputs"])
1623
+
1624
+ return input_ids, model_kwargs
1625
+
1626
+
1627
+ __all__ = [
1628
+ "GlmOcrTextModel",
1629
+ "GlmOcrVisionModel",
1630
+ "GlmOcrModel",
1631
+ "GlmOcrPreTrainedModel",
1632
+ "GlmOcrForConditionalGeneration",
1633
+ ]