transformers 5.0.0rc3__py3-none-any.whl → 5.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- transformers/__init__.py +4 -11
- transformers/activations.py +2 -2
- transformers/backbone_utils.py +326 -0
- transformers/cache_utils.py +11 -2
- transformers/cli/serve.py +11 -8
- transformers/configuration_utils.py +1 -69
- transformers/conversion_mapping.py +146 -26
- transformers/convert_slow_tokenizer.py +6 -4
- transformers/core_model_loading.py +207 -118
- transformers/dependency_versions_check.py +0 -1
- transformers/dependency_versions_table.py +7 -8
- transformers/file_utils.py +0 -2
- transformers/generation/candidate_generator.py +1 -2
- transformers/generation/continuous_batching/cache.py +40 -38
- transformers/generation/continuous_batching/cache_manager.py +3 -16
- transformers/generation/continuous_batching/continuous_api.py +94 -406
- transformers/generation/continuous_batching/input_ouputs.py +464 -0
- transformers/generation/continuous_batching/requests.py +54 -17
- transformers/generation/continuous_batching/scheduler.py +77 -95
- transformers/generation/logits_process.py +10 -5
- transformers/generation/stopping_criteria.py +1 -2
- transformers/generation/utils.py +75 -95
- transformers/image_processing_utils.py +0 -3
- transformers/image_processing_utils_fast.py +17 -18
- transformers/image_transforms.py +44 -13
- transformers/image_utils.py +0 -5
- transformers/initialization.py +57 -0
- transformers/integrations/__init__.py +10 -24
- transformers/integrations/accelerate.py +47 -11
- transformers/integrations/deepspeed.py +145 -3
- transformers/integrations/executorch.py +2 -6
- transformers/integrations/finegrained_fp8.py +142 -7
- transformers/integrations/flash_attention.py +2 -7
- transformers/integrations/hub_kernels.py +18 -7
- transformers/integrations/moe.py +226 -106
- transformers/integrations/mxfp4.py +47 -34
- transformers/integrations/peft.py +488 -176
- transformers/integrations/tensor_parallel.py +641 -581
- transformers/masking_utils.py +153 -9
- transformers/modeling_flash_attention_utils.py +1 -2
- transformers/modeling_utils.py +359 -358
- transformers/models/__init__.py +6 -0
- transformers/models/afmoe/configuration_afmoe.py +14 -4
- transformers/models/afmoe/modeling_afmoe.py +8 -8
- transformers/models/afmoe/modular_afmoe.py +7 -7
- transformers/models/aimv2/configuration_aimv2.py +2 -7
- transformers/models/aimv2/modeling_aimv2.py +26 -24
- transformers/models/aimv2/modular_aimv2.py +8 -12
- transformers/models/albert/configuration_albert.py +8 -1
- transformers/models/albert/modeling_albert.py +3 -3
- transformers/models/align/configuration_align.py +8 -5
- transformers/models/align/modeling_align.py +22 -24
- transformers/models/altclip/configuration_altclip.py +4 -6
- transformers/models/altclip/modeling_altclip.py +30 -26
- transformers/models/apertus/configuration_apertus.py +5 -7
- transformers/models/apertus/modeling_apertus.py +4 -4
- transformers/models/apertus/modular_apertus.py +8 -10
- transformers/models/arcee/configuration_arcee.py +5 -7
- transformers/models/arcee/modeling_arcee.py +4 -4
- transformers/models/aria/configuration_aria.py +11 -21
- transformers/models/aria/modeling_aria.py +39 -36
- transformers/models/aria/modular_aria.py +33 -39
- transformers/models/audio_spectrogram_transformer/modeling_audio_spectrogram_transformer.py +3 -3
- transformers/models/audioflamingo3/modeling_audioflamingo3.py +39 -30
- transformers/models/audioflamingo3/modular_audioflamingo3.py +41 -27
- transformers/models/auto/auto_factory.py +8 -6
- transformers/models/auto/configuration_auto.py +22 -0
- transformers/models/auto/image_processing_auto.py +17 -13
- transformers/models/auto/modeling_auto.py +15 -0
- transformers/models/auto/processing_auto.py +9 -18
- transformers/models/auto/tokenization_auto.py +17 -15
- transformers/models/autoformer/modeling_autoformer.py +2 -1
- transformers/models/aya_vision/configuration_aya_vision.py +4 -0
- transformers/models/aya_vision/modeling_aya_vision.py +29 -62
- transformers/models/aya_vision/modular_aya_vision.py +20 -45
- transformers/models/bamba/configuration_bamba.py +17 -7
- transformers/models/bamba/modeling_bamba.py +23 -55
- transformers/models/bamba/modular_bamba.py +19 -54
- transformers/models/bark/configuration_bark.py +2 -1
- transformers/models/bark/modeling_bark.py +24 -10
- transformers/models/bart/configuration_bart.py +9 -4
- transformers/models/bart/modeling_bart.py +9 -12
- transformers/models/beit/configuration_beit.py +2 -4
- transformers/models/beit/image_processing_beit_fast.py +3 -3
- transformers/models/beit/modeling_beit.py +14 -9
- transformers/models/bert/configuration_bert.py +12 -1
- transformers/models/bert/modeling_bert.py +6 -30
- transformers/models/bert_generation/configuration_bert_generation.py +17 -1
- transformers/models/bert_generation/modeling_bert_generation.py +6 -6
- transformers/models/big_bird/configuration_big_bird.py +12 -8
- transformers/models/big_bird/modeling_big_bird.py +0 -15
- transformers/models/bigbird_pegasus/configuration_bigbird_pegasus.py +9 -8
- transformers/models/bigbird_pegasus/modeling_bigbird_pegasus.py +9 -7
- transformers/models/biogpt/configuration_biogpt.py +8 -1
- transformers/models/biogpt/modeling_biogpt.py +4 -8
- transformers/models/biogpt/modular_biogpt.py +1 -5
- transformers/models/bit/configuration_bit.py +2 -4
- transformers/models/bit/modeling_bit.py +6 -5
- transformers/models/bitnet/configuration_bitnet.py +5 -7
- transformers/models/bitnet/modeling_bitnet.py +3 -4
- transformers/models/bitnet/modular_bitnet.py +3 -4
- transformers/models/blenderbot/configuration_blenderbot.py +8 -4
- transformers/models/blenderbot/modeling_blenderbot.py +4 -4
- transformers/models/blenderbot_small/configuration_blenderbot_small.py +8 -4
- transformers/models/blenderbot_small/modeling_blenderbot_small.py +4 -4
- transformers/models/blip/configuration_blip.py +9 -9
- transformers/models/blip/modeling_blip.py +55 -37
- transformers/models/blip_2/configuration_blip_2.py +2 -1
- transformers/models/blip_2/modeling_blip_2.py +81 -56
- transformers/models/bloom/configuration_bloom.py +5 -1
- transformers/models/bloom/modeling_bloom.py +2 -1
- transformers/models/blt/configuration_blt.py +23 -12
- transformers/models/blt/modeling_blt.py +20 -14
- transformers/models/blt/modular_blt.py +70 -10
- transformers/models/bridgetower/configuration_bridgetower.py +7 -1
- transformers/models/bridgetower/image_processing_bridgetower_fast.py +6 -6
- transformers/models/bridgetower/modeling_bridgetower.py +29 -15
- transformers/models/bros/configuration_bros.py +24 -17
- transformers/models/camembert/configuration_camembert.py +8 -1
- transformers/models/camembert/modeling_camembert.py +6 -6
- transformers/models/canine/configuration_canine.py +4 -1
- transformers/models/chameleon/configuration_chameleon.py +5 -7
- transformers/models/chameleon/image_processing_chameleon_fast.py +5 -5
- transformers/models/chameleon/modeling_chameleon.py +82 -36
- transformers/models/chinese_clip/configuration_chinese_clip.py +10 -7
- transformers/models/chinese_clip/modeling_chinese_clip.py +28 -29
- transformers/models/clap/configuration_clap.py +4 -8
- transformers/models/clap/modeling_clap.py +21 -22
- transformers/models/clip/configuration_clip.py +4 -1
- transformers/models/clip/image_processing_clip_fast.py +9 -0
- transformers/models/clip/modeling_clip.py +25 -22
- transformers/models/clipseg/configuration_clipseg.py +4 -1
- transformers/models/clipseg/modeling_clipseg.py +27 -25
- transformers/models/clipseg/processing_clipseg.py +11 -3
- transformers/models/clvp/configuration_clvp.py +14 -2
- transformers/models/clvp/modeling_clvp.py +19 -30
- transformers/models/codegen/configuration_codegen.py +4 -3
- transformers/models/codegen/modeling_codegen.py +2 -1
- transformers/models/cohere/configuration_cohere.py +5 -7
- transformers/models/cohere/modeling_cohere.py +4 -4
- transformers/models/cohere/modular_cohere.py +3 -3
- transformers/models/cohere2/configuration_cohere2.py +6 -8
- transformers/models/cohere2/modeling_cohere2.py +4 -4
- transformers/models/cohere2/modular_cohere2.py +9 -11
- transformers/models/cohere2_vision/configuration_cohere2_vision.py +5 -1
- transformers/models/cohere2_vision/image_processing_cohere2_vision_fast.py +3 -3
- transformers/models/cohere2_vision/modeling_cohere2_vision.py +24 -25
- transformers/models/cohere2_vision/modular_cohere2_vision.py +20 -20
- transformers/models/colqwen2/modeling_colqwen2.py +7 -6
- transformers/models/colqwen2/modular_colqwen2.py +7 -6
- transformers/models/conditional_detr/configuration_conditional_detr.py +19 -46
- transformers/models/conditional_detr/image_processing_conditional_detr.py +3 -4
- transformers/models/conditional_detr/image_processing_conditional_detr_fast.py +28 -14
- transformers/models/conditional_detr/modeling_conditional_detr.py +794 -942
- transformers/models/conditional_detr/modular_conditional_detr.py +901 -3
- transformers/models/convbert/configuration_convbert.py +11 -7
- transformers/models/convnext/configuration_convnext.py +2 -4
- transformers/models/convnext/image_processing_convnext_fast.py +2 -2
- transformers/models/convnext/modeling_convnext.py +7 -6
- transformers/models/convnextv2/configuration_convnextv2.py +2 -4
- transformers/models/convnextv2/modeling_convnextv2.py +7 -6
- transformers/models/cpmant/configuration_cpmant.py +4 -0
- transformers/models/csm/configuration_csm.py +9 -15
- transformers/models/csm/modeling_csm.py +3 -3
- transformers/models/ctrl/configuration_ctrl.py +16 -0
- transformers/models/ctrl/modeling_ctrl.py +13 -25
- transformers/models/cwm/configuration_cwm.py +5 -7
- transformers/models/cwm/modeling_cwm.py +4 -4
- transformers/models/d_fine/configuration_d_fine.py +10 -56
- transformers/models/d_fine/modeling_d_fine.py +728 -868
- transformers/models/d_fine/modular_d_fine.py +335 -412
- transformers/models/dab_detr/configuration_dab_detr.py +22 -48
- transformers/models/dab_detr/modeling_dab_detr.py +11 -7
- transformers/models/dac/modeling_dac.py +1 -1
- transformers/models/data2vec/configuration_data2vec_audio.py +4 -1
- transformers/models/data2vec/configuration_data2vec_text.py +11 -2
- transformers/models/data2vec/modeling_data2vec_audio.py +3 -3
- transformers/models/data2vec/modeling_data2vec_text.py +6 -6
- transformers/models/data2vec/modeling_data2vec_vision.py +4 -2
- transformers/models/dbrx/configuration_dbrx.py +11 -3
- transformers/models/dbrx/modeling_dbrx.py +6 -6
- transformers/models/dbrx/modular_dbrx.py +6 -6
- transformers/models/deberta/configuration_deberta.py +6 -0
- transformers/models/deberta_v2/configuration_deberta_v2.py +6 -0
- transformers/models/decision_transformer/configuration_decision_transformer.py +3 -1
- transformers/models/decision_transformer/modeling_decision_transformer.py +3 -3
- transformers/models/deepseek_v2/configuration_deepseek_v2.py +7 -10
- transformers/models/deepseek_v2/modeling_deepseek_v2.py +7 -8
- transformers/models/deepseek_v2/modular_deepseek_v2.py +8 -10
- transformers/models/deepseek_v3/configuration_deepseek_v3.py +7 -10
- transformers/models/deepseek_v3/modeling_deepseek_v3.py +7 -7
- transformers/models/deepseek_v3/modular_deepseek_v3.py +6 -5
- transformers/models/deepseek_vl/configuration_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl/image_processing_deepseek_vl.py +2 -2
- transformers/models/deepseek_vl/image_processing_deepseek_vl_fast.py +5 -5
- transformers/models/deepseek_vl/modeling_deepseek_vl.py +17 -12
- transformers/models/deepseek_vl/modular_deepseek_vl.py +4 -0
- transformers/models/deepseek_vl_hybrid/configuration_deepseek_vl_hybrid.py +4 -0
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid.py +2 -2
- transformers/models/deepseek_vl_hybrid/image_processing_deepseek_vl_hybrid_fast.py +6 -6
- transformers/models/deepseek_vl_hybrid/modeling_deepseek_vl_hybrid.py +68 -24
- transformers/models/deepseek_vl_hybrid/modular_deepseek_vl_hybrid.py +70 -19
- transformers/models/deformable_detr/configuration_deformable_detr.py +22 -45
- transformers/models/deformable_detr/image_processing_deformable_detr_fast.py +25 -11
- transformers/models/deformable_detr/modeling_deformable_detr.py +410 -607
- transformers/models/deformable_detr/modular_deformable_detr.py +1385 -3
- transformers/models/deit/modeling_deit.py +11 -7
- transformers/models/depth_anything/configuration_depth_anything.py +12 -42
- transformers/models/depth_anything/modeling_depth_anything.py +5 -3
- transformers/models/depth_pro/image_processing_depth_pro_fast.py +2 -2
- transformers/models/depth_pro/modeling_depth_pro.py +8 -4
- transformers/models/detr/configuration_detr.py +18 -49
- transformers/models/detr/image_processing_detr_fast.py +11 -11
- transformers/models/detr/modeling_detr.py +695 -734
- transformers/models/dia/configuration_dia.py +4 -7
- transformers/models/dia/generation_dia.py +8 -17
- transformers/models/dia/modeling_dia.py +7 -7
- transformers/models/dia/modular_dia.py +4 -4
- transformers/models/diffllama/configuration_diffllama.py +5 -7
- transformers/models/diffllama/modeling_diffllama.py +3 -8
- transformers/models/diffllama/modular_diffllama.py +2 -7
- transformers/models/dinat/configuration_dinat.py +2 -4
- transformers/models/dinat/modeling_dinat.py +7 -6
- transformers/models/dinov2/configuration_dinov2.py +2 -4
- transformers/models/dinov2/modeling_dinov2.py +9 -8
- transformers/models/dinov2_with_registers/configuration_dinov2_with_registers.py +2 -4
- transformers/models/dinov2_with_registers/modeling_dinov2_with_registers.py +9 -8
- transformers/models/dinov2_with_registers/modular_dinov2_with_registers.py +6 -7
- transformers/models/dinov3_convnext/configuration_dinov3_convnext.py +2 -4
- transformers/models/dinov3_convnext/modeling_dinov3_convnext.py +2 -3
- transformers/models/dinov3_vit/configuration_dinov3_vit.py +2 -4
- transformers/models/dinov3_vit/image_processing_dinov3_vit_fast.py +2 -2
- transformers/models/dinov3_vit/modeling_dinov3_vit.py +5 -6
- transformers/models/dinov3_vit/modular_dinov3_vit.py +5 -6
- transformers/models/distilbert/configuration_distilbert.py +8 -1
- transformers/models/distilbert/modeling_distilbert.py +3 -3
- transformers/models/doge/configuration_doge.py +17 -7
- transformers/models/doge/modeling_doge.py +4 -4
- transformers/models/doge/modular_doge.py +20 -10
- transformers/models/donut/image_processing_donut_fast.py +4 -4
- transformers/models/dots1/configuration_dots1.py +16 -7
- transformers/models/dots1/modeling_dots1.py +4 -4
- transformers/models/dpr/configuration_dpr.py +19 -1
- transformers/models/dpt/configuration_dpt.py +23 -65
- transformers/models/dpt/image_processing_dpt_fast.py +5 -5
- transformers/models/dpt/modeling_dpt.py +19 -15
- transformers/models/dpt/modular_dpt.py +4 -4
- transformers/models/edgetam/configuration_edgetam.py +1 -1
- transformers/models/edgetam/modeling_edgetam.py +53 -53
- transformers/models/edgetam/modular_edgetam.py +5 -7
- transformers/models/edgetam_video/modeling_edgetam_video.py +55 -56
- transformers/models/edgetam_video/modular_edgetam_video.py +9 -9
- transformers/models/efficientloftr/image_processing_efficientloftr_fast.py +4 -3
- transformers/models/efficientloftr/modeling_efficientloftr.py +19 -9
- transformers/models/efficientnet/image_processing_efficientnet_fast.py +2 -2
- transformers/models/electra/configuration_electra.py +13 -2
- transformers/models/electra/modeling_electra.py +6 -6
- transformers/models/emu3/configuration_emu3.py +12 -10
- transformers/models/emu3/modeling_emu3.py +84 -47
- transformers/models/emu3/modular_emu3.py +77 -39
- transformers/models/encoder_decoder/configuration_encoder_decoder.py +12 -1
- transformers/models/encoder_decoder/modeling_encoder_decoder.py +20 -24
- transformers/models/eomt/configuration_eomt.py +12 -13
- transformers/models/eomt/image_processing_eomt_fast.py +3 -3
- transformers/models/eomt/modeling_eomt.py +3 -3
- transformers/models/eomt/modular_eomt.py +17 -17
- transformers/models/eomt_dinov3/__init__.py +28 -0
- transformers/models/eomt_dinov3/configuration_eomt_dinov3.py +204 -0
- transformers/models/eomt_dinov3/modeling_eomt_dinov3.py +1376 -0
- transformers/models/eomt_dinov3/modular_eomt_dinov3.py +454 -0
- transformers/models/ernie/configuration_ernie.py +24 -2
- transformers/models/ernie/modeling_ernie.py +6 -30
- transformers/models/ernie4_5/configuration_ernie4_5.py +5 -7
- transformers/models/ernie4_5/modeling_ernie4_5.py +4 -4
- transformers/models/ernie4_5_moe/configuration_ernie4_5_moe.py +7 -10
- transformers/models/ernie4_5_moe/modeling_ernie4_5_moe.py +4 -4
- transformers/models/ernie4_5_vl_moe/configuration_ernie4_5_vl_moe.py +17 -6
- transformers/models/ernie4_5_vl_moe/modeling_ernie4_5_vl_moe.py +229 -188
- transformers/models/ernie4_5_vl_moe/modular_ernie4_5_vl_moe.py +79 -55
- transformers/models/esm/configuration_esm.py +9 -11
- transformers/models/esm/modeling_esm.py +3 -3
- transformers/models/esm/modeling_esmfold.py +1 -6
- transformers/models/esm/openfold_utils/protein.py +2 -3
- transformers/models/evolla/configuration_evolla.py +21 -8
- transformers/models/evolla/modeling_evolla.py +11 -7
- transformers/models/evolla/modular_evolla.py +5 -1
- transformers/models/exaone4/configuration_exaone4.py +8 -5
- transformers/models/exaone4/modeling_exaone4.py +4 -4
- transformers/models/exaone4/modular_exaone4.py +11 -8
- transformers/models/exaone_moe/__init__.py +27 -0
- transformers/models/exaone_moe/configuration_exaone_moe.py +235 -0
- transformers/models/exaone_moe/modeling_exaone_moe.py +665 -0
- transformers/models/exaone_moe/modular_exaone_moe.py +373 -0
- transformers/models/falcon/configuration_falcon.py +9 -1
- transformers/models/falcon/modeling_falcon.py +3 -8
- transformers/models/falcon_h1/configuration_falcon_h1.py +17 -8
- transformers/models/falcon_h1/modeling_falcon_h1.py +22 -54
- transformers/models/falcon_h1/modular_falcon_h1.py +21 -52
- transformers/models/falcon_mamba/configuration_falcon_mamba.py +5 -1
- transformers/models/falcon_mamba/modeling_falcon_mamba.py +18 -26
- transformers/models/falcon_mamba/modular_falcon_mamba.py +4 -0
- transformers/models/fast_vlm/configuration_fast_vlm.py +10 -1
- transformers/models/fast_vlm/modeling_fast_vlm.py +37 -64
- transformers/models/fast_vlm/modular_fast_vlm.py +146 -35
- transformers/models/fastspeech2_conformer/configuration_fastspeech2_conformer.py +0 -1
- transformers/models/flaubert/configuration_flaubert.py +10 -4
- transformers/models/flaubert/modeling_flaubert.py +1 -1
- transformers/models/flava/configuration_flava.py +4 -3
- transformers/models/flava/image_processing_flava_fast.py +4 -4
- transformers/models/flava/modeling_flava.py +36 -28
- transformers/models/flex_olmo/configuration_flex_olmo.py +11 -14
- transformers/models/flex_olmo/modeling_flex_olmo.py +4 -4
- transformers/models/flex_olmo/modular_flex_olmo.py +11 -14
- transformers/models/florence2/configuration_florence2.py +4 -0
- transformers/models/florence2/modeling_florence2.py +57 -32
- transformers/models/florence2/modular_florence2.py +48 -26
- transformers/models/fnet/configuration_fnet.py +6 -1
- transformers/models/focalnet/configuration_focalnet.py +2 -4
- transformers/models/focalnet/modeling_focalnet.py +10 -7
- transformers/models/fsmt/configuration_fsmt.py +12 -16
- transformers/models/funnel/configuration_funnel.py +8 -0
- transformers/models/fuyu/configuration_fuyu.py +5 -8
- transformers/models/fuyu/image_processing_fuyu_fast.py +5 -4
- transformers/models/fuyu/modeling_fuyu.py +24 -23
- transformers/models/gemma/configuration_gemma.py +5 -7
- transformers/models/gemma/modeling_gemma.py +4 -4
- transformers/models/gemma/modular_gemma.py +5 -7
- transformers/models/gemma2/configuration_gemma2.py +5 -7
- transformers/models/gemma2/modeling_gemma2.py +4 -4
- transformers/models/gemma2/modular_gemma2.py +8 -10
- transformers/models/gemma3/configuration_gemma3.py +28 -22
- transformers/models/gemma3/image_processing_gemma3_fast.py +2 -2
- transformers/models/gemma3/modeling_gemma3.py +37 -33
- transformers/models/gemma3/modular_gemma3.py +46 -42
- transformers/models/gemma3n/configuration_gemma3n.py +35 -22
- transformers/models/gemma3n/modeling_gemma3n.py +86 -58
- transformers/models/gemma3n/modular_gemma3n.py +112 -75
- transformers/models/git/configuration_git.py +5 -7
- transformers/models/git/modeling_git.py +31 -41
- transformers/models/glm/configuration_glm.py +7 -9
- transformers/models/glm/modeling_glm.py +4 -4
- transformers/models/glm4/configuration_glm4.py +7 -9
- transformers/models/glm4/modeling_glm4.py +4 -4
- transformers/models/glm46v/configuration_glm46v.py +4 -0
- transformers/models/glm46v/image_processing_glm46v.py +5 -2
- transformers/models/glm46v/image_processing_glm46v_fast.py +2 -2
- transformers/models/glm46v/modeling_glm46v.py +91 -46
- transformers/models/glm46v/modular_glm46v.py +4 -0
- transformers/models/glm4_moe/configuration_glm4_moe.py +17 -7
- transformers/models/glm4_moe/modeling_glm4_moe.py +4 -4
- transformers/models/glm4_moe/modular_glm4_moe.py +17 -7
- transformers/models/glm4_moe_lite/configuration_glm4_moe_lite.py +8 -10
- transformers/models/glm4_moe_lite/modeling_glm4_moe_lite.py +7 -7
- transformers/models/glm4_moe_lite/modular_glm4_moe_lite.py +8 -10
- transformers/models/glm4v/configuration_glm4v.py +12 -8
- transformers/models/glm4v/image_processing_glm4v.py +5 -2
- transformers/models/glm4v/image_processing_glm4v_fast.py +2 -2
- transformers/models/glm4v/modeling_glm4v.py +120 -63
- transformers/models/glm4v/modular_glm4v.py +82 -50
- transformers/models/glm4v_moe/configuration_glm4v_moe.py +18 -6
- transformers/models/glm4v_moe/modeling_glm4v_moe.py +115 -63
- transformers/models/glm4v_moe/modular_glm4v_moe.py +23 -12
- transformers/models/glm_image/configuration_glm_image.py +26 -20
- transformers/models/glm_image/image_processing_glm_image.py +1 -1
- transformers/models/glm_image/image_processing_glm_image_fast.py +5 -7
- transformers/models/glm_image/modeling_glm_image.py +337 -236
- transformers/models/glm_image/modular_glm_image.py +415 -255
- transformers/models/glm_image/processing_glm_image.py +65 -17
- transformers/{pipelines/deprecated → models/glm_ocr}/__init__.py +15 -2
- transformers/models/glm_ocr/configuration_glm_ocr.py +312 -0
- transformers/models/glm_ocr/modeling_glm_ocr.py +1633 -0
- transformers/models/glm_ocr/modular_glm_ocr.py +428 -0
- transformers/models/glmasr/modeling_glmasr.py +34 -28
- transformers/models/glmasr/modular_glmasr.py +23 -11
- transformers/models/glpn/image_processing_glpn_fast.py +3 -3
- transformers/models/glpn/modeling_glpn.py +4 -2
- transformers/models/got_ocr2/configuration_got_ocr2.py +6 -6
- transformers/models/got_ocr2/image_processing_got_ocr2_fast.py +3 -3
- transformers/models/got_ocr2/modeling_got_ocr2.py +31 -37
- transformers/models/got_ocr2/modular_got_ocr2.py +30 -19
- transformers/models/gpt2/configuration_gpt2.py +13 -1
- transformers/models/gpt2/modeling_gpt2.py +5 -5
- transformers/models/gpt_bigcode/configuration_gpt_bigcode.py +7 -1
- transformers/models/gpt_bigcode/modeling_gpt_bigcode.py +5 -4
- transformers/models/gpt_neo/configuration_gpt_neo.py +9 -1
- transformers/models/gpt_neo/modeling_gpt_neo.py +3 -7
- transformers/models/gpt_neox/configuration_gpt_neox.py +8 -3
- transformers/models/gpt_neox/modeling_gpt_neox.py +4 -4
- transformers/models/gpt_neox/modular_gpt_neox.py +4 -4
- transformers/models/gpt_neox_japanese/configuration_gpt_neox_japanese.py +9 -1
- transformers/models/gpt_neox_japanese/modeling_gpt_neox_japanese.py +2 -2
- transformers/models/gpt_oss/configuration_gpt_oss.py +10 -6
- transformers/models/gpt_oss/modeling_gpt_oss.py +46 -79
- transformers/models/gpt_oss/modular_gpt_oss.py +45 -78
- transformers/models/gptj/configuration_gptj.py +4 -4
- transformers/models/gptj/modeling_gptj.py +3 -7
- transformers/models/granite/configuration_granite.py +5 -7
- transformers/models/granite/modeling_granite.py +4 -4
- transformers/models/granite_speech/modeling_granite_speech.py +63 -37
- transformers/models/granitemoe/configuration_granitemoe.py +5 -7
- transformers/models/granitemoe/modeling_granitemoe.py +4 -4
- transformers/models/granitemoehybrid/configuration_granitemoehybrid.py +17 -7
- transformers/models/granitemoehybrid/modeling_granitemoehybrid.py +22 -54
- transformers/models/granitemoehybrid/modular_granitemoehybrid.py +39 -45
- transformers/models/granitemoeshared/configuration_granitemoeshared.py +6 -7
- transformers/models/granitemoeshared/modeling_granitemoeshared.py +4 -4
- transformers/models/grounding_dino/configuration_grounding_dino.py +10 -45
- transformers/models/grounding_dino/image_processing_grounding_dino_fast.py +11 -11
- transformers/models/grounding_dino/modeling_grounding_dino.py +68 -86
- transformers/models/groupvit/configuration_groupvit.py +4 -1
- transformers/models/groupvit/modeling_groupvit.py +29 -22
- transformers/models/helium/configuration_helium.py +5 -7
- transformers/models/helium/modeling_helium.py +4 -4
- transformers/models/hgnet_v2/configuration_hgnet_v2.py +2 -4
- transformers/models/hgnet_v2/modeling_hgnet_v2.py +6 -5
- transformers/models/hgnet_v2/modular_hgnet_v2.py +7 -8
- transformers/models/hiera/configuration_hiera.py +2 -4
- transformers/models/hiera/modeling_hiera.py +11 -8
- transformers/models/hubert/configuration_hubert.py +4 -1
- transformers/models/hubert/modeling_hubert.py +7 -4
- transformers/models/hunyuan_v1_dense/configuration_hunyuan_v1_dense.py +5 -7
- transformers/models/hunyuan_v1_dense/modeling_hunyuan_v1_dense.py +28 -4
- transformers/models/hunyuan_v1_dense/modular_hunyuan_v1_dense.py +28 -6
- transformers/models/hunyuan_v1_moe/configuration_hunyuan_v1_moe.py +6 -8
- transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +22 -9
- transformers/models/hunyuan_v1_moe/modular_hunyuan_v1_moe.py +22 -8
- transformers/models/ibert/configuration_ibert.py +4 -1
- transformers/models/idefics/configuration_idefics.py +5 -7
- transformers/models/idefics/modeling_idefics.py +3 -4
- transformers/models/idefics/vision.py +5 -4
- transformers/models/idefics2/configuration_idefics2.py +1 -2
- transformers/models/idefics2/image_processing_idefics2_fast.py +1 -0
- transformers/models/idefics2/modeling_idefics2.py +72 -50
- transformers/models/idefics3/configuration_idefics3.py +1 -3
- transformers/models/idefics3/image_processing_idefics3_fast.py +29 -3
- transformers/models/idefics3/modeling_idefics3.py +63 -40
- transformers/models/ijepa/modeling_ijepa.py +3 -3
- transformers/models/imagegpt/configuration_imagegpt.py +9 -1
- transformers/models/imagegpt/image_processing_imagegpt_fast.py +2 -2
- transformers/models/imagegpt/modeling_imagegpt.py +8 -4
- transformers/models/informer/modeling_informer.py +3 -3
- transformers/models/instructblip/configuration_instructblip.py +2 -1
- transformers/models/instructblip/modeling_instructblip.py +65 -39
- transformers/models/instructblipvideo/configuration_instructblipvideo.py +2 -1
- transformers/models/instructblipvideo/modeling_instructblipvideo.py +60 -57
- transformers/models/instructblipvideo/modular_instructblipvideo.py +43 -32
- transformers/models/instructblipvideo/video_processing_instructblipvideo.py +2 -2
- transformers/models/internvl/configuration_internvl.py +5 -0
- transformers/models/internvl/modeling_internvl.py +35 -55
- transformers/models/internvl/modular_internvl.py +26 -38
- transformers/models/internvl/video_processing_internvl.py +2 -2
- transformers/models/jais2/configuration_jais2.py +5 -7
- transformers/models/jais2/modeling_jais2.py +4 -4
- transformers/models/jamba/configuration_jamba.py +5 -7
- transformers/models/jamba/modeling_jamba.py +4 -4
- transformers/models/jamba/modular_jamba.py +3 -3
- transformers/models/janus/image_processing_janus.py +2 -2
- transformers/models/janus/image_processing_janus_fast.py +8 -8
- transformers/models/janus/modeling_janus.py +63 -146
- transformers/models/janus/modular_janus.py +62 -20
- transformers/models/jetmoe/configuration_jetmoe.py +6 -4
- transformers/models/jetmoe/modeling_jetmoe.py +3 -3
- transformers/models/jetmoe/modular_jetmoe.py +3 -3
- transformers/models/kosmos2/configuration_kosmos2.py +10 -8
- transformers/models/kosmos2/modeling_kosmos2.py +56 -34
- transformers/models/kosmos2_5/configuration_kosmos2_5.py +8 -8
- transformers/models/kosmos2_5/modeling_kosmos2_5.py +54 -63
- transformers/models/kyutai_speech_to_text/configuration_kyutai_speech_to_text.py +8 -3
- transformers/models/kyutai_speech_to_text/modeling_kyutai_speech_to_text.py +44 -40
- transformers/models/kyutai_speech_to_text/modular_kyutai_speech_to_text.py +1 -1
- transformers/models/lasr/configuration_lasr.py +2 -4
- transformers/models/lasr/modeling_lasr.py +3 -3
- transformers/models/lasr/modular_lasr.py +3 -3
- transformers/models/layoutlm/configuration_layoutlm.py +14 -1
- transformers/models/layoutlm/modeling_layoutlm.py +3 -3
- transformers/models/layoutlmv2/configuration_layoutlmv2.py +14 -16
- transformers/models/layoutlmv2/image_processing_layoutlmv2_fast.py +2 -2
- transformers/models/layoutlmv3/configuration_layoutlmv3.py +16 -18
- transformers/models/layoutlmv3/image_processing_layoutlmv3_fast.py +2 -2
- transformers/models/layoutxlm/configuration_layoutxlm.py +14 -16
- transformers/models/led/configuration_led.py +7 -8
- transformers/models/levit/image_processing_levit_fast.py +4 -4
- transformers/models/lfm2/configuration_lfm2.py +5 -7
- transformers/models/lfm2/modeling_lfm2.py +4 -4
- transformers/models/lfm2/modular_lfm2.py +3 -3
- transformers/models/lfm2_moe/configuration_lfm2_moe.py +5 -7
- transformers/models/lfm2_moe/modeling_lfm2_moe.py +4 -4
- transformers/models/lfm2_vl/configuration_lfm2_vl.py +4 -0
- transformers/models/lfm2_vl/image_processing_lfm2_vl_fast.py +9 -15
- transformers/models/lfm2_vl/modeling_lfm2_vl.py +42 -28
- transformers/models/lfm2_vl/modular_lfm2_vl.py +42 -27
- transformers/models/lightglue/image_processing_lightglue_fast.py +4 -3
- transformers/models/lightglue/modeling_lightglue.py +3 -3
- transformers/models/lightglue/modular_lightglue.py +3 -3
- transformers/models/lighton_ocr/modeling_lighton_ocr.py +31 -28
- transformers/models/lighton_ocr/modular_lighton_ocr.py +19 -18
- transformers/models/lilt/configuration_lilt.py +6 -1
- transformers/models/llama/configuration_llama.py +5 -7
- transformers/models/llama/modeling_llama.py +4 -4
- transformers/models/llama4/configuration_llama4.py +67 -47
- transformers/models/llama4/image_processing_llama4_fast.py +3 -3
- transformers/models/llama4/modeling_llama4.py +46 -44
- transformers/models/llava/configuration_llava.py +10 -0
- transformers/models/llava/image_processing_llava_fast.py +3 -3
- transformers/models/llava/modeling_llava.py +38 -65
- transformers/models/llava_next/configuration_llava_next.py +2 -1
- transformers/models/llava_next/image_processing_llava_next_fast.py +6 -6
- transformers/models/llava_next/modeling_llava_next.py +61 -60
- transformers/models/llava_next_video/configuration_llava_next_video.py +10 -6
- transformers/models/llava_next_video/modeling_llava_next_video.py +115 -100
- transformers/models/llava_next_video/modular_llava_next_video.py +110 -101
- transformers/models/llava_onevision/configuration_llava_onevision.py +10 -6
- transformers/models/llava_onevision/image_processing_llava_onevision_fast.py +8 -7
- transformers/models/llava_onevision/modeling_llava_onevision.py +111 -105
- transformers/models/llava_onevision/modular_llava_onevision.py +106 -101
- transformers/models/longcat_flash/configuration_longcat_flash.py +7 -10
- transformers/models/longcat_flash/modeling_longcat_flash.py +7 -7
- transformers/models/longcat_flash/modular_longcat_flash.py +6 -5
- transformers/models/longformer/configuration_longformer.py +4 -1
- transformers/models/longt5/configuration_longt5.py +9 -6
- transformers/models/longt5/modeling_longt5.py +2 -1
- transformers/models/luke/configuration_luke.py +8 -1
- transformers/models/lw_detr/configuration_lw_detr.py +19 -31
- transformers/models/lw_detr/modeling_lw_detr.py +43 -44
- transformers/models/lw_detr/modular_lw_detr.py +36 -38
- transformers/models/lxmert/configuration_lxmert.py +16 -0
- transformers/models/m2m_100/configuration_m2m_100.py +7 -8
- transformers/models/m2m_100/modeling_m2m_100.py +3 -3
- transformers/models/mamba/configuration_mamba.py +5 -2
- transformers/models/mamba/modeling_mamba.py +18 -26
- transformers/models/mamba2/configuration_mamba2.py +5 -7
- transformers/models/mamba2/modeling_mamba2.py +22 -33
- transformers/models/marian/configuration_marian.py +10 -4
- transformers/models/marian/modeling_marian.py +4 -4
- transformers/models/markuplm/configuration_markuplm.py +4 -6
- transformers/models/markuplm/modeling_markuplm.py +3 -3
- transformers/models/mask2former/configuration_mask2former.py +12 -47
- transformers/models/mask2former/image_processing_mask2former_fast.py +8 -8
- transformers/models/mask2former/modeling_mask2former.py +18 -12
- transformers/models/maskformer/configuration_maskformer.py +14 -45
- transformers/models/maskformer/configuration_maskformer_swin.py +2 -4
- transformers/models/maskformer/image_processing_maskformer_fast.py +8 -8
- transformers/models/maskformer/modeling_maskformer.py +15 -9
- transformers/models/maskformer/modeling_maskformer_swin.py +2 -3
- transformers/models/mbart/configuration_mbart.py +9 -4
- transformers/models/mbart/modeling_mbart.py +9 -6
- transformers/models/megatron_bert/configuration_megatron_bert.py +13 -2
- transformers/models/megatron_bert/modeling_megatron_bert.py +0 -15
- transformers/models/metaclip_2/configuration_metaclip_2.py +4 -1
- transformers/models/metaclip_2/modeling_metaclip_2.py +49 -42
- transformers/models/metaclip_2/modular_metaclip_2.py +41 -25
- transformers/models/mgp_str/modeling_mgp_str.py +4 -2
- transformers/models/mimi/configuration_mimi.py +4 -0
- transformers/models/mimi/modeling_mimi.py +40 -36
- transformers/models/minimax/configuration_minimax.py +8 -11
- transformers/models/minimax/modeling_minimax.py +5 -5
- transformers/models/minimax/modular_minimax.py +9 -12
- transformers/models/minimax_m2/configuration_minimax_m2.py +8 -31
- transformers/models/minimax_m2/modeling_minimax_m2.py +4 -4
- transformers/models/minimax_m2/modular_minimax_m2.py +8 -31
- transformers/models/ministral/configuration_ministral.py +5 -7
- transformers/models/ministral/modeling_ministral.py +4 -4
- transformers/models/ministral/modular_ministral.py +5 -8
- transformers/models/ministral3/configuration_ministral3.py +4 -4
- transformers/models/ministral3/modeling_ministral3.py +4 -4
- transformers/models/ministral3/modular_ministral3.py +3 -3
- transformers/models/mistral/configuration_mistral.py +5 -7
- transformers/models/mistral/modeling_mistral.py +4 -4
- transformers/models/mistral/modular_mistral.py +3 -3
- transformers/models/mistral3/configuration_mistral3.py +4 -0
- transformers/models/mistral3/modeling_mistral3.py +36 -40
- transformers/models/mistral3/modular_mistral3.py +31 -32
- transformers/models/mixtral/configuration_mixtral.py +8 -11
- transformers/models/mixtral/modeling_mixtral.py +4 -4
- transformers/models/mlcd/modeling_mlcd.py +7 -5
- transformers/models/mlcd/modular_mlcd.py +7 -5
- transformers/models/mllama/configuration_mllama.py +5 -7
- transformers/models/mllama/image_processing_mllama_fast.py +6 -5
- transformers/models/mllama/modeling_mllama.py +19 -19
- transformers/models/mm_grounding_dino/configuration_mm_grounding_dino.py +10 -45
- transformers/models/mm_grounding_dino/modeling_mm_grounding_dino.py +66 -84
- transformers/models/mm_grounding_dino/modular_mm_grounding_dino.py +10 -45
- transformers/models/mobilebert/configuration_mobilebert.py +4 -1
- transformers/models/mobilebert/modeling_mobilebert.py +3 -3
- transformers/models/mobilenet_v2/image_processing_mobilenet_v2_fast.py +4 -4
- transformers/models/mobilenet_v2/modeling_mobilenet_v2.py +4 -2
- transformers/models/mobilevit/image_processing_mobilevit_fast.py +4 -4
- transformers/models/mobilevit/modeling_mobilevit.py +4 -2
- transformers/models/mobilevitv2/modeling_mobilevitv2.py +4 -2
- transformers/models/modernbert/configuration_modernbert.py +46 -21
- transformers/models/modernbert/modeling_modernbert.py +146 -899
- transformers/models/modernbert/modular_modernbert.py +185 -908
- transformers/models/modernbert_decoder/configuration_modernbert_decoder.py +21 -13
- transformers/models/modernbert_decoder/modeling_modernbert_decoder.py +9 -17
- transformers/models/modernbert_decoder/modular_modernbert_decoder.py +24 -23
- transformers/models/moonshine/configuration_moonshine.py +12 -7
- transformers/models/moonshine/modeling_moonshine.py +7 -7
- transformers/models/moonshine/modular_moonshine.py +19 -13
- transformers/models/moshi/configuration_moshi.py +28 -2
- transformers/models/moshi/modeling_moshi.py +4 -9
- transformers/models/mpnet/configuration_mpnet.py +6 -1
- transformers/models/mpt/configuration_mpt.py +16 -0
- transformers/models/mra/configuration_mra.py +8 -1
- transformers/models/mt5/configuration_mt5.py +9 -5
- transformers/models/mt5/modeling_mt5.py +5 -8
- transformers/models/musicgen/configuration_musicgen.py +12 -7
- transformers/models/musicgen/modeling_musicgen.py +6 -5
- transformers/models/musicgen_melody/configuration_musicgen_melody.py +15 -7
- transformers/models/musicgen_melody/modeling_musicgen_melody.py +7 -17
- transformers/models/mvp/configuration_mvp.py +8 -4
- transformers/models/mvp/modeling_mvp.py +6 -4
- transformers/models/nanochat/configuration_nanochat.py +5 -7
- transformers/models/nanochat/modeling_nanochat.py +4 -4
- transformers/models/nanochat/modular_nanochat.py +4 -4
- transformers/models/nemotron/configuration_nemotron.py +5 -7
- transformers/models/nemotron/modeling_nemotron.py +4 -14
- transformers/models/nllb/tokenization_nllb.py +7 -5
- transformers/models/nllb_moe/configuration_nllb_moe.py +7 -9
- transformers/models/nllb_moe/modeling_nllb_moe.py +3 -3
- transformers/models/nougat/image_processing_nougat_fast.py +8 -8
- transformers/models/nystromformer/configuration_nystromformer.py +8 -1
- transformers/models/olmo/configuration_olmo.py +5 -7
- transformers/models/olmo/modeling_olmo.py +4 -4
- transformers/models/olmo/modular_olmo.py +3 -3
- transformers/models/olmo2/configuration_olmo2.py +9 -11
- transformers/models/olmo2/modeling_olmo2.py +4 -4
- transformers/models/olmo2/modular_olmo2.py +7 -7
- transformers/models/olmo3/configuration_olmo3.py +10 -11
- transformers/models/olmo3/modeling_olmo3.py +4 -4
- transformers/models/olmo3/modular_olmo3.py +13 -14
- transformers/models/olmoe/configuration_olmoe.py +5 -7
- transformers/models/olmoe/modeling_olmoe.py +4 -4
- transformers/models/olmoe/modular_olmoe.py +3 -3
- transformers/models/omdet_turbo/configuration_omdet_turbo.py +14 -49
- transformers/models/omdet_turbo/modeling_omdet_turbo.py +22 -18
- transformers/models/oneformer/configuration_oneformer.py +9 -46
- transformers/models/oneformer/image_processing_oneformer_fast.py +8 -8
- transformers/models/oneformer/modeling_oneformer.py +14 -9
- transformers/models/openai/configuration_openai.py +16 -0
- transformers/models/opt/configuration_opt.py +6 -6
- transformers/models/opt/modeling_opt.py +5 -5
- transformers/models/ovis2/configuration_ovis2.py +4 -0
- transformers/models/ovis2/image_processing_ovis2_fast.py +3 -3
- transformers/models/ovis2/modeling_ovis2.py +58 -99
- transformers/models/ovis2/modular_ovis2.py +52 -13
- transformers/models/owlv2/configuration_owlv2.py +4 -1
- transformers/models/owlv2/image_processing_owlv2_fast.py +5 -5
- transformers/models/owlv2/modeling_owlv2.py +40 -27
- transformers/models/owlv2/modular_owlv2.py +5 -5
- transformers/models/owlvit/configuration_owlvit.py +4 -1
- transformers/models/owlvit/modeling_owlvit.py +40 -27
- transformers/models/paddleocr_vl/configuration_paddleocr_vl.py +9 -10
- transformers/models/paddleocr_vl/modeling_paddleocr_vl.py +88 -87
- transformers/models/paddleocr_vl/modular_paddleocr_vl.py +82 -53
- transformers/models/paligemma/configuration_paligemma.py +4 -0
- transformers/models/paligemma/modeling_paligemma.py +30 -26
- transformers/models/parakeet/configuration_parakeet.py +2 -4
- transformers/models/parakeet/modeling_parakeet.py +3 -3
- transformers/models/parakeet/modular_parakeet.py +3 -3
- transformers/models/patchtsmixer/modeling_patchtsmixer.py +3 -3
- transformers/models/patchtst/modeling_patchtst.py +3 -3
- transformers/models/pe_audio/modeling_pe_audio.py +4 -4
- transformers/models/pe_audio/modular_pe_audio.py +1 -1
- transformers/models/pe_audio_video/modeling_pe_audio_video.py +4 -4
- transformers/models/pe_audio_video/modular_pe_audio_video.py +4 -4
- transformers/models/pe_video/modeling_pe_video.py +36 -24
- transformers/models/pe_video/modular_pe_video.py +36 -23
- transformers/models/pegasus/configuration_pegasus.py +8 -5
- transformers/models/pegasus/modeling_pegasus.py +4 -4
- transformers/models/pegasus_x/configuration_pegasus_x.py +5 -3
- transformers/models/pegasus_x/modeling_pegasus_x.py +3 -3
- transformers/models/perceiver/image_processing_perceiver_fast.py +2 -2
- transformers/models/perceiver/modeling_perceiver.py +17 -9
- transformers/models/perception_lm/modeling_perception_lm.py +26 -27
- transformers/models/perception_lm/modular_perception_lm.py +27 -25
- transformers/models/persimmon/configuration_persimmon.py +5 -7
- transformers/models/persimmon/modeling_persimmon.py +5 -5
- transformers/models/phi/configuration_phi.py +8 -6
- transformers/models/phi/modeling_phi.py +4 -4
- transformers/models/phi/modular_phi.py +3 -3
- transformers/models/phi3/configuration_phi3.py +9 -11
- transformers/models/phi3/modeling_phi3.py +4 -4
- transformers/models/phi3/modular_phi3.py +3 -3
- transformers/models/phi4_multimodal/configuration_phi4_multimodal.py +11 -13
- transformers/models/phi4_multimodal/image_processing_phi4_multimodal_fast.py +4 -4
- transformers/models/phi4_multimodal/modeling_phi4_multimodal.py +46 -61
- transformers/models/phi4_multimodal/modular_phi4_multimodal.py +44 -30
- transformers/models/phimoe/configuration_phimoe.py +5 -7
- transformers/models/phimoe/modeling_phimoe.py +15 -39
- transformers/models/phimoe/modular_phimoe.py +12 -7
- transformers/models/pix2struct/configuration_pix2struct.py +12 -9
- transformers/models/pix2struct/image_processing_pix2struct_fast.py +5 -5
- transformers/models/pix2struct/modeling_pix2struct.py +14 -7
- transformers/models/pixio/configuration_pixio.py +2 -4
- transformers/models/pixio/modeling_pixio.py +9 -8
- transformers/models/pixio/modular_pixio.py +4 -2
- transformers/models/pixtral/image_processing_pixtral_fast.py +5 -5
- transformers/models/pixtral/modeling_pixtral.py +9 -12
- transformers/models/plbart/configuration_plbart.py +8 -5
- transformers/models/plbart/modeling_plbart.py +9 -7
- transformers/models/plbart/modular_plbart.py +1 -1
- transformers/models/poolformer/image_processing_poolformer_fast.py +7 -7
- transformers/models/pop2piano/configuration_pop2piano.py +7 -6
- transformers/models/pop2piano/modeling_pop2piano.py +2 -1
- transformers/models/pp_doclayout_v3/__init__.py +30 -0
- transformers/models/pp_doclayout_v3/configuration_pp_doclayout_v3.py +277 -0
- transformers/models/pp_doclayout_v3/image_processing_pp_doclayout_v3_fast.py +305 -0
- transformers/models/pp_doclayout_v3/modeling_pp_doclayout_v3.py +2083 -0
- transformers/models/pp_doclayout_v3/modular_pp_doclayout_v3.py +1549 -0
- transformers/models/prompt_depth_anything/configuration_prompt_depth_anything.py +12 -46
- transformers/models/prompt_depth_anything/image_processing_prompt_depth_anything_fast.py +6 -6
- transformers/models/prompt_depth_anything/modeling_prompt_depth_anything.py +8 -6
- transformers/models/prompt_depth_anything/modular_prompt_depth_anything.py +12 -10
- transformers/models/prophetnet/configuration_prophetnet.py +11 -10
- transformers/models/prophetnet/modeling_prophetnet.py +12 -23
- transformers/models/pvt/image_processing_pvt.py +7 -7
- transformers/models/pvt/image_processing_pvt_fast.py +1 -1
- transformers/models/pvt_v2/configuration_pvt_v2.py +2 -4
- transformers/models/pvt_v2/modeling_pvt_v2.py +6 -5
- transformers/models/qwen2/configuration_qwen2.py +14 -4
- transformers/models/qwen2/modeling_qwen2.py +4 -4
- transformers/models/qwen2/modular_qwen2.py +3 -3
- transformers/models/qwen2/tokenization_qwen2.py +0 -4
- transformers/models/qwen2_5_omni/configuration_qwen2_5_omni.py +17 -5
- transformers/models/qwen2_5_omni/modeling_qwen2_5_omni.py +108 -88
- transformers/models/qwen2_5_omni/modular_qwen2_5_omni.py +115 -87
- transformers/models/qwen2_5_vl/configuration_qwen2_5_vl.py +7 -10
- transformers/models/qwen2_5_vl/modeling_qwen2_5_vl.py +98 -53
- transformers/models/qwen2_5_vl/modular_qwen2_5_vl.py +18 -6
- transformers/models/qwen2_audio/modeling_qwen2_audio.py +12 -12
- transformers/models/qwen2_moe/configuration_qwen2_moe.py +14 -4
- transformers/models/qwen2_moe/modeling_qwen2_moe.py +4 -4
- transformers/models/qwen2_moe/modular_qwen2_moe.py +3 -3
- transformers/models/qwen2_vl/configuration_qwen2_vl.py +7 -10
- transformers/models/qwen2_vl/image_processing_qwen2_vl_fast.py +4 -6
- transformers/models/qwen2_vl/modeling_qwen2_vl.py +97 -53
- transformers/models/qwen2_vl/video_processing_qwen2_vl.py +4 -6
- transformers/models/qwen3/configuration_qwen3.py +15 -5
- transformers/models/qwen3/modeling_qwen3.py +4 -4
- transformers/models/qwen3/modular_qwen3.py +3 -3
- transformers/models/qwen3_moe/configuration_qwen3_moe.py +20 -7
- transformers/models/qwen3_moe/modeling_qwen3_moe.py +4 -4
- transformers/models/qwen3_next/configuration_qwen3_next.py +16 -4
- transformers/models/qwen3_next/modeling_qwen3_next.py +5 -5
- transformers/models/qwen3_next/modular_qwen3_next.py +4 -4
- transformers/models/qwen3_omni_moe/configuration_qwen3_omni_moe.py +55 -19
- transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +161 -98
- transformers/models/qwen3_omni_moe/modular_qwen3_omni_moe.py +107 -34
- transformers/models/qwen3_vl/configuration_qwen3_vl.py +7 -6
- transformers/models/qwen3_vl/modeling_qwen3_vl.py +115 -49
- transformers/models/qwen3_vl/modular_qwen3_vl.py +88 -37
- transformers/models/qwen3_vl_moe/configuration_qwen3_vl_moe.py +7 -6
- transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +173 -99
- transformers/models/qwen3_vl_moe/modular_qwen3_vl_moe.py +23 -7
- transformers/models/rag/configuration_rag.py +6 -6
- transformers/models/rag/modeling_rag.py +3 -3
- transformers/models/rag/retrieval_rag.py +1 -1
- transformers/models/recurrent_gemma/configuration_recurrent_gemma.py +8 -6
- transformers/models/recurrent_gemma/modeling_recurrent_gemma.py +4 -5
- transformers/models/reformer/configuration_reformer.py +7 -7
- transformers/models/rembert/configuration_rembert.py +8 -1
- transformers/models/rembert/modeling_rembert.py +0 -22
- transformers/models/resnet/configuration_resnet.py +2 -4
- transformers/models/resnet/modeling_resnet.py +6 -5
- transformers/models/roberta/configuration_roberta.py +11 -2
- transformers/models/roberta/modeling_roberta.py +6 -6
- transformers/models/roberta_prelayernorm/configuration_roberta_prelayernorm.py +11 -2
- transformers/models/roberta_prelayernorm/modeling_roberta_prelayernorm.py +6 -6
- transformers/models/roc_bert/configuration_roc_bert.py +8 -1
- transformers/models/roc_bert/modeling_roc_bert.py +6 -41
- transformers/models/roformer/configuration_roformer.py +13 -2
- transformers/models/roformer/modeling_roformer.py +0 -14
- transformers/models/rt_detr/configuration_rt_detr.py +8 -49
- transformers/models/rt_detr/configuration_rt_detr_resnet.py +2 -4
- transformers/models/rt_detr/image_processing_rt_detr_fast.py +24 -11
- transformers/models/rt_detr/modeling_rt_detr.py +578 -737
- transformers/models/rt_detr/modeling_rt_detr_resnet.py +2 -3
- transformers/models/rt_detr/modular_rt_detr.py +1508 -6
- transformers/models/rt_detr_v2/configuration_rt_detr_v2.py +12 -57
- transformers/models/rt_detr_v2/modeling_rt_detr_v2.py +318 -453
- transformers/models/rt_detr_v2/modular_rt_detr_v2.py +25 -66
- transformers/models/rwkv/configuration_rwkv.py +2 -3
- transformers/models/rwkv/modeling_rwkv.py +0 -23
- transformers/models/sam/configuration_sam.py +2 -0
- transformers/models/sam/image_processing_sam_fast.py +4 -4
- transformers/models/sam/modeling_sam.py +13 -8
- transformers/models/sam/processing_sam.py +3 -3
- transformers/models/sam2/configuration_sam2.py +1 -1
- transformers/models/sam2/modeling_sam2.py +56 -52
- transformers/models/sam2/modular_sam2.py +47 -55
- transformers/models/sam2_video/modeling_sam2_video.py +50 -51
- transformers/models/sam2_video/modular_sam2_video.py +12 -10
- transformers/models/sam3/modeling_sam3.py +43 -47
- transformers/models/sam3/processing_sam3.py +8 -4
- transformers/models/sam3_tracker/configuration_sam3_tracker.py +1 -2
- transformers/models/sam3_tracker/modeling_sam3_tracker.py +50 -49
- transformers/models/sam3_tracker/modular_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker/processing_sam3_tracker.py +0 -1
- transformers/models/sam3_tracker_video/modeling_sam3_tracker_video.py +50 -49
- transformers/models/sam3_tracker_video/modular_sam3_tracker_video.py +10 -22
- transformers/models/sam3_video/modeling_sam3_video.py +27 -14
- transformers/models/sam_hq/configuration_sam_hq.py +2 -0
- transformers/models/sam_hq/modeling_sam_hq.py +13 -9
- transformers/models/sam_hq/modular_sam_hq.py +6 -6
- transformers/models/sam_hq/processing_sam_hq.py +7 -6
- transformers/models/seamless_m4t/configuration_seamless_m4t.py +8 -9
- transformers/models/seamless_m4t_v2/configuration_seamless_m4t_v2.py +8 -9
- transformers/models/seed_oss/configuration_seed_oss.py +7 -9
- transformers/models/seed_oss/modeling_seed_oss.py +4 -4
- transformers/models/seed_oss/modular_seed_oss.py +3 -3
- transformers/models/segformer/image_processing_segformer_fast.py +4 -4
- transformers/models/segformer/modeling_segformer.py +4 -2
- transformers/models/segformer/modular_segformer.py +3 -3
- transformers/models/seggpt/modeling_seggpt.py +20 -8
- transformers/models/sew/configuration_sew.py +4 -1
- transformers/models/sew/modeling_sew.py +9 -5
- transformers/models/sew/modular_sew.py +2 -1
- transformers/models/sew_d/configuration_sew_d.py +4 -1
- transformers/models/sew_d/modeling_sew_d.py +4 -1
- transformers/models/shieldgemma2/modeling_shieldgemma2.py +4 -4
- transformers/models/siglip/configuration_siglip.py +4 -1
- transformers/models/siglip/modeling_siglip.py +27 -71
- transformers/models/siglip2/__init__.py +1 -0
- transformers/models/siglip2/configuration_siglip2.py +4 -2
- transformers/models/siglip2/image_processing_siglip2_fast.py +2 -2
- transformers/models/siglip2/modeling_siglip2.py +37 -78
- transformers/models/siglip2/modular_siglip2.py +74 -25
- transformers/models/siglip2/tokenization_siglip2.py +95 -0
- transformers/models/smollm3/configuration_smollm3.py +6 -6
- transformers/models/smollm3/modeling_smollm3.py +4 -4
- transformers/models/smollm3/modular_smollm3.py +9 -9
- transformers/models/smolvlm/configuration_smolvlm.py +1 -3
- transformers/models/smolvlm/image_processing_smolvlm_fast.py +29 -3
- transformers/models/smolvlm/modeling_smolvlm.py +75 -46
- transformers/models/smolvlm/modular_smolvlm.py +36 -23
- transformers/models/smolvlm/video_processing_smolvlm.py +9 -9
- transformers/models/solar_open/__init__.py +27 -0
- transformers/models/solar_open/configuration_solar_open.py +184 -0
- transformers/models/solar_open/modeling_solar_open.py +642 -0
- transformers/models/solar_open/modular_solar_open.py +224 -0
- transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py +6 -4
- transformers/models/speech_to_text/configuration_speech_to_text.py +9 -8
- transformers/models/speech_to_text/modeling_speech_to_text.py +3 -3
- transformers/models/speecht5/configuration_speecht5.py +7 -8
- transformers/models/splinter/configuration_splinter.py +6 -6
- transformers/models/splinter/modeling_splinter.py +8 -3
- transformers/models/squeezebert/configuration_squeezebert.py +14 -1
- transformers/models/stablelm/configuration_stablelm.py +8 -6
- transformers/models/stablelm/modeling_stablelm.py +5 -5
- transformers/models/starcoder2/configuration_starcoder2.py +11 -5
- transformers/models/starcoder2/modeling_starcoder2.py +5 -5
- transformers/models/starcoder2/modular_starcoder2.py +4 -4
- transformers/models/superglue/configuration_superglue.py +4 -0
- transformers/models/superglue/image_processing_superglue_fast.py +4 -3
- transformers/models/superglue/modeling_superglue.py +9 -4
- transformers/models/superpoint/image_processing_superpoint_fast.py +3 -4
- transformers/models/superpoint/modeling_superpoint.py +4 -2
- transformers/models/swin/configuration_swin.py +2 -4
- transformers/models/swin/modeling_swin.py +11 -8
- transformers/models/swin2sr/image_processing_swin2sr_fast.py +2 -2
- transformers/models/swin2sr/modeling_swin2sr.py +4 -2
- transformers/models/swinv2/configuration_swinv2.py +2 -4
- transformers/models/swinv2/modeling_swinv2.py +10 -7
- transformers/models/switch_transformers/configuration_switch_transformers.py +11 -6
- transformers/models/switch_transformers/modeling_switch_transformers.py +3 -3
- transformers/models/switch_transformers/modular_switch_transformers.py +3 -3
- transformers/models/t5/configuration_t5.py +9 -8
- transformers/models/t5/modeling_t5.py +5 -8
- transformers/models/t5gemma/configuration_t5gemma.py +10 -25
- transformers/models/t5gemma/modeling_t5gemma.py +9 -9
- transformers/models/t5gemma/modular_t5gemma.py +11 -24
- transformers/models/t5gemma2/configuration_t5gemma2.py +35 -48
- transformers/models/t5gemma2/modeling_t5gemma2.py +143 -100
- transformers/models/t5gemma2/modular_t5gemma2.py +152 -136
- transformers/models/table_transformer/configuration_table_transformer.py +18 -49
- transformers/models/table_transformer/modeling_table_transformer.py +27 -53
- transformers/models/tapas/configuration_tapas.py +12 -1
- transformers/models/tapas/modeling_tapas.py +1 -1
- transformers/models/tapas/tokenization_tapas.py +1 -0
- transformers/models/textnet/configuration_textnet.py +4 -6
- transformers/models/textnet/image_processing_textnet_fast.py +3 -3
- transformers/models/textnet/modeling_textnet.py +15 -14
- transformers/models/time_series_transformer/modeling_time_series_transformer.py +3 -3
- transformers/models/timesfm/modeling_timesfm.py +5 -6
- transformers/models/timesfm/modular_timesfm.py +5 -6
- transformers/models/timm_backbone/configuration_timm_backbone.py +33 -7
- transformers/models/timm_backbone/modeling_timm_backbone.py +21 -24
- transformers/models/timm_wrapper/modeling_timm_wrapper.py +9 -4
- transformers/models/trocr/configuration_trocr.py +11 -7
- transformers/models/trocr/modeling_trocr.py +4 -2
- transformers/models/tvp/configuration_tvp.py +10 -35
- transformers/models/tvp/image_processing_tvp_fast.py +6 -5
- transformers/models/tvp/modeling_tvp.py +1 -1
- transformers/models/udop/configuration_udop.py +16 -7
- transformers/models/udop/modeling_udop.py +10 -6
- transformers/models/umt5/configuration_umt5.py +8 -6
- transformers/models/umt5/modeling_umt5.py +7 -3
- transformers/models/unispeech/configuration_unispeech.py +4 -1
- transformers/models/unispeech/modeling_unispeech.py +7 -4
- transformers/models/unispeech_sat/configuration_unispeech_sat.py +4 -1
- transformers/models/unispeech_sat/modeling_unispeech_sat.py +7 -4
- transformers/models/upernet/configuration_upernet.py +8 -35
- transformers/models/upernet/modeling_upernet.py +1 -1
- transformers/models/vaultgemma/configuration_vaultgemma.py +5 -7
- transformers/models/vaultgemma/modeling_vaultgemma.py +4 -4
- transformers/models/video_llama_3/configuration_video_llama_3.py +4 -0
- transformers/models/video_llama_3/image_processing_video_llama_3_fast.py +4 -6
- transformers/models/video_llama_3/modeling_video_llama_3.py +85 -48
- transformers/models/video_llama_3/modular_video_llama_3.py +56 -43
- transformers/models/video_llama_3/video_processing_video_llama_3.py +29 -8
- transformers/models/video_llava/configuration_video_llava.py +4 -0
- transformers/models/video_llava/modeling_video_llava.py +87 -89
- transformers/models/videomae/modeling_videomae.py +4 -5
- transformers/models/vilt/configuration_vilt.py +4 -1
- transformers/models/vilt/image_processing_vilt_fast.py +6 -6
- transformers/models/vilt/modeling_vilt.py +27 -12
- transformers/models/vipllava/configuration_vipllava.py +4 -0
- transformers/models/vipllava/modeling_vipllava.py +57 -31
- transformers/models/vipllava/modular_vipllava.py +50 -24
- transformers/models/vision_encoder_decoder/modeling_vision_encoder_decoder.py +10 -6
- transformers/models/vision_text_dual_encoder/modeling_vision_text_dual_encoder.py +27 -20
- transformers/models/visual_bert/configuration_visual_bert.py +6 -1
- transformers/models/vit/configuration_vit.py +2 -2
- transformers/models/vit/modeling_vit.py +7 -5
- transformers/models/vit_mae/modeling_vit_mae.py +11 -7
- transformers/models/vit_msn/modeling_vit_msn.py +11 -7
- transformers/models/vitdet/configuration_vitdet.py +2 -4
- transformers/models/vitdet/modeling_vitdet.py +2 -3
- transformers/models/vitmatte/configuration_vitmatte.py +6 -35
- transformers/models/vitmatte/image_processing_vitmatte_fast.py +2 -2
- transformers/models/vitmatte/modeling_vitmatte.py +1 -1
- transformers/models/vitpose/configuration_vitpose.py +6 -43
- transformers/models/vitpose/modeling_vitpose.py +5 -3
- transformers/models/vitpose_backbone/configuration_vitpose_backbone.py +2 -4
- transformers/models/vitpose_backbone/modeling_vitpose_backbone.py +5 -6
- transformers/models/vits/configuration_vits.py +4 -0
- transformers/models/vits/modeling_vits.py +9 -7
- transformers/models/vivit/modeling_vivit.py +4 -4
- transformers/models/vjepa2/modeling_vjepa2.py +9 -9
- transformers/models/voxtral/configuration_voxtral.py +0 -1
- transformers/models/voxtral/modeling_voxtral.py +25 -24
- transformers/models/voxtral/modular_voxtral.py +26 -20
- transformers/models/wav2vec2/configuration_wav2vec2.py +4 -1
- transformers/models/wav2vec2/modeling_wav2vec2.py +7 -4
- transformers/models/wav2vec2_bert/configuration_wav2vec2_bert.py +4 -1
- transformers/models/wav2vec2_conformer/configuration_wav2vec2_conformer.py +4 -1
- transformers/models/wavlm/configuration_wavlm.py +4 -1
- transformers/models/wavlm/modeling_wavlm.py +4 -1
- transformers/models/whisper/configuration_whisper.py +6 -4
- transformers/models/whisper/generation_whisper.py +0 -1
- transformers/models/whisper/modeling_whisper.py +3 -3
- transformers/models/x_clip/configuration_x_clip.py +4 -1
- transformers/models/x_clip/modeling_x_clip.py +26 -27
- transformers/models/xglm/configuration_xglm.py +9 -7
- transformers/models/xlm/configuration_xlm.py +10 -7
- transformers/models/xlm/modeling_xlm.py +1 -1
- transformers/models/xlm_roberta/configuration_xlm_roberta.py +11 -2
- transformers/models/xlm_roberta/modeling_xlm_roberta.py +6 -6
- transformers/models/xlm_roberta_xl/configuration_xlm_roberta_xl.py +10 -1
- transformers/models/xlm_roberta_xl/modeling_xlm_roberta_xl.py +6 -6
- transformers/models/xlnet/configuration_xlnet.py +3 -1
- transformers/models/xlstm/configuration_xlstm.py +5 -7
- transformers/models/xlstm/modeling_xlstm.py +0 -32
- transformers/models/xmod/configuration_xmod.py +11 -2
- transformers/models/xmod/modeling_xmod.py +13 -16
- transformers/models/yolos/image_processing_yolos_fast.py +25 -28
- transformers/models/yolos/modeling_yolos.py +7 -7
- transformers/models/yolos/modular_yolos.py +16 -16
- transformers/models/yoso/configuration_yoso.py +8 -1
- transformers/models/youtu/__init__.py +27 -0
- transformers/models/youtu/configuration_youtu.py +194 -0
- transformers/models/youtu/modeling_youtu.py +619 -0
- transformers/models/youtu/modular_youtu.py +254 -0
- transformers/models/zamba/configuration_zamba.py +5 -7
- transformers/models/zamba/modeling_zamba.py +25 -56
- transformers/models/zamba2/configuration_zamba2.py +8 -13
- transformers/models/zamba2/modeling_zamba2.py +53 -78
- transformers/models/zamba2/modular_zamba2.py +36 -29
- transformers/models/zoedepth/configuration_zoedepth.py +17 -40
- transformers/models/zoedepth/image_processing_zoedepth_fast.py +9 -9
- transformers/models/zoedepth/modeling_zoedepth.py +5 -3
- transformers/pipelines/__init__.py +1 -61
- transformers/pipelines/any_to_any.py +1 -1
- transformers/pipelines/automatic_speech_recognition.py +0 -2
- transformers/pipelines/base.py +1 -1
- transformers/pipelines/image_text_to_text.py +1 -1
- transformers/pipelines/text_to_audio.py +5 -1
- transformers/processing_utils.py +35 -44
- transformers/pytorch_utils.py +2 -26
- transformers/quantizers/quantizer_compressed_tensors.py +7 -5
- transformers/quantizers/quantizer_fbgemm_fp8.py +20 -23
- transformers/quantizers/quantizer_finegrained_fp8.py +14 -20
- transformers/quantizers/quantizer_mxfp4.py +1 -1
- transformers/quantizers/quantizer_torchao.py +0 -16
- transformers/safetensors_conversion.py +11 -4
- transformers/testing_utils.py +3 -28
- transformers/tokenization_mistral_common.py +9 -0
- transformers/tokenization_python.py +6 -4
- transformers/tokenization_utils_base.py +119 -219
- transformers/tokenization_utils_tokenizers.py +31 -2
- transformers/trainer.py +25 -33
- transformers/trainer_seq2seq.py +1 -1
- transformers/training_args.py +411 -417
- transformers/utils/__init__.py +1 -4
- transformers/utils/auto_docstring.py +15 -18
- transformers/utils/backbone_utils.py +13 -373
- transformers/utils/doc.py +4 -36
- transformers/utils/generic.py +69 -33
- transformers/utils/import_utils.py +72 -75
- transformers/utils/loading_report.py +133 -105
- transformers/utils/quantization_config.py +0 -21
- transformers/video_processing_utils.py +5 -5
- transformers/video_utils.py +3 -1
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/METADATA +118 -237
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/RECORD +1019 -994
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/WHEEL +1 -1
- transformers/pipelines/deprecated/text2text_generation.py +0 -408
- transformers/pipelines/image_to_text.py +0 -189
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/entry_points.txt +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/licenses/LICENSE +0 -0
- {transformers-5.0.0rc3.dist-info → transformers-5.1.0.dist-info}/top_level.txt +0 -0
|
@@ -32,7 +32,7 @@ from ...integrations import use_kernel_forward_from_hub, use_kernelized_func
|
|
|
32
32
|
from ...masking_utils import create_causal_mask
|
|
33
33
|
from ...modeling_flash_attention_utils import FlashAttentionKwargs
|
|
34
34
|
from ...modeling_layers import GradientCheckpointingLayer
|
|
35
|
-
from ...modeling_outputs import BaseModelOutputWithPast, ModelOutput
|
|
35
|
+
from ...modeling_outputs import BaseModelOutputWithPast, BaseModelOutputWithPooling, ModelOutput
|
|
36
36
|
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
|
|
37
37
|
from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
38
38
|
from ...processing_utils import Unpack
|
|
@@ -126,9 +126,9 @@ class GlmImageVisionAttention(nn.Module):
|
|
|
126
126
|
key_states = key_states.transpose(0, 1).unsqueeze(0)
|
|
127
127
|
value_states = value_states.transpose(0, 1).unsqueeze(0)
|
|
128
128
|
|
|
129
|
-
attention_interface: Callable =
|
|
130
|
-
|
|
131
|
-
|
|
129
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
130
|
+
self.config._attn_implementation, eager_attention_forward
|
|
131
|
+
)
|
|
132
132
|
|
|
133
133
|
if "flash" in self.config._attn_implementation:
|
|
134
134
|
# Flash Attention: Use cu_seqlens for variable length attention
|
|
@@ -402,9 +402,9 @@ class GlmImageTextAttention(nn.Module):
|
|
|
402
402
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
|
403
403
|
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
|
404
404
|
|
|
405
|
-
attention_interface: Callable =
|
|
406
|
-
|
|
407
|
-
|
|
405
|
+
attention_interface: Callable = ALL_ATTENTION_FUNCTIONS.get_interface(
|
|
406
|
+
self.config._attn_implementation, eager_attention_forward
|
|
407
|
+
)
|
|
408
408
|
|
|
409
409
|
attn_output, attn_weights = attention_interface(
|
|
410
410
|
self,
|
|
@@ -612,6 +612,23 @@ class GlmImageVQVAEVectorQuantizer(nn.Module):
|
|
|
612
612
|
return hidden_state_quant, loss, min_encoding_indices
|
|
613
613
|
|
|
614
614
|
|
|
615
|
+
@dataclass
|
|
616
|
+
@auto_docstring
|
|
617
|
+
class GlmImageVQVAEModelOutput(BaseModelOutputWithPooling):
|
|
618
|
+
r"""
|
|
619
|
+
quantized_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
620
|
+
Quantized last hidden state from the VQ-VAE model.
|
|
621
|
+
image_tokens (`torch.FloatTensor` of shape `(batch_size, config.vocab_size`):
|
|
622
|
+
Indices of the image tokens predicted by the VQ-VAE model.
|
|
623
|
+
embedding_loss (`torch.FloatTensor`):
|
|
624
|
+
The embedding loss computed during quantization.
|
|
625
|
+
"""
|
|
626
|
+
|
|
627
|
+
quantized_last_hidden_state: torch.FloatTensor | None = None
|
|
628
|
+
image_tokens: torch.FloatTensor | None = None
|
|
629
|
+
embedding_loss: torch.FloatTensor | None = None
|
|
630
|
+
|
|
631
|
+
|
|
615
632
|
@auto_docstring(
|
|
616
633
|
custom_intro="""
|
|
617
634
|
The VQ-VAE model used in GlmImage for encoding/decoding images into discrete tokens.
|
|
@@ -625,6 +642,7 @@ class GlmImageVQVAE(GlmImagePreTrainedModel):
|
|
|
625
642
|
_no_split_modules = [
|
|
626
643
|
"GlmImageVQVAEVectorQuantizer",
|
|
627
644
|
]
|
|
645
|
+
_can_record_outputs = {}
|
|
628
646
|
|
|
629
647
|
def __init__(self, config: GlmImageVQVAEConfig):
|
|
630
648
|
super().__init__(config)
|
|
@@ -634,16 +652,26 @@ class GlmImageVQVAE(GlmImagePreTrainedModel):
|
|
|
634
652
|
self.eval() # GlmImage's VQ model is frozen
|
|
635
653
|
self.post_init()
|
|
636
654
|
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
655
|
+
@check_model_inputs
|
|
656
|
+
def encode(self, hidden_states) -> GlmImageVQVAEModelOutput:
|
|
657
|
+
conv_hidden_states = self.quant_conv(hidden_states)
|
|
658
|
+
quantized_last_hidden_state, emb_loss, indices = self.quantize(conv_hidden_states)
|
|
659
|
+
return GlmImageVQVAEModelOutput(
|
|
660
|
+
last_hidden_state=hidden_states,
|
|
661
|
+
quantized_last_hidden_state=quantized_last_hidden_state,
|
|
662
|
+
image_tokens=indices,
|
|
663
|
+
embedding_loss=emb_loss,
|
|
664
|
+
)
|
|
641
665
|
|
|
642
666
|
|
|
643
667
|
class GlmImageVisionModel(GlmImagePreTrainedModel):
|
|
644
668
|
config: GlmImageVisionConfig
|
|
645
669
|
input_modalities = ("image",)
|
|
646
670
|
_no_split_modules = ["GlmImageVisionBlock"]
|
|
671
|
+
_can_record_outputs = {
|
|
672
|
+
"hidden_states": GlmImageVisionBlock,
|
|
673
|
+
"attentions": GlmImageVisionAttention,
|
|
674
|
+
}
|
|
647
675
|
main_input_name = "pixel_values"
|
|
648
676
|
|
|
649
677
|
def __init__(self, config: GlmImageVisionConfig) -> None:
|
|
@@ -688,13 +716,16 @@ class GlmImageVisionModel(GlmImagePreTrainedModel):
|
|
|
688
716
|
pos_ids = torch.cat(pos_ids, dim=0)
|
|
689
717
|
return pos_ids
|
|
690
718
|
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
|
|
696
|
-
|
|
697
|
-
|
|
719
|
+
@check_model_inputs
|
|
720
|
+
@auto_docstring
|
|
721
|
+
def forward(
|
|
722
|
+
self, pixel_values: torch.Tensor, grid_thw: torch.Tensor, **kwargs: Unpack[TransformersKwargs]
|
|
723
|
+
) -> tuple | BaseModelOutputWithPooling:
|
|
724
|
+
r"""
|
|
725
|
+
pixel_values (`torch.Tensor` of shape `(total_patches, num_channels * patch_size * patch_size)`):
|
|
726
|
+
Packed pixel values.
|
|
727
|
+
grid_thw (`torch.Tensor` of shape `(num_images, 3)`):
|
|
728
|
+
The temporal, height and width of feature shape of each image.
|
|
698
729
|
|
|
699
730
|
Returns:
|
|
700
731
|
`torch.Tensor` of shape `(total_patches, hidden_size)`: Hidden states.
|
|
@@ -723,7 +754,8 @@ class GlmImageVisionModel(GlmImagePreTrainedModel):
|
|
|
723
754
|
hidden_states,
|
|
724
755
|
cu_seqlens=cu_seqlens,
|
|
725
756
|
)
|
|
726
|
-
|
|
757
|
+
|
|
758
|
+
return BaseModelOutputWithPooling(last_hidden_state=hidden_states)
|
|
727
759
|
|
|
728
760
|
|
|
729
761
|
class GlmImageTextRotaryEmbedding(nn.Module):
|
|
@@ -927,6 +959,10 @@ class GlmImageModel(GlmImagePreTrainedModel):
|
|
|
927
959
|
self.rope_deltas = None # cache rope_deltas here
|
|
928
960
|
self.vqmodel = GlmImageVQVAE._from_config(config.vq_config)
|
|
929
961
|
|
|
962
|
+
# Per-sample caches for batch processing
|
|
963
|
+
self._cached_decode_position_ids = None # shape: [batch_size, 3, max_decode_len]
|
|
964
|
+
self._prefill_len = None # prefill sequence length (same for all samples in batch)
|
|
965
|
+
|
|
930
966
|
# Initialize weights and apply final processing
|
|
931
967
|
self.post_init()
|
|
932
968
|
|
|
@@ -940,220 +976,169 @@ class GlmImageModel(GlmImagePreTrainedModel):
|
|
|
940
976
|
self,
|
|
941
977
|
input_ids: torch.LongTensor | None = None,
|
|
942
978
|
image_grid_thw: torch.LongTensor | None = None,
|
|
979
|
+
images_per_sample: torch.LongTensor | None = None,
|
|
943
980
|
attention_mask: torch.LongTensor | None = None,
|
|
944
981
|
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
945
982
|
"""
|
|
946
|
-
Calculate the 3D rope index for image generation task.
|
|
947
|
-
|
|
948
|
-
Explanation:
|
|
949
|
-
Each embedding sequence may contain image tokens (for generation) and text tokens,
|
|
950
|
-
or just text tokens.
|
|
951
|
-
|
|
952
|
-
Input format:
|
|
953
|
-
- Text-to-Image: [text tokens] + <|dit_token_16384|>
|
|
954
|
-
- Image-to-Image: <|dit_token_16384|> [image tokens] <|dit_token_16385|> + [text tokens] + <|dit_token_16384|>
|
|
955
|
-
|
|
956
|
-
For pure text embedding sequence, the rotary position embedding is the same across all 3 dimensions.
|
|
957
|
-
Examples:
|
|
958
|
-
input_ids: [T T T T T], here T is for text.
|
|
959
|
-
temporal position_ids: [0, 1, 2, 3, 4]
|
|
960
|
-
height position_ids: [0, 1, 2, 3, 4]
|
|
961
|
-
width position_ids: [0, 1, 2, 3, 4]
|
|
962
|
-
|
|
963
|
-
For sequences with image tokens, we use special markers to denote image regions:
|
|
964
|
-
- <|dit_token_16384|>: image start marker
|
|
965
|
-
- <|dit_token_16385|>: image end marker
|
|
966
|
-
- Image tokens between these markers use 2D spatial position encoding.
|
|
967
|
-
|
|
968
|
-
For image tokens:
|
|
969
|
-
- temporal: stays constant at (image_start_pos + 1)
|
|
970
|
-
- height: increments every w tokens, representing row position
|
|
971
|
-
- width: cycles from 0 to w-1, representing column position
|
|
972
|
-
|
|
973
|
-
After each image region, the next position jumps to: image_start_pos + 1 + max(h, w)
|
|
974
|
-
This ensures sufficient positional separation between images and subsequent tokens.
|
|
975
|
-
|
|
976
|
-
Examples:
|
|
977
|
-
=== Case 1: Image-to-Image Generation ===
|
|
978
|
-
|
|
979
|
-
Source image with grid [1, 3, 2], followed by text, then generation.
|
|
980
|
-
input_ids: [<|dit_token_16384|> V V V V V V <|dit_token_16385|> T T T T <|dit_token_16384|>]
|
|
981
|
-
image_grid_thw: [[1, 3, 2], [1, 4, 4]] # first is source, second is target
|
|
982
|
-
|
|
983
|
-
For source image (h=3, w=2, 6 tokens):
|
|
984
|
-
Start marker at position 0
|
|
985
|
-
Image tokens at temporal=1, height=[1,1,2,2,3,3], width=[1,2,1,2,1,2]
|
|
986
|
-
End marker at position 4 (= 0 + 1 + max(3,2))
|
|
987
|
-
|
|
988
|
-
Text tokens and trailing start marker continue from position 5.
|
|
989
|
-
|
|
990
|
-
Full prefill position_ids:
|
|
991
|
-
temporal: [0, 1,1,1,1,1,1, 4, 5,6,7,8, 9]
|
|
992
|
-
height: [0, 1,1,2,2,3,3, 4, 5,6,7,8, 9]
|
|
993
|
-
width: [0, 1,2,1,2,1,2, 4, 5,6,7,8, 9]
|
|
994
|
-
|
|
995
|
-
Decode stage: use image_grid_thw[-1] = [1, 4, 4] to build cached position_ids,
|
|
996
|
-
starting from gen_st_idx = 10.
|
|
997
|
-
|
|
998
|
-
=== Case 2: Text-to-Image Generation (multi-resolution) ===
|
|
999
|
-
|
|
1000
|
-
Pure text input with two image_grids for progressive generation.
|
|
1001
|
-
input_ids: [hello<sop>3 3<eop><sop>3 2<eop><|dit_token_16384|>]
|
|
1002
|
-
Assume "hello<sop>3 3<eop><sop>3 2<eop>" = 4 tokens (positions 0-3)
|
|
1003
|
-
<|dit_token_16384|> at position 4
|
|
1004
|
-
image_grid_thw: [[1, 3, 3], [1, 3, 2]]
|
|
1005
|
-
- image_grid_thw[-1] = [1, 3, 2]: first generated image (smaller/draft)
|
|
1006
|
-
- image_grid_thw[-2] = [1, 3, 3]: second generated image (larger/final)
|
|
1007
|
-
|
|
1008
|
-
Prefill position_ids (5 tokens: 4 text + 1 start marker):
|
|
1009
|
-
temporal: [0, 1, 2, 3, 4]
|
|
1010
|
-
height: [0, 1, 2, 3, 4]
|
|
1011
|
-
width: [0, 1, 2, 3, 4]
|
|
1012
|
-
|
|
1013
|
-
Decode stage builds position_ids in reverse order of image_grid_thw:
|
|
1014
|
-
|
|
1015
|
-
First: image_grid_thw[-1] = [1, 3, 2] (6 tokens), starting at position 5:
|
|
1016
|
-
temporal: [5, 5, 5, 5, 5, 5]
|
|
1017
|
-
height: [5, 5, 6, 6, 7, 7]
|
|
1018
|
-
width: [5, 6, 5, 6, 5, 6]
|
|
1019
|
-
next_pos = 5 + max(3, 2) = 8
|
|
1020
|
-
|
|
1021
|
-
Then: image_grid_thw[-2] = [1, 3, 3] (9 tokens), starting at position 8:
|
|
1022
|
-
temporal: [8, 8, 8, 8, 8, 8, 8, 8, 8]
|
|
1023
|
-
height: [8, 8, 8, 9, 9, 9, 10, 10, 10]
|
|
1024
|
-
width: [8, 9, 10, 8, 9, 10, 8, 9, 10]
|
|
1025
|
-
next_pos = 8 + max(3, 3) = 11
|
|
1026
|
-
|
|
1027
|
-
Finally: <|dit_token_16385|> end marker at position 11
|
|
1028
|
-
|
|
1029
|
-
Full sequence position_ids (prefill + decode):
|
|
1030
|
-
temporal: [0,1,2,3, 4, 5,5,5,5,5,5, 8,8,8,8,8,8,8,8,8, 11]
|
|
1031
|
-
height: [0,1,2,3, 4, 5,5,6,6,7,7, 8,8,8,9,9,9,10,10,10, 11]
|
|
1032
|
-
width: [0,1,2,3, 4, 5,6,5,6,5,6, 8,9,10,8,9,10,8,9,10, 11]
|
|
1033
|
-
|
|
1034
|
-
_cached_decode_position_ids shape: [3, 6 + 9 + 1] = [3, 16]
|
|
1035
|
-
(includes all generated image tokens + end marker)
|
|
983
|
+
Calculate the 3D rope index for image generation task with full batch support.
|
|
1036
984
|
|
|
1037
985
|
Args:
|
|
1038
986
|
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
|
1039
|
-
Indices of input sequence tokens in the vocabulary.
|
|
1040
|
-
|
|
1041
|
-
|
|
1042
|
-
|
|
1043
|
-
|
|
1044
|
-
|
|
1045
|
-
|
|
1046
|
-
processed in reverse order (last grid first, second-to-last grid second, etc.)
|
|
987
|
+
Indices of input sequence tokens in the vocabulary.
|
|
988
|
+
image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
|
|
989
|
+
The temporal, height and width of feature shape of each image.
|
|
990
|
+
Images are packed across all samples in the batch.
|
|
991
|
+
images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
992
|
+
Number of images (including target grids) for each sample in the batch.
|
|
993
|
+
Used to split image_grid_thw by sample.
|
|
1047
994
|
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
|
|
1048
|
-
Mask to avoid performing attention on padding token indices.
|
|
1049
|
-
- 1 for tokens that are **not masked**,
|
|
1050
|
-
- 0 for tokens that are **masked**.
|
|
995
|
+
Mask to avoid performing attention on padding token indices.
|
|
1051
996
|
|
|
1052
997
|
Returns:
|
|
1053
998
|
position_ids (`torch.LongTensor` of shape `(3, batch_size, sequence_length)`):
|
|
1054
999
|
Position IDs for temporal, height, and width dimensions.
|
|
1055
1000
|
mrope_position_deltas (`torch.Tensor` of shape `(batch_size, 1)`):
|
|
1056
|
-
Position deltas for multi-modal rotary position embedding
|
|
1001
|
+
Position deltas for multi-modal rotary position embedding.
|
|
1057
1002
|
"""
|
|
1058
|
-
|
|
1059
1003
|
batch_size, seq_len = input_ids.shape
|
|
1060
1004
|
device = input_ids.device
|
|
1061
1005
|
dtype = input_ids.dtype
|
|
1062
1006
|
|
|
1063
1007
|
image_start_token_id = self.config.image_start_token_id
|
|
1064
1008
|
image_end_token_id = self.config.image_end_token_id
|
|
1065
|
-
num_complete_images = (input_ids == image_end_token_id).sum().item()
|
|
1066
1009
|
|
|
1067
|
-
position_ids = torch.ones(
|
|
1068
|
-
|
|
1069
|
-
|
|
1070
|
-
|
|
1010
|
+
position_ids = torch.ones(3, batch_size, seq_len, dtype=dtype, device=device)
|
|
1011
|
+
text_positions = torch.arange(seq_len, device=device)[None, :].repeat(3, 1)
|
|
1012
|
+
|
|
1013
|
+
# Split image_grid_thw by sample if images_per_sample is provided
|
|
1014
|
+
if image_grid_thw is not None and images_per_sample is not None:
|
|
1015
|
+
grids_per_sample = torch.split(image_grid_thw, images_per_sample.tolist())
|
|
1016
|
+
elif image_grid_thw is not None:
|
|
1017
|
+
# Fallback: assume all grids belong to first sample (batch_size=1)
|
|
1018
|
+
grids_per_sample = [image_grid_thw] * batch_size
|
|
1019
|
+
else:
|
|
1020
|
+
grids_per_sample = [None] * batch_size
|
|
1021
|
+
|
|
1022
|
+
# Per-sample caches for decode stage
|
|
1023
|
+
all_decode_position_ids = []
|
|
1024
|
+
|
|
1071
1025
|
for batch_idx in range(batch_size):
|
|
1072
1026
|
curr_input_ids = input_ids[batch_idx]
|
|
1073
|
-
|
|
1074
|
-
curr_input_ids = curr_input_ids[attention_mask[batch_idx] == 1]
|
|
1027
|
+
curr_grids = grids_per_sample[batch_idx]
|
|
1075
1028
|
|
|
1076
|
-
|
|
1077
|
-
|
|
1078
|
-
|
|
1029
|
+
if attention_mask is not None and attention_mask.shape[1] == seq_len:
|
|
1030
|
+
valid_mask = attention_mask[batch_idx] == 1
|
|
1031
|
+
curr_input_ids_valid = curr_input_ids[valid_mask]
|
|
1032
|
+
else:
|
|
1033
|
+
# attention_mask may have different length during assisted decoding
|
|
1034
|
+
curr_input_ids_valid = curr_input_ids
|
|
1035
|
+
valid_mask = None
|
|
1036
|
+
|
|
1037
|
+
# Find image boundaries in this sample
|
|
1038
|
+
image_end_positions = torch.where(curr_input_ids_valid == image_end_token_id)[0]
|
|
1039
|
+
image_start_positions = torch.where(curr_input_ids_valid == image_start_token_id)[0] + 1
|
|
1040
|
+
num_complete_images = len(image_end_positions)
|
|
1041
|
+
|
|
1042
|
+
current_pos = 0
|
|
1079
1043
|
prev_image_end = 0
|
|
1080
1044
|
curr_position_ids = []
|
|
1081
|
-
for start, end, grid in zip(image_start, image_end, image_grid_thw):
|
|
1082
|
-
_, num_width_grid, num_height_grid = grid
|
|
1083
1045
|
|
|
1084
|
-
|
|
1046
|
+
# Process complete images (source images in image-to-image task)
|
|
1047
|
+
for img_idx, (start, end) in enumerate(zip(image_start_positions, image_end_positions)):
|
|
1048
|
+
if curr_grids is None or img_idx >= len(curr_grids):
|
|
1049
|
+
break
|
|
1050
|
+
grid = curr_grids[img_idx]
|
|
1051
|
+
# grid format is [temporal, height, width]
|
|
1052
|
+
_, height, width = grid.tolist()
|
|
1053
|
+
|
|
1054
|
+
# Text tokens before this image
|
|
1085
1055
|
llm_pos_length = start - prev_image_end
|
|
1086
|
-
llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(
|
|
1087
|
-
device=input_ids.device
|
|
1088
|
-
)
|
|
1056
|
+
llm_position_ids = text_positions[:, current_pos : current_pos + llm_pos_length].to(device=device)
|
|
1089
1057
|
current_pos += llm_position_ids.shape[-1]
|
|
1090
1058
|
|
|
1091
|
-
#
|
|
1092
|
-
|
|
1093
|
-
|
|
1094
|
-
|
|
1095
|
-
|
|
1096
|
-
|
|
1097
|
-
|
|
1098
|
-
|
|
1099
|
-
position_temporal = torch.full(
|
|
1100
|
-
(image_seq_length,), current_pos, device=input_ids.device, dtype=torch.long
|
|
1059
|
+
# Image tokens with 2D spatial encoding
|
|
1060
|
+
# For an image with height H and width W:
|
|
1061
|
+
# - position_width cycles [0, 1, ..., W-1] for each row, repeated H times
|
|
1062
|
+
# - position_height stays constant per row, [0]*W, [1]*W, ..., [H-1]*W
|
|
1063
|
+
image_seq_length = height * width
|
|
1064
|
+
position_width = torch.arange(current_pos, current_pos + width, device=device).repeat(height)
|
|
1065
|
+
position_height = torch.arange(current_pos, current_pos + height, device=device).repeat_interleave(
|
|
1066
|
+
width
|
|
1101
1067
|
)
|
|
1068
|
+
position_temporal = torch.full((image_seq_length,), current_pos, device=device, dtype=torch.long)
|
|
1102
1069
|
vision_position_ids = torch.stack([position_temporal, position_height, position_width], dim=0)
|
|
1103
|
-
current_pos += max(
|
|
1070
|
+
current_pos += max(height, width)
|
|
1104
1071
|
|
|
1105
1072
|
prev_image_end = end
|
|
1106
1073
|
curr_position_ids.append(torch.cat([llm_position_ids, vision_position_ids], dim=-1))
|
|
1107
1074
|
|
|
1108
|
-
#
|
|
1109
|
-
end_position = len(
|
|
1110
|
-
llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=
|
|
1075
|
+
# Remaining text tokens (including the final image_start token for generation)
|
|
1076
|
+
end_position = len(curr_input_ids_valid) - prev_image_end
|
|
1077
|
+
llm_position_ids = text_positions[:, current_pos : current_pos + end_position].to(device=device)
|
|
1111
1078
|
current_pos += llm_position_ids.shape[-1]
|
|
1112
1079
|
curr_position_ids.append(llm_position_ids)
|
|
1080
|
+
|
|
1081
|
+
# Concatenate all position ids for this sample
|
|
1113
1082
|
curr_position_ids = torch.cat(curr_position_ids, dim=-1)
|
|
1114
|
-
|
|
1115
|
-
|
|
1083
|
+
|
|
1084
|
+
# Store in the main position_ids tensor
|
|
1085
|
+
if valid_mask is not None:
|
|
1086
|
+
position_ids[:, batch_idx, valid_mask] = curr_position_ids
|
|
1116
1087
|
else:
|
|
1117
|
-
position_ids[:, batch_idx, :] = curr_position_ids
|
|
1088
|
+
position_ids[:, batch_idx, :] = curr_position_ids
|
|
1089
|
+
|
|
1090
|
+
# Build decode position ids for this sample
|
|
1091
|
+
if curr_grids is not None and len(curr_grids) > 0:
|
|
1092
|
+
num_decode_grids = len(curr_grids) - num_complete_images
|
|
1093
|
+
num_decode_grids = max(num_decode_grids, 0)
|
|
1094
|
+
decode_pos = current_pos
|
|
1095
|
+
|
|
1096
|
+
decode_temporal_list = []
|
|
1097
|
+
decode_height_list = []
|
|
1098
|
+
decode_width_list = []
|
|
1099
|
+
|
|
1100
|
+
for i in range(1, num_decode_grids + 1):
|
|
1101
|
+
grid_idx = -i
|
|
1102
|
+
h = curr_grids[grid_idx, 1].item()
|
|
1103
|
+
w = curr_grids[grid_idx, 2].item()
|
|
1104
|
+
total_tokens = h * w
|
|
1105
|
+
|
|
1106
|
+
h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
|
|
1107
|
+
w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()
|
|
1108
|
+
|
|
1109
|
+
decode_temporal_list.append(
|
|
1110
|
+
torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long)
|
|
1111
|
+
)
|
|
1112
|
+
decode_height_list.append(decode_pos + h_indices)
|
|
1113
|
+
decode_width_list.append(decode_pos + w_indices)
|
|
1114
|
+
decode_pos = decode_pos + max(h, w)
|
|
1115
|
+
|
|
1116
|
+
# End marker
|
|
1117
|
+
decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
|
|
1118
|
+
decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
|
|
1119
|
+
decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
|
|
1120
|
+
|
|
1121
|
+
sample_decode_pos_ids = torch.stack(
|
|
1122
|
+
[
|
|
1123
|
+
torch.cat(decode_temporal_list, dim=0),
|
|
1124
|
+
torch.cat(decode_height_list, dim=0),
|
|
1125
|
+
torch.cat(decode_width_list, dim=0),
|
|
1126
|
+
],
|
|
1127
|
+
dim=0,
|
|
1128
|
+
)
|
|
1129
|
+
all_decode_position_ids.append(sample_decode_pos_ids)
|
|
1118
1130
|
|
|
1119
|
-
#
|
|
1120
|
-
# slice these instead of computing each decoding step
|
|
1131
|
+
# Store prefill length (same for all samples since input_ids is padded to same length)
|
|
1121
1132
|
self._prefill_len = seq_len
|
|
1122
|
-
|
|
1123
|
-
|
|
1124
|
-
|
|
1125
|
-
|
|
1126
|
-
|
|
1127
|
-
|
|
1128
|
-
|
|
1129
|
-
|
|
1130
|
-
|
|
1131
|
-
for i in range(1, num_decode_grids + 1):
|
|
1132
|
-
grid_idx = -i
|
|
1133
|
-
h = image_grid_thw[grid_idx, 1].item()
|
|
1134
|
-
w = image_grid_thw[grid_idx, 2].item()
|
|
1135
|
-
total_tokens = h * w
|
|
1136
|
-
|
|
1137
|
-
h_indices = torch.arange(h, device=device).unsqueeze(1).expand(h, w).flatten()
|
|
1138
|
-
w_indices = torch.arange(w, device=device).unsqueeze(0).expand(h, w).flatten()
|
|
1139
|
-
|
|
1140
|
-
decode_temporal_list.append(torch.full((total_tokens,), decode_pos, device=device, dtype=torch.long))
|
|
1141
|
-
decode_height_list.append(decode_pos + h_indices)
|
|
1142
|
-
decode_width_list.append(decode_pos + w_indices)
|
|
1143
|
-
decode_pos = decode_pos + max(h, w)
|
|
1144
|
-
|
|
1145
|
-
decode_temporal_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
|
|
1146
|
-
decode_height_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
|
|
1147
|
-
decode_width_list.append(torch.tensor([decode_pos], device=device, dtype=torch.long))
|
|
1148
|
-
|
|
1149
|
-
self._cached_decode_position_ids = torch.stack(
|
|
1150
|
-
[
|
|
1151
|
-
torch.cat(decode_temporal_list, dim=0),
|
|
1152
|
-
torch.cat(decode_height_list, dim=0),
|
|
1153
|
-
torch.cat(decode_width_list, dim=0),
|
|
1154
|
-
],
|
|
1155
|
-
dim=0,
|
|
1156
|
-
)
|
|
1133
|
+
|
|
1134
|
+
# Pad decode position ids to same length and stack
|
|
1135
|
+
if all_decode_position_ids:
|
|
1136
|
+
max_decode_len = max(x.shape[1] for x in all_decode_position_ids)
|
|
1137
|
+
padded_decode_pos_ids = [
|
|
1138
|
+
F.pad(pos_ids, (0, max_decode_len - pos_ids.shape[1]), mode="replicate")
|
|
1139
|
+
for pos_ids in all_decode_position_ids
|
|
1140
|
+
]
|
|
1141
|
+
self._cached_decode_position_ids = torch.stack(padded_decode_pos_ids, dim=0) # [batch, 3, max_decode_len]
|
|
1157
1142
|
else:
|
|
1158
1143
|
self._cached_decode_position_ids = None
|
|
1159
1144
|
|
|
@@ -1161,21 +1146,27 @@ class GlmImageModel(GlmImagePreTrainedModel):
|
|
|
1161
1146
|
|
|
1162
1147
|
return position_ids, mrope_position_deltas
|
|
1163
1148
|
|
|
1164
|
-
|
|
1165
|
-
|
|
1166
|
-
|
|
1167
|
-
|
|
1168
|
-
|
|
1169
|
-
|
|
1170
|
-
|
|
1171
|
-
|
|
1172
|
-
|
|
1149
|
+
@can_return_tuple
|
|
1150
|
+
@auto_docstring
|
|
1151
|
+
def get_image_features(
|
|
1152
|
+
self,
|
|
1153
|
+
pixel_values: torch.FloatTensor,
|
|
1154
|
+
image_grid_thw: torch.LongTensor | None = None,
|
|
1155
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1156
|
+
) -> tuple | BaseModelOutputWithPooling:
|
|
1157
|
+
r"""
|
|
1158
|
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
1159
|
+
The tensors corresponding to the input images.
|
|
1160
|
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
1161
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
1173
1162
|
"""
|
|
1174
1163
|
pixel_values = pixel_values.type(self.visual.dtype)
|
|
1175
|
-
|
|
1164
|
+
vision_outputs = self.visual(pixel_values, grid_thw=image_grid_thw, return_dict=True, **kwargs)
|
|
1176
1165
|
split_sizes = (image_grid_thw.prod(-1) // self.visual.spatial_merge_size**2).tolist()
|
|
1177
|
-
image_embeds = torch.split(
|
|
1178
|
-
|
|
1166
|
+
image_embeds = torch.split(vision_outputs.last_hidden_state, split_sizes)
|
|
1167
|
+
vision_outputs.pooler_output = image_embeds
|
|
1168
|
+
|
|
1169
|
+
return vision_outputs
|
|
1179
1170
|
|
|
1180
1171
|
def get_placeholder_mask(
|
|
1181
1172
|
self,
|
|
@@ -1219,23 +1210,63 @@ class GlmImageModel(GlmImagePreTrainedModel):
|
|
|
1219
1210
|
inputs_embeds: torch.FloatTensor | None = None,
|
|
1220
1211
|
pixel_values: torch.Tensor | None = None,
|
|
1221
1212
|
image_grid_thw: torch.LongTensor | None = None,
|
|
1213
|
+
images_per_sample: torch.LongTensor | None = None,
|
|
1222
1214
|
rope_deltas: torch.LongTensor | None = None,
|
|
1223
1215
|
cache_position: torch.LongTensor | None = None,
|
|
1224
1216
|
**kwargs: Unpack[TransformersKwargs],
|
|
1225
1217
|
) -> tuple | GlmImageModelOutputWithPast:
|
|
1226
1218
|
r"""
|
|
1227
|
-
image_grid_thw (`torch.LongTensor` of shape `(
|
|
1219
|
+
image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
|
|
1228
1220
|
The temporal, height and width of feature shape of each image in LLM.
|
|
1221
|
+
Images are packed across all samples in the batch.
|
|
1222
|
+
images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
1223
|
+
Number of images (including target grids) for each sample in the batch.
|
|
1229
1224
|
rope_deltas (`torch.LongTensor` of shape `(batch_size, )`, *optional*):
|
|
1230
1225
|
The rope index difference between sequence length and multimodal rope.
|
|
1231
1226
|
"""
|
|
1232
1227
|
if (input_ids is None) ^ (inputs_embeds is not None):
|
|
1233
1228
|
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
|
1234
1229
|
|
|
1230
|
+
batch_size = input_ids.shape[0] if input_ids is not None else inputs_embeds.shape[0]
|
|
1231
|
+
|
|
1235
1232
|
if pixel_values is not None:
|
|
1236
|
-
|
|
1237
|
-
|
|
1238
|
-
|
|
1233
|
+
# Process source images (image-to-image mode)
|
|
1234
|
+
# Source images are identified by counting image_end_token_id in input_ids
|
|
1235
|
+
# Note: We must exclude padding tokens since pad_token_id == image_end_token_id
|
|
1236
|
+
if images_per_sample is not None:
|
|
1237
|
+
grids_per_sample = torch.split(image_grid_thw, images_per_sample.tolist())
|
|
1238
|
+
# Create mask for non-padding tokens (attention_mask=1 means non-padding)
|
|
1239
|
+
# Handle 4D attention mask (from static cache) by extracting diagonal
|
|
1240
|
+
if attention_mask is not None and attention_mask.ndim == 4:
|
|
1241
|
+
non_pad_mask = torch.diagonal(attention_mask[:, 0], dim1=1, dim2=2)
|
|
1242
|
+
if non_pad_mask.dtype.is_floating_point:
|
|
1243
|
+
non_pad_mask = non_pad_mask / torch.finfo(non_pad_mask.dtype).min
|
|
1244
|
+
non_pad_mask = (1.0 - non_pad_mask).int()
|
|
1245
|
+
# Only keep columns matching input_ids length
|
|
1246
|
+
non_pad_mask = non_pad_mask[:, -input_ids.shape[1] :]
|
|
1247
|
+
else:
|
|
1248
|
+
non_pad_mask = attention_mask if attention_mask is not None else torch.ones_like(input_ids)
|
|
1249
|
+
|
|
1250
|
+
source_grids_list = []
|
|
1251
|
+
for sample_idx in range(batch_size):
|
|
1252
|
+
is_image_end = input_ids[sample_idx] == self.config.image_end_token_id
|
|
1253
|
+
is_non_pad = non_pad_mask[sample_idx] == 1
|
|
1254
|
+
num_source = (is_image_end & is_non_pad).sum().item()
|
|
1255
|
+
if num_source > 0:
|
|
1256
|
+
source_grids_list.append(grids_per_sample[sample_idx][:num_source])
|
|
1257
|
+
if len(source_grids_list) == 0:
|
|
1258
|
+
raise ValueError(
|
|
1259
|
+
"pixel_values provided but no source images found in input_ids. "
|
|
1260
|
+
"Ensure input_ids contains image_end_token_id for each source image."
|
|
1261
|
+
)
|
|
1262
|
+
source_grids = torch.cat(source_grids_list, dim=0)
|
|
1263
|
+
else:
|
|
1264
|
+
# Fallback for batch_size=1: all but last grid are source images
|
|
1265
|
+
source_grids = image_grid_thw[:-1]
|
|
1266
|
+
|
|
1267
|
+
image_features = self.get_image_features(pixel_values, source_grids, return_dict=True)
|
|
1268
|
+
image_embeds = torch.cat(image_features.pooler_output, dim=0)
|
|
1269
|
+
image_ids = self.get_image_tokens(image_embeds, source_grids)
|
|
1239
1270
|
image_ids = image_ids.view(-1).to(input_ids.device)
|
|
1240
1271
|
special_image_mask = self.get_placeholder_mask(input_ids, image_ids)
|
|
1241
1272
|
input_ids = input_ids.masked_scatter(special_image_mask, image_ids)
|
|
@@ -1253,8 +1284,6 @@ class GlmImageModel(GlmImagePreTrainedModel):
|
|
|
1253
1284
|
attention_mask_2d = (1.0 - attention_mask_2d).int()
|
|
1254
1285
|
|
|
1255
1286
|
# Calculate RoPE index once per generation in the pre-fill stage only.
|
|
1256
|
-
# It is safe to assume that `length!=1` means we're in pre-fill because the
|
|
1257
|
-
# model is used only by DiT pipeline without assisted decoding, etc. techniques
|
|
1258
1287
|
is_prefill_stage = (input_ids is not None and input_ids.shape[1] != 1) or (
|
|
1259
1288
|
inputs_embeds is not None and inputs_embeds.shape[1] != 1
|
|
1260
1289
|
)
|
|
@@ -1262,17 +1291,27 @@ class GlmImageModel(GlmImagePreTrainedModel):
|
|
|
1262
1291
|
position_ids, rope_deltas = self.get_rope_index(
|
|
1263
1292
|
input_ids,
|
|
1264
1293
|
image_grid_thw,
|
|
1294
|
+
images_per_sample=images_per_sample,
|
|
1265
1295
|
attention_mask=attention_mask_2d,
|
|
1266
1296
|
)
|
|
1267
1297
|
self.rope_deltas = rope_deltas
|
|
1268
1298
|
# then use the prev pre-calculated rope-deltas to get the correct position ids
|
|
1269
1299
|
else:
|
|
1270
1300
|
batch_size, seq_length, _ = inputs_embeds.shape
|
|
1271
|
-
#
|
|
1272
|
-
|
|
1273
|
-
|
|
1274
|
-
|
|
1275
|
-
|
|
1301
|
+
# Per-sample decode position lookup
|
|
1302
|
+
# _cached_decode_position_ids shape: [batch_size, 3, max_decode_len]
|
|
1303
|
+
if self._cached_decode_position_ids is not None:
|
|
1304
|
+
step = cache_position[0].item() - self._prefill_len
|
|
1305
|
+
# Get position ids for all samples at once, then transpose to [3, batch_size, seq_length]
|
|
1306
|
+
position_ids = self._cached_decode_position_ids[:, :, step : step + seq_length].permute(1, 0, 2)
|
|
1307
|
+
else:
|
|
1308
|
+
# Fallback for text-to-image or cases without cached decode positions
|
|
1309
|
+
# Use simple incremental positions
|
|
1310
|
+
start_pos = cache_position[0].item()
|
|
1311
|
+
position_ids = torch.arange(
|
|
1312
|
+
start_pos, start_pos + seq_length, device=inputs_embeds.device, dtype=torch.long
|
|
1313
|
+
)
|
|
1314
|
+
position_ids = position_ids.unsqueeze(0).repeat(3, batch_size, 1)
|
|
1276
1315
|
|
|
1277
1316
|
outputs = self.language_model(
|
|
1278
1317
|
input_ids=None,
|
|
@@ -1319,8 +1358,8 @@ class GlmImageModel(GlmImagePreTrainedModel):
|
|
|
1319
1358
|
grid_t, grid_h, grid_w = image_grid_thw[i].tolist()
|
|
1320
1359
|
hs = hs.view(grid_t, grid_h, grid_w, hidden_size)
|
|
1321
1360
|
hs = hs.permute(0, 3, 1, 2).contiguous()
|
|
1322
|
-
|
|
1323
|
-
all_image_toks.append(
|
|
1361
|
+
vqmodel_outputs: GlmImageVQVAEModelOutput = self.vqmodel.encode(hs)
|
|
1362
|
+
all_image_toks.append(vqmodel_outputs.image_tokens)
|
|
1324
1363
|
return torch.cat(all_image_toks, dim=0)
|
|
1325
1364
|
|
|
1326
1365
|
|
|
@@ -1369,8 +1408,20 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
|
|
|
1369
1408
|
# Initialize weights and apply final processing
|
|
1370
1409
|
self.post_init()
|
|
1371
1410
|
|
|
1372
|
-
|
|
1373
|
-
|
|
1411
|
+
@auto_docstring
|
|
1412
|
+
def get_image_features(
|
|
1413
|
+
self,
|
|
1414
|
+
pixel_values: torch.FloatTensor,
|
|
1415
|
+
image_grid_thw: torch.LongTensor | None = None,
|
|
1416
|
+
**kwargs: Unpack[TransformersKwargs],
|
|
1417
|
+
) -> tuple | BaseModelOutputWithPooling:
|
|
1418
|
+
r"""
|
|
1419
|
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, image_size, image_size)`):
|
|
1420
|
+
The tensors corresponding to the input images.
|
|
1421
|
+
image_grid_thw (`torch.LongTensor` of shape `(num_images, 3)`, *optional*):
|
|
1422
|
+
The temporal, height and width of feature shape of each image in LLM.
|
|
1423
|
+
"""
|
|
1424
|
+
return self.model.get_image_features(pixel_values, image_grid_thw, **kwargs)
|
|
1374
1425
|
|
|
1375
1426
|
def get_image_tokens(self, hidden_states: torch.FloatTensor, image_grid_thw: torch.LongTensor | None = None):
|
|
1376
1427
|
return self.model.get_image_tokens(hidden_states, image_grid_thw)
|
|
@@ -1385,6 +1436,7 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
|
|
|
1385
1436
|
labels: torch.LongTensor | None = None,
|
|
1386
1437
|
pixel_values: torch.Tensor | None = None,
|
|
1387
1438
|
image_grid_thw: torch.LongTensor | None = None,
|
|
1439
|
+
images_per_sample: torch.LongTensor | None = None,
|
|
1388
1440
|
cache_position: torch.LongTensor | None = None,
|
|
1389
1441
|
logits_to_keep: int | torch.Tensor = 0,
|
|
1390
1442
|
**kwargs: Unpack[TransformersKwargs],
|
|
@@ -1394,14 +1446,18 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
|
|
|
1394
1446
|
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
|
|
1395
1447
|
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
|
|
1396
1448
|
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
|
|
1397
|
-
image_grid_thw (`torch.LongTensor` of shape `(
|
|
1449
|
+
image_grid_thw (`torch.LongTensor` of shape `(total_images_in_batch, 3)`, *optional*):
|
|
1398
1450
|
The temporal, height and width of feature shape of each image in LLM.
|
|
1451
|
+
Images are packed across all samples in the batch.
|
|
1452
|
+
images_per_sample (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
|
1453
|
+
Number of images (including target grids) for each sample in the batch.
|
|
1399
1454
|
|
|
1400
1455
|
Example:
|
|
1401
1456
|
|
|
1402
1457
|
```python
|
|
1403
1458
|
>>> from PIL import Image
|
|
1404
|
-
>>> import
|
|
1459
|
+
>>> import httpx
|
|
1460
|
+
>>> from io import BytesIO
|
|
1405
1461
|
>>> from transformers import AutoProcessor, GlmImageForConditionalGeneration
|
|
1406
1462
|
|
|
1407
1463
|
>>> model = GlmImageForConditionalGeneration.from_pretrained("zai-org/GLM-Image")
|
|
@@ -1417,7 +1473,8 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
|
|
|
1417
1473
|
},
|
|
1418
1474
|
]
|
|
1419
1475
|
>>> url = "https://www.ilankelman.org/stopsigns/australia.jpg"
|
|
1420
|
-
>>>
|
|
1476
|
+
>>> with httpx.stream("GET", url) as response:
|
|
1477
|
+
... image = Image.open(BytesIO(response.read()))
|
|
1421
1478
|
|
|
1422
1479
|
>>> text = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
|
1423
1480
|
>>> inputs = processor(text=[text], images=[image], vision_infos=[vision_infos])
|
|
@@ -1431,6 +1488,7 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
|
|
|
1431
1488
|
input_ids=input_ids,
|
|
1432
1489
|
pixel_values=pixel_values,
|
|
1433
1490
|
image_grid_thw=image_grid_thw,
|
|
1491
|
+
images_per_sample=images_per_sample,
|
|
1434
1492
|
position_ids=position_ids,
|
|
1435
1493
|
attention_mask=attention_mask,
|
|
1436
1494
|
past_key_values=past_key_values,
|
|
@@ -1469,6 +1527,7 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
|
|
|
1469
1527
|
use_cache=True,
|
|
1470
1528
|
pixel_values=None,
|
|
1471
1529
|
image_grid_thw=None,
|
|
1530
|
+
images_per_sample=None,
|
|
1472
1531
|
is_first_iteration=False,
|
|
1473
1532
|
**kwargs,
|
|
1474
1533
|
):
|
|
@@ -1487,6 +1546,7 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
|
|
|
1487
1546
|
)
|
|
1488
1547
|
|
|
1489
1548
|
model_inputs["position_ids"] = None
|
|
1549
|
+
model_inputs["images_per_sample"] = images_per_sample
|
|
1490
1550
|
|
|
1491
1551
|
if not is_first_iteration and use_cache:
|
|
1492
1552
|
model_inputs["pixel_values"] = None
|
|
@@ -1523,11 +1583,42 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
|
|
|
1523
1583
|
if expand_size == 1:
|
|
1524
1584
|
return input_ids, model_kwargs
|
|
1525
1585
|
|
|
1526
|
-
visual_keys = ["pixel_values", "image_grid_thw"]
|
|
1586
|
+
visual_keys = ["pixel_values", "image_grid_thw", "images_per_sample"]
|
|
1527
1587
|
|
|
1528
1588
|
def _expand_dict_for_generation_visual(dict_to_expand):
|
|
1529
1589
|
image_grid_thw = model_kwargs.get("image_grid_thw", None)
|
|
1530
|
-
|
|
1590
|
+
if image_grid_thw is None:
|
|
1591
|
+
return dict_to_expand
|
|
1592
|
+
|
|
1593
|
+
images_per_sample = model_kwargs.get("images_per_sample", None)
|
|
1594
|
+
|
|
1595
|
+
# Use images_per_sample if available
|
|
1596
|
+
if images_per_sample is not None:
|
|
1597
|
+
image_nums = images_per_sample.tolist()
|
|
1598
|
+
elif input_ids is not None:
|
|
1599
|
+
# Try to infer from image_grid_thw / batch_size
|
|
1600
|
+
batch_size = input_ids.shape[0]
|
|
1601
|
+
total_grids = image_grid_thw.shape[0]
|
|
1602
|
+
if total_grids % batch_size == 0:
|
|
1603
|
+
grids_per_sample = total_grids // batch_size
|
|
1604
|
+
image_nums = [grids_per_sample] * batch_size
|
|
1605
|
+
else:
|
|
1606
|
+
# Cannot evenly distribute grids - fall back to simple repeat_interleave
|
|
1607
|
+
# This handles test cases where image_grid_thw has (batch_size + 1) rows
|
|
1608
|
+
dict_to_expand["image_grid_thw"] = image_grid_thw.repeat_interleave(expand_size, dim=0)
|
|
1609
|
+
if dict_to_expand.get("pixel_values") is not None:
|
|
1610
|
+
dict_to_expand["pixel_values"] = dict_to_expand["pixel_values"].repeat_interleave(
|
|
1611
|
+
expand_size, dim=0
|
|
1612
|
+
)
|
|
1613
|
+
return dict_to_expand
|
|
1614
|
+
else:
|
|
1615
|
+
image_nums = self._get_image_nums(input_ids).tolist()
|
|
1616
|
+
|
|
1617
|
+
# Get source image counts per sample from image_end_token_id count
|
|
1618
|
+
source_image_nums = [
|
|
1619
|
+
(input_ids[batch_idx] == self.config.image_end_token_id).sum().item()
|
|
1620
|
+
for batch_idx in range(len(image_nums))
|
|
1621
|
+
]
|
|
1531
1622
|
|
|
1532
1623
|
def _repeat_interleave_samples(x, lengths, repeat_times):
|
|
1533
1624
|
samples = torch.split(x, lengths)
|
|
@@ -1537,21 +1628,31 @@ class GlmImageForConditionalGeneration(GlmImagePreTrainedModel, GenerationMixin)
|
|
|
1537
1628
|
|
|
1538
1629
|
for key in dict_to_expand:
|
|
1539
1630
|
if key == "pixel_values":
|
|
1540
|
-
#
|
|
1541
|
-
|
|
1542
|
-
|
|
1543
|
-
|
|
1544
|
-
|
|
1545
|
-
|
|
1546
|
-
|
|
1631
|
+
# Split images into samples based on source image counts
|
|
1632
|
+
if sum(source_image_nums) > 0:
|
|
1633
|
+
# Split grids by sample to compute pixel counts
|
|
1634
|
+
grids_per_sample = torch.split(image_grid_thw, image_nums)
|
|
1635
|
+
lengths = []
|
|
1636
|
+
for batch_idx, sample_grids in enumerate(grids_per_sample):
|
|
1637
|
+
num_source = source_image_nums[batch_idx]
|
|
1638
|
+
if num_source > 0:
|
|
1639
|
+
source_grids = sample_grids[:num_source]
|
|
1640
|
+
lengths.append(torch.prod(source_grids, dim=1).sum().item())
|
|
1641
|
+
else:
|
|
1642
|
+
lengths.append(0)
|
|
1643
|
+
|
|
1644
|
+
dict_to_expand[key] = _repeat_interleave_samples(
|
|
1645
|
+
dict_to_expand[key], lengths=lengths, repeat_times=expand_size
|
|
1646
|
+
)
|
|
1547
1647
|
elif key == "image_grid_thw":
|
|
1548
|
-
#
|
|
1549
|
-
lengths = list(image_nums)
|
|
1550
|
-
last_image = dict_to_expand[key][:-1]
|
|
1648
|
+
# Expand all grids (source + target) per sample
|
|
1551
1649
|
dict_to_expand[key] = _repeat_interleave_samples(
|
|
1552
|
-
dict_to_expand[key]
|
|
1650
|
+
dict_to_expand[key], lengths=image_nums, repeat_times=expand_size
|
|
1553
1651
|
)
|
|
1554
|
-
|
|
1652
|
+
elif key == "images_per_sample":
|
|
1653
|
+
# Simply repeat the counts
|
|
1654
|
+
if dict_to_expand.get(key) is not None:
|
|
1655
|
+
dict_to_expand[key] = dict_to_expand[key].repeat_interleave(expand_size, dim=0)
|
|
1555
1656
|
return dict_to_expand
|
|
1556
1657
|
|
|
1557
1658
|
def _expand_dict_for_generation(dict_to_expand):
|