crfm-helm 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +138 -31
- crfm_helm-0.5.1.dist-info/RECORD +654 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +25 -3
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +41 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +213 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +392 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +575 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +41 -1
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +205 -35
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +163 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +757 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +94 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +3 -4
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +5 -3
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_image2structure.yaml +304 -0
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
- helm/benchmark/static/schema_vlm.yaml +823 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
- helm/benchmark/static_build/assets/index-878a1094.css +1 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +233 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +301 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +104 -73
- helm/clients/vertexai_client.py +400 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +111 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +54 -49
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +33 -3
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1159 -538
- helm/config/model_metadata.yaml +868 -41
- helm/config/tokenizer_configs.yaml +149 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_deployment_definition.py +0 -92
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
4
|
+
|
|
5
|
+
try:
|
|
6
|
+
from flax.core.frozen_dict import freeze
|
|
7
|
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
|
8
|
+
from jax.experimental import PartitionSpec as P
|
|
9
|
+
except ModuleNotFoundError as e:
|
|
10
|
+
handle_module_not_found_error(e, ["heim"])
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
# utils adapted from https://github.com/google-research/google-research/blob/master/flax_models/t5x/partitions.py
|
|
14
|
+
# Sentinels
|
|
15
|
+
_unmatched = object()
|
|
16
|
+
|
|
17
|
+
# For specifying empty leaf dict `{}`
|
|
18
|
+
empty_dict = object()
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def _match(qs, ks):
|
|
22
|
+
"""Return True if regexes in qs match any window of strings in tuple ks."""
|
|
23
|
+
# compile regexes and force complete match
|
|
24
|
+
qts = tuple(map(lambda x: re.compile(x + "$"), qs))
|
|
25
|
+
for i in range(len(ks) - len(qs) + 1):
|
|
26
|
+
matches = [x.match(y) for x, y in zip(qts, ks[i:])]
|
|
27
|
+
if matches and all(matches):
|
|
28
|
+
return True
|
|
29
|
+
return False
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def _replacement_rules(rules):
|
|
33
|
+
def replace(key, val):
|
|
34
|
+
for rule, replacement in rules:
|
|
35
|
+
if _match(rule, key):
|
|
36
|
+
return replacement
|
|
37
|
+
return val
|
|
38
|
+
|
|
39
|
+
return replace
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _get_partition_rules():
|
|
43
|
+
return [
|
|
44
|
+
# embeddings
|
|
45
|
+
(("embed_positions", "embedding"), P("mp", None)),
|
|
46
|
+
(("embed_tokens", "embedding"), P("mp", None)),
|
|
47
|
+
(("rel_bias", "embedding"), P(None, "mp")),
|
|
48
|
+
# attention
|
|
49
|
+
(("(q_proj|k_proj|v_proj)", "kernel"), P(None, "mp")),
|
|
50
|
+
(("out_proj", "kernel"), P("mp", None)),
|
|
51
|
+
# FFN
|
|
52
|
+
(("Dense_0", "kernel"), P(None, "mp")),
|
|
53
|
+
(("GLU.*", "Dense_1", "kernel"), P(None, "mp")),
|
|
54
|
+
(("GLU.*", "Dense_2", "kernel"), P("mp", None)),
|
|
55
|
+
(("FFN.*", "Dense_1", "kernel"), P("mp", None)),
|
|
56
|
+
# layer norms
|
|
57
|
+
(("(bias|scale)",), None),
|
|
58
|
+
(("lm_head", "kernel"), P(None, "mp")),
|
|
59
|
+
# head scale and tau
|
|
60
|
+
(("(head_scale|tau)",), None),
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def set_partitions(in_dict, use_scan):
|
|
65
|
+
rules = _get_partition_rules()
|
|
66
|
+
replace = _replacement_rules(rules)
|
|
67
|
+
initd = {k: _unmatched for k in flatten_dict(in_dict)}
|
|
68
|
+
result = {k: replace(k, v) for k, v in initd.items()}
|
|
69
|
+
for k, v in result.items():
|
|
70
|
+
if v == _unmatched:
|
|
71
|
+
print(f"Unmatched -> {k}")
|
|
72
|
+
l = list(result.keys())
|
|
73
|
+
if use_scan:
|
|
74
|
+
# add None dimension to layers
|
|
75
|
+
result = {
|
|
76
|
+
k: (
|
|
77
|
+
(P(*(None,) + v) if v is not None else None)
|
|
78
|
+
if any(x in k for x in ["FlaxBartEncoderLayers", "FlaxBartDecoderLayers"])
|
|
79
|
+
else v
|
|
80
|
+
)
|
|
81
|
+
for k, v in result.items()
|
|
82
|
+
}
|
|
83
|
+
assert _unmatched not in result.values(), "Incomplete partition spec."
|
|
84
|
+
return freeze(unflatten_dict(result))
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
""" DalleBart processor """
|
|
2
|
+
|
|
3
|
+
from typing import List
|
|
4
|
+
|
|
5
|
+
from .configuration import DalleBartConfig
|
|
6
|
+
from .text import TextNormalizer
|
|
7
|
+
from .tokenizer import DalleBartTokenizer
|
|
8
|
+
from .utils import PretrainedFromWandbMixin
|
|
9
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class DalleBartProcessorBase:
|
|
13
|
+
def __init__(self, tokenizer: DalleBartTokenizer, normalize_text: bool, max_text_length: int):
|
|
14
|
+
self.tokenizer = tokenizer
|
|
15
|
+
self.normalize_text = normalize_text
|
|
16
|
+
self.max_text_length = max_text_length
|
|
17
|
+
if normalize_text:
|
|
18
|
+
self.text_processor = TextNormalizer()
|
|
19
|
+
# create unconditional tokens
|
|
20
|
+
uncond = self.tokenizer(
|
|
21
|
+
"",
|
|
22
|
+
return_tensors="jax",
|
|
23
|
+
padding="max_length",
|
|
24
|
+
truncation=True,
|
|
25
|
+
max_length=self.max_text_length,
|
|
26
|
+
).data
|
|
27
|
+
self.input_ids_uncond = uncond["input_ids"]
|
|
28
|
+
self.attention_mask_uncond = uncond["attention_mask"]
|
|
29
|
+
|
|
30
|
+
def __call__(self, text: List[str] = None):
|
|
31
|
+
try:
|
|
32
|
+
import jax.numpy as jnp
|
|
33
|
+
except ModuleNotFoundError as e:
|
|
34
|
+
handle_module_not_found_error(e, ["heim"])
|
|
35
|
+
|
|
36
|
+
# check that text is not a string
|
|
37
|
+
assert not isinstance(text, str), "text must be a list of strings"
|
|
38
|
+
|
|
39
|
+
if self.normalize_text:
|
|
40
|
+
text = [self.text_processor(t) for t in text]
|
|
41
|
+
res = self.tokenizer(
|
|
42
|
+
text,
|
|
43
|
+
return_tensors="jax",
|
|
44
|
+
padding="max_length",
|
|
45
|
+
truncation=True,
|
|
46
|
+
max_length=self.max_text_length,
|
|
47
|
+
).data
|
|
48
|
+
|
|
49
|
+
# tokens used only with super conditioning
|
|
50
|
+
n = len(text)
|
|
51
|
+
res["input_ids_uncond"] = jnp.repeat(self.input_ids_uncond, n, axis=0)
|
|
52
|
+
res["attention_mask_uncond"] = jnp.repeat(self.attention_mask_uncond, n, axis=0)
|
|
53
|
+
return res
|
|
54
|
+
|
|
55
|
+
@classmethod
|
|
56
|
+
def from_pretrained(cls, *args, **kwargs):
|
|
57
|
+
tokenizer = DalleBartTokenizer.from_pretrained(*args, **kwargs)
|
|
58
|
+
config = DalleBartConfig.from_pretrained(*args, **kwargs)
|
|
59
|
+
return cls(tokenizer, config.normalize_text, config.max_text_length)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class DalleBartProcessor(PretrainedFromWandbMixin, DalleBartProcessorBase):
|
|
63
|
+
pass
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Utilities for processing text.
|
|
3
|
+
"""
|
|
4
|
+
|
|
5
|
+
import html
|
|
6
|
+
import math
|
|
7
|
+
import random
|
|
8
|
+
import re
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import emoji
|
|
12
|
+
from huggingface_hub import hf_hub_download
|
|
13
|
+
|
|
14
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
15
|
+
|
|
16
|
+
try:
|
|
17
|
+
import ftfy
|
|
18
|
+
from unidecode import unidecode
|
|
19
|
+
except ModuleNotFoundError as e:
|
|
20
|
+
handle_module_not_found_error(e, ["heim"])
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# based on wiki word occurrence
|
|
24
|
+
person_token = [("a person", 282265), ("someone", 121194), ("somebody", 12219)]
|
|
25
|
+
temp_token = "xtokx" # avoid repeating chars
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class HashtagProcessor:
|
|
29
|
+
# Adapted from wordninja library
|
|
30
|
+
# We use our wikipedia word count + a good heuristic to make it work
|
|
31
|
+
def __init__(self):
|
|
32
|
+
wiki_word_frequency = hf_hub_download("dalle-mini/dalle-mini", filename="enwiki-words-frequency.txt")
|
|
33
|
+
self._word_cost = (l.split()[0] for l in Path(wiki_word_frequency).read_text(encoding="utf8").splitlines())
|
|
34
|
+
self._word_cost = {str(k): math.log(float(i + 1)) for i, k in enumerate(self._word_cost)}
|
|
35
|
+
self._max_word = max(len(x) for x in self._word_cost.keys())
|
|
36
|
+
self._SPLIT_RE = re.compile("[^a-zA-Z0-9']+")
|
|
37
|
+
|
|
38
|
+
def __call__(self, s):
|
|
39
|
+
"""Uses dynamic programming to infer the location of spaces in a string without spaces."""
|
|
40
|
+
l = [self._split(x) for x in self._SPLIT_RE.split(s)]
|
|
41
|
+
return " ".join([item for sublist in l for item in sublist])
|
|
42
|
+
|
|
43
|
+
def _split(self, s):
|
|
44
|
+
# Find the best match for the i first characters, assuming cost has
|
|
45
|
+
# been built for the i-1 first characters.
|
|
46
|
+
# Returns a pair (match_cost, match_length).
|
|
47
|
+
def best_match(i):
|
|
48
|
+
candidates = enumerate(reversed(cost[max(0, i - self._max_word) : i]))
|
|
49
|
+
return min((c + self._word_cost.get(s[i - k - 1 : i].lower(), 9e999), k + 1) for k, c in candidates)
|
|
50
|
+
|
|
51
|
+
# Build the cost array
|
|
52
|
+
cost = [0]
|
|
53
|
+
for i in range(1, len(s) + 1):
|
|
54
|
+
c, k = best_match(i)
|
|
55
|
+
cost.append(c)
|
|
56
|
+
|
|
57
|
+
# Backtrack to recover the minimal-cost string.
|
|
58
|
+
out = []
|
|
59
|
+
i = len(s)
|
|
60
|
+
while i > 0:
|
|
61
|
+
c, k = best_match(i)
|
|
62
|
+
assert c == cost[i]
|
|
63
|
+
newToken = True
|
|
64
|
+
if not s[i - k : i] == "'": # ignore a lone apostrophe
|
|
65
|
+
if len(out) > 0:
|
|
66
|
+
# re-attach split 's and split digits
|
|
67
|
+
if out[-1] == "'s" or (s[i - 1].isdigit() and out[-1][0].isdigit()): # digit followed by digit
|
|
68
|
+
out[-1] = s[i - k : i] + out[-1] # combine current token with previous token
|
|
69
|
+
newToken = False
|
|
70
|
+
|
|
71
|
+
if newToken:
|
|
72
|
+
out.append(s[i - k : i])
|
|
73
|
+
|
|
74
|
+
i -= k
|
|
75
|
+
|
|
76
|
+
return reversed(out)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def replace_person_token(t):
|
|
80
|
+
"Used for CC12M"
|
|
81
|
+
t = re.sub("<person>([,\s]*(and)*[,\s]*<person>)+", " people ", t)
|
|
82
|
+
while "<person>" in t:
|
|
83
|
+
t = t.replace("<person>", f" {random.choices(*tuple(zip(*person_token)))[0]} ", 1)
|
|
84
|
+
return t
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def fix_html(t):
|
|
88
|
+
# from OpenAI CLIP
|
|
89
|
+
return html.unescape(html.unescape(t))
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def replace_punctuation_with_commas(t):
|
|
93
|
+
return re.sub("[()[\].,|:;?!=+~\-\/{}]", ",", t)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def simplify_quotes(t):
|
|
97
|
+
return re.sub("""['"`]""", ' " ', t)
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
def merge_quotes(t):
|
|
101
|
+
return re.sub('(\s*"+\s*)+', ' " ', t)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def remove_comma_numbers(t):
|
|
105
|
+
def _f(t):
|
|
106
|
+
return re.sub("(\d),(\d{3})", r"\1\2", t)
|
|
107
|
+
|
|
108
|
+
return _f(_f(t))
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def pre_process_dot_numbers(t):
|
|
112
|
+
return re.sub("(\w)\.(\w)", rf"\1{temp_token}dot{temp_token}\2", t)
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def post_process_dot_numbers(t):
|
|
116
|
+
return re.sub(f"{temp_token}dot{temp_token}", ".", t)
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def pre_process_quotes(t):
|
|
120
|
+
# allows quotes only for 's, 't, 'd, 'm, 'll, 're, 've
|
|
121
|
+
return re.sub(r"'(?=([stdm]|(ll)|(re)|(ve)|(ll))\b)", rf"{temp_token}quote{temp_token}", t)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def post_process_quotes(t):
|
|
125
|
+
return re.sub(f"{temp_token}quote{temp_token}", "'", t)
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
def pre_process_dates(t):
|
|
129
|
+
return re.sub("(\d)/(\d)", rf"\1{temp_token}slash{temp_token}\2", t)
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def post_process_dates(t):
|
|
133
|
+
return re.sub(f"{temp_token}slash{temp_token}", "/", t)
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def merge_commas(t):
|
|
137
|
+
return re.sub("(\s*,+\s*)+", ", ", t)
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
def add_space_after_commas(t):
|
|
141
|
+
return re.sub(",", ", ", t)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def handle_special_chars(t):
|
|
145
|
+
"Handle special characters"
|
|
146
|
+
# replace "-" with a space when between words without space
|
|
147
|
+
t = re.sub("(\w)-(\w)", r"\1 \2", t)
|
|
148
|
+
# always add space around some characters
|
|
149
|
+
return re.sub("([%&\/$*])", r" \1 ", t)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def expand_hashtags(t, hashtag_processor):
|
|
153
|
+
"Remove # and try to split words"
|
|
154
|
+
return re.sub("#(\w+)", lambda m: hashtag_processor(m.group(1)), t)
|
|
155
|
+
|
|
156
|
+
|
|
157
|
+
_re_ignore_chars = r"[_#\\]"
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def ignore_chars(t):
|
|
161
|
+
"Ignore useless characters"
|
|
162
|
+
return re.sub(_re_ignore_chars, " ", t)
|
|
163
|
+
|
|
164
|
+
|
|
165
|
+
def remove_extra_spaces(t):
|
|
166
|
+
"Remove extra spaces (including \t and \n)"
|
|
167
|
+
return re.sub("\s+", " ", t)
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
def remove_repeating_chars(t):
|
|
171
|
+
"If the same character is present 4+ times (not 3 because of roman 'VIII'), replace with single instance"
|
|
172
|
+
return re.sub(r"(\D)(\1{3,})", r"\1", t)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def remove_urls(t):
|
|
176
|
+
return re.sub(r"http\S+", "", t)
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def remove_html_tags(t):
|
|
180
|
+
return re.sub("<[^<]+?>", " ", t)
|
|
181
|
+
|
|
182
|
+
|
|
183
|
+
def remove_first_last_commas(t):
|
|
184
|
+
t = t.strip()
|
|
185
|
+
t = t[:-1] if t and t[-1] == "," else t
|
|
186
|
+
t = t[1:] if t and t[0] == "," else t
|
|
187
|
+
return t.strip()
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
def remove_wiki_ref(t):
|
|
191
|
+
t = re.sub(r"\A\s*\[\d+\]", "", t)
|
|
192
|
+
return re.sub(r"\[\d+\]\s*\Z", "", t)
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
class TextNormalizer:
|
|
196
|
+
"Normalize text"
|
|
197
|
+
|
|
198
|
+
def __init__(self):
|
|
199
|
+
self._hashtag_processor = HashtagProcessor()
|
|
200
|
+
|
|
201
|
+
def __call__(self, t):
|
|
202
|
+
# fix some characters
|
|
203
|
+
t = ftfy.fix_text(t)
|
|
204
|
+
# fix html
|
|
205
|
+
t = fix_html(t)
|
|
206
|
+
# decode emojis (would be removed by unidecode)
|
|
207
|
+
t = emoji.demojize(t)
|
|
208
|
+
# decode and simplify text: see unidecode library
|
|
209
|
+
t = unidecode(t)
|
|
210
|
+
# lower case
|
|
211
|
+
t = t.lower()
|
|
212
|
+
# replace <PERSON> (for CC12M)
|
|
213
|
+
t = replace_person_token(t)
|
|
214
|
+
# remove wiki reference (for WIT)
|
|
215
|
+
t = remove_wiki_ref(t)
|
|
216
|
+
# remove html tags
|
|
217
|
+
t = remove_html_tags(t)
|
|
218
|
+
# remove urls
|
|
219
|
+
t = remove_urls(t)
|
|
220
|
+
# remove commas in numbers
|
|
221
|
+
t = remove_comma_numbers(t)
|
|
222
|
+
# handle dots in numbers and quotes - Part 1
|
|
223
|
+
t = pre_process_dot_numbers(t)
|
|
224
|
+
t = pre_process_quotes(t)
|
|
225
|
+
t = pre_process_dates(t)
|
|
226
|
+
# handle special characters
|
|
227
|
+
t = handle_special_chars(t)
|
|
228
|
+
# handle hashtags
|
|
229
|
+
t = expand_hashtags(t, self._hashtag_processor)
|
|
230
|
+
# ignore useless characters
|
|
231
|
+
t = ignore_chars(t)
|
|
232
|
+
# simplify quotes
|
|
233
|
+
t = simplify_quotes(t)
|
|
234
|
+
# all punctuation becomes commas
|
|
235
|
+
t = replace_punctuation_with_commas(t)
|
|
236
|
+
# handle dots in numbers and quotes - Part 2
|
|
237
|
+
t = post_process_dot_numbers(t)
|
|
238
|
+
t = post_process_quotes(t)
|
|
239
|
+
t = post_process_dates(t)
|
|
240
|
+
# handle repeating characters
|
|
241
|
+
t = remove_repeating_chars(t)
|
|
242
|
+
# merge quotes
|
|
243
|
+
t = merge_quotes(t)
|
|
244
|
+
# merge commas
|
|
245
|
+
t = merge_commas(t)
|
|
246
|
+
# remove multiple spaces
|
|
247
|
+
t = remove_extra_spaces(t)
|
|
248
|
+
# remove first and last comma
|
|
249
|
+
t = remove_first_last_commas(t)
|
|
250
|
+
# always start with a space
|
|
251
|
+
return f" {t}"
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import tempfile
|
|
3
|
+
|
|
4
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class PretrainedFromWandbMixin:
|
|
8
|
+
@classmethod
|
|
9
|
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
|
10
|
+
"""
|
|
11
|
+
Initializes from a wandb artifact or delegates loading to the superclass.
|
|
12
|
+
"""
|
|
13
|
+
try:
|
|
14
|
+
import wandb
|
|
15
|
+
except ModuleNotFoundError as e:
|
|
16
|
+
handle_module_not_found_error(e, ["heim"])
|
|
17
|
+
|
|
18
|
+
with tempfile.TemporaryDirectory() as tmp_dir: # avoid multiple artifact copies
|
|
19
|
+
if ":" in pretrained_model_name_or_path and not os.path.isdir(pretrained_model_name_or_path):
|
|
20
|
+
# wandb artifact
|
|
21
|
+
if wandb.run is not None:
|
|
22
|
+
artifact = wandb.run.use_artifact(pretrained_model_name_or_path)
|
|
23
|
+
else:
|
|
24
|
+
artifact = wandb.Api().artifact(pretrained_model_name_or_path)
|
|
25
|
+
pretrained_model_name_or_path = artifact.download(tmp_dir)
|
|
26
|
+
|
|
27
|
+
return super(PretrainedFromWandbMixin, cls).from_pretrained(
|
|
28
|
+
pretrained_model_name_or_path, *model_args, **kwargs
|
|
29
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from . import *
|
|
@@ -0,0 +1,40 @@
|
|
|
1
|
+
from typing import Tuple
|
|
2
|
+
|
|
3
|
+
from transformers import PretrainedConfig
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class VQGANConfig(PretrainedConfig):
|
|
7
|
+
def __init__(
|
|
8
|
+
self,
|
|
9
|
+
ch: int = 128,
|
|
10
|
+
out_ch: int = 3,
|
|
11
|
+
in_channels: int = 3,
|
|
12
|
+
num_res_blocks: int = 2,
|
|
13
|
+
resolution: int = 256,
|
|
14
|
+
z_channels: int = 256,
|
|
15
|
+
ch_mult: Tuple = (1, 1, 2, 2, 4),
|
|
16
|
+
attn_resolutions: int = (16,),
|
|
17
|
+
n_embed: int = 1024,
|
|
18
|
+
embed_dim: int = 256,
|
|
19
|
+
dropout: float = 0.0,
|
|
20
|
+
double_z: bool = False,
|
|
21
|
+
resamp_with_conv: bool = True,
|
|
22
|
+
give_pre_end: bool = False,
|
|
23
|
+
**kwargs,
|
|
24
|
+
):
|
|
25
|
+
super().__init__(**kwargs)
|
|
26
|
+
self.ch = ch
|
|
27
|
+
self.out_ch = out_ch
|
|
28
|
+
self.in_channels = in_channels
|
|
29
|
+
self.num_res_blocks = num_res_blocks
|
|
30
|
+
self.resolution = resolution
|
|
31
|
+
self.z_channels = z_channels
|
|
32
|
+
self.ch_mult = list(ch_mult)
|
|
33
|
+
self.attn_resolutions = list(attn_resolutions)
|
|
34
|
+
self.n_embed = n_embed
|
|
35
|
+
self.embed_dim = embed_dim
|
|
36
|
+
self.dropout = dropout
|
|
37
|
+
self.double_z = double_z
|
|
38
|
+
self.resamp_with_conv = resamp_with_conv
|
|
39
|
+
self.give_pre_end = give_pre_end
|
|
40
|
+
self.num_resolutions = len(ch_mult)
|
|
@@ -0,0 +1,107 @@
|
|
|
1
|
+
import re
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
|
|
5
|
+
from .modeling_flax_vqgan import VQModel
|
|
6
|
+
from .configuration_vqgan import VQGANConfig
|
|
7
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
8
|
+
|
|
9
|
+
try:
|
|
10
|
+
import jax.numpy as jnp
|
|
11
|
+
from flax.traverse_util import flatten_dict, unflatten_dict
|
|
12
|
+
except ModuleNotFoundError as e:
|
|
13
|
+
handle_module_not_found_error(e, ["heim"])
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
regex = r"\w+[.]\d+"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def rename_key(key):
|
|
20
|
+
pats = re.findall(regex, key)
|
|
21
|
+
for pat in pats:
|
|
22
|
+
key = key.replace(pat, "_".join(pat.split(".")))
|
|
23
|
+
return key
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
# Adapted from https://github.com/huggingface/transformers/blob/ff5cdc086be1e0c3e2bbad8e3469b34cffb55a85/src/transformers/modeling_flax_pytorch_utils.py#L61
|
|
27
|
+
def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
|
|
28
|
+
# convert pytorch tensor to numpy
|
|
29
|
+
pt_state_dict = {k: v.numpy() for k, v in pt_state_dict.items()}
|
|
30
|
+
|
|
31
|
+
random_flax_state_dict = flatten_dict(flax_model.params)
|
|
32
|
+
flax_state_dict = {}
|
|
33
|
+
|
|
34
|
+
remove_base_model_prefix = (flax_model.base_model_prefix not in flax_model.params) and (
|
|
35
|
+
flax_model.base_model_prefix in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
|
36
|
+
)
|
|
37
|
+
add_base_model_prefix = (flax_model.base_model_prefix in flax_model.params) and (
|
|
38
|
+
flax_model.base_model_prefix not in set([k.split(".")[0] for k in pt_state_dict.keys()])
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
# Need to change some parameters name to match Flax names so that we don't have to fork any layer
|
|
42
|
+
for pt_key, pt_tensor in pt_state_dict.items():
|
|
43
|
+
pt_tuple_key = tuple(pt_key.split("."))
|
|
44
|
+
|
|
45
|
+
has_base_model_prefix = pt_tuple_key[0] == flax_model.base_model_prefix
|
|
46
|
+
require_base_model_prefix = (flax_model.base_model_prefix,) + pt_tuple_key in random_flax_state_dict
|
|
47
|
+
|
|
48
|
+
if remove_base_model_prefix and has_base_model_prefix:
|
|
49
|
+
pt_tuple_key = pt_tuple_key[1:]
|
|
50
|
+
elif add_base_model_prefix and require_base_model_prefix:
|
|
51
|
+
pt_tuple_key = (flax_model.base_model_prefix,) + pt_tuple_key
|
|
52
|
+
|
|
53
|
+
# Correctly rename weight parameters
|
|
54
|
+
if (
|
|
55
|
+
"norm" in pt_key
|
|
56
|
+
and (pt_tuple_key[-1] == "bias")
|
|
57
|
+
and (pt_tuple_key[:-1] + ("bias",) not in random_flax_state_dict)
|
|
58
|
+
and (pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict)
|
|
59
|
+
):
|
|
60
|
+
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
|
61
|
+
elif pt_tuple_key[-1] in ["weight", "gamma"] and pt_tuple_key[:-1] + ("scale",) in random_flax_state_dict:
|
|
62
|
+
pt_tuple_key = pt_tuple_key[:-1] + ("scale",)
|
|
63
|
+
if pt_tuple_key[-1] == "weight" and pt_tuple_key[:-1] + ("embedding",) in random_flax_state_dict:
|
|
64
|
+
pt_tuple_key = pt_tuple_key[:-1] + ("embedding",)
|
|
65
|
+
elif pt_tuple_key[-1] == "weight" and pt_tensor.ndim == 4 and pt_tuple_key not in random_flax_state_dict:
|
|
66
|
+
# conv layer
|
|
67
|
+
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
|
68
|
+
pt_tensor = pt_tensor.transpose(2, 3, 1, 0)
|
|
69
|
+
elif pt_tuple_key[-1] == "weight" and pt_tuple_key not in random_flax_state_dict:
|
|
70
|
+
# linear layer
|
|
71
|
+
pt_tuple_key = pt_tuple_key[:-1] + ("kernel",)
|
|
72
|
+
pt_tensor = pt_tensor.T
|
|
73
|
+
elif pt_tuple_key[-1] == "gamma":
|
|
74
|
+
pt_tuple_key = pt_tuple_key[:-1] + ("weight",)
|
|
75
|
+
elif pt_tuple_key[-1] == "beta":
|
|
76
|
+
pt_tuple_key = pt_tuple_key[:-1] + ("bias",)
|
|
77
|
+
|
|
78
|
+
if pt_tuple_key in random_flax_state_dict:
|
|
79
|
+
if pt_tensor.shape != random_flax_state_dict[pt_tuple_key].shape:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"PyTorch checkpoint seems to be incorrect. Weight {pt_key} was expected to be of shape "
|
|
82
|
+
f"{random_flax_state_dict[pt_tuple_key].shape}, but is {pt_tensor.shape}."
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
# also add unexpected weight so that warning is thrown
|
|
86
|
+
flax_state_dict[pt_tuple_key] = jnp.asarray(pt_tensor)
|
|
87
|
+
|
|
88
|
+
return unflatten_dict(flax_state_dict)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def convert_model(config_path, pt_state_dict_path, save_path):
|
|
92
|
+
config = VQGANConfig.from_pretrained(config_path)
|
|
93
|
+
model = VQModel(config)
|
|
94
|
+
|
|
95
|
+
state_dict = torch.load(pt_state_dict_path, map_location="cpu")["state_dict"]
|
|
96
|
+
keys = list(state_dict.keys())
|
|
97
|
+
for key in keys:
|
|
98
|
+
if key.startswith("loss"):
|
|
99
|
+
state_dict.pop(key)
|
|
100
|
+
continue
|
|
101
|
+
renamed_key = rename_key(key)
|
|
102
|
+
state_dict[renamed_key] = state_dict.pop(key)
|
|
103
|
+
|
|
104
|
+
state = convert_pytorch_state_dict_to_flax(state_dict, model)
|
|
105
|
+
model.params = state
|
|
106
|
+
model.save_pretrained(save_path)
|
|
107
|
+
return model
|