crfm-helm 0.4.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/METADATA +138 -31
- crfm_helm-0.5.1.dist-info/RECORD +654 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +31 -3
- helm/benchmark/adaptation/adapters/adapter.py +2 -2
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +20 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +2 -3
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +32 -8
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +7 -19
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +60 -6
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/gender_perturbation.py +1 -1
- helm/benchmark/augmentations/perturbation.py +25 -3
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/test_perturbation.py +41 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/config_registry.py +7 -1
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +20 -7
- helm/benchmark/metrics/basic_metrics.py +169 -664
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -110
- helm/benchmark/metrics/dry_run_metrics.py +2 -2
- helm/benchmark/metrics/efficiency_metrics.py +213 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +392 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +89 -0
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +2 -2
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +2 -2
- helm/benchmark/metrics/test_classification_metrics.py +8 -5
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/{test_basic_metrics.py → test_evaluate_reference_metrics.py} +5 -1
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +10 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +575 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +74 -0
- helm/benchmark/model_metadata_registry.py +41 -1
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +26 -10
- helm/benchmark/presentation/schema.py +15 -40
- helm/benchmark/presentation/summarize.py +119 -79
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +1 -2
- helm/benchmark/presentation/test_summarize.py +3 -3
- helm/benchmark/run.py +54 -26
- helm/benchmark/run_expander.py +205 -35
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +163 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/run_specs/classic_run_specs.py +1510 -0
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +757 -0
- helm/benchmark/runner.py +51 -57
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +1 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +6 -2
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +1 -1
- helm/benchmark/scenarios/opinions_qa_scenario.py +0 -4
- helm/benchmark/scenarios/scenario.py +4 -0
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +6 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +94 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +246 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +3 -4
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +5 -3
- helm/benchmark/scenarios/wmt_14_scenario.py +1 -1
- helm/benchmark/server.py +24 -1
- helm/benchmark/slurm_runner.py +70 -49
- helm/benchmark/static/benchmarking.js +1 -1
- helm/benchmark/static/schema_classic.yaml +258 -1066
- helm/benchmark/static/schema_image2structure.yaml +304 -0
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +2 -227
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vhelm_lite.yaml +164 -0
- helm/benchmark/static/schema_vlm.yaml +823 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-737eef9e.js +10 -0
- helm/benchmark/static_build/assets/index-878a1094.css +1 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -44
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +4 -1
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +3 -9
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +233 -18
- helm/{proxy/clients → clients}/auto_client.py +59 -31
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +65 -7
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +4 -11
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +5 -5
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +5 -7
- helm/{proxy/clients → clients}/huggingface_client.py +43 -64
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +4 -4
- helm/{proxy/clients → clients}/megatron_client.py +5 -5
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +301 -0
- helm/{proxy/clients → clients}/palmyra_client.py +6 -8
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +13 -15
- helm/clients/test_client.py +100 -0
- helm/{proxy/clients → clients}/test_huggingface_client.py +15 -16
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +20 -8
- helm/{proxy/clients → clients}/together_client.py +104 -73
- helm/clients/vertexai_client.py +400 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
- helm/clients/vision_language/huggingface_vlm_client.py +111 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +54 -49
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +16 -4
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +1 -1
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +33 -3
- helm/common/key_value_store.py +35 -4
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +3 -3
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/request.py +15 -17
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +1 -1
- helm/config/model_deployments.yaml +1159 -538
- helm/config/model_metadata.yaml +868 -41
- helm/config/tokenizer_configs.yaml +149 -43
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +8 -6
- helm/proxy/example_queries.py +29 -17
- helm/proxy/server.py +70 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +96 -16
- helm/proxy/services/service.py +30 -0
- helm/proxy/services/test_remote_service.py +4 -3
- helm/proxy/services/test_service.py +0 -12
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/auto_tokenizer.py +6 -9
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +7 -26
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/test_anthropic_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/test_huggingface_tokenizer.py +3 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/{proxy/tokenizers → tokenizers}/vertexai_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.4.0.dist-info/RECORD +0 -397
- helm/benchmark/run_specs.py +0 -2762
- helm/benchmark/test_model_deployment_definition.py +0 -92
- helm/benchmark/test_model_properties.py +0 -1570
- helm/benchmark/vlm_run_specs.py +0 -97
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/huggingface_window_service.py +0 -60
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/common/cache_utils.py +0 -14
- helm/proxy/clients/aleph_alpha_client.py +0 -95
- helm/proxy/clients/goose_ai_client.py +0 -99
- helm/proxy/clients/microsoft_client.py +0 -180
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/simple_client.py +0 -60
- helm/proxy/clients/test_client.py +0 -49
- helm/proxy/clients/vertexai_client.py +0 -115
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -88
- helm/proxy/token_counters/test_openai_token_counter.py +0 -81
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/LICENSE +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.4.0.dist-info → crfm_helm-0.5.1.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
helm/proxy/server.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
# mypy: check_untyped_defs = False
|
|
2
|
-
|
|
3
1
|
"""
|
|
4
2
|
Starts a REST server for the frontend to interact with.
|
|
5
3
|
Look at `index.js` to see how the functionality is invoked.
|
|
@@ -20,12 +18,17 @@ from helm.benchmark.config_registry import (
|
|
|
20
18
|
register_configs_from_directory,
|
|
21
19
|
register_builtin_configs_from_helm_package,
|
|
22
20
|
)
|
|
21
|
+
from helm.benchmark.model_deployment_registry import get_default_model_deployment_for_model
|
|
23
22
|
from helm.common.authentication import Authentication
|
|
23
|
+
from helm.common.cache_backend_config import CacheBackendConfig, MongoCacheBackendConfig, SqliteCacheBackendConfig
|
|
24
|
+
from helm.common.general import ensure_directory_exists
|
|
24
25
|
from helm.common.hierarchical_logger import hlog
|
|
25
26
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
26
27
|
from helm.common.request import Request
|
|
27
28
|
from helm.common.perspective_api_request import PerspectiveAPIRequest
|
|
29
|
+
from helm.common.moderations_api_request import ModerationAPIRequest
|
|
28
30
|
from helm.common.tokenization_request import TokenizationRequest, DecodeRequest
|
|
31
|
+
from helm.proxy.services.service import CACHE_DIR
|
|
29
32
|
from .accounts import Account
|
|
30
33
|
from .services.server_service import ServerService
|
|
31
34
|
from .query import Query
|
|
@@ -39,6 +42,7 @@ except ModuleNotFoundError as e:
|
|
|
39
42
|
bottle.BaseRequest.MEMFILE_MAX = 1024 * 1024
|
|
40
43
|
|
|
41
44
|
app = bottle.default_app()
|
|
45
|
+
service: ServerService
|
|
42
46
|
|
|
43
47
|
|
|
44
48
|
def safe_call(func, to_json=True):
|
|
@@ -87,9 +91,16 @@ def handle_static_filename(filename):
|
|
|
87
91
|
return resp
|
|
88
92
|
|
|
89
93
|
|
|
94
|
+
@app.get("/output/<filename:path>")
|
|
95
|
+
def handle_output_filename(filename):
|
|
96
|
+
resp = bottle.static_file(filename, root=app.config["crfm.proxy.outputpath"])
|
|
97
|
+
return resp
|
|
98
|
+
|
|
99
|
+
|
|
90
100
|
@app.get("/api/general_info")
|
|
91
101
|
def handle_get_general_info():
|
|
92
102
|
def perform(args):
|
|
103
|
+
global service
|
|
93
104
|
return dataclasses.asdict(service.get_general_info())
|
|
94
105
|
|
|
95
106
|
return safe_call(perform)
|
|
@@ -98,6 +109,7 @@ def handle_get_general_info():
|
|
|
98
109
|
@app.get("/api/window_service_info")
|
|
99
110
|
def handle_get_window_service_info():
|
|
100
111
|
def perform(args):
|
|
112
|
+
global service
|
|
101
113
|
return dataclasses.asdict(service.get_window_service_info(args["model_name"]))
|
|
102
114
|
|
|
103
115
|
return safe_call(perform)
|
|
@@ -106,6 +118,7 @@ def handle_get_window_service_info():
|
|
|
106
118
|
@app.post("/api/account")
|
|
107
119
|
def handle_create_account():
|
|
108
120
|
def perform(args):
|
|
121
|
+
global service
|
|
109
122
|
auth = Authentication(**json.loads(args["auth"]))
|
|
110
123
|
return dataclasses.asdict(service.create_account(auth))
|
|
111
124
|
|
|
@@ -115,6 +128,7 @@ def handle_create_account():
|
|
|
115
128
|
@app.delete("/api/account")
|
|
116
129
|
def handle_delete_account():
|
|
117
130
|
def perform(args):
|
|
131
|
+
global service
|
|
118
132
|
auth = Authentication(**json.loads(args["auth"]))
|
|
119
133
|
api_key = args["api_key"]
|
|
120
134
|
return dataclasses.asdict(service.delete_account(auth, api_key))
|
|
@@ -125,6 +139,7 @@ def handle_delete_account():
|
|
|
125
139
|
@app.get("/api/account")
|
|
126
140
|
def handle_get_account():
|
|
127
141
|
def perform(args):
|
|
142
|
+
global service
|
|
128
143
|
auth = Authentication(**json.loads(args["auth"]))
|
|
129
144
|
if "all" in args and args["all"].lower() == "true":
|
|
130
145
|
return [dataclasses.asdict(account) for account in service.get_accounts(auth)]
|
|
@@ -137,6 +152,7 @@ def handle_get_account():
|
|
|
137
152
|
@app.put("/api/account")
|
|
138
153
|
def handle_update_account():
|
|
139
154
|
def perform(args):
|
|
155
|
+
global service
|
|
140
156
|
auth = Authentication(**json.loads(args["auth"]))
|
|
141
157
|
account = from_dict(Account, json.loads(args["account"]))
|
|
142
158
|
return dataclasses.asdict(service.update_account(auth, account))
|
|
@@ -147,6 +163,7 @@ def handle_update_account():
|
|
|
147
163
|
@app.put("/api/account/api_key")
|
|
148
164
|
def handle_update_api_key():
|
|
149
165
|
def perform(args):
|
|
166
|
+
global service
|
|
150
167
|
auth = Authentication(**json.loads(args["auth"]))
|
|
151
168
|
account = from_dict(Account, json.loads(args["account"]))
|
|
152
169
|
return dataclasses.asdict(service.rotate_api_key(auth, account))
|
|
@@ -157,6 +174,7 @@ def handle_update_api_key():
|
|
|
157
174
|
@app.get("/api/query")
|
|
158
175
|
def handle_query():
|
|
159
176
|
def perform(args):
|
|
177
|
+
global service
|
|
160
178
|
query = Query(**args)
|
|
161
179
|
return dataclasses.asdict(service.expand_query(query))
|
|
162
180
|
|
|
@@ -166,9 +184,28 @@ def handle_query():
|
|
|
166
184
|
@app.get("/api/request")
|
|
167
185
|
def handle_request():
|
|
168
186
|
def perform(args):
|
|
187
|
+
global service
|
|
169
188
|
auth = Authentication(**json.loads(args["auth"]))
|
|
170
189
|
request = Request(**json.loads(args["request"]))
|
|
171
|
-
|
|
190
|
+
# Hack to maintain reverse compatibility with clients with version <= 0.3.0.
|
|
191
|
+
# Clients with version <= 0.3.0 do not set model_deployment, but this is now
|
|
192
|
+
# required by Request.
|
|
193
|
+
if not request.model_deployment:
|
|
194
|
+
model_deployment = get_default_model_deployment_for_model(request.model)
|
|
195
|
+
if model_deployment is None:
|
|
196
|
+
raise ValueError(f"Unknown model '{request.model}'")
|
|
197
|
+
request = dataclasses.replace(request, model_deployment=model_deployment)
|
|
198
|
+
|
|
199
|
+
raw_response = dataclasses.asdict(service.make_request(auth, request))
|
|
200
|
+
|
|
201
|
+
# Hack to maintain reverse compatibility with clients with version <= 1.0.0.
|
|
202
|
+
# Clients with version <= 1.0.0 expect each token to contain a `top_logprobs`
|
|
203
|
+
# field of type dict.
|
|
204
|
+
for completion in raw_response["completions"]:
|
|
205
|
+
for token in completion["tokens"]:
|
|
206
|
+
token["top_logprobs"] = {}
|
|
207
|
+
|
|
208
|
+
return raw_response
|
|
172
209
|
|
|
173
210
|
return safe_call(perform)
|
|
174
211
|
|
|
@@ -176,6 +213,7 @@ def handle_request():
|
|
|
176
213
|
@app.get("/api/tokenize")
|
|
177
214
|
def handle_tokenization():
|
|
178
215
|
def perform(args):
|
|
216
|
+
global service
|
|
179
217
|
auth = Authentication(**json.loads(args["auth"]))
|
|
180
218
|
request = TokenizationRequest(**json.loads(args["request"]))
|
|
181
219
|
return dataclasses.asdict(service.tokenize(auth, request))
|
|
@@ -186,6 +224,7 @@ def handle_tokenization():
|
|
|
186
224
|
@app.get("/api/decode")
|
|
187
225
|
def handle_decode():
|
|
188
226
|
def perform(args):
|
|
227
|
+
global service
|
|
189
228
|
auth = Authentication(**json.loads(args["auth"]))
|
|
190
229
|
request = DecodeRequest(**json.loads(args["request"]))
|
|
191
230
|
return dataclasses.asdict(service.decode(auth, request))
|
|
@@ -196,6 +235,7 @@ def handle_decode():
|
|
|
196
235
|
@app.get("/api/toxicity")
|
|
197
236
|
def handle_toxicity_request():
|
|
198
237
|
def perform(args):
|
|
238
|
+
global service
|
|
199
239
|
auth = Authentication(**json.loads(args["auth"]))
|
|
200
240
|
request = PerspectiveAPIRequest(**json.loads(args["request"]))
|
|
201
241
|
return dataclasses.asdict(service.get_toxicity_scores(auth, request))
|
|
@@ -203,9 +243,21 @@ def handle_toxicity_request():
|
|
|
203
243
|
return safe_call(perform)
|
|
204
244
|
|
|
205
245
|
|
|
246
|
+
@app.get("/api/moderation")
|
|
247
|
+
def handle_moderation_request():
|
|
248
|
+
def perform(args):
|
|
249
|
+
global service
|
|
250
|
+
auth = Authentication(**json.loads(args["auth"]))
|
|
251
|
+
request = ModerationAPIRequest(**json.loads(args["request"]))
|
|
252
|
+
return dataclasses.asdict(service.get_moderation_results(auth, request))
|
|
253
|
+
|
|
254
|
+
return safe_call(perform)
|
|
255
|
+
|
|
256
|
+
|
|
206
257
|
@app.get("/api/shutdown")
|
|
207
258
|
def handle_shutdown():
|
|
208
259
|
def perform(args):
|
|
260
|
+
global service
|
|
209
261
|
auth = Authentication(**json.loads(args["auth"]))
|
|
210
262
|
service.shutdown(auth)
|
|
211
263
|
|
|
@@ -218,6 +270,7 @@ def main():
|
|
|
218
270
|
parser.add_argument("-p", "--port", type=int, help="What port to listen on", default=1959)
|
|
219
271
|
parser.add_argument("--ssl-key-file", type=str, help="Path to SSL key file")
|
|
220
272
|
parser.add_argument("--ssl-cert-file", type=str, help="Path to SSL cert file")
|
|
273
|
+
parser.add_argument("--ssl-ca-certs", type=str, help="Path to SSL CA certs")
|
|
221
274
|
parser.add_argument("-b", "--base-path", help="What directory has credentials, etc.", default="prod_env")
|
|
222
275
|
parser.add_argument("-w", "--workers", type=int, help="Number of worker processes to handle requests", default=8)
|
|
223
276
|
parser.add_argument("-t", "--timeout", type=int, help="Request timeout in seconds", default=5 * 60)
|
|
@@ -232,17 +285,29 @@ def main():
|
|
|
232
285
|
register_builtin_configs_from_helm_package()
|
|
233
286
|
register_configs_from_directory(args.base_path)
|
|
234
287
|
|
|
235
|
-
|
|
288
|
+
cache_backend_config: CacheBackendConfig
|
|
289
|
+
if args.mongo_uri:
|
|
290
|
+
cache_backend_config = MongoCacheBackendConfig(args.mongo_uri)
|
|
291
|
+
else:
|
|
292
|
+
sqlite_cache_path = os.path.join(args.base_path, CACHE_DIR)
|
|
293
|
+
ensure_directory_exists(sqlite_cache_path)
|
|
294
|
+
cache_backend_config = SqliteCacheBackendConfig(sqlite_cache_path)
|
|
295
|
+
|
|
296
|
+
service = ServerService(base_path=args.base_path, cache_backend_config=cache_backend_config)
|
|
236
297
|
|
|
237
298
|
gunicorn_args = {
|
|
238
299
|
"workers": args.workers,
|
|
239
300
|
"timeout": args.timeout,
|
|
240
301
|
"limit_request_line": 0, # Controls the maximum size of HTTP request line in bytes. 0 = unlimited.
|
|
241
302
|
}
|
|
242
|
-
if args.ssl_key_file
|
|
303
|
+
if args.ssl_key_file:
|
|
243
304
|
gunicorn_args["keyfile"] = args.ssl_key_file
|
|
305
|
+
if args.ssl_cert_file:
|
|
244
306
|
gunicorn_args["certfile"] = args.ssl_cert_file
|
|
307
|
+
if args.ssl_ca_certs:
|
|
308
|
+
gunicorn_args["ca_certs"] = args.ssl_ca_certs
|
|
245
309
|
|
|
246
310
|
# Clear arguments before running gunicorn as it also uses argparse
|
|
247
311
|
sys.argv = [sys.argv[0]]
|
|
312
|
+
app.config["crfm.proxy.outputpath"] = os.path.join(os.path.realpath(args.base_path), "cache", "output")
|
|
248
313
|
app.run(host="0.0.0.0", port=args.port, server="gunicorn", **gunicorn_args)
|
|
@@ -5,9 +5,15 @@ import urllib.parse
|
|
|
5
5
|
from dataclasses import asdict
|
|
6
6
|
from typing import Any, List, Optional
|
|
7
7
|
|
|
8
|
+
from helm.common.cache import CacheConfig
|
|
9
|
+
from helm.common.cache_backend_config import BlackHoleCacheBackendConfig
|
|
8
10
|
from helm.common.authentication import Authentication
|
|
11
|
+
from helm.common.moderations_api_request import ModerationAPIRequest, ModerationAPIRequestResult
|
|
9
12
|
from helm.common.critique_request import CritiqueRequest, CritiqueRequestResult
|
|
13
|
+
from helm.common.nudity_check_request import NudityCheckRequest, NudityCheckResult
|
|
14
|
+
from helm.common.file_upload_request import FileUploadRequest, FileUploadResult
|
|
10
15
|
from helm.common.perspective_api_request import PerspectiveAPIRequest, PerspectiveAPIRequestResult
|
|
16
|
+
from helm.common.clip_score_request import CLIPScoreRequest, CLIPScoreResult
|
|
11
17
|
from helm.common.tokenization_request import (
|
|
12
18
|
WindowServiceInfo,
|
|
13
19
|
TokenizationRequest,
|
|
@@ -27,6 +33,8 @@ class RemoteServiceError(Exception):
|
|
|
27
33
|
|
|
28
34
|
|
|
29
35
|
class RemoteService(Service):
|
|
36
|
+
NOT_SUPPORTED_ERROR: str = "Not supported through the remote service."
|
|
37
|
+
|
|
30
38
|
def __init__(self, base_url):
|
|
31
39
|
self.base_url: str = base_url
|
|
32
40
|
|
|
@@ -84,6 +92,15 @@ class RemoteService(Service):
|
|
|
84
92
|
RemoteService._check_response(response, request_json)
|
|
85
93
|
return from_dict(DecodeRequestResult, response)
|
|
86
94
|
|
|
95
|
+
def upload(self, auth: Authentication, request: FileUploadRequest) -> FileUploadResult:
|
|
96
|
+
raise NotImplementedError(self.NOT_SUPPORTED_ERROR)
|
|
97
|
+
|
|
98
|
+
def check_nudity(self, auth: Authentication, request: NudityCheckRequest) -> NudityCheckResult:
|
|
99
|
+
raise NotImplementedError(self.NOT_SUPPORTED_ERROR)
|
|
100
|
+
|
|
101
|
+
def compute_clip_score(self, auth: Authentication, request: CLIPScoreRequest) -> CLIPScoreResult:
|
|
102
|
+
raise NotImplementedError(self.NOT_SUPPORTED_ERROR)
|
|
103
|
+
|
|
87
104
|
def get_toxicity_scores(self, auth: Authentication, request: PerspectiveAPIRequest) -> PerspectiveAPIRequestResult:
|
|
88
105
|
request_json: str = json.dumps(asdict(request))
|
|
89
106
|
params = {
|
|
@@ -94,6 +111,16 @@ class RemoteService(Service):
|
|
|
94
111
|
RemoteService._check_response(response, request_json)
|
|
95
112
|
return from_dict(PerspectiveAPIRequestResult, response)
|
|
96
113
|
|
|
114
|
+
def get_moderation_results(self, auth: Authentication, request: ModerationAPIRequest) -> ModerationAPIRequestResult:
|
|
115
|
+
request_json: str = json.dumps(asdict(request))
|
|
116
|
+
params = {
|
|
117
|
+
"auth": json.dumps(asdict(auth)),
|
|
118
|
+
"request": request_json,
|
|
119
|
+
}
|
|
120
|
+
response = requests.get(f"{self.base_url}/api/moderation?{urllib.parse.urlencode(params)}").json()
|
|
121
|
+
RemoteService._check_response(response, request_json)
|
|
122
|
+
return from_dict(ModerationAPIRequestResult, response)
|
|
123
|
+
|
|
97
124
|
def make_critique_request(self, auth: Authentication, request: CritiqueRequest) -> CritiqueRequestResult:
|
|
98
125
|
raise NotImplementedError("make_critique_request is not supported by RemoteServer")
|
|
99
126
|
|
|
@@ -153,6 +180,10 @@ class RemoteService(Service):
|
|
|
153
180
|
# A ConnectionError is expected when shutting down the server.
|
|
154
181
|
pass
|
|
155
182
|
|
|
183
|
+
def get_cache_config(self, shard_name: str) -> CacheConfig:
|
|
184
|
+
"""Returns a CacheConfig"""
|
|
185
|
+
return BlackHoleCacheBackendConfig().get_cache_config(shard_name)
|
|
186
|
+
|
|
156
187
|
|
|
157
188
|
def add_service_args(parser: argparse.ArgumentParser):
|
|
158
189
|
"""Add command-line arguments to enable command-line utilities to specify how to connect to a remote server."""
|
|
@@ -2,10 +2,15 @@ import dataclasses
|
|
|
2
2
|
import os
|
|
3
3
|
import signal
|
|
4
4
|
from typing import List, Optional
|
|
5
|
-
from helm.common.cache_utils import build_cache_config
|
|
6
5
|
|
|
6
|
+
from helm.common.cache import CacheConfig
|
|
7
|
+
from helm.common.cache_backend_config import CacheBackendConfig, BlackHoleCacheBackendConfig
|
|
7
8
|
from helm.common.critique_request import CritiqueRequest, CritiqueRequestResult
|
|
8
9
|
from helm.common.authentication import Authentication
|
|
10
|
+
from helm.common.moderations_api_request import ModerationAPIRequest, ModerationAPIRequestResult
|
|
11
|
+
from helm.common.clip_score_request import CLIPScoreRequest, CLIPScoreResult
|
|
12
|
+
from helm.common.nudity_check_request import NudityCheckRequest, NudityCheckResult
|
|
13
|
+
from helm.common.file_upload_request import FileUploadRequest, FileUploadResult
|
|
9
14
|
from helm.common.general import ensure_directory_exists, parse_hocon, get_credentials
|
|
10
15
|
from helm.common.perspective_api_request import PerspectiveAPIRequest, PerspectiveAPIRequestResult
|
|
11
16
|
from helm.common.tokenization_request import (
|
|
@@ -18,16 +23,20 @@ from helm.common.tokenization_request import (
|
|
|
18
23
|
from helm.common.request import Request, RequestResult
|
|
19
24
|
from helm.common.hierarchical_logger import hlog
|
|
20
25
|
from helm.proxy.accounts import Accounts, Account
|
|
21
|
-
from helm.
|
|
22
|
-
from helm.
|
|
26
|
+
from helm.clients.auto_client import AutoClient
|
|
27
|
+
from helm.clients.moderation_api_client import ModerationAPIClient
|
|
28
|
+
from helm.clients.perspective_api_client import PerspectiveAPIClient
|
|
29
|
+
from helm.clients.image_generation.nudity_check_client import NudityCheckClient
|
|
30
|
+
from helm.clients.gcs_client import GCSClient
|
|
31
|
+
from helm.clients.clip_score_client import CLIPScoreClient
|
|
32
|
+
from helm.clients.toxicity_classifier_client import ToxicityClassifierClient
|
|
23
33
|
from helm.proxy.example_queries import example_queries
|
|
24
34
|
from helm.benchmark.model_metadata_registry import ALL_MODELS_METADATA
|
|
25
35
|
from helm.benchmark.model_deployment_registry import get_model_deployment_host_organization
|
|
26
36
|
from helm.proxy.query import Query, QueryResult
|
|
27
37
|
from helm.proxy.retry import retry_request
|
|
28
38
|
from helm.proxy.token_counters.auto_token_counter import AutoTokenCounter
|
|
29
|
-
from helm.
|
|
30
|
-
from helm.proxy.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
|
|
39
|
+
from helm.tokenizers.auto_tokenizer import AutoTokenizer
|
|
31
40
|
from .service import (
|
|
32
41
|
Service,
|
|
33
42
|
CACHE_DIR,
|
|
@@ -44,19 +53,32 @@ class ServerService(Service):
|
|
|
44
53
|
Main class that supports various functionality for the server.
|
|
45
54
|
"""
|
|
46
55
|
|
|
47
|
-
def __init__(
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
base_path: str = "prod_env",
|
|
59
|
+
root_mode: bool = False,
|
|
60
|
+
cache_backend_config: CacheBackendConfig = BlackHoleCacheBackendConfig(),
|
|
61
|
+
):
|
|
62
|
+
ensure_directory_exists(base_path)
|
|
63
|
+
client_file_storage_path = os.path.join(base_path, CACHE_DIR)
|
|
64
|
+
ensure_directory_exists(client_file_storage_path)
|
|
65
|
+
|
|
48
66
|
credentials = get_credentials(base_path)
|
|
49
|
-
cache_path = os.path.join(base_path, CACHE_DIR)
|
|
50
|
-
ensure_directory_exists(cache_path)
|
|
51
67
|
accounts_path = os.path.join(base_path, ACCOUNTS_FILE)
|
|
52
68
|
|
|
53
|
-
self.
|
|
54
|
-
self.
|
|
55
|
-
|
|
56
|
-
self.token_counter = AutoTokenCounter(
|
|
69
|
+
self.cache_backend_config = cache_backend_config
|
|
70
|
+
self.client = AutoClient(credentials, client_file_storage_path, cache_backend_config)
|
|
71
|
+
self.tokenizer = AutoTokenizer(credentials, cache_backend_config)
|
|
72
|
+
self.token_counter = AutoTokenCounter(self.tokenizer)
|
|
57
73
|
self.accounts = Accounts(accounts_path, root_mode=root_mode)
|
|
58
|
-
|
|
74
|
+
|
|
75
|
+
# Lazily instantiate the following clients
|
|
76
|
+
self.moderation_api_client: Optional[ModerationAPIClient] = None
|
|
59
77
|
self.toxicity_classifier_client: Optional[ToxicityClassifierClient] = None
|
|
78
|
+
self.perspective_api_client: Optional[PerspectiveAPIClient] = None
|
|
79
|
+
self.nudity_check_client: Optional[NudityCheckClient] = None
|
|
80
|
+
self.clip_score_client: Optional[CLIPScoreClient] = None
|
|
81
|
+
self.gcs_client: Optional[GCSClient] = None
|
|
60
82
|
|
|
61
83
|
def get_general_info(self) -> GeneralInfo:
|
|
62
84
|
# Can't send release_dates in ModelMetadata bacause dates cannot be round-tripped to and from JSON easily.
|
|
@@ -91,6 +113,21 @@ class ServerService(Service):
|
|
|
91
113
|
requests.append(request)
|
|
92
114
|
return QueryResult(requests=requests)
|
|
93
115
|
|
|
116
|
+
def _get_model_group_for_model_deployment(self, model_deployment: str) -> str:
|
|
117
|
+
if model_deployment.startswith("openai/"):
|
|
118
|
+
if model_deployment.startswith("openai/code-"):
|
|
119
|
+
return "codex"
|
|
120
|
+
elif model_deployment.startswith("openai/dall-e-"):
|
|
121
|
+
return "dall_e"
|
|
122
|
+
elif model_deployment.startswith("openai/gpt-4-"):
|
|
123
|
+
return "gpt4"
|
|
124
|
+
else:
|
|
125
|
+
return "gpt3"
|
|
126
|
+
elif model_deployment.startswith("ai21/"):
|
|
127
|
+
return "jurassic"
|
|
128
|
+
else:
|
|
129
|
+
return get_model_deployment_host_organization(model_deployment)
|
|
130
|
+
|
|
94
131
|
def make_request(self, auth: Authentication, request: Request) -> RequestResult:
|
|
95
132
|
"""Actually make a request to an API."""
|
|
96
133
|
# TODO: try to invoke the API even if we're not authenticated, and if
|
|
@@ -98,9 +135,9 @@ class ServerService(Service):
|
|
|
98
135
|
# https://github.com/stanford-crfm/benchmarking/issues/56
|
|
99
136
|
|
|
100
137
|
self.accounts.authenticate(auth)
|
|
101
|
-
|
|
138
|
+
model_group: str = self._get_model_group_for_model_deployment(request.model_deployment)
|
|
102
139
|
# Make sure we can use
|
|
103
|
-
self.accounts.check_can_use(auth.api_key,
|
|
140
|
+
self.accounts.check_can_use(auth.api_key, model_group)
|
|
104
141
|
|
|
105
142
|
# Use!
|
|
106
143
|
request_result: RequestResult = self.client.make_request(request)
|
|
@@ -109,7 +146,7 @@ class ServerService(Service):
|
|
|
109
146
|
if not request_result.cached:
|
|
110
147
|
# Count the number of tokens used
|
|
111
148
|
count: int = self.token_counter.count_tokens(request, request_result.completions)
|
|
112
|
-
self.accounts.use(auth.api_key,
|
|
149
|
+
self.accounts.use(auth.api_key, model_group, count)
|
|
113
150
|
|
|
114
151
|
return request_result
|
|
115
152
|
|
|
@@ -123,6 +160,36 @@ class ServerService(Service):
|
|
|
123
160
|
self.accounts.authenticate(auth)
|
|
124
161
|
return self.tokenizer.decode(request)
|
|
125
162
|
|
|
163
|
+
def upload(self, auth: Authentication, request: FileUploadRequest) -> FileUploadResult:
|
|
164
|
+
"""Uploads a file to external storage."""
|
|
165
|
+
self.accounts.authenticate(auth)
|
|
166
|
+
|
|
167
|
+
if not self.gcs_client:
|
|
168
|
+
self.gcs_client = self.client.get_gcs_client()
|
|
169
|
+
|
|
170
|
+
assert self.gcs_client
|
|
171
|
+
return self.gcs_client.upload(request)
|
|
172
|
+
|
|
173
|
+
def check_nudity(self, auth: Authentication, request: NudityCheckRequest) -> NudityCheckResult:
|
|
174
|
+
"""Check for nudity."""
|
|
175
|
+
self.accounts.authenticate(auth)
|
|
176
|
+
|
|
177
|
+
if not self.nudity_check_client:
|
|
178
|
+
self.nudity_check_client = self.client.get_nudity_check_client()
|
|
179
|
+
|
|
180
|
+
assert self.nudity_check_client
|
|
181
|
+
return self.nudity_check_client.check_nudity(request)
|
|
182
|
+
|
|
183
|
+
def compute_clip_score(self, auth: Authentication, request: CLIPScoreRequest) -> CLIPScoreResult:
|
|
184
|
+
"""Computes CLIPScore for a given caption and image."""
|
|
185
|
+
self.accounts.authenticate(auth)
|
|
186
|
+
|
|
187
|
+
if not self.clip_score_client:
|
|
188
|
+
self.clip_score_client = self.client.get_clip_score_client()
|
|
189
|
+
|
|
190
|
+
assert self.clip_score_client
|
|
191
|
+
return self.clip_score_client.compute_score(request)
|
|
192
|
+
|
|
126
193
|
def get_toxicity_scores(self, auth: Authentication, request: PerspectiveAPIRequest) -> PerspectiveAPIRequestResult:
|
|
127
194
|
@retry_request
|
|
128
195
|
def get_toxicity_scores_with_retry(request: PerspectiveAPIRequest) -> PerspectiveAPIRequestResult:
|
|
@@ -133,6 +200,16 @@ class ServerService(Service):
|
|
|
133
200
|
self.accounts.authenticate(auth)
|
|
134
201
|
return get_toxicity_scores_with_retry(request)
|
|
135
202
|
|
|
203
|
+
def get_moderation_results(self, auth: Authentication, request: ModerationAPIRequest) -> ModerationAPIRequestResult:
|
|
204
|
+
@retry_request
|
|
205
|
+
def get_moderation_results_with_retry(request: ModerationAPIRequest) -> ModerationAPIRequestResult:
|
|
206
|
+
if not self.moderation_api_client:
|
|
207
|
+
self.moderation_api_client = self.client.get_moderation_api_client()
|
|
208
|
+
return self.moderation_api_client.get_moderation_results(request)
|
|
209
|
+
|
|
210
|
+
self.accounts.authenticate(auth)
|
|
211
|
+
return get_moderation_results_with_retry(request)
|
|
212
|
+
|
|
136
213
|
def make_critique_request(self, auth: Authentication, request: CritiqueRequest) -> CritiqueRequestResult:
|
|
137
214
|
self.accounts.authenticate(auth)
|
|
138
215
|
return self.client.get_critique_client().make_critique_request(request)
|
|
@@ -168,3 +245,6 @@ class ServerService(Service):
|
|
|
168
245
|
hlog(f"Shutting down server by killing its own process {pid}...")
|
|
169
246
|
os.kill(pid, signal.SIGTERM)
|
|
170
247
|
hlog("Done.")
|
|
248
|
+
|
|
249
|
+
def get_cache_config(self, shard_name: str) -> CacheConfig:
|
|
250
|
+
return self.cache_backend_config.get_cache_config(shard_name)
|
helm/proxy/services/service.py
CHANGED
|
@@ -5,7 +5,11 @@ from typing import Dict, List, Tuple, Any
|
|
|
5
5
|
|
|
6
6
|
from helm.common.general import parse_hocon
|
|
7
7
|
from helm.common.critique_request import CritiqueRequest, CritiqueRequestResult
|
|
8
|
+
from helm.common.clip_score_request import CLIPScoreRequest, CLIPScoreResult
|
|
9
|
+
from helm.common.file_upload_request import FileUploadResult, FileUploadRequest
|
|
10
|
+
from helm.common.nudity_check_request import NudityCheckRequest, NudityCheckResult
|
|
8
11
|
from helm.common.perspective_api_request import PerspectiveAPIRequestResult, PerspectiveAPIRequest
|
|
12
|
+
from helm.common.moderations_api_request import ModerationAPIRequest, ModerationAPIRequestResult
|
|
9
13
|
from helm.common.tokenization_request import (
|
|
10
14
|
WindowServiceInfo,
|
|
11
15
|
TokenizationRequest,
|
|
@@ -17,6 +21,7 @@ from helm.common.request import Request, RequestResult
|
|
|
17
21
|
from helm.benchmark.model_metadata_registry import ModelMetadata
|
|
18
22
|
from helm.proxy.query import Query, QueryResult
|
|
19
23
|
from helm.proxy.accounts import Authentication, Account
|
|
24
|
+
from helm.common.cache import CacheConfig
|
|
20
25
|
|
|
21
26
|
VERSION = "1.0"
|
|
22
27
|
ACCOUNTS_FILE = "accounts.sqlite"
|
|
@@ -105,11 +110,31 @@ class Service(ABC):
|
|
|
105
110
|
"""Decodes to text."""
|
|
106
111
|
pass
|
|
107
112
|
|
|
113
|
+
@abstractmethod
|
|
114
|
+
def upload(self, auth: Authentication, request: FileUploadRequest) -> FileUploadResult:
|
|
115
|
+
"""Uploads a file to external storage."""
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
@abstractmethod
|
|
119
|
+
def check_nudity(self, auth: Authentication, request: NudityCheckRequest) -> NudityCheckResult:
|
|
120
|
+
"""Check for nudity for a batch of images."""
|
|
121
|
+
pass
|
|
122
|
+
|
|
123
|
+
@abstractmethod
|
|
124
|
+
def compute_clip_score(self, auth: Authentication, request: CLIPScoreRequest) -> CLIPScoreResult:
|
|
125
|
+
"""Computes CLIPScore for a given caption and image."""
|
|
126
|
+
pass
|
|
127
|
+
|
|
108
128
|
@abstractmethod
|
|
109
129
|
def get_toxicity_scores(self, auth: Authentication, request: PerspectiveAPIRequest) -> PerspectiveAPIRequestResult:
|
|
110
130
|
"""Get toxicity scores for a batch of text."""
|
|
111
131
|
pass
|
|
112
132
|
|
|
133
|
+
@abstractmethod
|
|
134
|
+
def get_moderation_results(self, auth: Authentication, request: ModerationAPIRequest) -> ModerationAPIRequestResult:
|
|
135
|
+
"""Get OpenAI's moderation results for some text."""
|
|
136
|
+
pass
|
|
137
|
+
|
|
113
138
|
@abstractmethod
|
|
114
139
|
def make_critique_request(self, auth: Authentication, request: CritiqueRequest) -> CritiqueRequestResult:
|
|
115
140
|
"""Get responses to a critique request."""
|
|
@@ -149,3 +174,8 @@ class Service(ABC):
|
|
|
149
174
|
def shutdown(self, auth: Authentication):
|
|
150
175
|
"""Shutdown server."""
|
|
151
176
|
pass
|
|
177
|
+
|
|
178
|
+
@abstractmethod
|
|
179
|
+
def get_cache_config(self, shard_name: str) -> CacheConfig:
|
|
180
|
+
"""Returns a CacheConfig"""
|
|
181
|
+
pass
|
|
@@ -17,7 +17,7 @@ from sqlitedict import SqliteDict
|
|
|
17
17
|
from helm.common.authentication import Authentication
|
|
18
18
|
from helm.common.request import Request, RequestResult
|
|
19
19
|
from helm.common.tokenization_request import TokenizationRequest, TokenizationRequestResult
|
|
20
|
-
from helm.proxy.accounts import Account
|
|
20
|
+
from helm.proxy.accounts import Account, set_default_quotas
|
|
21
21
|
from .remote_service import RemoteService
|
|
22
22
|
from .service import ACCOUNTS_FILE
|
|
23
23
|
|
|
@@ -55,6 +55,7 @@ class TestRemoteServerService:
|
|
|
55
55
|
|
|
56
56
|
with SqliteDict(os.path.join(path, ACCOUNTS_FILE)) as cache:
|
|
57
57
|
account: Account = Account(TestRemoteServerService._ADMIN_API_KEY, is_admin=True)
|
|
58
|
+
set_default_quotas(account)
|
|
58
59
|
cache[TestRemoteServerService._ADMIN_API_KEY] = asdict(account)
|
|
59
60
|
cache.commit()
|
|
60
61
|
return path
|
|
@@ -126,9 +127,9 @@ class TestRemoteServerService:
|
|
|
126
127
|
assert response.success
|
|
127
128
|
|
|
128
129
|
def test_tokenize(self):
|
|
129
|
-
request = TokenizationRequest(text="1 2 3", tokenizer="simple/
|
|
130
|
+
request = TokenizationRequest(text="1 2 3", tokenizer="simple/tokenizer1")
|
|
130
131
|
response: TokenizationRequestResult = self.service.tokenize(self.auth, request)
|
|
131
|
-
assert [token.value for token in response.tokens] == ["1", "2", "3"]
|
|
132
|
+
assert [token.value for token in response.tokens] == ["1", " ", "2", " ", "3"]
|
|
132
133
|
|
|
133
134
|
def test_make_request_plus_sign(self):
|
|
134
135
|
# Ensure + in prompt doesn't get replaced by a blank space
|
|
@@ -197,18 +197,6 @@ def helper_prod_test_service(request: Request, expected_text: str):
|
|
|
197
197
|
# Consistency of log probs
|
|
198
198
|
assert completion.logprob == sum(token.logprob for token in completion.tokens)
|
|
199
199
|
|
|
200
|
-
for token in completion.tokens[1:]:
|
|
201
|
-
assert len(token.top_logprobs) == request.top_k_per_token
|
|
202
|
-
|
|
203
|
-
# If generated token was one of the top, make sure has the right probability
|
|
204
|
-
if token.text in token.top_logprobs:
|
|
205
|
-
assert token.logprob == token.top_logprobs[token.text]
|
|
206
|
-
|
|
207
|
-
# If temperature = 0, then make sure we're getting the top probability token
|
|
208
|
-
if request.temperature == 0:
|
|
209
|
-
assert token.text in token.top_logprobs
|
|
210
|
-
assert token.logprob == max(token.top_logprobs.values())
|
|
211
|
-
|
|
212
200
|
# Make sure we get the expected_text in one of the completions
|
|
213
201
|
assert any(completion.text == expected_text for completion in result.completions)
|
|
214
202
|
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
import os
|
|
2
|
+
import pytest
|
|
3
|
+
import tempfile
|
|
4
|
+
|
|
5
|
+
from helm.proxy.accounts import Accounts, Authentication, InsufficientQuotaError, Usage
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestAutoTokenCounter:
|
|
9
|
+
def setup_method(self, method):
|
|
10
|
+
accounts_file = tempfile.NamedTemporaryFile(delete=False)
|
|
11
|
+
self.accounts_path: str = accounts_file.name
|
|
12
|
+
self.accounts = Accounts(self.accounts_path)
|
|
13
|
+
self.root_auth = Authentication(Accounts.DEFAULT_API_KEY)
|
|
14
|
+
|
|
15
|
+
def teardown_method(self, method):
|
|
16
|
+
os.remove(self.accounts_path)
|
|
17
|
+
|
|
18
|
+
def test_check_can_use(self):
|
|
19
|
+
model_group = "anthropic"
|
|
20
|
+
account = self.accounts.create_account(self.root_auth)
|
|
21
|
+
|
|
22
|
+
# Cannot use this account because no quota was added
|
|
23
|
+
with pytest.raises(InsufficientQuotaError):
|
|
24
|
+
self.accounts.check_can_use(account.api_key, model_group)
|
|
25
|
+
|
|
26
|
+
# Add monthly quota
|
|
27
|
+
account.usages[model_group] = {}
|
|
28
|
+
account.usages[model_group]["monthly"] = Usage(quota=1000)
|
|
29
|
+
self.accounts.update_account(self.root_auth, account)
|
|
30
|
+
|
|
31
|
+
# Now this account has quota and can be used
|
|
32
|
+
self.accounts.check_can_use(account.api_key, model_group)
|