crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -1,182 +0,0 @@
|
|
|
1
|
-
from typing import List, Optional, Dict
|
|
2
|
-
|
|
3
|
-
from filelock import FileLock
|
|
4
|
-
from openai.api_resources.abstract import engine_api_resource
|
|
5
|
-
import openai as turing
|
|
6
|
-
|
|
7
|
-
from helm.common.cache import CacheConfig
|
|
8
|
-
from helm.common.request import (
|
|
9
|
-
wrap_request_time,
|
|
10
|
-
EMBEDDING_UNAVAILABLE_REQUEST_RESULT,
|
|
11
|
-
Request,
|
|
12
|
-
RequestResult,
|
|
13
|
-
Sequence,
|
|
14
|
-
Token,
|
|
15
|
-
)
|
|
16
|
-
from helm.proxy.tokenizers.tokenizer import Tokenizer
|
|
17
|
-
from .client import CachingClient, truncate_sequence
|
|
18
|
-
from .openai_client import ORIGINAL_COMPLETION_ATTRIBUTES
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
class MicrosoftClient(CachingClient):
|
|
22
|
-
"""
|
|
23
|
-
Client for the Microsoft's Megatron-Turing NLG models (https://arxiv.org/abs/2201.11990).
|
|
24
|
-
|
|
25
|
-
According to the internal documentation: https://github.com/microsoft/turing-academic-TNLG,
|
|
26
|
-
"the model will generate roughly 3 tokens per second. The response will be returned once
|
|
27
|
-
all tokens have been generated."
|
|
28
|
-
"""
|
|
29
|
-
|
|
30
|
-
@staticmethod
|
|
31
|
-
def convert_to_raw_request(request: Request) -> Dict:
|
|
32
|
-
return {
|
|
33
|
-
"engine": request.model_engine,
|
|
34
|
-
"prompt": request.prompt,
|
|
35
|
-
"temperature": request.temperature,
|
|
36
|
-
"max_tokens": request.max_tokens,
|
|
37
|
-
"best_of": request.top_k_per_token,
|
|
38
|
-
"logprobs": request.top_k_per_token,
|
|
39
|
-
# Despite what was stated here: https://github.com/microsoft/turing-academic-TNLG#api-parameters,
|
|
40
|
-
# their API supports at most one stop sequence. Pass in the first one for now and handle the rest
|
|
41
|
-
# of the stop sequences during post processing (see `truncate_sequence` below).
|
|
42
|
-
"stop": None if len(request.stop_sequences) == 0 else request.stop_sequences[0],
|
|
43
|
-
"top_p": request.top_p,
|
|
44
|
-
"echo": request.echo_prompt,
|
|
45
|
-
}
|
|
46
|
-
|
|
47
|
-
def __init__(
|
|
48
|
-
self,
|
|
49
|
-
lock_file_path: str,
|
|
50
|
-
tokenizer: Tokenizer,
|
|
51
|
-
cache_config: CacheConfig,
|
|
52
|
-
api_key: Optional[str] = None,
|
|
53
|
-
org_id: Optional[str] = None,
|
|
54
|
-
):
|
|
55
|
-
super().__init__(cache_config=cache_config, tokenizer=tokenizer)
|
|
56
|
-
|
|
57
|
-
# Adapted from their documentation: https://github.com/microsoft/turing-academic-TNLG
|
|
58
|
-
class EngineAPIResource(engine_api_resource.EngineAPIResource):
|
|
59
|
-
@classmethod
|
|
60
|
-
def class_url(
|
|
61
|
-
cls, engine: Optional[str] = None, api_type: Optional[str] = None, api_version: Optional[str] = None
|
|
62
|
-
) -> str:
|
|
63
|
-
return f"/{engine}/inference"
|
|
64
|
-
|
|
65
|
-
self.org_id: Optional[str] = org_id
|
|
66
|
-
self.api_key: Optional[str] = api_key
|
|
67
|
-
self.api_base: str = "https://turingnlg-turingnlg-mstap-v2.turingase.p.azurewebsites.net"
|
|
68
|
-
self.completion_attributes = (EngineAPIResource,) + ORIGINAL_COMPLETION_ATTRIBUTES[1:]
|
|
69
|
-
|
|
70
|
-
# The Microsoft Turing server only allows a single request at a time, so acquire a
|
|
71
|
-
# process-safe lock before making a request.
|
|
72
|
-
# https://github.com/microsoft/turing-academic-TNLG#rate-limitations
|
|
73
|
-
#
|
|
74
|
-
# Since the model will generate roughly three tokens per second and the max context window
|
|
75
|
-
# is 2048 tokens, we expect the maximum time for a request to be fulfilled to be 700 seconds.
|
|
76
|
-
self._lock = FileLock(lock_file_path, timeout=700)
|
|
77
|
-
|
|
78
|
-
def make_request(self, request: Request) -> RequestResult:
|
|
79
|
-
"""
|
|
80
|
-
Make a request for the Microsoft MT-NLG models.
|
|
81
|
-
|
|
82
|
-
They mimicked the OpenAI completions API, but not all the parameters are supported.
|
|
83
|
-
|
|
84
|
-
Supported parameters:
|
|
85
|
-
engine
|
|
86
|
-
prompt
|
|
87
|
-
temperature
|
|
88
|
-
max_tokens
|
|
89
|
-
best_of
|
|
90
|
-
logprobs
|
|
91
|
-
stop ("Only a single "stop" value (str) is currently supported.")
|
|
92
|
-
top_p
|
|
93
|
-
echo
|
|
94
|
-
n (Not originally supported, but we simulate n by making multiple requests)
|
|
95
|
-
|
|
96
|
-
Not supported parameters:
|
|
97
|
-
presence_penalty
|
|
98
|
-
frequency_penalty
|
|
99
|
-
"""
|
|
100
|
-
# Embedding not supported for this model
|
|
101
|
-
if request.embedding:
|
|
102
|
-
return EMBEDDING_UNAVAILABLE_REQUEST_RESULT
|
|
103
|
-
|
|
104
|
-
raw_request = MicrosoftClient.convert_to_raw_request(request)
|
|
105
|
-
completions: List[Sequence] = []
|
|
106
|
-
request_time = 0
|
|
107
|
-
request_datetime: Optional[int] = None
|
|
108
|
-
all_cached = True
|
|
109
|
-
|
|
110
|
-
# API currently only supports 1 completion at a time, so we have to hit it multiple times.
|
|
111
|
-
for completion_index in range(request.num_completions):
|
|
112
|
-
try:
|
|
113
|
-
|
|
114
|
-
def do_it():
|
|
115
|
-
with self._lock:
|
|
116
|
-
# Following https://beta.openai.com/docs/api-reference/authentication
|
|
117
|
-
# `organization` can be set to None.
|
|
118
|
-
turing.organization = self.org_id
|
|
119
|
-
turing.api_key = self.api_key
|
|
120
|
-
turing.api_base = self.api_base
|
|
121
|
-
turing.api_resources.completion.Completion.__bases__ = self.completion_attributes
|
|
122
|
-
|
|
123
|
-
response: Dict = turing.Completion.create(**raw_request)
|
|
124
|
-
# Validate the responses, so we don't cache malformed responses with null `logprobs` and `text`
|
|
125
|
-
if (
|
|
126
|
-
"choices" not in response
|
|
127
|
-
or len(response["choices"]) == 0
|
|
128
|
-
or response["choices"][0].get("text") is None
|
|
129
|
-
or response["choices"][0].get("logprobs") is None
|
|
130
|
-
):
|
|
131
|
-
raise turing.error.OpenAIError(
|
|
132
|
-
f"For request: {raw_request}, invalid response from the MT-NLG server: {response}."
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
return response
|
|
136
|
-
|
|
137
|
-
def fail():
|
|
138
|
-
raise RuntimeError(
|
|
139
|
-
f"The result has not been uploaded to the cache for the following request: {cache_key}"
|
|
140
|
-
)
|
|
141
|
-
|
|
142
|
-
# We want to make `request.num_completions` fresh requests,
|
|
143
|
-
# cache key should contain the completion_index.
|
|
144
|
-
cache_key = CachingClient.make_cache_key({"completion_index": completion_index, **raw_request}, request)
|
|
145
|
-
response, cached = self.cache.get(cache_key, wrap_request_time(do_it if self.api_key else fail))
|
|
146
|
-
except turing.error.OpenAIError as e:
|
|
147
|
-
error: str = f"OpenAI (Turing API) error: {e}"
|
|
148
|
-
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
149
|
-
|
|
150
|
-
for raw_completion in response["choices"]:
|
|
151
|
-
sequence_logprob = 0
|
|
152
|
-
tokens: List[Token] = []
|
|
153
|
-
|
|
154
|
-
raw_data = raw_completion["logprobs"]
|
|
155
|
-
for text, logprob, top_logprobs in zip(
|
|
156
|
-
raw_data["tokens"], raw_data["token_logprobs"], raw_data["top_logprobs"]
|
|
157
|
-
):
|
|
158
|
-
tokens.append(Token(text=text, logprob=logprob or 0, top_logprobs=dict(top_logprobs or {})))
|
|
159
|
-
sequence_logprob += logprob or 0
|
|
160
|
-
|
|
161
|
-
completion = Sequence(
|
|
162
|
-
text=raw_completion["text"],
|
|
163
|
-
logprob=sequence_logprob,
|
|
164
|
-
tokens=tokens,
|
|
165
|
-
finish_reason={"reason": raw_completion["finish_reason"]},
|
|
166
|
-
)
|
|
167
|
-
completion = truncate_sequence(completion, request)
|
|
168
|
-
completions.append(completion)
|
|
169
|
-
|
|
170
|
-
request_time += response["request_time"]
|
|
171
|
-
# Use the datetime from the first completion because that's when the request was fired
|
|
172
|
-
request_datetime = request_datetime or response.get("request_datetime")
|
|
173
|
-
all_cached = all_cached and cached
|
|
174
|
-
|
|
175
|
-
return RequestResult(
|
|
176
|
-
success=True,
|
|
177
|
-
cached=all_cached,
|
|
178
|
-
request_time=request_time,
|
|
179
|
-
request_datetime=request_datetime,
|
|
180
|
-
completions=completions,
|
|
181
|
-
embedding=[],
|
|
182
|
-
)
|
|
@@ -1,206 +0,0 @@
|
|
|
1
|
-
# mypy: check_untyped_defs = False
|
|
2
|
-
from dataclasses import replace
|
|
3
|
-
from typing import Any, Dict, List, Optional, cast
|
|
4
|
-
|
|
5
|
-
from helm.common.cache import CacheConfig
|
|
6
|
-
from helm.common.request import wrap_request_time, Request, RequestResult, Sequence, Token
|
|
7
|
-
from helm.common.hierarchical_logger import hlog
|
|
8
|
-
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
9
|
-
from helm.common.tokenization_request import (
|
|
10
|
-
TokenizationRequest,
|
|
11
|
-
TokenizationRequestResult,
|
|
12
|
-
)
|
|
13
|
-
from helm.proxy.tokenizers.tokenizer import Tokenizer
|
|
14
|
-
from .client import CachingClient, truncate_sequence
|
|
15
|
-
|
|
16
|
-
try:
|
|
17
|
-
import openai
|
|
18
|
-
import tiktoken
|
|
19
|
-
except ModuleNotFoundError as e:
|
|
20
|
-
handle_module_not_found_error(e, ["openai"])
|
|
21
|
-
|
|
22
|
-
|
|
23
|
-
ORIGINAL_COMPLETION_ATTRIBUTES = openai.api_resources.completion.Completion.__bases__
|
|
24
|
-
|
|
25
|
-
|
|
26
|
-
class OpenAIClient(CachingClient):
|
|
27
|
-
END_OF_TEXT: str = "<|endoftext|>"
|
|
28
|
-
|
|
29
|
-
def __init__(
|
|
30
|
-
self,
|
|
31
|
-
tokenizer: Tokenizer,
|
|
32
|
-
cache_config: CacheConfig,
|
|
33
|
-
api_key: Optional[str] = None,
|
|
34
|
-
org_id: Optional[str] = None,
|
|
35
|
-
):
|
|
36
|
-
super().__init__(cache_config=cache_config, tokenizer=tokenizer)
|
|
37
|
-
self.org_id: Optional[str] = org_id
|
|
38
|
-
self.api_key: Optional[str] = api_key
|
|
39
|
-
self.api_base: str = "https://api.openai.com/v1"
|
|
40
|
-
|
|
41
|
-
def _is_chat_model_engine(self, model_engine: str):
|
|
42
|
-
return model_engine.startswith("gpt-3.5") or model_engine.startswith("gpt-4")
|
|
43
|
-
|
|
44
|
-
def make_request(self, request: Request) -> RequestResult:
|
|
45
|
-
if self.api_key is None:
|
|
46
|
-
raise ValueError("OpenAI API key is required")
|
|
47
|
-
|
|
48
|
-
raw_request: Dict[str, Any]
|
|
49
|
-
if request.embedding:
|
|
50
|
-
raw_request = {
|
|
51
|
-
"input": request.prompt,
|
|
52
|
-
"engine": request.model_engine,
|
|
53
|
-
}
|
|
54
|
-
elif self._is_chat_model_engine(request.model_engine):
|
|
55
|
-
messages: Optional[List[Dict[str, str]]] = request.messages
|
|
56
|
-
if request.messages and len(request.messages) > 1:
|
|
57
|
-
# Checks that all messages have a role and some content
|
|
58
|
-
for message in request.messages:
|
|
59
|
-
if not message.get("role") or not message.get("content"):
|
|
60
|
-
raise ValueError("All messages must have a role and content")
|
|
61
|
-
# Checks that the last role is "user"
|
|
62
|
-
if request.messages[-1]["role"] != "user":
|
|
63
|
-
raise ValueError("Last message must have role 'user'")
|
|
64
|
-
if request.prompt != "":
|
|
65
|
-
hlog("WARNING: Since message is set, prompt will be ignored")
|
|
66
|
-
else:
|
|
67
|
-
# Convert prompt into a single message
|
|
68
|
-
# For now, put the whole prompt in a single user message, and expect the response
|
|
69
|
-
# to be returned in a single assistant message.
|
|
70
|
-
# TODO: Support ChatML for creating multiple messages with different roles.
|
|
71
|
-
# See: https://github.com/openai/openai-python/blob/main/chatml.md
|
|
72
|
-
messages = [{"role": "user", "content": request.prompt}]
|
|
73
|
-
raw_request = {
|
|
74
|
-
"model": request.model_engine,
|
|
75
|
-
"messages": messages,
|
|
76
|
-
"temperature": request.temperature,
|
|
77
|
-
"top_p": request.top_p,
|
|
78
|
-
"n": request.num_completions,
|
|
79
|
-
"stop": request.stop_sequences or None, # API doesn't like empty list
|
|
80
|
-
# Note: Chat models may require adding an extra token to max_tokens
|
|
81
|
-
# for the internal special role token.
|
|
82
|
-
"max_tokens": request.max_tokens,
|
|
83
|
-
"presence_penalty": request.presence_penalty,
|
|
84
|
-
"frequency_penalty": request.frequency_penalty,
|
|
85
|
-
}
|
|
86
|
-
else:
|
|
87
|
-
raw_request = {
|
|
88
|
-
"engine": request.model_engine,
|
|
89
|
-
"prompt": request.prompt,
|
|
90
|
-
"temperature": request.temperature,
|
|
91
|
-
"n": request.num_completions,
|
|
92
|
-
"max_tokens": request.max_tokens,
|
|
93
|
-
"best_of": request.top_k_per_token,
|
|
94
|
-
"logprobs": request.top_k_per_token,
|
|
95
|
-
"stop": request.stop_sequences or None, # API doesn't like empty list
|
|
96
|
-
"top_p": request.top_p,
|
|
97
|
-
"presence_penalty": request.presence_penalty,
|
|
98
|
-
"frequency_penalty": request.frequency_penalty,
|
|
99
|
-
"echo": request.echo_prompt,
|
|
100
|
-
}
|
|
101
|
-
|
|
102
|
-
# OpenAI doesn't let you ask for more completions than the number of
|
|
103
|
-
# per-token candidates.
|
|
104
|
-
raw_request["best_of"] = max(raw_request["best_of"], raw_request["n"])
|
|
105
|
-
raw_request["logprobs"] = max(raw_request["logprobs"], raw_request["n"])
|
|
106
|
-
|
|
107
|
-
try:
|
|
108
|
-
if request.embedding:
|
|
109
|
-
|
|
110
|
-
def do_it():
|
|
111
|
-
openai.organization = self.org_id
|
|
112
|
-
openai.api_key = self.api_key
|
|
113
|
-
openai.api_base = self.api_base
|
|
114
|
-
return openai.Embedding.create(**raw_request)
|
|
115
|
-
|
|
116
|
-
elif self._is_chat_model_engine(request.model_engine):
|
|
117
|
-
|
|
118
|
-
def do_it():
|
|
119
|
-
openai.organization = self.org_id
|
|
120
|
-
openai.api_key = self.api_key
|
|
121
|
-
openai.api_base = self.api_base
|
|
122
|
-
return openai.ChatCompletion.create(**raw_request)
|
|
123
|
-
|
|
124
|
-
else:
|
|
125
|
-
|
|
126
|
-
def do_it():
|
|
127
|
-
# Following https://beta.openai.com/docs/api-reference/authentication
|
|
128
|
-
# `organization` can be set to None.
|
|
129
|
-
openai.organization = self.org_id
|
|
130
|
-
openai.api_key = self.api_key
|
|
131
|
-
openai.api_base = self.api_base
|
|
132
|
-
openai.api_resources.completion.Completion.__bases__ = ORIGINAL_COMPLETION_ATTRIBUTES
|
|
133
|
-
return openai.Completion.create(**raw_request)
|
|
134
|
-
|
|
135
|
-
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
136
|
-
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
137
|
-
except openai.error.OpenAIError as e:
|
|
138
|
-
error: str = f"OpenAI error: {e}"
|
|
139
|
-
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
140
|
-
|
|
141
|
-
# If the user is requesting completions instead of an embedding, then `completions`
|
|
142
|
-
# needs to be populated, and `embedding` should be an empty list and vice-versa.
|
|
143
|
-
embedding: List[float] = []
|
|
144
|
-
completions: List[Sequence] = []
|
|
145
|
-
tokens: List[Token]
|
|
146
|
-
if request.embedding:
|
|
147
|
-
# If the user is requesting an embedding instead of completion
|
|
148
|
-
# then completions would be left as an empty list. The embedding needs to be set.
|
|
149
|
-
embedding = response["data"][0]["embedding"]
|
|
150
|
-
elif self._is_chat_model_engine(request.model_engine):
|
|
151
|
-
for raw_completion in response["choices"]:
|
|
152
|
-
# The OpenAI chat completion API doesn't support echo.
|
|
153
|
-
# If `echo_prompt` is true, combine the prompt and completion.
|
|
154
|
-
raw_completion_content = raw_completion["message"]["content"]
|
|
155
|
-
text: str = request.prompt + raw_completion_content if request.echo_prompt else raw_completion_content
|
|
156
|
-
# The OpenAI chat completion API doesn't return us tokens or logprobs, so we tokenize ourselves.
|
|
157
|
-
tokenization_result: TokenizationRequestResult = self.tokenizer.tokenize(
|
|
158
|
-
TokenizationRequest(
|
|
159
|
-
text, tokenizer="openai/" + tiktoken.encoding_for_model(request.model_engine).name
|
|
160
|
-
)
|
|
161
|
-
)
|
|
162
|
-
# Log probs are not currently not supported by the OpenAI chat completion API, so set to 0 for now.
|
|
163
|
-
tokens = [
|
|
164
|
-
Token(text=cast(str, raw_token), logprob=0, top_logprobs={})
|
|
165
|
-
for raw_token in tokenization_result.raw_tokens
|
|
166
|
-
]
|
|
167
|
-
completion = Sequence(
|
|
168
|
-
text=text,
|
|
169
|
-
logprob=0, # OpenAI does not provide logprobs
|
|
170
|
-
tokens=tokens,
|
|
171
|
-
finish_reason={"reason": raw_completion["finish_reason"]},
|
|
172
|
-
)
|
|
173
|
-
completions.append(truncate_sequence(completion, request)) # Truncate the text by stop sequences
|
|
174
|
-
else:
|
|
175
|
-
for raw_completion in response["choices"]:
|
|
176
|
-
sequence_logprob = 0
|
|
177
|
-
tokens = []
|
|
178
|
-
|
|
179
|
-
raw_data = raw_completion["logprobs"]
|
|
180
|
-
for text, logprob, top_logprobs in zip(
|
|
181
|
-
raw_data["tokens"], raw_data["token_logprobs"], raw_data["top_logprobs"]
|
|
182
|
-
):
|
|
183
|
-
tokens.append(Token(text=text, logprob=logprob or 0, top_logprobs=dict(top_logprobs or {})))
|
|
184
|
-
sequence_logprob += logprob or 0
|
|
185
|
-
completion = Sequence(
|
|
186
|
-
text=raw_completion["text"],
|
|
187
|
-
logprob=sequence_logprob,
|
|
188
|
-
tokens=tokens,
|
|
189
|
-
finish_reason={"reason": raw_completion["finish_reason"]},
|
|
190
|
-
)
|
|
191
|
-
# OpenAI sends us back tokens past the end of text token,
|
|
192
|
-
# so we need to manually truncate the list of tokens.
|
|
193
|
-
# TODO: filed an issue with their support to check what the expected behavior here is.
|
|
194
|
-
completion = truncate_sequence(
|
|
195
|
-
completion, replace(request, stop_sequences=request.stop_sequences + [OpenAIClient.END_OF_TEXT])
|
|
196
|
-
)
|
|
197
|
-
completions.append(completion)
|
|
198
|
-
|
|
199
|
-
return RequestResult(
|
|
200
|
-
success=True,
|
|
201
|
-
cached=cached,
|
|
202
|
-
request_time=response["request_time"],
|
|
203
|
-
request_datetime=response.get("request_datetime"),
|
|
204
|
-
completions=completions,
|
|
205
|
-
embedding=embedding,
|
|
206
|
-
)
|
|
@@ -1,28 +0,0 @@
|
|
|
1
|
-
from typing import Dict, List, Optional
|
|
2
|
-
|
|
3
|
-
from helm.proxy.models import Model
|
|
4
|
-
from helm.proxy.services.remote_service import RemoteService
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
_remote_model_registry: Dict[str, Model] = {}
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
def get_remote_model(model_name: str) -> Optional[Model]:
|
|
11
|
-
"""Returns a Model for the model_name."""
|
|
12
|
-
return _remote_model_registry.get(model_name)
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
def check_and_register_remote_model(server_url: str, model_names: List[str]):
|
|
16
|
-
try:
|
|
17
|
-
service = RemoteService(server_url)
|
|
18
|
-
info = service.get_general_info()
|
|
19
|
-
models = {}
|
|
20
|
-
for model in info.all_models:
|
|
21
|
-
models[model.name] = model
|
|
22
|
-
for model_name in model_names:
|
|
23
|
-
if model_name in models:
|
|
24
|
-
_remote_model_registry[model_name] = models[model_name]
|
|
25
|
-
else:
|
|
26
|
-
raise RuntimeError(f"remote service not contain {model_name}")
|
|
27
|
-
except Exception as e:
|
|
28
|
-
raise RuntimeError(f"check and register remote service error: {e}")
|
|
@@ -1,61 +0,0 @@
|
|
|
1
|
-
from typing import List, Dict
|
|
2
|
-
|
|
3
|
-
from helm.common.cache import CacheConfig
|
|
4
|
-
from helm.common.request import wrap_request_time, Request, RequestResult, Sequence, Token
|
|
5
|
-
from helm.proxy.tokenizers.simple_tokenizer import SimpleTokenizer
|
|
6
|
-
from helm.proxy.tokenizers.tokenizer import Tokenizer
|
|
7
|
-
from .client import CachingClient
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
class SimpleClient(CachingClient):
|
|
11
|
-
"""Implements some "models" that just generate silly things quickly just to debug the infrastructure."""
|
|
12
|
-
|
|
13
|
-
def __init__(self, tokenizer: Tokenizer, cache_config: CacheConfig):
|
|
14
|
-
super().__init__(cache_config=cache_config, tokenizer=tokenizer)
|
|
15
|
-
|
|
16
|
-
def make_request(self, request: Request) -> RequestResult:
|
|
17
|
-
raw_request = {
|
|
18
|
-
"engine": request.model_engine,
|
|
19
|
-
"prompt": request.prompt,
|
|
20
|
-
"n": request.num_completions,
|
|
21
|
-
}
|
|
22
|
-
|
|
23
|
-
if request.model_engine == "model1":
|
|
24
|
-
|
|
25
|
-
def do_it():
|
|
26
|
-
return self.invoke_model1(raw_request)
|
|
27
|
-
|
|
28
|
-
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
29
|
-
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
30
|
-
completions = [
|
|
31
|
-
Sequence(
|
|
32
|
-
text=text,
|
|
33
|
-
logprob=logprob,
|
|
34
|
-
tokens=[Token(text=text, logprob=logprob, top_logprobs=response["completions"])],
|
|
35
|
-
)
|
|
36
|
-
for text, logprob in response["completions"].items()
|
|
37
|
-
]
|
|
38
|
-
else:
|
|
39
|
-
raise ValueError(f"Invalid model: {request.model}")
|
|
40
|
-
|
|
41
|
-
return RequestResult(
|
|
42
|
-
success=True,
|
|
43
|
-
cached=False,
|
|
44
|
-
request_time=0,
|
|
45
|
-
request_datetime=response.get("request_datetime"),
|
|
46
|
-
completions=completions,
|
|
47
|
-
embedding=[],
|
|
48
|
-
)
|
|
49
|
-
|
|
50
|
-
def invoke_model1(self, raw_request: Dict) -> Dict:
|
|
51
|
-
"""
|
|
52
|
-
Example: 7 2 4 6
|
|
53
|
-
Completions (num_completions = 3):
|
|
54
|
-
- 6
|
|
55
|
-
- 4
|
|
56
|
-
- 2
|
|
57
|
-
"""
|
|
58
|
-
prompt_tokens: List[str] = SimpleTokenizer.tokenize_by_space(raw_request["prompt"])
|
|
59
|
-
choices = reversed(prompt_tokens[-raw_request["n"] :])
|
|
60
|
-
response = {"completions": dict((text, -i) for i, text in enumerate(choices))}
|
|
61
|
-
return response
|
|
@@ -1,63 +0,0 @@
|
|
|
1
|
-
# mypy: check_untyped_defs = False
|
|
2
|
-
import os
|
|
3
|
-
import tempfile
|
|
4
|
-
from typing import List
|
|
5
|
-
|
|
6
|
-
from helm.common.cache import SqliteCacheConfig
|
|
7
|
-
from helm.common.tokenization_request import (
|
|
8
|
-
DecodeRequest,
|
|
9
|
-
DecodeRequestResult,
|
|
10
|
-
TokenizationRequest,
|
|
11
|
-
TokenizationRequestResult,
|
|
12
|
-
)
|
|
13
|
-
from helm.proxy.tokenizers.anthropic_tokenizer import AnthropicTokenizer
|
|
14
|
-
from .anthropic_client import AnthropicClient
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class TestAnthropicClient:
|
|
18
|
-
TEST_PROMPT: str = "I am a computer scientist."
|
|
19
|
-
TEST_ENCODED: List[int] = [45, 1413, 269, 6797, 22228, 18]
|
|
20
|
-
TEST_TOKENS: List[str] = ["I", " am", " a", " computer", " scientist", "."]
|
|
21
|
-
|
|
22
|
-
def setup_method(self, method):
|
|
23
|
-
cache_file = tempfile.NamedTemporaryFile(delete=False)
|
|
24
|
-
self.cache_path: str = cache_file.name
|
|
25
|
-
self.client = AnthropicClient(
|
|
26
|
-
tokenizer=AnthropicTokenizer(SqliteCacheConfig(self.cache_path)),
|
|
27
|
-
cache_config=SqliteCacheConfig(self.cache_path),
|
|
28
|
-
)
|
|
29
|
-
|
|
30
|
-
def teardown_method(self, method):
|
|
31
|
-
os.remove(self.cache_path)
|
|
32
|
-
|
|
33
|
-
def test_tokenize(self):
|
|
34
|
-
request = TokenizationRequest(text=self.TEST_PROMPT)
|
|
35
|
-
result: TokenizationRequestResult = self.client.tokenize(request)
|
|
36
|
-
assert not result.cached, "First time making the tokenize request. Result should not be cached"
|
|
37
|
-
assert result.raw_tokens == self.TEST_TOKENS
|
|
38
|
-
result: TokenizationRequestResult = self.client.tokenize(request)
|
|
39
|
-
assert result.cached, "Result should be cached"
|
|
40
|
-
assert result.raw_tokens == self.TEST_TOKENS
|
|
41
|
-
|
|
42
|
-
def test_encode(self):
|
|
43
|
-
request = TokenizationRequest(text=self.TEST_PROMPT, encode=True, truncation=True, max_length=1)
|
|
44
|
-
result: TokenizationRequestResult = self.client.tokenize(request)
|
|
45
|
-
assert not result.cached, "First time making the tokenize request. Result should not be cached"
|
|
46
|
-
assert result.raw_tokens == [self.TEST_ENCODED[0]]
|
|
47
|
-
result: TokenizationRequestResult = self.client.tokenize(request)
|
|
48
|
-
assert result.cached, "Result should be cached"
|
|
49
|
-
assert result.raw_tokens == [self.TEST_ENCODED[0]]
|
|
50
|
-
|
|
51
|
-
request = TokenizationRequest(text=self.TEST_PROMPT, encode=True, truncation=True, max_length=1024)
|
|
52
|
-
result = self.client.tokenize(request)
|
|
53
|
-
assert not result.cached, "First time making this particular request. Result should not be cached"
|
|
54
|
-
assert result.raw_tokens == self.TEST_ENCODED
|
|
55
|
-
|
|
56
|
-
def test_decode(self):
|
|
57
|
-
request = DecodeRequest(tokens=self.TEST_ENCODED)
|
|
58
|
-
result: DecodeRequestResult = self.client.decode(request)
|
|
59
|
-
assert not result.cached, "First time making the decode request. Result should not be cached"
|
|
60
|
-
assert result.text == self.TEST_PROMPT
|
|
61
|
-
result: DecodeRequestResult = self.client.decode(request)
|
|
62
|
-
assert result.cached, "Result should be cached"
|
|
63
|
-
assert result.text == self.TEST_PROMPT
|
|
@@ -1,31 +0,0 @@
|
|
|
1
|
-
from .client import truncate_sequence
|
|
2
|
-
from typing import List
|
|
3
|
-
from helm.common.request import Request, Sequence, Token
|
|
4
|
-
|
|
5
|
-
|
|
6
|
-
def truncate_sequence_helper(tokens: List[str], request: Request, expected_tokens: List[str]):
|
|
7
|
-
sequence = Sequence(
|
|
8
|
-
text="".join(tokens),
|
|
9
|
-
tokens=[Token(text=text, logprob=-1, top_logprobs={}) for text in tokens],
|
|
10
|
-
logprob=-len(tokens),
|
|
11
|
-
)
|
|
12
|
-
|
|
13
|
-
output_sequence = truncate_sequence(sequence, request)
|
|
14
|
-
|
|
15
|
-
assert expected_tokens == [token.text for token in output_sequence.tokens]
|
|
16
|
-
assert "".join(expected_tokens) == output_sequence.text
|
|
17
|
-
assert output_sequence.logprob == sum(token.logprob for token in output_sequence.tokens)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def test_truncate_sequence():
|
|
21
|
-
# echo_prompt = True, nothing gets truncated
|
|
22
|
-
truncate_sequence_helper(["a", "b", "c"], Request(prompt="abc", echo_prompt=True), ["a", "b", "c"])
|
|
23
|
-
|
|
24
|
-
# Nothing gets truncated
|
|
25
|
-
truncate_sequence_helper(["hello", " world"], Request(stop_sequences=["#"]), ["hello", " world"])
|
|
26
|
-
|
|
27
|
-
# Truncate using stop sequences
|
|
28
|
-
truncate_sequence_helper(["hello", " world", "\n", "what"], Request(stop_sequences=["\n"]), ["hello", " world"])
|
|
29
|
-
|
|
30
|
-
# Truncate using max tokens
|
|
31
|
-
truncate_sequence_helper(["a", "b", "c"], Request(max_tokens=2), ["a", "b"])
|
|
@@ -1,87 +0,0 @@
|
|
|
1
|
-
# mypy: check_untyped_defs = False
|
|
2
|
-
import os
|
|
3
|
-
import pytest
|
|
4
|
-
import tempfile
|
|
5
|
-
|
|
6
|
-
from helm.common.cache import SqliteCacheConfig
|
|
7
|
-
from helm.common.request import Request, RequestResult
|
|
8
|
-
from helm.common.tokenization_request import (
|
|
9
|
-
DecodeRequest,
|
|
10
|
-
DecodeRequestResult,
|
|
11
|
-
TokenizationRequest,
|
|
12
|
-
TokenizationRequestResult,
|
|
13
|
-
)
|
|
14
|
-
from helm.proxy.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
|
|
15
|
-
from .huggingface_client import HuggingFaceClient
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class TestHuggingFaceClient:
|
|
19
|
-
def setup_method(self, method):
|
|
20
|
-
cache_file = tempfile.NamedTemporaryFile(delete=False)
|
|
21
|
-
self.cache_path: str = cache_file.name
|
|
22
|
-
self.client = HuggingFaceClient(
|
|
23
|
-
tokenizer=HuggingFaceTokenizer(SqliteCacheConfig(self.cache_path)),
|
|
24
|
-
cache_config=SqliteCacheConfig(self.cache_path),
|
|
25
|
-
)
|
|
26
|
-
|
|
27
|
-
def teardown_method(self, method):
|
|
28
|
-
os.remove(self.cache_path)
|
|
29
|
-
|
|
30
|
-
def test_tokenize(self):
|
|
31
|
-
request = TokenizationRequest(text="I am a computer scientist.")
|
|
32
|
-
result: TokenizationRequestResult = self.client.tokenize(request)
|
|
33
|
-
assert not result.cached, "First time making the tokenize request. Result should not be cached"
|
|
34
|
-
result: TokenizationRequestResult = self.client.tokenize(request)
|
|
35
|
-
assert result.cached, "Result should be cached"
|
|
36
|
-
assert result.raw_tokens == ["I", " am", " a", " computer", " scientist", "."]
|
|
37
|
-
|
|
38
|
-
def test_encode(self):
|
|
39
|
-
request = TokenizationRequest(text="I am a computer scientist.", encode=True, truncation=True, max_length=1)
|
|
40
|
-
result: TokenizationRequestResult = self.client.tokenize(request)
|
|
41
|
-
assert not result.cached, "First time making the tokenize request. Result should not be cached"
|
|
42
|
-
result: TokenizationRequestResult = self.client.tokenize(request)
|
|
43
|
-
assert result.cached, "Result should be cached"
|
|
44
|
-
assert result.raw_tokens == [40]
|
|
45
|
-
|
|
46
|
-
request = TokenizationRequest(text="I am a computer scientist.", encode=True, truncation=True, max_length=1024)
|
|
47
|
-
result = self.client.tokenize(request)
|
|
48
|
-
assert not result.cached, "First time making this particular request. Result should not be cached"
|
|
49
|
-
assert result.raw_tokens == [40, 716, 257, 3644, 11444, 13]
|
|
50
|
-
|
|
51
|
-
def test_decode(self):
|
|
52
|
-
request = DecodeRequest(tokens=[40, 716, 257, 3644, 11444, 13])
|
|
53
|
-
result: DecodeRequestResult = self.client.decode(request)
|
|
54
|
-
assert not result.cached, "First time making the decode request. Result should not be cached"
|
|
55
|
-
result: DecodeRequestResult = self.client.decode(request)
|
|
56
|
-
assert result.cached, "Result should be cached"
|
|
57
|
-
assert result.text == "I am a computer scientist."
|
|
58
|
-
|
|
59
|
-
def test_gpt2(self):
|
|
60
|
-
prompt: str = "I am a computer scientist."
|
|
61
|
-
result: RequestResult = self.client.make_request(
|
|
62
|
-
Request(
|
|
63
|
-
model="huggingface/gpt2",
|
|
64
|
-
prompt=prompt,
|
|
65
|
-
num_completions=3,
|
|
66
|
-
top_k_per_token=5,
|
|
67
|
-
max_tokens=0,
|
|
68
|
-
echo_prompt=True,
|
|
69
|
-
)
|
|
70
|
-
)
|
|
71
|
-
assert len(result.completions) == 3
|
|
72
|
-
assert result.completions[0].text.startswith(
|
|
73
|
-
prompt
|
|
74
|
-
), "echo_prompt was set to true. Expected the prompt at the beginning of each completion"
|
|
75
|
-
|
|
76
|
-
@pytest.mark.skip(reason="GPT-J 6B is 22 GB and extremely slow without a GPU.")
|
|
77
|
-
def test_gptj_6b(self):
|
|
78
|
-
result: RequestResult = self.client.make_request(
|
|
79
|
-
Request(
|
|
80
|
-
model="huggingface/gpt-j-6b",
|
|
81
|
-
prompt="I am a computer scientist.",
|
|
82
|
-
num_completions=3,
|
|
83
|
-
top_k_per_token=5,
|
|
84
|
-
max_tokens=0,
|
|
85
|
-
)
|
|
86
|
-
)
|
|
87
|
-
assert len(result.completions) == 3
|