crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import requests
|
|
2
|
+
from typing import Any, Dict, List, Optional, TypedDict
|
|
3
|
+
|
|
4
|
+
from helm.proxy.retry import NonRetriableException
|
|
5
|
+
from helm.common.cache import CacheConfig
|
|
6
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
7
|
+
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput
|
|
8
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
9
|
+
from .client import CachingClient, truncate_and_tokenize_response_text
|
|
10
|
+
|
|
11
|
+
try:
|
|
12
|
+
from mistralai.client import MistralClient
|
|
13
|
+
from mistralai.models.chat_completion import ChatMessage, ChatCompletionResponse
|
|
14
|
+
except ModuleNotFoundError as e:
|
|
15
|
+
handle_module_not_found_error(e, ["mistral"])
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class MistralAIRequest(TypedDict):
|
|
19
|
+
"""Data passed between make_request and _send_request. Used as the cache key."""
|
|
20
|
+
|
|
21
|
+
model: str
|
|
22
|
+
prompt: str
|
|
23
|
+
max_tokens: int
|
|
24
|
+
temperature: float
|
|
25
|
+
top_p: float
|
|
26
|
+
random_seed: Optional[int]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class MistralAIClient(CachingClient):
|
|
30
|
+
"""
|
|
31
|
+
Client for Mistral API.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(
|
|
35
|
+
self,
|
|
36
|
+
tokenizer: Tokenizer,
|
|
37
|
+
tokenizer_name: str,
|
|
38
|
+
cache_config: CacheConfig,
|
|
39
|
+
api_key: str,
|
|
40
|
+
mistral_model: Optional[str] = None,
|
|
41
|
+
):
|
|
42
|
+
super().__init__(cache_config=cache_config)
|
|
43
|
+
self.api_key: str = api_key
|
|
44
|
+
self.tokenizer = tokenizer
|
|
45
|
+
self.tokenizer_name = tokenizer_name
|
|
46
|
+
self._client = MistralClient(api_key=self.api_key)
|
|
47
|
+
self.mistral_model = mistral_model
|
|
48
|
+
|
|
49
|
+
def _send_request(self, raw_request: MistralAIRequest) -> Dict[str, Any]:
|
|
50
|
+
messages = [ChatMessage(role="user", content=raw_request["prompt"])]
|
|
51
|
+
|
|
52
|
+
chat_response: ChatCompletionResponse = self._client.chat(
|
|
53
|
+
model=raw_request["model"],
|
|
54
|
+
messages=messages,
|
|
55
|
+
temperature=raw_request["temperature"],
|
|
56
|
+
max_tokens=raw_request["max_tokens"],
|
|
57
|
+
top_p=raw_request["top_p"],
|
|
58
|
+
random_seed=raw_request["random_seed"],
|
|
59
|
+
safe_prompt=False, # Disable safe_prompt
|
|
60
|
+
)
|
|
61
|
+
# Documentation: "If mode is 'json', the output will only contain JSON serializable types."
|
|
62
|
+
# Source: https://docs.pydantic.dev/latest/api/base_model/#pydantic.BaseModel.model_dump
|
|
63
|
+
#
|
|
64
|
+
# We need to ensure that the output only contains JSON serializable types because the output
|
|
65
|
+
# will be serialized for storage in the cache.
|
|
66
|
+
return chat_response.model_dump(mode="json")
|
|
67
|
+
|
|
68
|
+
def _get_random_seed(self, request: Request, completion_index: int) -> Optional[int]:
|
|
69
|
+
if request.random is None and completion_index == 0:
|
|
70
|
+
return None
|
|
71
|
+
|
|
72
|
+
# Treat the user's request.random as an integer for the random seed.
|
|
73
|
+
try:
|
|
74
|
+
request_random_seed = int(request.random) if request.random is not None else 0
|
|
75
|
+
except ValueError:
|
|
76
|
+
raise NonRetriableException("MistralAIClient only supports integer values for request.random")
|
|
77
|
+
|
|
78
|
+
# A large prime is used so that the resulting values are unlikely to collide
|
|
79
|
+
# with request.random values chosen by the user.
|
|
80
|
+
fixed_large_prime = 1911011
|
|
81
|
+
completion_index_random_seed = completion_index * fixed_large_prime
|
|
82
|
+
|
|
83
|
+
return request_random_seed + completion_index_random_seed
|
|
84
|
+
|
|
85
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
86
|
+
"""Make a request"""
|
|
87
|
+
completions: List[GeneratedOutput] = []
|
|
88
|
+
|
|
89
|
+
# `num_completions` is not supported, so instead make `num_completions` separate requests.
|
|
90
|
+
for completion_index in range(request.num_completions):
|
|
91
|
+
try:
|
|
92
|
+
raw_request: MistralAIRequest = {
|
|
93
|
+
"model": self.mistral_model or request.model_engine,
|
|
94
|
+
"prompt": request.prompt,
|
|
95
|
+
"max_tokens": request.max_tokens,
|
|
96
|
+
"temperature": request.temperature,
|
|
97
|
+
"top_p": request.top_p,
|
|
98
|
+
"random_seed": self._get_random_seed(request, completion_index),
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
def do_it() -> Dict[str, Any]:
|
|
102
|
+
result: Dict[str, Any] = self._send_request(raw_request)
|
|
103
|
+
return result
|
|
104
|
+
|
|
105
|
+
# We need to include the engine's name to differentiate among requests made for different model
|
|
106
|
+
# engines since the engine name is not included in the request itself.
|
|
107
|
+
# In addition, we want to make `request.num_completions` fresh
|
|
108
|
+
# requests, cache key should contain the completion_index.
|
|
109
|
+
# Echoing the original prompt is not officially supported by Mistral. We instead prepend the
|
|
110
|
+
# completion with the prompt when `echo_prompt` is true, so keep track of it in the cache key.
|
|
111
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
112
|
+
|
|
113
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
114
|
+
except (requests.exceptions.RequestException, AssertionError) as e:
|
|
115
|
+
error: str = f"MistralClient error: {e}"
|
|
116
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
117
|
+
|
|
118
|
+
response_message: Dict[str, Any] = response["choices"][0]["message"]
|
|
119
|
+
assert response_message["role"] == "assistant"
|
|
120
|
+
response_text: str = response_message["content"]
|
|
121
|
+
|
|
122
|
+
# The Mistral API doesn't support echo. If `echo_prompt` is true, combine the prompt and completion.
|
|
123
|
+
text: str = request.prompt + response_text if request.echo_prompt else response_text
|
|
124
|
+
sequence = truncate_and_tokenize_response_text(text, request, self.tokenizer, self.tokenizer_name)
|
|
125
|
+
completions.append(sequence)
|
|
126
|
+
|
|
127
|
+
return RequestResult(
|
|
128
|
+
success=True,
|
|
129
|
+
cached=cached,
|
|
130
|
+
request_time=response["request_time"],
|
|
131
|
+
request_datetime=response["request_datetime"],
|
|
132
|
+
completions=completions,
|
|
133
|
+
embedding=[],
|
|
134
|
+
)
|
|
@@ -0,0 +1,109 @@
|
|
|
1
|
+
from typing import Any, Dict
|
|
2
|
+
|
|
3
|
+
from helm.common.request import wrap_request_time
|
|
4
|
+
from helm.common.cache import Cache, CacheConfig
|
|
5
|
+
from helm.common.moderations_api_request import (
|
|
6
|
+
ModerationCategoryScores,
|
|
7
|
+
ModerationCategoryFlaggedResults,
|
|
8
|
+
ModerationAPIRequest,
|
|
9
|
+
ModerationAPIRequestResult,
|
|
10
|
+
)
|
|
11
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ModerationAPIClient:
|
|
15
|
+
"""
|
|
16
|
+
From https://beta.openai.com/docs/guides/moderation/overview, the moderation endpoint is a tool
|
|
17
|
+
to check whether content complies with OpenAI's content policy. Developers can thus identify content
|
|
18
|
+
that OpenAI's content policy prohibits and take action, for instance by filtering it.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
# For descriptions of the models, see https://beta.openai.com/docs/api-reference/moderations/create
|
|
22
|
+
LATEST_MODEL: str = "text-moderation-latest"
|
|
23
|
+
STABLE_MODEL: str = "text-moderation-stable"
|
|
24
|
+
|
|
25
|
+
# List of categories (https://beta.openai.com/docs/guides/moderation/overview)
|
|
26
|
+
HATE: str = "hate"
|
|
27
|
+
HATE_THREATENING: str = "hate/threatening"
|
|
28
|
+
SELF_HARM: str = "self-harm"
|
|
29
|
+
SEXUAL: str = "sexual"
|
|
30
|
+
SEXUAL_MINORS: str = "sexual/minors"
|
|
31
|
+
VIOLENCE: str = "violence"
|
|
32
|
+
VIOLENCE_GRAPHIC: str = "violence/graphic"
|
|
33
|
+
|
|
34
|
+
def __init__(self, api_key: str, cache_config: CacheConfig):
|
|
35
|
+
self.cache = Cache(cache_config)
|
|
36
|
+
try:
|
|
37
|
+
from openai import OpenAI
|
|
38
|
+
except ModuleNotFoundError as e:
|
|
39
|
+
handle_module_not_found_error(e, ["openai"])
|
|
40
|
+
# TODO: Add OpenAI organization.
|
|
41
|
+
self.client = OpenAI(api_key=api_key)
|
|
42
|
+
|
|
43
|
+
def get_moderation_results(self, request: ModerationAPIRequest) -> ModerationAPIRequestResult:
|
|
44
|
+
"""
|
|
45
|
+
Sends a request to OpenAI's moderation endpoint.
|
|
46
|
+
https://beta.openai.com/docs/api-reference/moderations/create
|
|
47
|
+
"""
|
|
48
|
+
try:
|
|
49
|
+
import openai
|
|
50
|
+
except ModuleNotFoundError as e:
|
|
51
|
+
handle_module_not_found_error(e, ["openai"])
|
|
52
|
+
|
|
53
|
+
raw_request: Dict[str, str] = {
|
|
54
|
+
"input": request.text,
|
|
55
|
+
"model": self.LATEST_MODEL if request.use_latest_model else self.STABLE_MODEL,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
try:
|
|
59
|
+
|
|
60
|
+
def do_it() -> Dict[str, Any]:
|
|
61
|
+
result = self.client.moderations.create(input=request.text).model_dump(mode="json")
|
|
62
|
+
assert "results" in result and len(result["results"]) > 0, f"Invalid response: {result}"
|
|
63
|
+
return result
|
|
64
|
+
|
|
65
|
+
response, cached = self.cache.get(raw_request, wrap_request_time(do_it))
|
|
66
|
+
except openai.OpenAIError as e:
|
|
67
|
+
error: str = f"Moderation API error: {e}"
|
|
68
|
+
return ModerationAPIRequestResult(
|
|
69
|
+
success=False, cached=False, error=error, flagged=None, flagged_results=None, scores=None
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
moderation_result = response["results"][0]
|
|
73
|
+
category_results: Dict[str, bool] = moderation_result["categories"]
|
|
74
|
+
score_results: Dict[str, float] = moderation_result["category_scores"]
|
|
75
|
+
|
|
76
|
+
flagged_results = ModerationCategoryFlaggedResults(
|
|
77
|
+
hate_flagged=category_results[self.HATE],
|
|
78
|
+
hate_threatening_flagged=category_results[self.HATE_THREATENING],
|
|
79
|
+
self_harm_flagged=category_results[self.SELF_HARM],
|
|
80
|
+
sexual_flagged=category_results[self.SEXUAL],
|
|
81
|
+
sexual_minors_flagged=category_results[self.SEXUAL_MINORS],
|
|
82
|
+
violence_flagged=category_results[self.VIOLENCE],
|
|
83
|
+
violence_graphic_flagged=category_results[self.VIOLENCE_GRAPHIC],
|
|
84
|
+
)
|
|
85
|
+
scores = ModerationCategoryScores(
|
|
86
|
+
hate_score=score_results[self.HATE],
|
|
87
|
+
hate_threatening_score=score_results[self.HATE_THREATENING],
|
|
88
|
+
self_harm_score=score_results[self.SELF_HARM],
|
|
89
|
+
sexual_score=score_results[self.SEXUAL],
|
|
90
|
+
sexual_minors_score=score_results[self.SEXUAL_MINORS],
|
|
91
|
+
violence_score=score_results[self.VIOLENCE],
|
|
92
|
+
violence_graphic_score=score_results[self.VIOLENCE_GRAPHIC],
|
|
93
|
+
)
|
|
94
|
+
return ModerationAPIRequestResult(
|
|
95
|
+
success=True,
|
|
96
|
+
cached=cached,
|
|
97
|
+
flagged=moderation_result["flagged"],
|
|
98
|
+
flagged_results=flagged_results,
|
|
99
|
+
scores=scores,
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
def will_be_flagged(self, text: str) -> bool:
|
|
103
|
+
"""Returns True if the text is against OpenAI's content policy and will be flagged, False otherwise."""
|
|
104
|
+
result: ModerationAPIRequestResult = self.get_moderation_results(
|
|
105
|
+
# Use the latest model so the account does not get banned
|
|
106
|
+
ModerationAPIRequest(text=text, use_latest_model=True)
|
|
107
|
+
)
|
|
108
|
+
assert result.flagged is not None
|
|
109
|
+
return result.flagged
|
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
from threading import Lock
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from transformers import AutoConfig, AutoModelForCausalLM
|
|
5
|
+
from helm.common.cache import CacheConfig
|
|
6
|
+
|
|
7
|
+
from helm.common.optional_dependencies import OptionalDependencyNotInstalled
|
|
8
|
+
from helm.clients.huggingface_client import HuggingFaceClient
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
_register_open_lm_lock = Lock()
|
|
12
|
+
_register_open_lm_done = False
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def _register_open_lm_for_auto_model():
|
|
16
|
+
"""Register OpenLMForCausalLM for AutoModelForCausalLM."""
|
|
17
|
+
try:
|
|
18
|
+
from open_lm.utils.transformers.hf_model import OpenLMforCausalLM
|
|
19
|
+
from open_lm.utils.transformers.hf_config import OpenLMConfig
|
|
20
|
+
except ModuleNotFoundError as e:
|
|
21
|
+
# Provide manual instructions for installing open_lm from GitHub
|
|
22
|
+
# because PyPI does not allow installing dependencies directly from GitHub.
|
|
23
|
+
raise OptionalDependencyNotInstalled(
|
|
24
|
+
f"Optional dependency {e.name} is not installed. "
|
|
25
|
+
"Please run `pip install open_lm@git+https://github.com/mlfoundations/open_lm.git@main` to install it."
|
|
26
|
+
) from e
|
|
27
|
+
|
|
28
|
+
with _register_open_lm_lock:
|
|
29
|
+
global _register_open_lm_done
|
|
30
|
+
if not _register_open_lm_done:
|
|
31
|
+
AutoConfig.register("openlm", OpenLMConfig)
|
|
32
|
+
AutoModelForCausalLM.register(OpenLMConfig, OpenLMforCausalLM)
|
|
33
|
+
_register_open_lm_done = True
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class OpenLMClient(HuggingFaceClient):
|
|
37
|
+
"""Client for OpenLM: https://github.com/mlfoundations/open_lm"""
|
|
38
|
+
|
|
39
|
+
def __init__(self, cache_config: CacheConfig, pretrained_model_name_or_path: Optional[str] = None, **kwargs):
|
|
40
|
+
_register_open_lm_for_auto_model()
|
|
41
|
+
super().__init__(
|
|
42
|
+
cache_config=cache_config, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs
|
|
43
|
+
)
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
# mypy: check_untyped_defs = False
|
|
2
|
+
from dataclasses import replace
|
|
3
|
+
from typing import Any, Dict, List, Optional, cast, Union
|
|
4
|
+
|
|
5
|
+
from helm.benchmark.model_metadata_registry import is_vlm
|
|
6
|
+
from helm.common.cache import CacheConfig
|
|
7
|
+
from helm.common.media_object import TEXT_TYPE
|
|
8
|
+
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
9
|
+
from helm.common.hierarchical_logger import hlog
|
|
10
|
+
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
11
|
+
from helm.common.tokenization_request import (
|
|
12
|
+
TokenizationRequest,
|
|
13
|
+
TokenizationRequestResult,
|
|
14
|
+
)
|
|
15
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
16
|
+
from .client import CachingClient, truncate_sequence, generate_uid_for_multimodal_prompt
|
|
17
|
+
|
|
18
|
+
try:
|
|
19
|
+
import openai
|
|
20
|
+
from openai import OpenAI
|
|
21
|
+
except ModuleNotFoundError as e:
|
|
22
|
+
handle_module_not_found_error(e, ["openai"])
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class OpenAIClient(CachingClient):
|
|
26
|
+
END_OF_TEXT: str = "<|endoftext|>"
|
|
27
|
+
|
|
28
|
+
# Error OpenAI throws when the image in the prompt violates their content policy
|
|
29
|
+
INAPPROPRIATE_IMAGE_ERROR: str = "Your input image may contain content that is not allowed by our safety system"
|
|
30
|
+
|
|
31
|
+
# Set the finish reason to this if the prompt violates OpenAI's content policy
|
|
32
|
+
CONTENT_POLICY_VIOLATED_FINISH_REASON: str = (
|
|
33
|
+
"The prompt violates OpenAI's content policy. "
|
|
34
|
+
"See https://labs.openai.com/policies/content-policy for more information."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
tokenizer: Tokenizer,
|
|
40
|
+
tokenizer_name: str,
|
|
41
|
+
cache_config: CacheConfig,
|
|
42
|
+
api_key: Optional[str] = None,
|
|
43
|
+
org_id: Optional[str] = None,
|
|
44
|
+
base_url: Optional[str] = None,
|
|
45
|
+
):
|
|
46
|
+
super().__init__(cache_config=cache_config)
|
|
47
|
+
self.tokenizer = tokenizer
|
|
48
|
+
self.tokenizer_name = tokenizer_name
|
|
49
|
+
self.client = OpenAI(api_key=api_key, organization=org_id, base_url=base_url)
|
|
50
|
+
|
|
51
|
+
def _is_chat_model_engine(self, model_engine: str) -> bool:
|
|
52
|
+
if model_engine == "gpt-3.5-turbo-instruct":
|
|
53
|
+
return False
|
|
54
|
+
elif model_engine.startswith("gpt-3.5") or model_engine.startswith("gpt-4"):
|
|
55
|
+
return True
|
|
56
|
+
return False
|
|
57
|
+
|
|
58
|
+
def _get_model_for_request(self, request: Request) -> str:
|
|
59
|
+
return request.model_engine
|
|
60
|
+
|
|
61
|
+
def _get_cache_key(self, raw_request: Dict, request: Request):
|
|
62
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
63
|
+
if is_vlm(request.model):
|
|
64
|
+
assert request.multimodal_prompt is not None
|
|
65
|
+
prompt_key: str = generate_uid_for_multimodal_prompt(request.multimodal_prompt)
|
|
66
|
+
cache_key = {**cache_key, "multimodal_prompt": prompt_key}
|
|
67
|
+
del cache_key["messages"]
|
|
68
|
+
return cache_key
|
|
69
|
+
|
|
70
|
+
def _make_embedding_request(self, request: Request) -> RequestResult:
|
|
71
|
+
raw_request: Dict[str, Any]
|
|
72
|
+
raw_request = {
|
|
73
|
+
"input": request.prompt,
|
|
74
|
+
# Note: In older deprecated versions of the OpenAI API, "model" used to be "engine".
|
|
75
|
+
"model": self._get_model_for_request(request),
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
def do_it() -> Dict[str, Any]:
|
|
79
|
+
return self.client.embeddings.create(**raw_request).model_dump(mode="json")
|
|
80
|
+
|
|
81
|
+
try:
|
|
82
|
+
cache_key = self._get_cache_key(raw_request, request)
|
|
83
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
84
|
+
except openai.OpenAIError as e:
|
|
85
|
+
error: str = f"OpenAI error: {e}"
|
|
86
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
87
|
+
|
|
88
|
+
# If the user is requesting completions instead of an embedding, then `completions`
|
|
89
|
+
# needs to be populated, and `embedding` should be an empty list and vice-versa.
|
|
90
|
+
embedding: List[float] = []
|
|
91
|
+
# If the user is requesting an embedding instead of completion
|
|
92
|
+
# then completions would be left as an empty list. The embedding needs to be set.
|
|
93
|
+
embedding = response["data"][0]["embedding"]
|
|
94
|
+
|
|
95
|
+
return RequestResult(
|
|
96
|
+
success=True,
|
|
97
|
+
cached=cached,
|
|
98
|
+
request_time=response["request_time"],
|
|
99
|
+
request_datetime=response.get("request_datetime"),
|
|
100
|
+
completions=[],
|
|
101
|
+
embedding=embedding,
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
def _make_chat_request(self, request: Request) -> RequestResult:
|
|
105
|
+
messages: Optional[List[Dict[str, Union[str, Any]]]] = request.messages
|
|
106
|
+
if request.messages is not None:
|
|
107
|
+
# Checks that all messages have a role and some content
|
|
108
|
+
for message in request.messages:
|
|
109
|
+
if not message.get("role") or not message.get("content"):
|
|
110
|
+
raise ValueError("All messages must have a role and content")
|
|
111
|
+
# Checks that the last role is "user"
|
|
112
|
+
if request.messages[-1]["role"] != "user":
|
|
113
|
+
raise ValueError("Last message must have role 'user'")
|
|
114
|
+
if request.prompt != "":
|
|
115
|
+
hlog("WARNING: Since message is set, prompt will be ignored")
|
|
116
|
+
else:
|
|
117
|
+
# Convert prompt into a single message
|
|
118
|
+
# For now, put the whole prompt in a single user message, and expect the response
|
|
119
|
+
# to be returned in a single assistant message.
|
|
120
|
+
# TODO: Support ChatML for creating multiple messages with different roles.
|
|
121
|
+
# See: https://github.com/openai/openai-python/blob/main/chatml.md
|
|
122
|
+
|
|
123
|
+
# Content can either be text or a list of multimodal content made up of text and images:
|
|
124
|
+
# https://platform.openai.com/docs/guides/vision
|
|
125
|
+
content: Union[str, List[Union[str, Any]]]
|
|
126
|
+
if request.multimodal_prompt is not None:
|
|
127
|
+
content = []
|
|
128
|
+
for media_object in request.multimodal_prompt.media_objects:
|
|
129
|
+
if media_object.is_type("image") and media_object.location:
|
|
130
|
+
from helm.common.images_utils import encode_base64
|
|
131
|
+
|
|
132
|
+
base64_image: str = encode_base64(media_object.location)
|
|
133
|
+
content.append(
|
|
134
|
+
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{base64_image}"}}
|
|
135
|
+
)
|
|
136
|
+
elif media_object.is_type(TEXT_TYPE):
|
|
137
|
+
if media_object.text is None:
|
|
138
|
+
raise ValueError("MediaObject of text type has missing text field value")
|
|
139
|
+
content.append({"type": media_object.type, "text": media_object.text})
|
|
140
|
+
else:
|
|
141
|
+
raise ValueError(f"Unrecognized MediaObject type {media_object.type}")
|
|
142
|
+
|
|
143
|
+
else:
|
|
144
|
+
content = request.prompt
|
|
145
|
+
|
|
146
|
+
messages = [{"role": "user", "content": content}]
|
|
147
|
+
|
|
148
|
+
raw_request: Dict[str, Any] = {
|
|
149
|
+
"model": self._get_model_for_request(request),
|
|
150
|
+
"messages": messages,
|
|
151
|
+
"temperature": request.temperature,
|
|
152
|
+
"top_p": request.top_p,
|
|
153
|
+
"n": request.num_completions,
|
|
154
|
+
"stop": request.stop_sequences or None, # API doesn't like empty list
|
|
155
|
+
# Note: Chat models may require adding an extra token to max_tokens
|
|
156
|
+
# for the internal special role token.
|
|
157
|
+
"max_tokens": request.max_tokens,
|
|
158
|
+
"presence_penalty": request.presence_penalty,
|
|
159
|
+
"frequency_penalty": request.frequency_penalty,
|
|
160
|
+
}
|
|
161
|
+
|
|
162
|
+
# OpenAI's vision API doesn't allow None values for stop.
|
|
163
|
+
# Fails with "body -> stop: none is not an allowed value" if None is passed.
|
|
164
|
+
if is_vlm(request.model) and raw_request["stop"] is None:
|
|
165
|
+
raw_request.pop("stop")
|
|
166
|
+
|
|
167
|
+
def do_it() -> Dict[str, Any]:
|
|
168
|
+
return self.client.chat.completions.create(**raw_request).model_dump(mode="json")
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
cache_key = self._get_cache_key(raw_request, request)
|
|
172
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
173
|
+
except openai.OpenAIError as e:
|
|
174
|
+
if self.INAPPROPRIATE_IMAGE_ERROR in str(e):
|
|
175
|
+
hlog(f"Failed safety check: {str(request)}")
|
|
176
|
+
empty_completion = GeneratedOutput(
|
|
177
|
+
text="",
|
|
178
|
+
logprob=0,
|
|
179
|
+
tokens=[],
|
|
180
|
+
finish_reason={"reason": self.CONTENT_POLICY_VIOLATED_FINISH_REASON},
|
|
181
|
+
)
|
|
182
|
+
return RequestResult(
|
|
183
|
+
success=True,
|
|
184
|
+
cached=False,
|
|
185
|
+
request_time=0,
|
|
186
|
+
completions=[empty_completion] * request.num_completions,
|
|
187
|
+
embedding=[],
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
error: str = f"OpenAI error: {e}"
|
|
191
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
192
|
+
|
|
193
|
+
completions: List[GeneratedOutput] = []
|
|
194
|
+
for raw_completion in response["choices"]:
|
|
195
|
+
# The OpenAI chat completion API doesn't support echo.
|
|
196
|
+
# If `echo_prompt` is true, combine the prompt and completion.
|
|
197
|
+
raw_completion_content = raw_completion["message"]["content"]
|
|
198
|
+
text: str = request.prompt + raw_completion_content if request.echo_prompt else raw_completion_content
|
|
199
|
+
# The OpenAI chat completion API doesn't return us tokens or logprobs, so we tokenize ourselves.
|
|
200
|
+
tokenization_result: TokenizationRequestResult = self.tokenizer.tokenize(
|
|
201
|
+
TokenizationRequest(text, tokenizer=self.tokenizer_name)
|
|
202
|
+
)
|
|
203
|
+
# Log probs are not currently not supported by the OpenAI chat completion API, so set to 0 for now.
|
|
204
|
+
tokens: List[Token] = [
|
|
205
|
+
Token(text=cast(str, raw_token), logprob=0) for raw_token in tokenization_result.raw_tokens
|
|
206
|
+
]
|
|
207
|
+
completion = GeneratedOutput(
|
|
208
|
+
text=text,
|
|
209
|
+
logprob=0, # OpenAI does not provide logprobs
|
|
210
|
+
tokens=tokens,
|
|
211
|
+
finish_reason={"reason": raw_completion["finish_reason"]},
|
|
212
|
+
)
|
|
213
|
+
completions.append(truncate_sequence(completion, request)) # Truncate the text by stop sequences
|
|
214
|
+
|
|
215
|
+
return RequestResult(
|
|
216
|
+
success=True,
|
|
217
|
+
cached=cached,
|
|
218
|
+
request_time=response["request_time"],
|
|
219
|
+
request_datetime=response.get("request_datetime"),
|
|
220
|
+
completions=completions,
|
|
221
|
+
embedding=[],
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
def _to_raw_completion_request(self, request: Request) -> Dict[str, Any]:
|
|
225
|
+
raw_request: Dict[str, Any] = {
|
|
226
|
+
# Note: In older deprecated versions of the OpenAI API, "model" used to be "engine".
|
|
227
|
+
"model": self._get_model_for_request(request),
|
|
228
|
+
"prompt": request.prompt,
|
|
229
|
+
"temperature": request.temperature,
|
|
230
|
+
"n": request.num_completions,
|
|
231
|
+
"max_tokens": request.max_tokens,
|
|
232
|
+
"best_of": request.top_k_per_token,
|
|
233
|
+
"logprobs": request.top_k_per_token,
|
|
234
|
+
"stop": request.stop_sequences or None, # API doesn't like empty list
|
|
235
|
+
"top_p": request.top_p,
|
|
236
|
+
"presence_penalty": request.presence_penalty,
|
|
237
|
+
"frequency_penalty": request.frequency_penalty,
|
|
238
|
+
"echo": request.echo_prompt,
|
|
239
|
+
}
|
|
240
|
+
|
|
241
|
+
# OpenAI doesn't let you ask for more completions than the number of
|
|
242
|
+
# per-token candidates.
|
|
243
|
+
raw_request["best_of"] = max(raw_request["best_of"], raw_request["n"])
|
|
244
|
+
raw_request["logprobs"] = max(raw_request["logprobs"], raw_request["n"])
|
|
245
|
+
|
|
246
|
+
return raw_request
|
|
247
|
+
|
|
248
|
+
def _make_completion_request(self, request: Request) -> RequestResult:
|
|
249
|
+
raw_request = self._to_raw_completion_request(request)
|
|
250
|
+
|
|
251
|
+
def do_it() -> Dict[str, Any]:
|
|
252
|
+
return self.client.completions.create(**raw_request).model_dump(mode="json")
|
|
253
|
+
|
|
254
|
+
try:
|
|
255
|
+
cache_key = self._get_cache_key(raw_request, request)
|
|
256
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
257
|
+
except openai.OpenAIError as e:
|
|
258
|
+
error: str = f"OpenAI error: {e}"
|
|
259
|
+
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
260
|
+
|
|
261
|
+
completions: List[GeneratedOutput] = []
|
|
262
|
+
for raw_completion in response["choices"]:
|
|
263
|
+
sequence_logprob = 0
|
|
264
|
+
tokens: List[Token] = []
|
|
265
|
+
|
|
266
|
+
raw_data = raw_completion["logprobs"]
|
|
267
|
+
for (
|
|
268
|
+
text,
|
|
269
|
+
logprob,
|
|
270
|
+
) in zip(raw_data["tokens"], raw_data["token_logprobs"]):
|
|
271
|
+
tokens.append(Token(text=text, logprob=logprob or 0))
|
|
272
|
+
sequence_logprob += logprob or 0
|
|
273
|
+
completion = GeneratedOutput(
|
|
274
|
+
text=raw_completion["text"],
|
|
275
|
+
logprob=sequence_logprob,
|
|
276
|
+
tokens=tokens,
|
|
277
|
+
finish_reason={"reason": raw_completion["finish_reason"]},
|
|
278
|
+
)
|
|
279
|
+
# OpenAI sends us back tokens past the end of text token,
|
|
280
|
+
# so we need to manually truncate the list of tokens.
|
|
281
|
+
# TODO: filed an issue with their support to check what the expected behavior here is.
|
|
282
|
+
completion = truncate_sequence(
|
|
283
|
+
completion, replace(request, stop_sequences=request.stop_sequences + [OpenAIClient.END_OF_TEXT])
|
|
284
|
+
)
|
|
285
|
+
completions.append(completion)
|
|
286
|
+
|
|
287
|
+
return RequestResult(
|
|
288
|
+
success=True,
|
|
289
|
+
cached=cached,
|
|
290
|
+
request_time=response["request_time"],
|
|
291
|
+
request_datetime=response.get("request_datetime"),
|
|
292
|
+
completions=completions,
|
|
293
|
+
embedding=[],
|
|
294
|
+
)
|
|
295
|
+
|
|
296
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
297
|
+
if request.embedding:
|
|
298
|
+
return self._make_embedding_request(request)
|
|
299
|
+
elif self._is_chat_model_engine(request.model_engine):
|
|
300
|
+
return self._make_chat_request(request)
|
|
301
|
+
else:
|
|
302
|
+
return self._make_completion_request(request)
|
|
@@ -5,12 +5,12 @@ from typing import Any, Dict, List
|
|
|
5
5
|
|
|
6
6
|
from helm.common.cache import CacheConfig
|
|
7
7
|
from helm.common.hierarchical_logger import hlog
|
|
8
|
-
from helm.common.request import wrap_request_time, Request, RequestResult,
|
|
8
|
+
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token, ErrorFlags
|
|
9
9
|
from helm.common.tokenization_request import (
|
|
10
10
|
TokenizationRequest,
|
|
11
11
|
TokenizationRequestResult,
|
|
12
12
|
)
|
|
13
|
-
from helm.
|
|
13
|
+
from helm.tokenizers.tokenizer import Tokenizer
|
|
14
14
|
from .client import CachingClient, truncate_sequence
|
|
15
15
|
|
|
16
16
|
|
|
@@ -28,9 +28,11 @@ def _is_content_moderation_failure(response: Dict) -> bool:
|
|
|
28
28
|
|
|
29
29
|
|
|
30
30
|
class PalmyraClient(CachingClient):
|
|
31
|
-
def __init__(self,
|
|
32
|
-
super().__init__(cache_config=cache_config
|
|
31
|
+
def __init__(self, tokenizer: Tokenizer, tokenizer_name: str, cache_config: CacheConfig, api_key: str):
|
|
32
|
+
super().__init__(cache_config=cache_config)
|
|
33
33
|
self.api_key: str = api_key
|
|
34
|
+
self.tokenizer = tokenizer
|
|
35
|
+
self.tokenizer_name = tokenizer_name
|
|
34
36
|
|
|
35
37
|
def _send_request(self, model_name: str, raw_request: Dict[str, Any]) -> Dict[str, Any]:
|
|
36
38
|
response = requests.request(
|
|
@@ -65,14 +67,14 @@ class PalmyraClient(CachingClient):
|
|
|
65
67
|
# "random_seed": request.random,
|
|
66
68
|
}
|
|
67
69
|
|
|
68
|
-
completions: List[
|
|
70
|
+
completions: List[GeneratedOutput] = []
|
|
69
71
|
model_name: str = request.model_engine
|
|
70
72
|
|
|
71
73
|
# `num_completions` is not supported, so instead make `num_completions` separate requests.
|
|
72
74
|
for completion_index in range(request.num_completions):
|
|
73
75
|
try:
|
|
74
76
|
|
|
75
|
-
def do_it():
|
|
77
|
+
def do_it() -> Dict[str, Any]:
|
|
76
78
|
# Add an argument timeout to raw_request to avoid waiting getting timeout of 60s
|
|
77
79
|
# which happens for long prompts.
|
|
78
80
|
request_with_timeout = {"timeout": 300, **raw_request}
|
|
@@ -100,7 +102,10 @@ class PalmyraClient(CachingClient):
|
|
|
100
102
|
return RequestResult(success=False, cached=False, error=error, completions=[], embedding=[])
|
|
101
103
|
|
|
102
104
|
if _is_content_moderation_failure(response):
|
|
103
|
-
hlog(
|
|
105
|
+
hlog(
|
|
106
|
+
f"WARNING: Returning empty request for {request.model_deployment} "
|
|
107
|
+
"due to content moderation filter"
|
|
108
|
+
)
|
|
104
109
|
return RequestResult(
|
|
105
110
|
success=False,
|
|
106
111
|
cached=False,
|
|
@@ -119,15 +124,13 @@ class PalmyraClient(CachingClient):
|
|
|
119
124
|
# The Writer API doesn't return us tokens or logprobs, so we tokenize ourselves.
|
|
120
125
|
tokenization_result: TokenizationRequestResult = self.tokenizer.tokenize(
|
|
121
126
|
# Writer uses the GPT-2 tokenizer
|
|
122
|
-
TokenizationRequest(text, tokenizer=
|
|
127
|
+
TokenizationRequest(text, tokenizer=self.tokenizer_name)
|
|
123
128
|
)
|
|
124
129
|
|
|
125
130
|
# Log probs are not currently not supported by the Writer, so set to 0 for now.
|
|
126
|
-
tokens: List[Token] = [
|
|
127
|
-
Token(text=str(text), logprob=0, top_logprobs={}) for text in tokenization_result.raw_tokens
|
|
128
|
-
]
|
|
131
|
+
tokens: List[Token] = [Token(text=str(text), logprob=0) for text in tokenization_result.raw_tokens]
|
|
129
132
|
|
|
130
|
-
completion =
|
|
133
|
+
completion = GeneratedOutput(text=response_text, logprob=0, tokens=tokens)
|
|
131
134
|
sequence = truncate_sequence(completion, request, print_warning=True)
|
|
132
135
|
completions.append(sequence)
|
|
133
136
|
|