crfm-helm 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +138 -31
- crfm_helm-0.5.1.dist-info/RECORD +654 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +25 -3
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +41 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +213 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +392 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +575 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +41 -1
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +205 -35
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +163 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +757 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +94 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +3 -4
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +5 -3
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_image2structure.yaml +304 -0
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
- helm/benchmark/static/schema_vlm.yaml +823 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
- helm/benchmark/static_build/assets/index-878a1094.css +1 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +233 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +301 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +104 -73
- helm/clients/vertexai_client.py +400 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +111 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +54 -49
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +33 -3
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1159 -538
- helm/config/model_metadata.yaml +868 -41
- helm/config/tokenizer_configs.yaml +149 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_deployment_definition.py +0 -92
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,1834 @@
|
|
|
1
|
+
# coding=utf-8
|
|
2
|
+
# Copyright 2021-2022 The Fairseq Authors and The Google Flax Team Authors And The HuggingFace Inc. team and & DALL·E Mini team. All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
5
|
+
# you may not use this file except in compliance with the License.
|
|
6
|
+
# You may obtain a copy of the License at
|
|
7
|
+
#
|
|
8
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
9
|
+
#
|
|
10
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
11
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
12
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
13
|
+
# See the License for the specific language governing permissions and
|
|
14
|
+
# limitations under the License.
|
|
15
|
+
""" DalleBart model. """
|
|
16
|
+
|
|
17
|
+
import math
|
|
18
|
+
from functools import partial
|
|
19
|
+
from typing import Any, Dict, Optional, Tuple
|
|
20
|
+
|
|
21
|
+
from transformers.modeling_flax_outputs import (
|
|
22
|
+
FlaxBaseModelOutput,
|
|
23
|
+
FlaxBaseModelOutputWithPastAndCrossAttentions,
|
|
24
|
+
FlaxCausalLMOutputWithCrossAttentions,
|
|
25
|
+
FlaxSeq2SeqLMOutput,
|
|
26
|
+
)
|
|
27
|
+
from transformers.modeling_flax_utils import ACT2FN
|
|
28
|
+
from transformers.models.bart.modeling_flax_bart import (
|
|
29
|
+
FlaxBartAttention,
|
|
30
|
+
FlaxBartForConditionalGeneration,
|
|
31
|
+
FlaxBartForConditionalGenerationModule,
|
|
32
|
+
FlaxBartModule,
|
|
33
|
+
)
|
|
34
|
+
from transformers.utils import ModelOutput, logging
|
|
35
|
+
from transformers.generation.configuration_utils import GenerationConfig
|
|
36
|
+
|
|
37
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
38
|
+
from .configuration import DalleBartConfig
|
|
39
|
+
from .utils import PretrainedFromWandbMixin
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
import flax
|
|
43
|
+
import flax.linen as nn
|
|
44
|
+
import jax
|
|
45
|
+
import jax.numpy as jnp
|
|
46
|
+
from einops import rearrange
|
|
47
|
+
from flax.core.frozen_dict import unfreeze
|
|
48
|
+
from flax.linen import combine_masks, make_causal_mask
|
|
49
|
+
from flax.linen import partitioning as nn_partitioning
|
|
50
|
+
from flax.linen.linear import PrecisionLike
|
|
51
|
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
|
52
|
+
from jax import custom_jvp, lax
|
|
53
|
+
from jax.random import PRNGKey
|
|
54
|
+
except ModuleNotFoundError as e:
|
|
55
|
+
handle_module_not_found_error(e, ["heim"])
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
logger = logging.get_logger(__name__)
|
|
59
|
+
|
|
60
|
+
remat = nn_partitioning.remat
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def smelu(beta: Any = 1.0):
|
|
64
|
+
"""
|
|
65
|
+
Implementation of "Real World Large Scale Recommendation Systems Reproducibility and Smooth Activations"
|
|
66
|
+
https://arxiv.org/abs/2202.06499
|
|
67
|
+
"""
|
|
68
|
+
|
|
69
|
+
@custom_jvp
|
|
70
|
+
@jax.jit
|
|
71
|
+
def _smelu(x: Any) -> Any:
|
|
72
|
+
x = jnp.where(x <= -beta, 0.0, x)
|
|
73
|
+
return jnp.where(x >= beta, x, jnp.square(x + beta) / (4 * beta))
|
|
74
|
+
|
|
75
|
+
_smelu.defjvps(
|
|
76
|
+
lambda g, ans, x: lax.select(
|
|
77
|
+
x == -beta,
|
|
78
|
+
lax.full_like(g, 0),
|
|
79
|
+
lax.select(x == beta, lax.full_like(g, 1), g),
|
|
80
|
+
)
|
|
81
|
+
)
|
|
82
|
+
return _smelu
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
ACT2FN.update({"smelu": smelu()})
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
# deepnet initialization
|
|
89
|
+
def deepnet_init(init_std, gain=1):
|
|
90
|
+
init = jax.nn.initializers.normal(init_std)
|
|
91
|
+
|
|
92
|
+
def _init(*args, **kwargs):
|
|
93
|
+
return gain * init(*args, **kwargs)
|
|
94
|
+
|
|
95
|
+
return _init
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
# deepnet gain
|
|
99
|
+
deepnet_gain = {
|
|
100
|
+
"encoder": {
|
|
101
|
+
"alpha": lambda config: 0.81 * (config.encoder_layers**4 * config.decoder_layers) ** 0.0625,
|
|
102
|
+
"beta": lambda config: 0.87 * (config.encoder_layers**4 * config.decoder_layers) ** -0.0625,
|
|
103
|
+
},
|
|
104
|
+
"decoder": {
|
|
105
|
+
"alpha": lambda config: (3 * config.decoder_layers) ** 0.25,
|
|
106
|
+
"beta": lambda config: (12 * config.decoder_layers) ** -0.25,
|
|
107
|
+
},
|
|
108
|
+
}
|
|
109
|
+
|
|
110
|
+
# subln gain
|
|
111
|
+
subln_gain = {
|
|
112
|
+
"encoder": lambda config: math.sqrt(
|
|
113
|
+
1.0 / 3.0 * math.log(3 * config.decoder_layers) * math.log(2 * config.encoder_layers)
|
|
114
|
+
),
|
|
115
|
+
"decoder": lambda config: math.sqrt(math.log(3 * config.decoder_layers)),
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
class RMSNorm(nn.Module):
|
|
120
|
+
"""
|
|
121
|
+
From "Root Mean Square Layer Normalization" by https://arxiv.org/abs/1910.07467
|
|
122
|
+
|
|
123
|
+
Adapted from flax.linen.LayerNorm
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
epsilon: float = 1e-6
|
|
127
|
+
dtype: Any = jnp.float32
|
|
128
|
+
param_dtype: Any = jnp.float32
|
|
129
|
+
use_scale: bool = True
|
|
130
|
+
scale_init: Any = jax.nn.initializers.ones
|
|
131
|
+
|
|
132
|
+
@nn.compact
|
|
133
|
+
def __call__(self, x):
|
|
134
|
+
reduction_axes = (-1,)
|
|
135
|
+
feature_axes = (-1,)
|
|
136
|
+
|
|
137
|
+
rms_sq = self._compute_rms_sq(x, reduction_axes)
|
|
138
|
+
|
|
139
|
+
return self._normalize(
|
|
140
|
+
self,
|
|
141
|
+
x,
|
|
142
|
+
rms_sq,
|
|
143
|
+
reduction_axes,
|
|
144
|
+
feature_axes,
|
|
145
|
+
self.dtype,
|
|
146
|
+
self.param_dtype,
|
|
147
|
+
self.epsilon,
|
|
148
|
+
self.use_scale,
|
|
149
|
+
self.scale_init,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
def _compute_rms_sq(self, x, axes):
|
|
153
|
+
x = jnp.asarray(x, jnp.promote_types(jnp.float32, jnp.result_type(x)))
|
|
154
|
+
rms_sq = jnp.mean(jax.lax.square(x), axes)
|
|
155
|
+
return rms_sq
|
|
156
|
+
|
|
157
|
+
def _normalize(
|
|
158
|
+
self,
|
|
159
|
+
mdl,
|
|
160
|
+
x,
|
|
161
|
+
rms_sq,
|
|
162
|
+
reduction_axes,
|
|
163
|
+
feature_axes,
|
|
164
|
+
dtype,
|
|
165
|
+
param_dtype,
|
|
166
|
+
epsilon,
|
|
167
|
+
use_scale,
|
|
168
|
+
scale_init,
|
|
169
|
+
):
|
|
170
|
+
reduction_axes = nn.normalization._canonicalize_axes(x.ndim, reduction_axes)
|
|
171
|
+
feature_axes = nn.normalization._canonicalize_axes(x.ndim, feature_axes)
|
|
172
|
+
stats_shape = list(x.shape)
|
|
173
|
+
for axis in reduction_axes:
|
|
174
|
+
stats_shape[axis] = 1
|
|
175
|
+
rms_sq = rms_sq.reshape(stats_shape)
|
|
176
|
+
feature_shape = [1] * x.ndim
|
|
177
|
+
reduced_feature_shape = []
|
|
178
|
+
for ax in feature_axes:
|
|
179
|
+
feature_shape[ax] = x.shape[ax]
|
|
180
|
+
reduced_feature_shape.append(x.shape[ax])
|
|
181
|
+
mul = lax.rsqrt(rms_sq + epsilon)
|
|
182
|
+
if use_scale:
|
|
183
|
+
scale = mdl.param("scale", scale_init, reduced_feature_shape, param_dtype).reshape(feature_shape)
|
|
184
|
+
mul *= scale
|
|
185
|
+
y = mul * x
|
|
186
|
+
return jnp.asarray(y, dtype)
|
|
187
|
+
|
|
188
|
+
|
|
189
|
+
def norm(type, *args, **kwargs):
|
|
190
|
+
if type == "rmsnorm":
|
|
191
|
+
return RMSNorm(*args, **kwargs)
|
|
192
|
+
elif type == "layernorm":
|
|
193
|
+
return nn.LayerNorm(*args, **kwargs)
|
|
194
|
+
else:
|
|
195
|
+
raise ValueError(f"Unknown norm type {type}")
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
def dot_product_attention_weights(
|
|
199
|
+
query: Any,
|
|
200
|
+
key: Any,
|
|
201
|
+
bias: Optional[Any] = None,
|
|
202
|
+
mask: Optional[Any] = None,
|
|
203
|
+
embed_pos: Optional[Any] = None,
|
|
204
|
+
broadcast_dropout: bool = True,
|
|
205
|
+
dropout_rng: Optional[PRNGKey] = None,
|
|
206
|
+
dropout_rate: float = 0.0,
|
|
207
|
+
deterministic: bool = False,
|
|
208
|
+
dtype: Any = jnp.float32,
|
|
209
|
+
precision: PrecisionLike = None,
|
|
210
|
+
sinkhorn_iters: int = 1,
|
|
211
|
+
is_encoder: bool = False,
|
|
212
|
+
tau=None,
|
|
213
|
+
):
|
|
214
|
+
"""
|
|
215
|
+
Computes dot-product attention weights given query and key.
|
|
216
|
+
mask is included into the bias.
|
|
217
|
+
|
|
218
|
+
Adapted from flax.linen.attention.dot_product_attention_weights"
|
|
219
|
+
"""
|
|
220
|
+
assert query.ndim == key.ndim, "q, k must have same rank."
|
|
221
|
+
assert query.shape[:-3] == key.shape[:-3], "q, k batch dims must match."
|
|
222
|
+
assert query.shape[-2] == key.shape[-2], "q, k num_heads must match."
|
|
223
|
+
assert query.shape[-1] == key.shape[-1], "q, k depths must match."
|
|
224
|
+
|
|
225
|
+
# attn weight shape is (batch..., num_heads, q_length, kv_length)
|
|
226
|
+
attn_weights = jnp.einsum("...qhd,...khd->...hqk", query, key, precision=precision)
|
|
227
|
+
|
|
228
|
+
# divide by tau (used in Swin v2)
|
|
229
|
+
if tau is not None:
|
|
230
|
+
attn_weights = attn_weights / tau
|
|
231
|
+
else:
|
|
232
|
+
depth = query.shape[-1]
|
|
233
|
+
attn_weights = attn_weights / jnp.sqrt(depth).astype(dtype)
|
|
234
|
+
|
|
235
|
+
# apply attention bias: masking, dropout, proximity bias, etc.
|
|
236
|
+
if bias is not None:
|
|
237
|
+
attn_weights = attn_weights + bias
|
|
238
|
+
|
|
239
|
+
# add relative position
|
|
240
|
+
if embed_pos is not None:
|
|
241
|
+
attn_weights = attn_weights + embed_pos
|
|
242
|
+
|
|
243
|
+
# normalize the attention weights
|
|
244
|
+
if not is_encoder or sinkhorn_iters == 1:
|
|
245
|
+
# sinkhorn does not work for causal (leaks info of future tokens into past)
|
|
246
|
+
attn_weights = jax.nn.softmax(attn_weights).astype(dtype)
|
|
247
|
+
else:
|
|
248
|
+
# adapted from https://github.com/lucidrains/sinkhorn-transformer
|
|
249
|
+
for i in range(sinkhorn_iters):
|
|
250
|
+
# when causal, some attn_weights have been set to -inf through bias
|
|
251
|
+
if i % 2 == 0:
|
|
252
|
+
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-1, keepdims=True)
|
|
253
|
+
else:
|
|
254
|
+
attn_weights -= jax.nn.logsumexp(attn_weights, axis=-2, keepdims=True)
|
|
255
|
+
if mask is not None:
|
|
256
|
+
attn_weights = jnp.where(mask, attn_weights, -jnp.inf)
|
|
257
|
+
attn_weights = jnp.exp(attn_weights).astype(dtype)
|
|
258
|
+
|
|
259
|
+
# apply attention dropout
|
|
260
|
+
if not deterministic and dropout_rate > 0.0:
|
|
261
|
+
keep_prob = 1.0 - dropout_rate
|
|
262
|
+
if broadcast_dropout:
|
|
263
|
+
# dropout is broadcast across the batch + head dimensions
|
|
264
|
+
dropout_shape = tuple([1] * (key.ndim - 2)) + attn_weights.shape[-2:]
|
|
265
|
+
keep = jax.random.bernoulli(dropout_rng, keep_prob, dropout_shape)
|
|
266
|
+
else:
|
|
267
|
+
keep = jax.random.bernoulli(dropout_rng, keep_prob, attn_weights.shape)
|
|
268
|
+
multiplier = keep.astype(attn_weights.dtype) / jnp.asarray(keep_prob, dtype=dtype)
|
|
269
|
+
attn_weights = attn_weights * multiplier
|
|
270
|
+
|
|
271
|
+
return attn_weights
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
class FlaxBartAttention(FlaxBartAttention):
|
|
275
|
+
"""
|
|
276
|
+
Edits:
|
|
277
|
+
- causal mask is used only in decoder and considers image_length
|
|
278
|
+
- scale attention heads per NormFormer paper
|
|
279
|
+
"""
|
|
280
|
+
|
|
281
|
+
is_encoder: bool = False
|
|
282
|
+
is_cross_attention: bool = False
|
|
283
|
+
q_length: int = None
|
|
284
|
+
k_length: int = None
|
|
285
|
+
|
|
286
|
+
def setup(self) -> None:
|
|
287
|
+
self.head_dim = self.embed_dim // self.num_heads
|
|
288
|
+
if self.head_dim * self.num_heads != self.embed_dim:
|
|
289
|
+
raise ValueError(
|
|
290
|
+
f"embed_dim must be divisible by num_heads (got `embed_dim`: {self.embed_dim}"
|
|
291
|
+
f" and `num_heads`: {self.num_heads})."
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
dense = partial(
|
|
295
|
+
nn.Dense,
|
|
296
|
+
self.embed_dim,
|
|
297
|
+
use_bias=self.bias,
|
|
298
|
+
dtype=self.dtype,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
if self.config.use_deepnet_scaling:
|
|
302
|
+
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](self.config)
|
|
303
|
+
elif self.config.use_subln_init and not self.is_cross_attention:
|
|
304
|
+
gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
|
|
305
|
+
|
|
306
|
+
self.q_proj = dense(kernel_init=jax.nn.initializers.normal(self.config.init_std))
|
|
307
|
+
self.k_proj = dense(kernel_init=jax.nn.initializers.normal(self.config.init_std))
|
|
308
|
+
self.v_proj = dense(
|
|
309
|
+
kernel_init=(
|
|
310
|
+
deepnet_init(self.config.init_std, gain)
|
|
311
|
+
if (self.config.use_deepnet_scaling or (self.config.use_subln_init and not self.is_cross_attention))
|
|
312
|
+
else jax.nn.initializers.normal(self.config.init_std)
|
|
313
|
+
)
|
|
314
|
+
)
|
|
315
|
+
self.out_proj = dense(
|
|
316
|
+
kernel_init=(
|
|
317
|
+
deepnet_init(self.config.init_std, gain)
|
|
318
|
+
if (self.config.use_deepnet_scaling or (self.config.use_subln_init and not self.is_cross_attention))
|
|
319
|
+
else jax.nn.initializers.normal(self.config.init_std)
|
|
320
|
+
)
|
|
321
|
+
)
|
|
322
|
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
|
323
|
+
|
|
324
|
+
if self.config.use_head_scale:
|
|
325
|
+
self.head_scale = self.param("head_scale", jax.nn.initializers.ones, (1, 1, self.num_heads, 1))
|
|
326
|
+
|
|
327
|
+
if self.config.use_cosine_attention:
|
|
328
|
+
# TODO: try using a learnt scale, somehow it immediately diverges in my experiments
|
|
329
|
+
self.tau = self.config.tau_init
|
|
330
|
+
|
|
331
|
+
if self.config.use_swin_position_embeddings:
|
|
332
|
+
self.rel_bias = nn.Embed(
|
|
333
|
+
self.q_length,
|
|
334
|
+
self.k_length * self.num_heads,
|
|
335
|
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
if self.causal:
|
|
339
|
+
# used only in decoder
|
|
340
|
+
self.causal_mask = make_causal_mask(jnp.ones((1, self.config.image_length), dtype="bool"), dtype="bool")
|
|
341
|
+
|
|
342
|
+
if self.config.ln_positions in ["subln"] and not self.is_cross_attention:
|
|
343
|
+
self.mid_layernorm = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)
|
|
344
|
+
|
|
345
|
+
def __call__(
|
|
346
|
+
self,
|
|
347
|
+
hidden_states: jnp.ndarray,
|
|
348
|
+
key_value_states: Optional[jnp.ndarray] = None,
|
|
349
|
+
attention_mask: Optional[jnp.ndarray] = None,
|
|
350
|
+
init_cache: bool = False,
|
|
351
|
+
deterministic: bool = True,
|
|
352
|
+
) -> Tuple[jnp.ndarray]:
|
|
353
|
+
"""Input shape: Batch x Time x Channel"""
|
|
354
|
+
|
|
355
|
+
# if key_value_states are provided this layer is used as a cross-attention layer
|
|
356
|
+
# for the decoder
|
|
357
|
+
is_cross_attention = key_value_states is not None
|
|
358
|
+
batch_size = hidden_states.shape[0]
|
|
359
|
+
|
|
360
|
+
# get query proj
|
|
361
|
+
query_states = self.q_proj(hidden_states)
|
|
362
|
+
# get key, value proj
|
|
363
|
+
if is_cross_attention:
|
|
364
|
+
# cross_attentions
|
|
365
|
+
key_states = self.k_proj(key_value_states)
|
|
366
|
+
value_states = self.v_proj(key_value_states)
|
|
367
|
+
else:
|
|
368
|
+
# self_attention
|
|
369
|
+
key_states = self.k_proj(hidden_states)
|
|
370
|
+
value_states = self.v_proj(hidden_states)
|
|
371
|
+
|
|
372
|
+
query_states = self._split_heads(query_states)
|
|
373
|
+
key_states = self._split_heads(key_states)
|
|
374
|
+
value_states = self._split_heads(value_states)
|
|
375
|
+
|
|
376
|
+
# handle cache prepare causal attention mask
|
|
377
|
+
if self.causal:
|
|
378
|
+
query_length, key_length = query_states.shape[1], key_states.shape[1]
|
|
379
|
+
if self.has_variable("cache", "cached_key"):
|
|
380
|
+
mask_shift = self.variables["cache"]["cache_index"]
|
|
381
|
+
max_decoder_length = self.variables["cache"]["cached_key"].shape[1]
|
|
382
|
+
causal_mask = lax.dynamic_slice(
|
|
383
|
+
self.causal_mask,
|
|
384
|
+
(0, 0, mask_shift, 0),
|
|
385
|
+
(1, 1, query_length, max_decoder_length),
|
|
386
|
+
)
|
|
387
|
+
else:
|
|
388
|
+
causal_mask = self.causal_mask[:, :, :query_length, :key_length]
|
|
389
|
+
causal_mask = jnp.broadcast_to(causal_mask, (batch_size,) + causal_mask.shape[1:])
|
|
390
|
+
|
|
391
|
+
# combine masks if needed
|
|
392
|
+
if attention_mask is not None and self.causal:
|
|
393
|
+
attention_mask = jnp.broadcast_to(jnp.expand_dims(attention_mask, axis=(-3, -2)), causal_mask.shape)
|
|
394
|
+
attention_mask = combine_masks(attention_mask, causal_mask)
|
|
395
|
+
elif self.causal:
|
|
396
|
+
attention_mask = causal_mask
|
|
397
|
+
elif attention_mask is not None:
|
|
398
|
+
attention_mask = jnp.expand_dims(attention_mask, axis=(-3, -2))
|
|
399
|
+
|
|
400
|
+
# During fast autoregressive decoding, we feed one position at a time,
|
|
401
|
+
# and cache the keys and values step by step.
|
|
402
|
+
if self.causal and (self.has_variable("cache", "cached_key") or init_cache):
|
|
403
|
+
key_states, value_states, attention_mask = self._concatenate_to_cache(
|
|
404
|
+
key_states, value_states, query_states, attention_mask
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
# Convert the boolean attention mask to an attention bias.
|
|
408
|
+
if attention_mask is not None:
|
|
409
|
+
# attention mask in the form of attention bias
|
|
410
|
+
attention_bias = lax.select(
|
|
411
|
+
attention_mask > 0,
|
|
412
|
+
jnp.full(attention_mask.shape, 0.0).astype(self.dtype),
|
|
413
|
+
jnp.full(attention_mask.shape, -jnp.inf).astype(self.dtype),
|
|
414
|
+
)
|
|
415
|
+
else:
|
|
416
|
+
attention_bias = None
|
|
417
|
+
|
|
418
|
+
dropout_rng = None
|
|
419
|
+
if not deterministic and self.dropout > 0.0:
|
|
420
|
+
dropout_rng = self.make_rng("dropout")
|
|
421
|
+
|
|
422
|
+
if self.config.use_cosine_attention:
|
|
423
|
+
# normalize q and k
|
|
424
|
+
query_states = query_states / (jnp.linalg.norm(query_states, axis=-1, keepdims=True) + 1e-8)
|
|
425
|
+
key_states = key_states / (jnp.linalg.norm(key_states, axis=-1, keepdims=True) + 1e-8)
|
|
426
|
+
|
|
427
|
+
# relative position embeddings
|
|
428
|
+
if self.config.use_swin_position_embeddings:
|
|
429
|
+
position_ids = jnp.arange(self.q_length)
|
|
430
|
+
embed_pos = self.rel_bias(position_ids)
|
|
431
|
+
embed_pos = rearrange(embed_pos, "q (k h) -> 1 h q k", h=self.num_heads)
|
|
432
|
+
else:
|
|
433
|
+
embed_pos = None
|
|
434
|
+
|
|
435
|
+
tau = self.tau if self.config.use_cosine_attention else None
|
|
436
|
+
attn_weights = dot_product_attention_weights(
|
|
437
|
+
query_states,
|
|
438
|
+
key_states,
|
|
439
|
+
bias=attention_bias,
|
|
440
|
+
mask=attention_mask,
|
|
441
|
+
embed_pos=embed_pos,
|
|
442
|
+
dropout_rng=dropout_rng,
|
|
443
|
+
dropout_rate=self.dropout,
|
|
444
|
+
broadcast_dropout=True,
|
|
445
|
+
deterministic=deterministic,
|
|
446
|
+
dtype=self.dtype,
|
|
447
|
+
precision=None,
|
|
448
|
+
sinkhorn_iters=self.config.sinkhorn_iters,
|
|
449
|
+
is_encoder=self.is_encoder,
|
|
450
|
+
tau=tau,
|
|
451
|
+
)
|
|
452
|
+
|
|
453
|
+
attn_output = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value_states)
|
|
454
|
+
if self.config.use_head_scale:
|
|
455
|
+
# per Normformer
|
|
456
|
+
attn_output = attn_output * self.head_scale
|
|
457
|
+
attn_output = self._merge_heads(attn_output)
|
|
458
|
+
|
|
459
|
+
if self.config.ln_positions in ["subln"] and not self.is_cross_attention:
|
|
460
|
+
attn_output = self.mid_layernorm(attn_output)
|
|
461
|
+
|
|
462
|
+
attn_output = self.out_proj(attn_output)
|
|
463
|
+
|
|
464
|
+
return attn_output, attn_weights
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
class GLU(nn.Module):
|
|
468
|
+
"""From "GLU Variants Improve Transformer" by https://arxiv.org/abs/2002.05202"""
|
|
469
|
+
|
|
470
|
+
config: DalleBartConfig
|
|
471
|
+
ffn_dim: int
|
|
472
|
+
embed_dim: int
|
|
473
|
+
dtype: jnp.dtype = jnp.float32
|
|
474
|
+
is_encoder: bool = False
|
|
475
|
+
|
|
476
|
+
@nn.compact
|
|
477
|
+
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
|
478
|
+
|
|
479
|
+
if self.config.use_deepnet_scaling:
|
|
480
|
+
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](self.config)
|
|
481
|
+
elif self.config.use_subln_init:
|
|
482
|
+
gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
|
|
483
|
+
|
|
484
|
+
if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
|
|
485
|
+
x = norm(
|
|
486
|
+
self.config.ln_type,
|
|
487
|
+
dtype=self.dtype,
|
|
488
|
+
epsilon=1e-05,
|
|
489
|
+
use_scale=self.config.force_ln_scale,
|
|
490
|
+
)(x)
|
|
491
|
+
w = nn.Dense(
|
|
492
|
+
self.ffn_dim,
|
|
493
|
+
dtype=self.dtype,
|
|
494
|
+
use_bias=self.config.use_bias,
|
|
495
|
+
kernel_init=(
|
|
496
|
+
deepnet_init(self.config.init_std, gain)
|
|
497
|
+
if (self.config.use_deepnet_scaling or self.config.use_subln_init)
|
|
498
|
+
else jax.nn.initializers.normal(self.config.init_std)
|
|
499
|
+
),
|
|
500
|
+
)(x)
|
|
501
|
+
w = ACT2FN[self.config.activation_function](w)
|
|
502
|
+
v = nn.Dense(
|
|
503
|
+
self.ffn_dim,
|
|
504
|
+
dtype=self.dtype,
|
|
505
|
+
use_bias=self.config.use_bias,
|
|
506
|
+
kernel_init=(
|
|
507
|
+
deepnet_init(self.config.init_std, gain)
|
|
508
|
+
if (self.config.use_deepnet_scaling or self.config.use_subln_init)
|
|
509
|
+
else jax.nn.initializers.normal(self.config.init_std)
|
|
510
|
+
),
|
|
511
|
+
)(x)
|
|
512
|
+
x = w * v
|
|
513
|
+
if self.config.ln_positions in ["normformer", "subln"]:
|
|
514
|
+
x = norm(
|
|
515
|
+
self.config.ln_type,
|
|
516
|
+
dtype=self.dtype,
|
|
517
|
+
epsilon=1e-05,
|
|
518
|
+
use_scale=self.config.force_ln_scale,
|
|
519
|
+
)(x)
|
|
520
|
+
x = nn.Dropout(rate=self.config.activation_dropout)(x, deterministic=deterministic)
|
|
521
|
+
|
|
522
|
+
x = nn.Dense(
|
|
523
|
+
self.embed_dim,
|
|
524
|
+
dtype=self.dtype,
|
|
525
|
+
use_bias=self.config.use_bias,
|
|
526
|
+
kernel_init=(
|
|
527
|
+
deepnet_init(self.config.init_std, gain)
|
|
528
|
+
if (self.config.use_deepnet_scaling or self.config.use_subln_init)
|
|
529
|
+
else jax.nn.initializers.normal(self.config.init_std)
|
|
530
|
+
),
|
|
531
|
+
)(x)
|
|
532
|
+
if self.config.ln_positions in ["swinv2", "cogview"]:
|
|
533
|
+
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
|
534
|
+
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
|
535
|
+
return x
|
|
536
|
+
|
|
537
|
+
|
|
538
|
+
class FFN(nn.Module):
|
|
539
|
+
"""Simple FFN layer"""
|
|
540
|
+
|
|
541
|
+
config: DalleBartConfig
|
|
542
|
+
ffn_dim: int
|
|
543
|
+
embed_dim: int
|
|
544
|
+
dtype: jnp.dtype = jnp.float32
|
|
545
|
+
is_encoder: bool = False
|
|
546
|
+
|
|
547
|
+
@nn.compact
|
|
548
|
+
def __call__(self, x: jnp.ndarray, deterministic: bool = True) -> jnp.ndarray:
|
|
549
|
+
|
|
550
|
+
if self.config.use_deepnet_scaling:
|
|
551
|
+
gain = deepnet_gain["encoder" if self.is_encoder else "decoder"]["beta"](self.config)
|
|
552
|
+
elif self.config.use_subln_init:
|
|
553
|
+
gain = subln_gain["encoder" if self.is_encoder else "decoder"](self.config)
|
|
554
|
+
if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
|
|
555
|
+
x = norm(
|
|
556
|
+
self.config.ln_type,
|
|
557
|
+
dtype=self.dtype,
|
|
558
|
+
epsilon=1e-05,
|
|
559
|
+
use_scale=self.config.force_ln_scale,
|
|
560
|
+
)(x)
|
|
561
|
+
x = nn.Dense(
|
|
562
|
+
self.ffn_dim,
|
|
563
|
+
dtype=self.dtype,
|
|
564
|
+
use_bias=self.config.use_bias,
|
|
565
|
+
kernel_init=(
|
|
566
|
+
deepnet_init(self.config.init_std, gain)
|
|
567
|
+
if (self.config.use_deepnet_scaling or self.config.use_subln_init)
|
|
568
|
+
else jax.nn.initializers.normal(self.config.init_std)
|
|
569
|
+
),
|
|
570
|
+
)(x)
|
|
571
|
+
x = ACT2FN[self.config.activation_function](x)
|
|
572
|
+
if self.config.ln_positions in ["normformer", "subln"]:
|
|
573
|
+
x = norm(
|
|
574
|
+
self.config.ln_type,
|
|
575
|
+
dtype=self.dtype,
|
|
576
|
+
epsilon=1e-05,
|
|
577
|
+
use_scale=self.config.force_ln_scale,
|
|
578
|
+
)(x)
|
|
579
|
+
x = nn.Dropout(rate=self.config.activation_dropout)(x, deterministic=deterministic)
|
|
580
|
+
x = nn.Dense(
|
|
581
|
+
self.embed_dim,
|
|
582
|
+
dtype=self.dtype,
|
|
583
|
+
use_bias=self.config.use_bias,
|
|
584
|
+
kernel_init=(
|
|
585
|
+
deepnet_init(self.config.init_std, gain)
|
|
586
|
+
if (self.config.use_deepnet_scaling or self.config.use_subln_init)
|
|
587
|
+
else jax.nn.initializers.normal(self.config.init_std)
|
|
588
|
+
),
|
|
589
|
+
)(x)
|
|
590
|
+
if self.config.ln_positions in ["swinv2", "cogview"]:
|
|
591
|
+
x = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(x)
|
|
592
|
+
x = nn.Dropout(rate=self.config.dropout)(x, deterministic=deterministic)
|
|
593
|
+
return x
|
|
594
|
+
|
|
595
|
+
|
|
596
|
+
class FlaxBartEncoderLayer(nn.Module):
|
|
597
|
+
"""
|
|
598
|
+
Edits:
|
|
599
|
+
- no bias
|
|
600
|
+
- use custom FlaxBartAttention
|
|
601
|
+
"""
|
|
602
|
+
|
|
603
|
+
config: DalleBartConfig
|
|
604
|
+
dtype: jnp.dtype = jnp.float32
|
|
605
|
+
add_norm: bool = False
|
|
606
|
+
use_scale: bool = True
|
|
607
|
+
|
|
608
|
+
@nn.compact
|
|
609
|
+
def __call__(
|
|
610
|
+
self,
|
|
611
|
+
hidden_states: jnp.ndarray,
|
|
612
|
+
attention_mask: jnp.ndarray,
|
|
613
|
+
output_attentions: bool = True,
|
|
614
|
+
deterministic: bool = True,
|
|
615
|
+
) -> Tuple[jnp.ndarray]:
|
|
616
|
+
|
|
617
|
+
if self.config.use_scan:
|
|
618
|
+
hidden_states = hidden_states[0]
|
|
619
|
+
|
|
620
|
+
res_gain = deepnet_gain["encoder"]["alpha"](self.config) if self.config.use_deepnet_scaling else 1
|
|
621
|
+
|
|
622
|
+
embed_dim = self.config.d_model
|
|
623
|
+
residual = hidden_states
|
|
624
|
+
if self.config.ln_positions in ["normformer", "cogview", "preln", "subln"]:
|
|
625
|
+
hidden_states = norm(
|
|
626
|
+
self.config.ln_type,
|
|
627
|
+
dtype=self.dtype,
|
|
628
|
+
epsilon=1e-05,
|
|
629
|
+
use_scale=self.config.force_ln_scale,
|
|
630
|
+
)(hidden_states)
|
|
631
|
+
hidden_states, attn_weights = FlaxBartAttention(
|
|
632
|
+
config=self.config,
|
|
633
|
+
embed_dim=embed_dim,
|
|
634
|
+
num_heads=self.config.encoder_attention_heads,
|
|
635
|
+
dropout=self.config.attention_dropout,
|
|
636
|
+
bias=self.config.use_bias,
|
|
637
|
+
dtype=self.dtype,
|
|
638
|
+
is_encoder=True,
|
|
639
|
+
is_cross_attention=False,
|
|
640
|
+
q_length=self.config.max_text_length,
|
|
641
|
+
k_length=self.config.max_text_length,
|
|
642
|
+
)(hidden_states=hidden_states, attention_mask=attention_mask)
|
|
643
|
+
|
|
644
|
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
|
645
|
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
|
|
646
|
+
hidden_states = nn.Dropout(rate=self.config.dropout)(hidden_states, deterministic=deterministic)
|
|
647
|
+
hidden_states = residual * res_gain + hidden_states
|
|
648
|
+
if self.config.ln_positions in ["postln"]:
|
|
649
|
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
|
|
650
|
+
|
|
651
|
+
residual = hidden_states
|
|
652
|
+
ff_block = (
|
|
653
|
+
GLU(
|
|
654
|
+
config=self.config,
|
|
655
|
+
ffn_dim=self.config.encoder_ffn_dim,
|
|
656
|
+
embed_dim=embed_dim,
|
|
657
|
+
dtype=self.dtype,
|
|
658
|
+
is_encoder=True,
|
|
659
|
+
)
|
|
660
|
+
if self.config.use_glu
|
|
661
|
+
else FFN(
|
|
662
|
+
config=self.config,
|
|
663
|
+
ffn_dim=self.config.encoder_ffn_dim,
|
|
664
|
+
embed_dim=embed_dim,
|
|
665
|
+
dtype=self.dtype,
|
|
666
|
+
is_encoder=True,
|
|
667
|
+
)
|
|
668
|
+
)
|
|
669
|
+
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
|
670
|
+
hidden_states = residual * res_gain + hidden_states
|
|
671
|
+
if self.add_norm:
|
|
672
|
+
use_scale = self.use_scale or self.config.force_ln_scale
|
|
673
|
+
hidden_states = norm(
|
|
674
|
+
self.config.ln_type,
|
|
675
|
+
dtype=self.dtype,
|
|
676
|
+
epsilon=1e-05,
|
|
677
|
+
use_scale=use_scale,
|
|
678
|
+
)(hidden_states)
|
|
679
|
+
|
|
680
|
+
outputs = (hidden_states,)
|
|
681
|
+
|
|
682
|
+
if output_attentions:
|
|
683
|
+
outputs += (attn_weights,)
|
|
684
|
+
|
|
685
|
+
if self.config.use_scan:
|
|
686
|
+
outputs = (outputs, None)
|
|
687
|
+
|
|
688
|
+
return outputs
|
|
689
|
+
|
|
690
|
+
|
|
691
|
+
class FlaxBartDecoderLayer(nn.Module):
|
|
692
|
+
"""
|
|
693
|
+
Edits:
|
|
694
|
+
- no bias
|
|
695
|
+
- use custom FlaxBartAttention
|
|
696
|
+
"""
|
|
697
|
+
|
|
698
|
+
config: DalleBartConfig
|
|
699
|
+
dtype: jnp.dtype = jnp.float32
|
|
700
|
+
add_norm: bool = False
|
|
701
|
+
use_scale: bool = True
|
|
702
|
+
|
|
703
|
+
@nn.compact
|
|
704
|
+
def __call__(
|
|
705
|
+
self,
|
|
706
|
+
hidden_states: jnp.ndarray,
|
|
707
|
+
attention_mask: jnp.ndarray,
|
|
708
|
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
|
709
|
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
710
|
+
init_cache: bool = False,
|
|
711
|
+
output_attentions: bool = True,
|
|
712
|
+
deterministic: bool = True,
|
|
713
|
+
) -> Tuple[jnp.ndarray]:
|
|
714
|
+
|
|
715
|
+
if self.config.use_scan:
|
|
716
|
+
hidden_states = hidden_states[0]
|
|
717
|
+
|
|
718
|
+
res_gain = deepnet_gain["decoder"]["alpha"](self.config) if self.config.use_deepnet_scaling else 1
|
|
719
|
+
|
|
720
|
+
embed_dim = self.config.d_model
|
|
721
|
+
residual = hidden_states
|
|
722
|
+
|
|
723
|
+
# Self Attention
|
|
724
|
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
|
725
|
+
hidden_states = norm(
|
|
726
|
+
self.config.ln_type,
|
|
727
|
+
dtype=self.dtype,
|
|
728
|
+
epsilon=1e-05,
|
|
729
|
+
use_scale=self.config.force_ln_scale,
|
|
730
|
+
)(hidden_states)
|
|
731
|
+
hidden_states, attn_weights = FlaxBartAttention(
|
|
732
|
+
config=self.config,
|
|
733
|
+
embed_dim=embed_dim,
|
|
734
|
+
num_heads=self.config.decoder_attention_heads,
|
|
735
|
+
dropout=self.config.attention_dropout,
|
|
736
|
+
causal=True,
|
|
737
|
+
bias=self.config.use_bias,
|
|
738
|
+
dtype=self.dtype,
|
|
739
|
+
is_encoder=False,
|
|
740
|
+
is_cross_attention=False,
|
|
741
|
+
q_length=self.config.image_length,
|
|
742
|
+
k_length=self.config.image_length,
|
|
743
|
+
)(
|
|
744
|
+
hidden_states=hidden_states,
|
|
745
|
+
attention_mask=attention_mask,
|
|
746
|
+
init_cache=init_cache,
|
|
747
|
+
)
|
|
748
|
+
|
|
749
|
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
|
750
|
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
|
|
751
|
+
hidden_states = nn.Dropout(rate=self.config.dropout)(hidden_states, deterministic=deterministic)
|
|
752
|
+
hidden_states = residual * res_gain + hidden_states
|
|
753
|
+
if self.config.ln_positions in ["postln"]:
|
|
754
|
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
|
|
755
|
+
|
|
756
|
+
# Cross Attention
|
|
757
|
+
cross_attn_weights = None
|
|
758
|
+
if encoder_hidden_states is not None:
|
|
759
|
+
residual = hidden_states
|
|
760
|
+
if self.config.ln_positions in ["normformer", "cogview", "preln"]:
|
|
761
|
+
hidden_states = norm(
|
|
762
|
+
self.config.ln_type,
|
|
763
|
+
dtype=self.dtype,
|
|
764
|
+
epsilon=1e-05,
|
|
765
|
+
use_scale=self.config.force_ln_scale,
|
|
766
|
+
)(hidden_states)
|
|
767
|
+
hidden_states, cross_attn_weights = FlaxBartAttention(
|
|
768
|
+
config=self.config,
|
|
769
|
+
embed_dim=embed_dim,
|
|
770
|
+
num_heads=self.config.decoder_attention_heads,
|
|
771
|
+
dropout=self.config.attention_dropout,
|
|
772
|
+
bias=self.config.use_bias,
|
|
773
|
+
dtype=self.dtype,
|
|
774
|
+
is_encoder=False,
|
|
775
|
+
is_cross_attention=True,
|
|
776
|
+
q_length=self.config.image_length,
|
|
777
|
+
k_length=self.config.max_text_length,
|
|
778
|
+
)(
|
|
779
|
+
hidden_states=hidden_states,
|
|
780
|
+
key_value_states=encoder_hidden_states,
|
|
781
|
+
attention_mask=encoder_attention_mask,
|
|
782
|
+
)
|
|
783
|
+
if self.config.ln_positions in ["normformer", "swinv2", "cogview"]:
|
|
784
|
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
|
|
785
|
+
hidden_states = nn.Dropout(rate=self.config.dropout)(hidden_states, deterministic=deterministic)
|
|
786
|
+
hidden_states = residual * res_gain + hidden_states
|
|
787
|
+
if self.config.ln_positions in ["postln"]:
|
|
788
|
+
hidden_states = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)(hidden_states)
|
|
789
|
+
|
|
790
|
+
# Feed forward
|
|
791
|
+
residual = hidden_states
|
|
792
|
+
ff_block = (
|
|
793
|
+
GLU(
|
|
794
|
+
config=self.config,
|
|
795
|
+
ffn_dim=self.config.decoder_ffn_dim,
|
|
796
|
+
embed_dim=embed_dim,
|
|
797
|
+
dtype=self.dtype,
|
|
798
|
+
is_encoder=False,
|
|
799
|
+
)
|
|
800
|
+
if self.config.use_glu
|
|
801
|
+
else FFN(
|
|
802
|
+
config=self.config,
|
|
803
|
+
ffn_dim=self.config.decoder_ffn_dim,
|
|
804
|
+
embed_dim=embed_dim,
|
|
805
|
+
dtype=self.dtype,
|
|
806
|
+
is_encoder=False,
|
|
807
|
+
)
|
|
808
|
+
)
|
|
809
|
+
hidden_states = ff_block(hidden_states, deterministic=deterministic)
|
|
810
|
+
hidden_states = residual * res_gain + hidden_states
|
|
811
|
+
if self.add_norm:
|
|
812
|
+
use_scale = self.use_scale or self.config.force_ln_scale
|
|
813
|
+
hidden_states = norm(
|
|
814
|
+
self.config.ln_type,
|
|
815
|
+
dtype=self.dtype,
|
|
816
|
+
epsilon=1e-05,
|
|
817
|
+
use_scale=use_scale,
|
|
818
|
+
)(hidden_states)
|
|
819
|
+
|
|
820
|
+
outputs = (hidden_states,)
|
|
821
|
+
|
|
822
|
+
if output_attentions:
|
|
823
|
+
outputs += (attn_weights, cross_attn_weights)
|
|
824
|
+
|
|
825
|
+
if self.config.use_scan:
|
|
826
|
+
outputs = (outputs, None)
|
|
827
|
+
|
|
828
|
+
return outputs
|
|
829
|
+
|
|
830
|
+
|
|
831
|
+
class FlaxBartEncoderLayerCollection(nn.Module):
|
|
832
|
+
config: DalleBartConfig
|
|
833
|
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
834
|
+
"""
|
|
835
|
+
Edits:
|
|
836
|
+
- use custom FlaxBartEncoderLayer
|
|
837
|
+
- allow Gradient Checkpointing (nn.remat)
|
|
838
|
+
"""
|
|
839
|
+
|
|
840
|
+
@nn.compact
|
|
841
|
+
def __call__(
|
|
842
|
+
self,
|
|
843
|
+
hidden_states,
|
|
844
|
+
attention_mask,
|
|
845
|
+
deterministic: bool = True,
|
|
846
|
+
output_attentions: bool = False,
|
|
847
|
+
output_hidden_states: bool = False,
|
|
848
|
+
return_dict: bool = True,
|
|
849
|
+
):
|
|
850
|
+
all_hidden_states = () if output_hidden_states else None
|
|
851
|
+
all_self_attns = () if output_attentions else None
|
|
852
|
+
|
|
853
|
+
n_layers = self.config.encoder_layers
|
|
854
|
+
layer = (
|
|
855
|
+
remat(
|
|
856
|
+
FlaxBartEncoderLayer,
|
|
857
|
+
static_argnums=(2, 3),
|
|
858
|
+
prevent_cse=not self.config.use_scan,
|
|
859
|
+
)
|
|
860
|
+
if self.config.gradient_checkpointing
|
|
861
|
+
else FlaxBartEncoderLayer
|
|
862
|
+
)
|
|
863
|
+
|
|
864
|
+
if self.config.use_scan:
|
|
865
|
+
# all blocks are the same so we use nn.scan
|
|
866
|
+
assert not output_attentions, "cannot scan with output_attentions"
|
|
867
|
+
assert not output_hidden_states, "cannot scan with output_hidden_states"
|
|
868
|
+
hidden_states = (hidden_states,)
|
|
869
|
+
# we use a scale on all norms (even last layer) to allow scanning
|
|
870
|
+
hidden_states, _ = nn.scan(
|
|
871
|
+
layer,
|
|
872
|
+
variable_axes={"params": 0, "cache": 0},
|
|
873
|
+
split_rngs={"params": True, "dropout": True},
|
|
874
|
+
in_axes=(nn.broadcast, nn.broadcast, nn.broadcast),
|
|
875
|
+
length=n_layers,
|
|
876
|
+
)(
|
|
877
|
+
self.config,
|
|
878
|
+
dtype=self.dtype,
|
|
879
|
+
add_norm=self.config.ln_positions == "postln",
|
|
880
|
+
name="FlaxBartEncoderLayers",
|
|
881
|
+
)(
|
|
882
|
+
hidden_states,
|
|
883
|
+
attention_mask,
|
|
884
|
+
output_attentions,
|
|
885
|
+
deterministic,
|
|
886
|
+
)
|
|
887
|
+
hidden_states = hidden_states[0]
|
|
888
|
+
else:
|
|
889
|
+
for i in range(n_layers):
|
|
890
|
+
if output_hidden_states:
|
|
891
|
+
all_hidden_states += (hidden_states,)
|
|
892
|
+
# final layernorm on the output of the last layer
|
|
893
|
+
# or every 6 layers for Swin v2
|
|
894
|
+
add_norm = self.config.ln_positions == "postln" or (
|
|
895
|
+
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0) and (i != n_layers - 1)
|
|
896
|
+
)
|
|
897
|
+
# we don't need to scale the norm for the last layer
|
|
898
|
+
use_scale = i != n_layers - 1
|
|
899
|
+
layer_outputs = layer(
|
|
900
|
+
self.config,
|
|
901
|
+
dtype=self.dtype,
|
|
902
|
+
add_norm=add_norm,
|
|
903
|
+
use_scale=use_scale,
|
|
904
|
+
name=f"FlaxBartEncoderLayer_{i}",
|
|
905
|
+
)(
|
|
906
|
+
hidden_states,
|
|
907
|
+
attention_mask,
|
|
908
|
+
output_attentions,
|
|
909
|
+
deterministic,
|
|
910
|
+
)
|
|
911
|
+
hidden_states = layer_outputs[0]
|
|
912
|
+
if output_attentions:
|
|
913
|
+
all_self_attns += (layer_outputs[1],)
|
|
914
|
+
|
|
915
|
+
# add hidden states from the last layer
|
|
916
|
+
if output_hidden_states:
|
|
917
|
+
all_hidden_states += (hidden_states,)
|
|
918
|
+
|
|
919
|
+
outputs = [
|
|
920
|
+
hidden_states,
|
|
921
|
+
all_hidden_states,
|
|
922
|
+
all_self_attns,
|
|
923
|
+
]
|
|
924
|
+
|
|
925
|
+
if not return_dict:
|
|
926
|
+
return tuple(v for v in outputs if v is not None)
|
|
927
|
+
|
|
928
|
+
return FlaxBaseModelOutput(
|
|
929
|
+
last_hidden_state=hidden_states,
|
|
930
|
+
hidden_states=all_hidden_states,
|
|
931
|
+
attentions=all_self_attns,
|
|
932
|
+
)
|
|
933
|
+
|
|
934
|
+
|
|
935
|
+
class FlaxBartDecoderLayerCollection(nn.Module):
|
|
936
|
+
config: DalleBartConfig
|
|
937
|
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
938
|
+
"""
|
|
939
|
+
Edits:
|
|
940
|
+
- use custom FlaxBartDecoderLayer
|
|
941
|
+
- allow Gradient Checkpointing (nn.remat)
|
|
942
|
+
"""
|
|
943
|
+
|
|
944
|
+
@nn.compact
|
|
945
|
+
def __call__(
|
|
946
|
+
self,
|
|
947
|
+
hidden_states,
|
|
948
|
+
attention_mask,
|
|
949
|
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
|
950
|
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
951
|
+
deterministic: bool = True,
|
|
952
|
+
init_cache: bool = False,
|
|
953
|
+
output_attentions: bool = False,
|
|
954
|
+
output_hidden_states: bool = False,
|
|
955
|
+
return_dict: bool = True,
|
|
956
|
+
):
|
|
957
|
+
# decoder layers
|
|
958
|
+
all_hidden_states = () if output_hidden_states else None
|
|
959
|
+
all_self_attns = () if output_attentions else None
|
|
960
|
+
all_cross_attentions = () if (output_attentions and encoder_hidden_states is not None) else None
|
|
961
|
+
|
|
962
|
+
n_layers = self.config.decoder_layers
|
|
963
|
+
layer = (
|
|
964
|
+
remat(
|
|
965
|
+
FlaxBartDecoderLayer,
|
|
966
|
+
static_argnums=(4, 5, 6),
|
|
967
|
+
prevent_cse=not self.config.use_scan,
|
|
968
|
+
)
|
|
969
|
+
if self.config.gradient_checkpointing
|
|
970
|
+
else FlaxBartDecoderLayer
|
|
971
|
+
)
|
|
972
|
+
|
|
973
|
+
if self.config.use_scan:
|
|
974
|
+
# all blocks are the same so we use nn.scan
|
|
975
|
+
assert not output_attentions, "cannot scan with output_attentions"
|
|
976
|
+
assert not output_hidden_states, "cannot scan with output_hidden_states"
|
|
977
|
+
hidden_states = (hidden_states,)
|
|
978
|
+
# we use a scale on all norms (even last layer) to allow scanning
|
|
979
|
+
hidden_states, _ = nn.scan(
|
|
980
|
+
layer,
|
|
981
|
+
variable_axes={"params": 0, "cache": 0},
|
|
982
|
+
split_rngs={"params": True, "dropout": True},
|
|
983
|
+
in_axes=(
|
|
984
|
+
nn.broadcast,
|
|
985
|
+
nn.broadcast,
|
|
986
|
+
nn.broadcast,
|
|
987
|
+
nn.broadcast,
|
|
988
|
+
nn.broadcast,
|
|
989
|
+
nn.broadcast,
|
|
990
|
+
),
|
|
991
|
+
length=n_layers,
|
|
992
|
+
)(
|
|
993
|
+
self.config,
|
|
994
|
+
dtype=self.dtype,
|
|
995
|
+
add_norm=self.config.ln_positions == "postln",
|
|
996
|
+
name="FlaxBartDecoderLayers",
|
|
997
|
+
)(
|
|
998
|
+
hidden_states,
|
|
999
|
+
attention_mask,
|
|
1000
|
+
encoder_hidden_states,
|
|
1001
|
+
encoder_attention_mask,
|
|
1002
|
+
init_cache,
|
|
1003
|
+
output_attentions,
|
|
1004
|
+
deterministic,
|
|
1005
|
+
)
|
|
1006
|
+
hidden_states = hidden_states[0]
|
|
1007
|
+
|
|
1008
|
+
else:
|
|
1009
|
+
for i in range(n_layers):
|
|
1010
|
+
if output_hidden_states:
|
|
1011
|
+
all_hidden_states += (hidden_states,)
|
|
1012
|
+
# final layernorm on the output of the last layer
|
|
1013
|
+
# or every 6 layers for Swin v2
|
|
1014
|
+
add_norm = self.config.ln_positions == "postln" or (
|
|
1015
|
+
self.config.ln_positions == "swinv2" and ((i + 1) % 6 == 0) and (i != n_layers - 1)
|
|
1016
|
+
)
|
|
1017
|
+
# we don't need to scale the norm for the last layer
|
|
1018
|
+
use_scale = i != n_layers - 1
|
|
1019
|
+
layer_outputs = layer(
|
|
1020
|
+
self.config,
|
|
1021
|
+
dtype=self.dtype,
|
|
1022
|
+
add_norm=add_norm,
|
|
1023
|
+
use_scale=use_scale,
|
|
1024
|
+
name=f"FlaxBartDecoderLayer_{i}",
|
|
1025
|
+
)(
|
|
1026
|
+
hidden_states,
|
|
1027
|
+
attention_mask,
|
|
1028
|
+
encoder_hidden_states,
|
|
1029
|
+
encoder_attention_mask,
|
|
1030
|
+
init_cache,
|
|
1031
|
+
output_attentions,
|
|
1032
|
+
deterministic,
|
|
1033
|
+
)
|
|
1034
|
+
|
|
1035
|
+
hidden_states = layer_outputs[0]
|
|
1036
|
+
if output_attentions:
|
|
1037
|
+
all_self_attns += (layer_outputs[1],)
|
|
1038
|
+
|
|
1039
|
+
if encoder_hidden_states is not None:
|
|
1040
|
+
all_cross_attentions += (layer_outputs[2],)
|
|
1041
|
+
|
|
1042
|
+
# add hidden states from the last decoder layer
|
|
1043
|
+
if output_hidden_states:
|
|
1044
|
+
all_hidden_states += (hidden_states,)
|
|
1045
|
+
|
|
1046
|
+
outputs = [
|
|
1047
|
+
hidden_states,
|
|
1048
|
+
all_hidden_states,
|
|
1049
|
+
all_self_attns,
|
|
1050
|
+
all_cross_attentions,
|
|
1051
|
+
]
|
|
1052
|
+
|
|
1053
|
+
if not return_dict:
|
|
1054
|
+
return tuple(v for v in outputs if v is not None)
|
|
1055
|
+
|
|
1056
|
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
|
1057
|
+
last_hidden_state=hidden_states,
|
|
1058
|
+
hidden_states=all_hidden_states,
|
|
1059
|
+
attentions=all_self_attns,
|
|
1060
|
+
cross_attentions=all_cross_attentions,
|
|
1061
|
+
)
|
|
1062
|
+
|
|
1063
|
+
|
|
1064
|
+
class FlaxBartEncoder(nn.Module):
|
|
1065
|
+
config: DalleBartConfig
|
|
1066
|
+
embed_tokens: nn.Embed
|
|
1067
|
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
1068
|
+
"""
|
|
1069
|
+
Edits:
|
|
1070
|
+
- offset set to 0 (no padding token)
|
|
1071
|
+
- use max_text_length instead of max_position_embeddings
|
|
1072
|
+
- use custom FlaxBartEncoderLayerCollection
|
|
1073
|
+
- embed_tokens cannot be None (issue at compile time)
|
|
1074
|
+
"""
|
|
1075
|
+
|
|
1076
|
+
def setup(self):
|
|
1077
|
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
|
1078
|
+
|
|
1079
|
+
embed_dim = self.config.d_model
|
|
1080
|
+
self.padding_idx = self.config.pad_token_id
|
|
1081
|
+
self.embed_scale = math.sqrt(embed_dim) if self.config.scale_embedding else 1.0
|
|
1082
|
+
|
|
1083
|
+
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
|
1084
|
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
|
1085
|
+
self.offset = 0
|
|
1086
|
+
if self.config.use_absolute_position_embeddings:
|
|
1087
|
+
self.embed_positions = nn.Embed(
|
|
1088
|
+
self.config.max_text_length + self.offset, # image length for BOS
|
|
1089
|
+
embed_dim,
|
|
1090
|
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
1091
|
+
)
|
|
1092
|
+
self.layers = FlaxBartEncoderLayerCollection(self.config, self.dtype)
|
|
1093
|
+
self.layernorm_embedding = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)
|
|
1094
|
+
|
|
1095
|
+
# postln is already applied in every layer
|
|
1096
|
+
if self.config.use_final_ln_encoder and self.config.ln_positions != "postln":
|
|
1097
|
+
self.final_ln = norm(
|
|
1098
|
+
self.config.ln_type,
|
|
1099
|
+
dtype=self.dtype,
|
|
1100
|
+
epsilon=1e-05,
|
|
1101
|
+
use_scale=self.config.force_ln_scale,
|
|
1102
|
+
)
|
|
1103
|
+
else:
|
|
1104
|
+
self.final_ln = None
|
|
1105
|
+
|
|
1106
|
+
def __call__(
|
|
1107
|
+
self,
|
|
1108
|
+
input_ids,
|
|
1109
|
+
attention_mask,
|
|
1110
|
+
position_ids,
|
|
1111
|
+
output_attentions: bool = False,
|
|
1112
|
+
output_hidden_states: bool = False,
|
|
1113
|
+
return_dict: bool = True,
|
|
1114
|
+
deterministic: bool = True,
|
|
1115
|
+
):
|
|
1116
|
+
input_shape = input_ids.shape
|
|
1117
|
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
|
1118
|
+
|
|
1119
|
+
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
|
1120
|
+
|
|
1121
|
+
if self.config.use_absolute_position_embeddings:
|
|
1122
|
+
embed_pos = self.embed_positions(position_ids + self.offset)
|
|
1123
|
+
hidden_states = hidden_states + embed_pos
|
|
1124
|
+
|
|
1125
|
+
hidden_states = self.layernorm_embedding(hidden_states)
|
|
1126
|
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
|
1127
|
+
|
|
1128
|
+
outputs = self.layers(
|
|
1129
|
+
hidden_states,
|
|
1130
|
+
attention_mask,
|
|
1131
|
+
deterministic=deterministic,
|
|
1132
|
+
output_attentions=output_attentions,
|
|
1133
|
+
output_hidden_states=output_hidden_states,
|
|
1134
|
+
return_dict=return_dict,
|
|
1135
|
+
)
|
|
1136
|
+
|
|
1137
|
+
if self.final_ln is None:
|
|
1138
|
+
final_output = outputs[0]
|
|
1139
|
+
else:
|
|
1140
|
+
final_output = self.final_ln(outputs[0])
|
|
1141
|
+
|
|
1142
|
+
if not return_dict:
|
|
1143
|
+
return (final_output,) + outputs[1:]
|
|
1144
|
+
|
|
1145
|
+
return FlaxBaseModelOutput(
|
|
1146
|
+
last_hidden_state=final_output,
|
|
1147
|
+
hidden_states=outputs.hidden_states,
|
|
1148
|
+
attentions=outputs.attentions,
|
|
1149
|
+
)
|
|
1150
|
+
|
|
1151
|
+
|
|
1152
|
+
class FlaxBartDecoder(nn.Module):
|
|
1153
|
+
config: DalleBartConfig
|
|
1154
|
+
embed_tokens: nn.Embed
|
|
1155
|
+
dtype: jnp.dtype = jnp.float32 # the dtype of the computation
|
|
1156
|
+
"""
|
|
1157
|
+
Edits:
|
|
1158
|
+
- offset set to 0 (no padding token)
|
|
1159
|
+
- use image_length instead of max_position_embeddings
|
|
1160
|
+
- use custom FlaxBartDecoderLayerCollection
|
|
1161
|
+
- embed_tokens cannot be None (issue at compile time)
|
|
1162
|
+
"""
|
|
1163
|
+
|
|
1164
|
+
def setup(self):
|
|
1165
|
+
self.dropout_layer = nn.Dropout(rate=self.config.dropout)
|
|
1166
|
+
|
|
1167
|
+
embed_dim = self.config.d_model
|
|
1168
|
+
self.padding_idx = self.config.pad_token_id
|
|
1169
|
+
self.embed_scale = math.sqrt(self.config.d_model) if self.config.scale_embedding else 1.0
|
|
1170
|
+
|
|
1171
|
+
# Bart is set up so that if padding_idx is specified then offset the embedding ids by 2
|
|
1172
|
+
# and adjust num_embeddings appropriately. Other models don't have this hack
|
|
1173
|
+
self.offset = 0
|
|
1174
|
+
if self.config.use_absolute_position_embeddings:
|
|
1175
|
+
self.embed_positions = nn.Embed(
|
|
1176
|
+
self.config.image_length + self.offset, # image length for BOS
|
|
1177
|
+
embed_dim,
|
|
1178
|
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
1179
|
+
)
|
|
1180
|
+
|
|
1181
|
+
self.layers = FlaxBartDecoderLayerCollection(self.config, self.dtype)
|
|
1182
|
+
self.layernorm_embedding = norm(self.config.ln_type, dtype=self.dtype, epsilon=1e-05)
|
|
1183
|
+
|
|
1184
|
+
# postln is already applied in every layer
|
|
1185
|
+
if self.config.use_final_ln_decoder and self.config.ln_positions != "postln":
|
|
1186
|
+
self.final_ln = norm(
|
|
1187
|
+
self.config.ln_type,
|
|
1188
|
+
dtype=self.dtype,
|
|
1189
|
+
epsilon=1e-05,
|
|
1190
|
+
use_scale=self.config.force_ln_scale,
|
|
1191
|
+
)
|
|
1192
|
+
|
|
1193
|
+
def __call__(
|
|
1194
|
+
self,
|
|
1195
|
+
input_ids,
|
|
1196
|
+
attention_mask,
|
|
1197
|
+
position_ids,
|
|
1198
|
+
encoder_hidden_states: Optional[jnp.ndarray] = None,
|
|
1199
|
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
1200
|
+
init_cache: bool = False,
|
|
1201
|
+
output_attentions: bool = False,
|
|
1202
|
+
output_hidden_states: bool = False,
|
|
1203
|
+
return_dict: bool = True,
|
|
1204
|
+
deterministic: bool = True,
|
|
1205
|
+
):
|
|
1206
|
+
input_shape = input_ids.shape
|
|
1207
|
+
input_ids = input_ids.reshape(-1, input_shape[-1])
|
|
1208
|
+
|
|
1209
|
+
hidden_states = self.embed_tokens(input_ids) * self.embed_scale
|
|
1210
|
+
|
|
1211
|
+
if self.config.use_absolute_position_embeddings:
|
|
1212
|
+
embed_pos = self.embed_positions(position_ids + self.offset)
|
|
1213
|
+
hidden_states = hidden_states + embed_pos
|
|
1214
|
+
|
|
1215
|
+
hidden_states = self.layernorm_embedding(hidden_states)
|
|
1216
|
+
hidden_states = self.dropout_layer(hidden_states, deterministic=deterministic)
|
|
1217
|
+
|
|
1218
|
+
outputs = self.layers(
|
|
1219
|
+
hidden_states,
|
|
1220
|
+
attention_mask,
|
|
1221
|
+
encoder_hidden_states,
|
|
1222
|
+
encoder_attention_mask,
|
|
1223
|
+
deterministic=deterministic,
|
|
1224
|
+
init_cache=init_cache,
|
|
1225
|
+
output_attentions=output_attentions,
|
|
1226
|
+
output_hidden_states=output_hidden_states,
|
|
1227
|
+
return_dict=return_dict,
|
|
1228
|
+
)
|
|
1229
|
+
|
|
1230
|
+
if self.final_ln is None:
|
|
1231
|
+
final_output = outputs[0]
|
|
1232
|
+
else:
|
|
1233
|
+
final_output = self.final_ln(outputs[0])
|
|
1234
|
+
|
|
1235
|
+
if not return_dict:
|
|
1236
|
+
return (final_output,) + outputs[1:]
|
|
1237
|
+
|
|
1238
|
+
return FlaxBaseModelOutputWithPastAndCrossAttentions(
|
|
1239
|
+
last_hidden_state=final_output,
|
|
1240
|
+
hidden_states=outputs.hidden_states,
|
|
1241
|
+
attentions=outputs.attentions,
|
|
1242
|
+
cross_attentions=outputs.cross_attentions,
|
|
1243
|
+
)
|
|
1244
|
+
|
|
1245
|
+
|
|
1246
|
+
class FlaxBartModule(FlaxBartModule):
|
|
1247
|
+
"""
|
|
1248
|
+
Edits
|
|
1249
|
+
- use custom FlaxBartEncoder & FlaxBartDecoder
|
|
1250
|
+
- use separate embeddings for Encoder & Decoder
|
|
1251
|
+
"""
|
|
1252
|
+
|
|
1253
|
+
def setup(self):
|
|
1254
|
+
encoder_embed_tokens = nn.Embed(
|
|
1255
|
+
self.config.encoder_vocab_size,
|
|
1256
|
+
self.config.d_model,
|
|
1257
|
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
1258
|
+
)
|
|
1259
|
+
decoder_embed_tokens = nn.Embed(
|
|
1260
|
+
self.config.image_vocab_size + 1, # image vocab size + 1 for BOS
|
|
1261
|
+
self.config.d_model,
|
|
1262
|
+
embedding_init=jax.nn.initializers.normal(self.config.init_std),
|
|
1263
|
+
)
|
|
1264
|
+
|
|
1265
|
+
self.encoder = FlaxBartEncoder(self.config, dtype=self.dtype, embed_tokens=encoder_embed_tokens)
|
|
1266
|
+
self.decoder = FlaxBartDecoder(self.config, dtype=self.dtype, embed_tokens=decoder_embed_tokens)
|
|
1267
|
+
|
|
1268
|
+
|
|
1269
|
+
class FlaxBartForConditionalGenerationModule(FlaxBartForConditionalGenerationModule):
|
|
1270
|
+
"""
|
|
1271
|
+
Edits:
|
|
1272
|
+
- no bias
|
|
1273
|
+
- lm_head set to image_vocab_size + 1 (for BOS)
|
|
1274
|
+
- uses custom FlaxBartModule
|
|
1275
|
+
"""
|
|
1276
|
+
|
|
1277
|
+
def setup(self):
|
|
1278
|
+
self.model = FlaxBartModule(config=self.config, dtype=self.dtype)
|
|
1279
|
+
self.lm_head = nn.Dense(
|
|
1280
|
+
self.config.image_vocab_size
|
|
1281
|
+
+ 1, # image vocab size + 1 for BOS to have same size as decoder inputs (for sharding)
|
|
1282
|
+
use_bias=False,
|
|
1283
|
+
dtype=self.dtype,
|
|
1284
|
+
kernel_init=jax.nn.initializers.normal(self.config.init_std),
|
|
1285
|
+
)
|
|
1286
|
+
|
|
1287
|
+
def __call__(
|
|
1288
|
+
self,
|
|
1289
|
+
input_ids,
|
|
1290
|
+
attention_mask,
|
|
1291
|
+
decoder_input_ids,
|
|
1292
|
+
decoder_attention_mask,
|
|
1293
|
+
position_ids,
|
|
1294
|
+
decoder_position_ids,
|
|
1295
|
+
output_attentions: bool = False,
|
|
1296
|
+
output_hidden_states: bool = False,
|
|
1297
|
+
return_dict: bool = True,
|
|
1298
|
+
deterministic: bool = True,
|
|
1299
|
+
):
|
|
1300
|
+
outputs = self.model(
|
|
1301
|
+
input_ids=input_ids,
|
|
1302
|
+
attention_mask=attention_mask,
|
|
1303
|
+
decoder_input_ids=decoder_input_ids,
|
|
1304
|
+
decoder_attention_mask=decoder_attention_mask,
|
|
1305
|
+
position_ids=position_ids,
|
|
1306
|
+
decoder_position_ids=decoder_position_ids,
|
|
1307
|
+
output_attentions=output_attentions,
|
|
1308
|
+
output_hidden_states=output_hidden_states,
|
|
1309
|
+
return_dict=return_dict,
|
|
1310
|
+
deterministic=deterministic,
|
|
1311
|
+
)
|
|
1312
|
+
|
|
1313
|
+
hidden_states = outputs[0]
|
|
1314
|
+
|
|
1315
|
+
if self.config.tie_word_embeddings:
|
|
1316
|
+
shared_embedding = self.model.variables["params"]["shared"]["embedding"]
|
|
1317
|
+
lm_logits = self.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
|
1318
|
+
else:
|
|
1319
|
+
lm_logits = self.lm_head(hidden_states)
|
|
1320
|
+
|
|
1321
|
+
if not return_dict:
|
|
1322
|
+
output = (lm_logits,) + outputs[1:]
|
|
1323
|
+
return output
|
|
1324
|
+
|
|
1325
|
+
return FlaxSeq2SeqLMOutput(
|
|
1326
|
+
logits=lm_logits,
|
|
1327
|
+
decoder_hidden_states=outputs.decoder_hidden_states,
|
|
1328
|
+
decoder_attentions=outputs.decoder_attentions,
|
|
1329
|
+
cross_attentions=outputs.cross_attentions,
|
|
1330
|
+
encoder_last_hidden_state=outputs.encoder_last_hidden_state,
|
|
1331
|
+
encoder_hidden_states=outputs.encoder_hidden_states,
|
|
1332
|
+
encoder_attentions=outputs.encoder_attentions,
|
|
1333
|
+
)
|
|
1334
|
+
|
|
1335
|
+
|
|
1336
|
+
@flax.struct.dataclass
|
|
1337
|
+
class SampleState:
|
|
1338
|
+
cur_len: jnp.ndarray
|
|
1339
|
+
sequences: jnp.ndarray
|
|
1340
|
+
running_token: jnp.ndarray
|
|
1341
|
+
is_sent_finished: jnp.ndarray
|
|
1342
|
+
prng_key: jnp.ndarray
|
|
1343
|
+
model_kwargs: Dict[str, jnp.ndarray]
|
|
1344
|
+
model_kwargs_uncond: Dict[str, jnp.ndarray]
|
|
1345
|
+
|
|
1346
|
+
|
|
1347
|
+
@flax.struct.dataclass
|
|
1348
|
+
class FlaxSampleOutput(ModelOutput):
|
|
1349
|
+
"""
|
|
1350
|
+
Flax Base class for outputs of decoder-only generation models using sampling.
|
|
1351
|
+
|
|
1352
|
+
|
|
1353
|
+
Args:
|
|
1354
|
+
sequences (`jnp.ndarray` of shape `(batch_size, max_length)`):
|
|
1355
|
+
The generated sequences.
|
|
1356
|
+
"""
|
|
1357
|
+
|
|
1358
|
+
sequences: jnp.ndarray = None
|
|
1359
|
+
|
|
1360
|
+
|
|
1361
|
+
class DalleBart(PretrainedFromWandbMixin, FlaxBartForConditionalGeneration):
|
|
1362
|
+
"""
|
|
1363
|
+
Edits:
|
|
1364
|
+
- renamed from FlaxBartForConditionalGeneration
|
|
1365
|
+
- uses custom FlaxBartForConditionalGenerationModule
|
|
1366
|
+
- no bias in decode method
|
|
1367
|
+
- custom prepare_inputs_for_generation using "max_length - 1" to avoid issues
|
|
1368
|
+
related to position embedding during model.generate()
|
|
1369
|
+
- custom generate method to allow super conditions
|
|
1370
|
+
- num_params property
|
|
1371
|
+
- unscan function
|
|
1372
|
+
"""
|
|
1373
|
+
|
|
1374
|
+
module_class = FlaxBartForConditionalGenerationModule
|
|
1375
|
+
config_class = DalleBartConfig
|
|
1376
|
+
|
|
1377
|
+
def num_params(self, params=None):
|
|
1378
|
+
if params is None:
|
|
1379
|
+
params = self.params
|
|
1380
|
+
num_params = jax.tree_util.tree_map(lambda param: param.size, flatten_dict(unfreeze(params))).values()
|
|
1381
|
+
return sum(list(num_params))
|
|
1382
|
+
|
|
1383
|
+
def unscan(self, params):
|
|
1384
|
+
if self.config.use_scan:
|
|
1385
|
+
self.config.use_scan = False
|
|
1386
|
+
params = flatten_dict(params)
|
|
1387
|
+
scanned_keys = [k for k in params.keys() if "layers" in k]
|
|
1388
|
+
for k in scanned_keys:
|
|
1389
|
+
v = params[k]
|
|
1390
|
+
name_idx = k.index("layers") + 1
|
|
1391
|
+
for i in range(len(v)):
|
|
1392
|
+
new_k = (
|
|
1393
|
+
*k[:name_idx],
|
|
1394
|
+
f"{k[name_idx][:-1]}_{i}",
|
|
1395
|
+
*k[name_idx + 1 :],
|
|
1396
|
+
)
|
|
1397
|
+
params[new_k] = v[i]
|
|
1398
|
+
del params[k]
|
|
1399
|
+
params = unflatten_dict(params)
|
|
1400
|
+
return params
|
|
1401
|
+
|
|
1402
|
+
def decode(
|
|
1403
|
+
self,
|
|
1404
|
+
decoder_input_ids,
|
|
1405
|
+
encoder_outputs,
|
|
1406
|
+
encoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
1407
|
+
decoder_attention_mask: Optional[jnp.ndarray] = None,
|
|
1408
|
+
decoder_position_ids: Optional[jnp.ndarray] = None,
|
|
1409
|
+
past_key_values: dict = None,
|
|
1410
|
+
output_attentions: Optional[bool] = None,
|
|
1411
|
+
output_hidden_states: Optional[bool] = None,
|
|
1412
|
+
return_dict: Optional[bool] = None,
|
|
1413
|
+
train: bool = False,
|
|
1414
|
+
params: dict = None,
|
|
1415
|
+
dropout_rng: PRNGKey = None,
|
|
1416
|
+
):
|
|
1417
|
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
|
1418
|
+
output_hidden_states = (
|
|
1419
|
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
|
1420
|
+
)
|
|
1421
|
+
return_dict = return_dict if return_dict is not None else self.config.return_dict
|
|
1422
|
+
|
|
1423
|
+
encoder_hidden_states = encoder_outputs[0]
|
|
1424
|
+
if encoder_attention_mask is None:
|
|
1425
|
+
batch_size, sequence_length = encoder_hidden_states.shape[:2]
|
|
1426
|
+
encoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
|
1427
|
+
|
|
1428
|
+
batch_size, sequence_length = decoder_input_ids.shape
|
|
1429
|
+
if decoder_attention_mask is None:
|
|
1430
|
+
decoder_attention_mask = jnp.ones((batch_size, sequence_length))
|
|
1431
|
+
|
|
1432
|
+
if decoder_position_ids is None:
|
|
1433
|
+
if past_key_values is not None:
|
|
1434
|
+
raise ValueError("Make sure to provide `decoder_position_ids` when passing `past_key_values`.")
|
|
1435
|
+
|
|
1436
|
+
decoder_position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))
|
|
1437
|
+
|
|
1438
|
+
# Handle any PRNG if needed
|
|
1439
|
+
rngs = {}
|
|
1440
|
+
if dropout_rng is not None:
|
|
1441
|
+
rngs["dropout"] = dropout_rng
|
|
1442
|
+
|
|
1443
|
+
inputs = {"params": params or self.params}
|
|
1444
|
+
|
|
1445
|
+
# if past_key_values are passed then cache is already initialized a private flag init_cache has to be
|
|
1446
|
+
# passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that
|
|
1447
|
+
# it can be changed by FlaxBartAttention module
|
|
1448
|
+
if past_key_values:
|
|
1449
|
+
inputs["cache"] = past_key_values
|
|
1450
|
+
mutable = ["cache"]
|
|
1451
|
+
else:
|
|
1452
|
+
mutable = False
|
|
1453
|
+
|
|
1454
|
+
def _decoder_forward(
|
|
1455
|
+
module,
|
|
1456
|
+
decoder_input_ids,
|
|
1457
|
+
decoder_attention_mask,
|
|
1458
|
+
decoder_position_ids,
|
|
1459
|
+
**kwargs,
|
|
1460
|
+
):
|
|
1461
|
+
decoder_module = module._get_decoder_module()
|
|
1462
|
+
outputs = decoder_module(
|
|
1463
|
+
decoder_input_ids,
|
|
1464
|
+
decoder_attention_mask,
|
|
1465
|
+
decoder_position_ids,
|
|
1466
|
+
**kwargs,
|
|
1467
|
+
)
|
|
1468
|
+
hidden_states = outputs[0]
|
|
1469
|
+
|
|
1470
|
+
if self.config.tie_word_embeddings:
|
|
1471
|
+
shared_embedding = module.model.variables["params"]["shared"]["embedding"]
|
|
1472
|
+
lm_logits = module.lm_head.apply({"params": {"kernel": shared_embedding.T}}, hidden_states)
|
|
1473
|
+
else:
|
|
1474
|
+
lm_logits = module.lm_head(hidden_states)
|
|
1475
|
+
|
|
1476
|
+
return lm_logits, outputs
|
|
1477
|
+
|
|
1478
|
+
outputs = self.module.apply(
|
|
1479
|
+
inputs,
|
|
1480
|
+
decoder_input_ids=jnp.array(decoder_input_ids, dtype="i4"),
|
|
1481
|
+
decoder_attention_mask=jnp.array(decoder_attention_mask, dtype="i4"),
|
|
1482
|
+
decoder_position_ids=jnp.array(decoder_position_ids, dtype="i4"),
|
|
1483
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
1484
|
+
encoder_attention_mask=jnp.array(encoder_attention_mask, dtype="i4"),
|
|
1485
|
+
output_attentions=output_attentions,
|
|
1486
|
+
output_hidden_states=output_hidden_states,
|
|
1487
|
+
return_dict=return_dict,
|
|
1488
|
+
deterministic=not train,
|
|
1489
|
+
rngs=rngs,
|
|
1490
|
+
mutable=mutable,
|
|
1491
|
+
method=_decoder_forward,
|
|
1492
|
+
)
|
|
1493
|
+
|
|
1494
|
+
if past_key_values is None:
|
|
1495
|
+
lm_logits, decoder_outputs = outputs
|
|
1496
|
+
else:
|
|
1497
|
+
(lm_logits, decoder_outputs), past = outputs
|
|
1498
|
+
|
|
1499
|
+
if return_dict:
|
|
1500
|
+
outputs = FlaxCausalLMOutputWithCrossAttentions(
|
|
1501
|
+
logits=lm_logits,
|
|
1502
|
+
hidden_states=decoder_outputs.hidden_states,
|
|
1503
|
+
attentions=decoder_outputs.attentions,
|
|
1504
|
+
cross_attentions=decoder_outputs.cross_attentions,
|
|
1505
|
+
)
|
|
1506
|
+
else:
|
|
1507
|
+
outputs = (lm_logits,) + decoder_outputs[1:]
|
|
1508
|
+
|
|
1509
|
+
# add updated cache to model output
|
|
1510
|
+
if past_key_values is not None and return_dict:
|
|
1511
|
+
outputs["past_key_values"] = unfreeze(past["cache"])
|
|
1512
|
+
return outputs
|
|
1513
|
+
elif past_key_values is not None and not return_dict:
|
|
1514
|
+
outputs = outputs[:1] + (unfreeze(past["cache"]),) + outputs[1:]
|
|
1515
|
+
|
|
1516
|
+
return outputs
|
|
1517
|
+
|
|
1518
|
+
def prepare_inputs_for_generation(
|
|
1519
|
+
self,
|
|
1520
|
+
decoder_input_ids,
|
|
1521
|
+
max_length,
|
|
1522
|
+
attention_mask: Optional[jnp.DeviceArray] = None,
|
|
1523
|
+
decoder_attention_mask: Optional[jnp.DeviceArray] = None,
|
|
1524
|
+
encoder_outputs=None,
|
|
1525
|
+
**kwargs,
|
|
1526
|
+
):
|
|
1527
|
+
# initializing the cache
|
|
1528
|
+
batch_size, seq_length = decoder_input_ids.shape
|
|
1529
|
+
|
|
1530
|
+
past_key_values = self.init_cache(batch_size, max_length - 1, encoder_outputs)
|
|
1531
|
+
# Note that usually one would have to put 0's in the attention_mask for x > input_ids.shape[-1] and x < cache_length.
|
|
1532
|
+
# But since the decoder uses a causal mask, those positions are masked anyways.
|
|
1533
|
+
# Thus we can create a single static attention_mask here, which is more efficient for compilation
|
|
1534
|
+
extended_attention_mask = jnp.ones((batch_size, max_length - 1), dtype="i4")
|
|
1535
|
+
if decoder_attention_mask is not None:
|
|
1536
|
+
position_ids = decoder_attention_mask.cumsum(axis=-1) - 1
|
|
1537
|
+
extended_attention_mask = lax.dynamic_update_slice(extended_attention_mask, decoder_attention_mask, (0, 0))
|
|
1538
|
+
else:
|
|
1539
|
+
position_ids = jnp.broadcast_to(jnp.arange(seq_length, dtype="i4")[None, :], (batch_size, seq_length))
|
|
1540
|
+
|
|
1541
|
+
return {
|
|
1542
|
+
"past_key_values": past_key_values,
|
|
1543
|
+
"encoder_outputs": encoder_outputs,
|
|
1544
|
+
"encoder_attention_mask": attention_mask,
|
|
1545
|
+
"decoder_attention_mask": extended_attention_mask,
|
|
1546
|
+
"decoder_position_ids": position_ids,
|
|
1547
|
+
}
|
|
1548
|
+
|
|
1549
|
+
def generate(
|
|
1550
|
+
self,
|
|
1551
|
+
input_ids: jnp.ndarray,
|
|
1552
|
+
attention_mask: Optional[jnp.ndarray] = None,
|
|
1553
|
+
max_length: Optional[int] = None,
|
|
1554
|
+
pad_token_id: Optional[int] = None,
|
|
1555
|
+
bos_token_id: Optional[int] = None,
|
|
1556
|
+
eos_token_id: Optional[int] = None,
|
|
1557
|
+
decoder_start_token_id: Optional[int] = None,
|
|
1558
|
+
do_sample: Optional[bool] = None,
|
|
1559
|
+
prng_key: Optional[jnp.ndarray] = None,
|
|
1560
|
+
top_k: Optional[int] = None,
|
|
1561
|
+
top_p: Optional[float] = None,
|
|
1562
|
+
temperature: Optional[float] = None,
|
|
1563
|
+
num_beams: Optional[int] = None,
|
|
1564
|
+
no_repeat_ngram_size: Optional[int] = None,
|
|
1565
|
+
min_length: Optional[int] = None,
|
|
1566
|
+
forced_bos_token_id: Optional[int] = None,
|
|
1567
|
+
forced_eos_token_id: Optional[int] = None,
|
|
1568
|
+
length_penalty: Optional[float] = None,
|
|
1569
|
+
early_stopping: Optional[bool] = None,
|
|
1570
|
+
trace: bool = True,
|
|
1571
|
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
|
1572
|
+
condition_scale: Optional[float] = 1.0,
|
|
1573
|
+
input_ids_uncond: Optional[jnp.ndarray] = None,
|
|
1574
|
+
attention_mask_uncond: Optional[jnp.ndarray] = None,
|
|
1575
|
+
**model_kwargs,
|
|
1576
|
+
):
|
|
1577
|
+
"""Edit: Allow super conditioning."""
|
|
1578
|
+
|
|
1579
|
+
# set init values
|
|
1580
|
+
max_length = max_length if max_length is not None else self.config.max_length
|
|
1581
|
+
bos_token_id = bos_token_id if bos_token_id is not None else self.config.bos_token_id
|
|
1582
|
+
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
1583
|
+
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
1584
|
+
decoder_start_token_id = (
|
|
1585
|
+
decoder_start_token_id if decoder_start_token_id else self.config.decoder_start_token_id
|
|
1586
|
+
)
|
|
1587
|
+
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
|
1588
|
+
|
|
1589
|
+
if decoder_start_token_id is None and self.config.is_encoder_decoder:
|
|
1590
|
+
raise ValueError("`decoder_start_token_id` has to be defined for encoder-decoder generation.")
|
|
1591
|
+
|
|
1592
|
+
do_sample = do_sample if do_sample is not None else self.config.do_sample
|
|
1593
|
+
num_beams = num_beams if num_beams is not None else self.config.num_beams
|
|
1594
|
+
|
|
1595
|
+
if self.config.is_encoder_decoder:
|
|
1596
|
+
# add encoder_outputs to model_kwargs
|
|
1597
|
+
if model_kwargs.get("encoder_outputs") is None:
|
|
1598
|
+
model_kwargs_input = dict(model_kwargs)
|
|
1599
|
+
model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
|
|
1600
|
+
input_ids,
|
|
1601
|
+
params,
|
|
1602
|
+
{"attention_mask": attention_mask, **model_kwargs_input},
|
|
1603
|
+
)
|
|
1604
|
+
if condition_scale != 1.0:
|
|
1605
|
+
assert input_ids_uncond is not None, "`input_ids_uncond` has to be defined for super conditioning."
|
|
1606
|
+
assert do_sample is True, "`do_sample` has to be True for super conditioning."
|
|
1607
|
+
assert num_beams == 1, "`num_beams` has to be 1 for super conditioning."
|
|
1608
|
+
model_kwargs_uncond = self._prepare_encoder_decoder_kwargs_for_generation(
|
|
1609
|
+
input_ids_uncond,
|
|
1610
|
+
params,
|
|
1611
|
+
{
|
|
1612
|
+
"attention_mask": attention_mask_uncond,
|
|
1613
|
+
**model_kwargs_input,
|
|
1614
|
+
},
|
|
1615
|
+
)
|
|
1616
|
+
else:
|
|
1617
|
+
model_kwargs_uncond = None
|
|
1618
|
+
# prepare decoder_input_ids for generation
|
|
1619
|
+
input_ids = jnp.ones((input_ids.shape[0], 1), dtype="i4") * decoder_start_token_id
|
|
1620
|
+
|
|
1621
|
+
if not do_sample and num_beams == 1:
|
|
1622
|
+
logits_processor = self._get_logits_processor(
|
|
1623
|
+
no_repeat_ngram_size,
|
|
1624
|
+
min_length,
|
|
1625
|
+
max_length,
|
|
1626
|
+
eos_token_id,
|
|
1627
|
+
forced_bos_token_id,
|
|
1628
|
+
forced_eos_token_id,
|
|
1629
|
+
)
|
|
1630
|
+
return self._greedy_search(
|
|
1631
|
+
input_ids,
|
|
1632
|
+
max_length,
|
|
1633
|
+
pad_token_id,
|
|
1634
|
+
eos_token_id,
|
|
1635
|
+
logits_processor=logits_processor,
|
|
1636
|
+
trace=trace,
|
|
1637
|
+
params=params,
|
|
1638
|
+
model_kwargs=model_kwargs,
|
|
1639
|
+
)
|
|
1640
|
+
elif do_sample and num_beams == 1:
|
|
1641
|
+
try:
|
|
1642
|
+
logits_warper = self._get_logits_warper(top_k=top_k, top_p=top_p, temperature=temperature)
|
|
1643
|
+
logits_processor = self._get_logits_processor(
|
|
1644
|
+
no_repeat_ngram_size,
|
|
1645
|
+
min_length,
|
|
1646
|
+
max_length,
|
|
1647
|
+
eos_token_id,
|
|
1648
|
+
forced_bos_token_id,
|
|
1649
|
+
forced_eos_token_id,
|
|
1650
|
+
)
|
|
1651
|
+
except:
|
|
1652
|
+
logits_warper = self._get_logits_warper(
|
|
1653
|
+
generation_config=GenerationConfig(top_k=top_k, top_p=top_p, temperature=temperature)
|
|
1654
|
+
)
|
|
1655
|
+
logits_processor = self._get_logits_processor(
|
|
1656
|
+
generation_config=GenerationConfig(
|
|
1657
|
+
no_repeat_ngram_size=no_repeat_ngram_size,
|
|
1658
|
+
min_length=min_length,
|
|
1659
|
+
max_length=max_length,
|
|
1660
|
+
eos_token_id=eos_token_id,
|
|
1661
|
+
forced_bos_token_id=forced_bos_token_id,
|
|
1662
|
+
forced_eos_token_id=forced_eos_token_id,
|
|
1663
|
+
)
|
|
1664
|
+
)
|
|
1665
|
+
|
|
1666
|
+
return self._sample(
|
|
1667
|
+
input_ids,
|
|
1668
|
+
max_length,
|
|
1669
|
+
pad_token_id,
|
|
1670
|
+
eos_token_id,
|
|
1671
|
+
prng_key,
|
|
1672
|
+
logits_warper=logits_warper,
|
|
1673
|
+
logits_processor=logits_processor,
|
|
1674
|
+
trace=trace,
|
|
1675
|
+
params=params,
|
|
1676
|
+
model_kwargs=model_kwargs,
|
|
1677
|
+
condition_scale=condition_scale,
|
|
1678
|
+
model_kwargs_uncond=model_kwargs_uncond,
|
|
1679
|
+
)
|
|
1680
|
+
elif not do_sample and num_beams > 1:
|
|
1681
|
+
# broadcast input_ids & encoder_outputs
|
|
1682
|
+
input_ids = self._expand_to_num_beams(input_ids, num_beams=num_beams)
|
|
1683
|
+
|
|
1684
|
+
if "encoder_outputs" in model_kwargs:
|
|
1685
|
+
model_kwargs["encoder_outputs"]["last_hidden_state"] = self._expand_to_num_beams(
|
|
1686
|
+
model_kwargs["encoder_outputs"]["last_hidden_state"],
|
|
1687
|
+
num_beams=num_beams,
|
|
1688
|
+
)
|
|
1689
|
+
|
|
1690
|
+
if "attention_mask" in model_kwargs:
|
|
1691
|
+
model_kwargs["attention_mask"] = self._expand_to_num_beams(
|
|
1692
|
+
model_kwargs["attention_mask"], num_beams=num_beams
|
|
1693
|
+
)
|
|
1694
|
+
|
|
1695
|
+
logits_processor = self._get_logits_processor(
|
|
1696
|
+
no_repeat_ngram_size,
|
|
1697
|
+
min_length,
|
|
1698
|
+
max_length,
|
|
1699
|
+
eos_token_id,
|
|
1700
|
+
forced_bos_token_id,
|
|
1701
|
+
forced_eos_token_id,
|
|
1702
|
+
)
|
|
1703
|
+
|
|
1704
|
+
return self._beam_search(
|
|
1705
|
+
input_ids,
|
|
1706
|
+
max_length,
|
|
1707
|
+
pad_token_id,
|
|
1708
|
+
eos_token_id,
|
|
1709
|
+
length_penalty=length_penalty,
|
|
1710
|
+
early_stopping=early_stopping,
|
|
1711
|
+
logits_processor=logits_processor,
|
|
1712
|
+
trace=trace,
|
|
1713
|
+
params=params,
|
|
1714
|
+
model_kwargs=model_kwargs,
|
|
1715
|
+
)
|
|
1716
|
+
else:
|
|
1717
|
+
raise NotImplementedError("`Beam sampling is currently not implemented.")
|
|
1718
|
+
|
|
1719
|
+
def _sample(
|
|
1720
|
+
self,
|
|
1721
|
+
input_ids: None,
|
|
1722
|
+
max_length: Optional[int] = None,
|
|
1723
|
+
pad_token_id: Optional[int] = None,
|
|
1724
|
+
eos_token_id: Optional[int] = None,
|
|
1725
|
+
prng_key: Optional[jnp.ndarray] = None,
|
|
1726
|
+
logits_processor=None,
|
|
1727
|
+
logits_warper=None,
|
|
1728
|
+
trace: bool = True,
|
|
1729
|
+
params: Optional[Dict[str, jnp.ndarray]] = None,
|
|
1730
|
+
model_kwargs: Optional[Dict[str, jnp.ndarray]] = None,
|
|
1731
|
+
condition_scale: float = 1.0,
|
|
1732
|
+
model_kwargs_uncond: Optional[Dict[str, jnp.ndarray]] = None,
|
|
1733
|
+
):
|
|
1734
|
+
# init values
|
|
1735
|
+
max_length = max_length if max_length is not None else self.config.max_length
|
|
1736
|
+
pad_token_id = pad_token_id if pad_token_id is not None else self.config.pad_token_id
|
|
1737
|
+
eos_token_id = eos_token_id if eos_token_id is not None else self.config.eos_token_id
|
|
1738
|
+
prng_key = prng_key if prng_key is not None else jax.random.PRNGKey(0)
|
|
1739
|
+
|
|
1740
|
+
batch_size, cur_len = input_ids.shape
|
|
1741
|
+
|
|
1742
|
+
eos_token_id = jnp.array(eos_token_id)
|
|
1743
|
+
pad_token_id = jnp.array(pad_token_id)
|
|
1744
|
+
cur_len = jnp.array(cur_len)
|
|
1745
|
+
|
|
1746
|
+
# per batch-item holding current token in loop.
|
|
1747
|
+
sequences = jnp.full((batch_size, max_length), pad_token_id, dtype=jnp.int32)
|
|
1748
|
+
sequences = lax.dynamic_update_slice(sequences, input_ids, (0, 0))
|
|
1749
|
+
|
|
1750
|
+
# per batch-item state bit indicating if sentence has finished.
|
|
1751
|
+
is_sent_finished = jnp.zeros((batch_size,), dtype=jnp.bool_)
|
|
1752
|
+
|
|
1753
|
+
# For Seq2Seq generation, we only need to use the decoder instead of the whole model in generation loop
|
|
1754
|
+
# and pass it the `encoder_outputs`, which are part of the `model_kwargs`.
|
|
1755
|
+
model = self.decode if self.config.is_encoder_decoder else self
|
|
1756
|
+
|
|
1757
|
+
# initialize model specific kwargs
|
|
1758
|
+
model_kwargs = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs)
|
|
1759
|
+
if condition_scale != 1.0:
|
|
1760
|
+
model_kwargs_uncond = self.prepare_inputs_for_generation(input_ids, max_length, **model_kwargs_uncond)
|
|
1761
|
+
|
|
1762
|
+
# initialize state
|
|
1763
|
+
state = SampleState(
|
|
1764
|
+
cur_len=cur_len,
|
|
1765
|
+
sequences=sequences,
|
|
1766
|
+
running_token=input_ids,
|
|
1767
|
+
is_sent_finished=is_sent_finished,
|
|
1768
|
+
prng_key=prng_key,
|
|
1769
|
+
model_kwargs=model_kwargs,
|
|
1770
|
+
model_kwargs_uncond=model_kwargs_uncond,
|
|
1771
|
+
)
|
|
1772
|
+
|
|
1773
|
+
def sample_search_cond_fn(state):
|
|
1774
|
+
"""state termination condition fn."""
|
|
1775
|
+
has_reached_max_length = state.cur_len == max_length
|
|
1776
|
+
all_sequence_finished = jnp.all(state.is_sent_finished)
|
|
1777
|
+
finish_generation = jnp.logical_or(has_reached_max_length, all_sequence_finished)
|
|
1778
|
+
return ~finish_generation
|
|
1779
|
+
|
|
1780
|
+
def sample_search_body_fn(state):
|
|
1781
|
+
"""state update fn."""
|
|
1782
|
+
prng_key, prng_key_next = jax.random.split(state.prng_key)
|
|
1783
|
+
model_outputs = model(state.running_token, params=params, **state.model_kwargs)
|
|
1784
|
+
|
|
1785
|
+
logits = model_outputs.logits[:, -1]
|
|
1786
|
+
|
|
1787
|
+
# perform super conditioning
|
|
1788
|
+
# Source: @RiversHaveWings - https://twitter.com/RiversHaveWings/status/1478093658716966912?s=20&t=xdm-wZ61Wf7OLnE_NJHZ1w
|
|
1789
|
+
if condition_scale != 1.0:
|
|
1790
|
+
model_outputs_uncond = model(state.running_token, params=params, **state.model_kwargs_uncond)
|
|
1791
|
+
logits_uncond = model_outputs_uncond.logits[:, -1]
|
|
1792
|
+
logits = logits_uncond + condition_scale * (logits - logits_uncond)
|
|
1793
|
+
else:
|
|
1794
|
+
model_outputs_uncond = None
|
|
1795
|
+
|
|
1796
|
+
# apply min_length, ...
|
|
1797
|
+
logits = logits_processor(state.sequences, logits, state.cur_len)
|
|
1798
|
+
# apply top_k, top_k, temperature
|
|
1799
|
+
logits = logits_warper(logits, logits, state.cur_len)
|
|
1800
|
+
|
|
1801
|
+
next_token = jax.random.categorical(prng_key, logits, axis=-1)
|
|
1802
|
+
|
|
1803
|
+
next_is_sent_finished = state.is_sent_finished | (next_token == eos_token_id)
|
|
1804
|
+
next_token = next_token * ~next_is_sent_finished + pad_token_id * next_is_sent_finished
|
|
1805
|
+
next_token = next_token[:, None]
|
|
1806
|
+
|
|
1807
|
+
next_sequences = lax.dynamic_update_slice(state.sequences, next_token, (0, state.cur_len))
|
|
1808
|
+
next_model_kwargs = self.update_inputs_for_generation(model_outputs, state.model_kwargs)
|
|
1809
|
+
next_model_kwargs_uncond = (
|
|
1810
|
+
self.update_inputs_for_generation(model_outputs_uncond, state.model_kwargs_uncond)
|
|
1811
|
+
if condition_scale != 1.0
|
|
1812
|
+
else None
|
|
1813
|
+
)
|
|
1814
|
+
|
|
1815
|
+
return SampleState(
|
|
1816
|
+
cur_len=state.cur_len + 1,
|
|
1817
|
+
sequences=next_sequences,
|
|
1818
|
+
running_token=next_token,
|
|
1819
|
+
is_sent_finished=next_is_sent_finished,
|
|
1820
|
+
model_kwargs=next_model_kwargs,
|
|
1821
|
+
model_kwargs_uncond=next_model_kwargs_uncond,
|
|
1822
|
+
prng_key=prng_key_next,
|
|
1823
|
+
)
|
|
1824
|
+
|
|
1825
|
+
# The very first prompt often has sequence length > 1, so run outside of `lax.while_loop` to comply with TPU
|
|
1826
|
+
if input_ids.shape[1] > 1:
|
|
1827
|
+
state = sample_search_body_fn(state)
|
|
1828
|
+
|
|
1829
|
+
if not trace:
|
|
1830
|
+
state = self._run_loop_in_debug(sample_search_cond_fn, sample_search_body_fn, state)
|
|
1831
|
+
else:
|
|
1832
|
+
state = lax.while_loop(sample_search_cond_fn, sample_search_body_fn, state)
|
|
1833
|
+
|
|
1834
|
+
return FlaxSampleOutput(sequences=state.sequences)
|