crfm-helm 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/METADATA +144 -36
- crfm_helm-0.5.0.dist-info/RECORD +642 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +37 -2
- helm/benchmark/adaptation/adapters/adapter.py +4 -42
- helm/benchmark/adaptation/adapters/adapter_factory.py +24 -27
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +2 -0
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +21 -4
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +104 -0
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +5 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -0
- helm/benchmark/adaptation/adapters/test_adapter.py +2 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +59 -14
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +40 -5
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +78 -10
- helm/benchmark/adaptation/common_adapter_specs.py +376 -0
- helm/benchmark/adaptation/prompt.py +7 -1
- helm/benchmark/adaptation/request_state.py +6 -1
- helm/benchmark/adaptation/scenario_state.py +6 -2
- helm/benchmark/annotation/annotator.py +43 -0
- helm/benchmark/annotation/annotator_factory.py +61 -0
- helm/benchmark/annotation/image2structure/image_compiler_annotator.py +88 -0
- helm/benchmark/annotation/image2structure/latex_compiler_annotator.py +59 -0
- helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +84 -0
- helm/benchmark/annotation/image2structure/webpage_compiler_annotator.py +132 -0
- helm/benchmark/annotation/test_annotator_factory.py +26 -0
- helm/benchmark/annotation/test_dummy_annotator.py +44 -0
- helm/benchmark/annotation_executor.py +124 -0
- helm/benchmark/augmentations/cleva_perturbation.py +7 -14
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +3 -3
- helm/benchmark/augmentations/contrast_sets_perturbation.py +0 -3
- helm/benchmark/augmentations/data_augmenter.py +0 -2
- helm/benchmark/augmentations/dialect_perturbation.py +2 -2
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +3 -3
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +2 -2
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +0 -7
- helm/benchmark/augmentations/perturbation.py +20 -7
- helm/benchmark/augmentations/perturbation_description.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +29 -0
- helm/benchmark/augmentations/synonym_perturbation.py +2 -2
- helm/benchmark/augmentations/test_perturbation.py +11 -7
- helm/benchmark/augmentations/translate_perturbation.py +30 -0
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/config_registry.py +38 -0
- helm/benchmark/executor.py +46 -16
- helm/benchmark/huggingface_registration.py +37 -7
- helm/benchmark/metrics/basic_metrics.py +172 -641
- helm/benchmark/metrics/bbq_metrics.py +3 -4
- helm/benchmark/metrics/bias_metrics.py +6 -6
- helm/benchmark/metrics/classification_metrics.py +11 -8
- helm/benchmark/metrics/cleva_accuracy_metrics.py +8 -5
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/code_metrics.py +4 -3
- helm/benchmark/metrics/code_metrics_helper.py +0 -2
- helm/benchmark/metrics/common_metric_specs.py +167 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +72 -0
- helm/benchmark/metrics/decodingtrust_ood_knowledge_metrics.py +66 -0
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +101 -0
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +202 -0
- helm/benchmark/metrics/disinformation_metrics.py +6 -112
- helm/benchmark/metrics/dry_run_metrics.py +5 -3
- helm/benchmark/metrics/efficiency_metrics.py +206 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +59 -0
- helm/benchmark/metrics/evaluate_reference_metrics.py +376 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +54 -0
- helm/benchmark/metrics/image_generation/aesthetics_scorer.py +66 -0
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +73 -0
- helm/benchmark/metrics/image_generation/denoised_runtime_metric.py +42 -0
- helm/benchmark/metrics/image_generation/detection_metrics.py +57 -0
- helm/benchmark/metrics/image_generation/detectors/base_detector.py +8 -0
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +178 -0
- helm/benchmark/metrics/image_generation/efficiency_metrics.py +41 -0
- helm/benchmark/metrics/image_generation/fidelity_metrics.py +168 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +63 -0
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +33 -0
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +50 -0
- helm/benchmark/metrics/image_generation/gender_metrics.py +58 -0
- helm/benchmark/metrics/image_generation/image_critique_metrics.py +284 -0
- helm/benchmark/metrics/image_generation/lpips_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/multi_scale_ssim_metrics.py +82 -0
- helm/benchmark/metrics/image_generation/nsfw_detector.py +96 -0
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +103 -0
- helm/benchmark/metrics/image_generation/nudity_metrics.py +38 -0
- helm/benchmark/metrics/image_generation/photorealism_critique_metrics.py +153 -0
- helm/benchmark/metrics/image_generation/psnr_metrics.py +78 -0
- helm/benchmark/metrics/image_generation/q16/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/q16/q16_toxicity_detector.py +90 -0
- helm/benchmark/metrics/image_generation/q16/test_q16.py +18 -0
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +48 -0
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +164 -0
- helm/benchmark/metrics/image_generation/uiqi_metrics.py +92 -0
- helm/benchmark/metrics/image_generation/watermark/__init__.py +0 -0
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +16 -0
- helm/benchmark/metrics/image_generation/watermark/watermark_detector.py +87 -0
- helm/benchmark/metrics/image_generation/watermark_metrics.py +48 -0
- helm/benchmark/metrics/instruction_following_critique_metrics.py +3 -1
- helm/benchmark/metrics/language_modeling_metrics.py +99 -0
- helm/benchmark/metrics/machine_translation_metrics.py +5 -5
- helm/benchmark/metrics/metric.py +93 -172
- helm/benchmark/metrics/metric_name.py +0 -1
- helm/benchmark/metrics/metric_service.py +16 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +3 -4
- helm/benchmark/metrics/ranking_metrics.py +6 -7
- helm/benchmark/metrics/reference_metric.py +148 -0
- helm/benchmark/metrics/summac/model_summac.py +0 -2
- helm/benchmark/metrics/summarization_metrics.py +8 -8
- helm/benchmark/metrics/test_classification_metrics.py +9 -6
- helm/benchmark/metrics/test_disinformation_metrics.py +78 -0
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +30 -0
- helm/benchmark/metrics/test_metric.py +2 -2
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +13 -3
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -0
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +9 -2
- helm/benchmark/metrics/toxicity_metrics.py +1 -1
- helm/benchmark/metrics/toxicity_utils.py +23 -0
- helm/benchmark/metrics/unitxt_metrics.py +81 -0
- helm/benchmark/metrics/vision_language/__init__.py +0 -0
- helm/benchmark/metrics/vision_language/emd_utils.py +341 -0
- helm/benchmark/metrics/vision_language/image_metrics.py +450 -0
- helm/benchmark/metrics/vision_language/image_utils.py +100 -0
- helm/benchmark/model_deployment_registry.py +164 -41
- helm/benchmark/model_metadata_registry.py +181 -35
- helm/benchmark/multi_gpu_runner.py +133 -0
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +8 -7
- helm/benchmark/presentation/run_display.py +50 -17
- helm/benchmark/presentation/schema.py +28 -46
- helm/benchmark/presentation/summarize.py +213 -96
- helm/benchmark/presentation/table.py +8 -8
- helm/benchmark/presentation/test_contamination.py +2 -2
- helm/benchmark/presentation/test_run_entry.py +14 -9
- helm/benchmark/presentation/test_summarize.py +5 -0
- helm/benchmark/run.py +66 -54
- helm/benchmark/run_expander.py +342 -31
- helm/benchmark/run_spec.py +93 -0
- helm/benchmark/run_spec_factory.py +162 -0
- helm/benchmark/run_specs/__init__.py +0 -0
- helm/benchmark/{run_specs.py → run_specs/classic_run_specs.py} +217 -1330
- helm/benchmark/run_specs/cleva_run_specs.py +277 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +314 -0
- helm/benchmark/run_specs/heim_run_specs.py +623 -0
- helm/benchmark/run_specs/instruction_following_run_specs.py +129 -0
- helm/benchmark/run_specs/lite_run_specs.py +307 -0
- helm/benchmark/run_specs/simple_run_specs.py +104 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +42 -0
- helm/benchmark/run_specs/vlm_run_specs.py +501 -0
- helm/benchmark/runner.py +116 -69
- helm/benchmark/runner_config_registry.py +21 -0
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +2 -2
- helm/benchmark/scenarios/cleva_scenario.py +43 -46
- helm/benchmark/scenarios/code_scenario.py +3 -2
- helm/benchmark/scenarios/commonsense_scenario.py +171 -191
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +169 -0
- helm/benchmark/scenarios/decodingtrust_adv_robustness_scenario.py +121 -0
- helm/benchmark/scenarios/decodingtrust_fairness_scenario.py +77 -0
- helm/benchmark/scenarios/decodingtrust_machine_ethics_scenario.py +324 -0
- helm/benchmark/scenarios/decodingtrust_ood_robustness_scenario.py +204 -0
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +559 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +67 -0
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +78 -0
- helm/benchmark/scenarios/dialogue_scenarios.py +0 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +1 -1
- helm/benchmark/scenarios/image_generation/__init__.py +0 -0
- helm/benchmark/scenarios/image_generation/common_syntactic_processes_scenario.py +105 -0
- helm/benchmark/scenarios/image_generation/cub200_scenario.py +95 -0
- helm/benchmark/scenarios/image_generation/daily_dalle_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/demographic_stereotypes_scenario.py +82 -0
- helm/benchmark/scenarios/image_generation/detection_scenario.py +83 -0
- helm/benchmark/scenarios/image_generation/draw_bench_scenario.py +74 -0
- helm/benchmark/scenarios/image_generation/i2p_scenario.py +57 -0
- helm/benchmark/scenarios/image_generation/landing_page_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/logos_scenario.py +223 -0
- helm/benchmark/scenarios/image_generation/magazine_cover_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/mental_disorders_scenario.py +46 -0
- helm/benchmark/scenarios/image_generation/mscoco_scenario.py +91 -0
- helm/benchmark/scenarios/image_generation/paint_skills_scenario.py +72 -0
- helm/benchmark/scenarios/image_generation/parti_prompts_scenario.py +94 -0
- helm/benchmark/scenarios/image_generation/radiology_scenario.py +42 -0
- helm/benchmark/scenarios/image_generation/relational_understanding_scenario.py +52 -0
- helm/benchmark/scenarios/image_generation/time_most_significant_historical_figures_scenario.py +124 -0
- helm/benchmark/scenarios/image_generation/winoground_scenario.py +62 -0
- helm/benchmark/scenarios/imdb_scenario.py +0 -1
- helm/benchmark/scenarios/legalbench_scenario.py +123 -0
- helm/benchmark/scenarios/live_qa_scenario.py +94 -0
- helm/benchmark/scenarios/lm_entry_scenario.py +185 -0
- helm/benchmark/scenarios/lsat_qa_scenario.py +4 -2
- helm/benchmark/scenarios/math_scenario.py +19 -2
- helm/benchmark/scenarios/medication_qa_scenario.py +60 -0
- helm/benchmark/scenarios/numeracy_scenario.py +3 -3
- helm/benchmark/scenarios/opinions_qa_scenario.py +6 -10
- helm/benchmark/scenarios/raft_scenario.py +2 -6
- helm/benchmark/scenarios/scenario.py +14 -2
- helm/benchmark/scenarios/simple_scenarios.py +122 -1
- helm/benchmark/scenarios/test_math_scenario.py +22 -0
- helm/benchmark/scenarios/test_scenario.py +6 -3
- helm/benchmark/scenarios/test_simple_scenarios.py +50 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +135 -0
- helm/benchmark/scenarios/the_pile_scenario.py +6 -7
- helm/benchmark/scenarios/unitxt_scenario.py +56 -0
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +3 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +103 -0
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +92 -0
- helm/benchmark/scenarios/vision_language/heim_human_eval_scenario.py +113 -0
- helm/benchmark/scenarios/vision_language/image2structure/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/chart2csv_scenario.py +55 -0
- helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +214 -0
- helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +25 -0
- helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +20 -0
- helm/benchmark/scenarios/vision_language/image2structure/utils_latex.py +347 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/driver.py +84 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/jekyll_server.py +182 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage/utils.py +31 -0
- helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +225 -0
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +124 -0
- helm/benchmark/scenarios/vision_language/mme_scenario.py +145 -0
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +187 -0
- helm/benchmark/scenarios/vision_language/multipanelvqa_scenario.py +169 -0
- helm/benchmark/scenarios/vision_language/pope_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +129 -0
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +108 -0
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +107 -0
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +1 -1
- helm/benchmark/scenarios/wmt_14_scenario.py +18 -18
- helm/benchmark/server.py +59 -2
- helm/benchmark/slurm_jobs.py +12 -0
- helm/benchmark/slurm_runner.py +79 -51
- helm/benchmark/static/benchmarking.js +3 -4
- helm/benchmark/static/contamination.yaml +1 -1
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/json-urls.js +4 -0
- helm/benchmark/static/{schema.yaml → schema_classic.yaml} +346 -930
- helm/benchmark/static/schema_instruction_following.yaml +210 -0
- helm/benchmark/static/schema_lite.yaml +824 -0
- helm/benchmark/static/schema_mmlu.yaml +1507 -0
- helm/benchmark/static/schema_unitxt.yaml +428 -0
- helm/benchmark/static/schema_vlm.yaml +576 -0
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/crfm-logo-74391ab8.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/heim-logo-3e5e3aa4.png +0 -0
- helm/benchmark/static_build/assets/helm-logo-simple-2ed5400b.png +0 -0
- helm/benchmark/static_build/assets/helmhero-28e90f4d.png +0 -0
- helm/benchmark/static_build/assets/index-5088afcb.css +1 -0
- helm/benchmark/static_build/assets/index-d839df55.js +9 -0
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/react-d4a0b69b.js +85 -0
- helm/benchmark/static_build/assets/recharts-6d337683.js +97 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tremor-54a99cc4.js +10 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/benchmark/static_build/config.js +4 -0
- helm/benchmark/static_build/index.html +20 -0
- helm/benchmark/test_data_preprocessor.py +3 -3
- helm/benchmark/test_model_deployment_definition.py +90 -0
- helm/benchmark/test_run_expander.py +1 -1
- helm/benchmark/tokenizer_config_registry.py +10 -14
- helm/benchmark/window_services/ai21_window_service.py +22 -33
- helm/benchmark/window_services/cohere_window_service.py +1 -63
- helm/benchmark/window_services/default_window_service.py +2 -35
- helm/benchmark/window_services/encoder_decoder_window_service.py +0 -11
- helm/benchmark/window_services/ice_window_service.py +0 -34
- helm/benchmark/window_services/image_generation/__init__.py +0 -0
- helm/benchmark/window_services/image_generation/clip_window_service.py +15 -0
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +9 -0
- helm/benchmark/window_services/image_generation/test_clip_window_service.py +29 -0
- helm/benchmark/window_services/image_generation/test_openai_dalle_window_service.py +30 -0
- helm/benchmark/window_services/local_window_service.py +21 -4
- helm/benchmark/window_services/no_decoding_window_service.py +32 -0
- helm/benchmark/window_services/test_anthropic_window_service.py +2 -1
- helm/benchmark/window_services/test_bloom_window_service.py +2 -1
- helm/benchmark/window_services/test_cohere_window_service.py +2 -1
- helm/benchmark/window_services/test_flan_t5_window_service.py +2 -1
- helm/benchmark/window_services/test_gpt2_window_service.py +2 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +2 -1
- helm/benchmark/window_services/test_gptj_window_service.py +3 -2
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -2
- helm/benchmark/window_services/test_ice_window_service.py +2 -1
- helm/benchmark/window_services/test_openai_window_service.py +2 -1
- helm/benchmark/window_services/test_opt_window_service.py +3 -2
- helm/benchmark/window_services/test_palmyra_window_service.py +2 -1
- helm/benchmark/window_services/test_t0pp_window_service.py +2 -1
- helm/benchmark/window_services/test_t511b_window_service.py +2 -1
- helm/benchmark/window_services/test_ul2_window_service.py +2 -1
- helm/benchmark/window_services/test_utils.py +3 -2
- helm/benchmark/window_services/test_yalm_window_service.py +2 -1
- helm/benchmark/window_services/window_service.py +42 -0
- helm/benchmark/window_services/window_service_factory.py +24 -269
- helm/benchmark/window_services/yalm_window_service.py +0 -27
- helm/clients/__init__.py +0 -0
- helm/{proxy/clients → clients}/ai21_client.py +5 -12
- helm/clients/aleph_alpha_client.py +112 -0
- helm/{proxy/clients → clients}/anthropic_client.py +213 -24
- helm/clients/auto_client.py +215 -0
- helm/clients/bedrock_client.py +128 -0
- helm/clients/bedrock_utils.py +72 -0
- helm/{proxy/clients → clients}/client.py +67 -55
- helm/clients/clip_score_client.py +49 -0
- helm/clients/clip_scorers/__init__.py +0 -0
- helm/clients/clip_scorers/base_clip_scorer.py +18 -0
- helm/clients/clip_scorers/clip_scorer.py +50 -0
- helm/clients/clip_scorers/multilingual_clip_scorer.py +50 -0
- helm/{proxy/clients → clients}/cohere_client.py +6 -17
- helm/clients/gcs_client.py +82 -0
- helm/{proxy/clients → clients}/google_client.py +7 -8
- helm/clients/google_translate_client.py +35 -0
- helm/{proxy/clients → clients}/http_model_client.py +6 -10
- helm/{proxy/clients → clients}/huggingface_client.py +134 -92
- helm/clients/image_generation/__init__.py +0 -0
- helm/clients/image_generation/adobe_vision_client.py +78 -0
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +98 -0
- helm/clients/image_generation/cogview2/__init__.py +0 -0
- helm/clients/image_generation/cogview2/coglm_strategy.py +96 -0
- helm/clients/image_generation/cogview2/coglm_utils.py +82 -0
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +15 -0
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +96 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_model.py +254 -0
- helm/clients/image_generation/cogview2/sr_pipeline/dsr_sampling.py +190 -0
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +141 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_model.py +269 -0
- helm/clients/image_generation/cogview2/sr_pipeline/itersr_sampling.py +120 -0
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +42 -0
- helm/clients/image_generation/cogview2_client.py +191 -0
- helm/clients/image_generation/dalle2_client.py +192 -0
- helm/clients/image_generation/dalle3_client.py +108 -0
- helm/clients/image_generation/dalle_mini/__init__.py +3 -0
- helm/clients/image_generation/dalle_mini/data.py +442 -0
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +175 -0
- helm/clients/image_generation/dalle_mini/model/modeling.py +1834 -0
- helm/clients/image_generation/dalle_mini/model/partitions.py +84 -0
- helm/clients/image_generation/dalle_mini/model/processor.py +63 -0
- helm/clients/image_generation/dalle_mini/model/text.py +251 -0
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +9 -0
- helm/clients/image_generation/dalle_mini/model/utils.py +29 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/configuration_vqgan.py +40 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +107 -0
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +610 -0
- helm/clients/image_generation/dalle_mini_client.py +190 -0
- helm/clients/image_generation/deep_floyd_client.py +78 -0
- helm/clients/image_generation/huggingface_diffusers_client.py +249 -0
- helm/clients/image_generation/image_generation_client_utils.py +9 -0
- helm/clients/image_generation/lexica_client.py +86 -0
- helm/clients/image_generation/mindalle/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/__init__.py +216 -0
- helm/clients/image_generation/mindalle/models/stage1/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage1/layers.py +312 -0
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +103 -0
- helm/clients/image_generation/mindalle/models/stage2/__init__.py +0 -0
- helm/clients/image_generation/mindalle/models/stage2/layers.py +144 -0
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +268 -0
- helm/clients/image_generation/mindalle/models/tokenizer.py +30 -0
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -0
- helm/clients/image_generation/mindalle/utils/config.py +129 -0
- helm/clients/image_generation/mindalle/utils/sampling.py +149 -0
- helm/clients/image_generation/mindalle/utils/utils.py +89 -0
- helm/clients/image_generation/mindalle_client.py +115 -0
- helm/clients/image_generation/nudity_check_client.py +64 -0
- helm/clients/image_generation/together_image_generation_client.py +111 -0
- helm/{proxy/clients → clients}/lit_gpt_client.py +7 -5
- helm/{proxy/clients → clients}/megatron_client.py +13 -7
- helm/clients/mistral_client.py +134 -0
- helm/clients/moderation_api_client.py +109 -0
- helm/clients/open_lm_client.py +43 -0
- helm/clients/openai_client.py +302 -0
- helm/{proxy/clients → clients}/palmyra_client.py +15 -12
- helm/{proxy/clients → clients}/perspective_api_client.py +7 -8
- helm/clients/simple_client.py +64 -0
- helm/{proxy/clients → clients}/test_auto_client.py +15 -15
- helm/clients/test_client.py +100 -0
- helm/clients/test_huggingface_client.py +70 -0
- helm/clients/test_simple_client.py +19 -0
- helm/{proxy/clients → clients}/test_together_client.py +23 -12
- helm/{proxy/clients → clients}/together_client.py +18 -71
- helm/clients/vertexai_client.py +391 -0
- helm/clients/vision_language/__init__.py +0 -0
- helm/clients/vision_language/huggingface_vlm_client.py +104 -0
- helm/{proxy/clients → clients}/vision_language/idefics_client.py +59 -52
- helm/clients/vision_language/open_flamingo/__init__.py +2 -0
- helm/clients/vision_language/open_flamingo/src/__init__.py +0 -0
- helm/clients/vision_language/open_flamingo/src/factory.py +147 -0
- helm/clients/vision_language/open_flamingo/src/flamingo.py +337 -0
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +155 -0
- helm/clients/vision_language/open_flamingo/src/helpers.py +267 -0
- helm/clients/vision_language/open_flamingo/src/utils.py +47 -0
- helm/clients/vision_language/open_flamingo_client.py +155 -0
- helm/clients/vision_language/qwen_vlm_client.py +171 -0
- helm/clients/vllm_client.py +46 -0
- helm/common/cache.py +24 -179
- helm/common/cache_backend_config.py +47 -0
- helm/common/clip_score_request.py +41 -0
- helm/common/concurrency.py +32 -0
- helm/common/credentials_utils.py +28 -0
- helm/common/file_caches/__init__.py +0 -0
- helm/common/file_caches/file_cache.py +16 -0
- helm/common/file_caches/local_file_cache.py +61 -0
- helm/common/file_caches/test_local_file_cache.py +25 -0
- helm/common/file_upload_request.py +27 -0
- helm/common/general.py +29 -10
- helm/common/image_generation_parameters.py +25 -0
- helm/common/images_utils.py +24 -1
- helm/common/key_value_store.py +113 -0
- helm/common/media_object.py +13 -0
- helm/common/moderations_api_request.py +71 -0
- helm/common/mongo_key_value_store.py +88 -0
- helm/common/multimodal_request_utils.py +31 -0
- helm/common/nudity_check_request.py +29 -0
- helm/common/object_spec.py +2 -2
- helm/common/request.py +36 -27
- helm/common/test_general.py +6 -0
- helm/common/tokenization_request.py +6 -3
- helm/config/__init__.py +0 -0
- helm/config/model_deployments.yaml +1942 -0
- helm/config/model_metadata.yaml +2201 -0
- helm/config/tokenizer_configs.yaml +362 -0
- helm/proxy/accounts.py +31 -4
- helm/proxy/critique/mechanical_turk_critique_importer.py +3 -0
- helm/proxy/critique/model_critique_client.py +13 -5
- helm/proxy/example_queries.py +29 -17
- helm/proxy/retry.py +8 -2
- helm/proxy/server.py +77 -5
- helm/proxy/services/remote_service.py +31 -0
- helm/proxy/services/server_service.py +103 -20
- helm/proxy/services/service.py +34 -2
- helm/proxy/services/test_remote_service.py +7 -6
- helm/proxy/services/test_service.py +27 -18
- helm/proxy/test_accounts.py +32 -0
- helm/proxy/token_counters/auto_token_counter.py +37 -37
- helm/proxy/token_counters/test_auto_token_counter.py +164 -0
- helm/proxy/token_counters/token_counter.py +3 -5
- helm/py.typed +0 -0
- helm/tokenizers/__init__.py +0 -0
- helm/{proxy/tokenizers → tokenizers}/ai21_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/aleph_alpha_tokenizer.py +3 -1
- helm/{proxy/tokenizers → tokenizers}/anthropic_tokenizer.py +17 -11
- helm/tokenizers/auto_tokenizer.py +93 -0
- helm/{proxy/tokenizers → tokenizers}/caching_tokenizer.py +8 -2
- helm/{proxy/tokenizers → tokenizers}/cohere_tokenizer.py +1 -1
- helm/{proxy/tokenizers → tokenizers}/http_model_tokenizer.py +3 -3
- helm/{proxy/tokenizers → tokenizers}/huggingface_tokenizer.py +56 -60
- helm/tokenizers/simple_tokenizer.py +33 -0
- helm/tokenizers/test_anthropic_tokenizer.py +82 -0
- helm/tokenizers/test_huggingface_tokenizer.py +136 -0
- helm/tokenizers/test_simple_tokenizer.py +33 -0
- helm/tokenizers/vertexai_tokenizer.py +97 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer.py +5 -3
- helm/tokenizers/yalm_tokenizer_data/__init__.py +0 -0
- helm/tokenizers/yalm_tokenizer_data/voc_100b.sp +0 -0
- helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/yalm_tokenizer.py +1 -1
- crfm_helm-0.3.0.dist-info/RECORD +0 -396
- helm/benchmark/vlm_run_specs.py +0 -71
- helm/benchmark/window_services/anthropic_window_service.py +0 -68
- helm/benchmark/window_services/bloom_window_service.py +0 -35
- helm/benchmark/window_services/flan_t5_window_service.py +0 -29
- helm/benchmark/window_services/gpt2_window_service.py +0 -32
- helm/benchmark/window_services/gptj_window_service.py +0 -38
- helm/benchmark/window_services/gptneox_window_service.py +0 -41
- helm/benchmark/window_services/http_model_window_service.py +0 -28
- helm/benchmark/window_services/huggingface_window_service.py +0 -59
- helm/benchmark/window_services/lit_gpt_window_service.py +0 -27
- helm/benchmark/window_services/llama_window_service.py +0 -28
- helm/benchmark/window_services/luminous_window_service.py +0 -67
- helm/benchmark/window_services/megatron_window_service.py +0 -10
- helm/benchmark/window_services/mt_nlg_window_service.py +0 -27
- helm/benchmark/window_services/openai_window_service.py +0 -13
- helm/benchmark/window_services/opt_window_service.py +0 -35
- helm/benchmark/window_services/palmyra_window_service.py +0 -45
- helm/benchmark/window_services/remote_window_service.py +0 -48
- helm/benchmark/window_services/santacoder_window_service.py +0 -27
- helm/benchmark/window_services/starcoder_window_service.py +0 -27
- helm/benchmark/window_services/t0pp_window_service.py +0 -35
- helm/benchmark/window_services/t511b_window_service.py +0 -30
- helm/benchmark/window_services/test_mt_nlg_window_service.py +0 -48
- helm/benchmark/window_services/ul2_window_service.py +0 -30
- helm/benchmark/window_services/wider_ai21_window_service.py +0 -24
- helm/benchmark/window_services/wider_openai_window_service.py +0 -52
- helm/proxy/clients/aleph_alpha_client.py +0 -99
- helm/proxy/clients/auto_client.py +0 -461
- helm/proxy/clients/goose_ai_client.py +0 -100
- helm/proxy/clients/microsoft_client.py +0 -182
- helm/proxy/clients/openai_client.py +0 -206
- helm/proxy/clients/remote_model_registry.py +0 -28
- helm/proxy/clients/simple_client.py +0 -61
- helm/proxy/clients/test_anthropic_client.py +0 -63
- helm/proxy/clients/test_client.py +0 -31
- helm/proxy/clients/test_huggingface_client.py +0 -87
- helm/proxy/models.py +0 -963
- helm/proxy/test_models.py +0 -27
- helm/proxy/token_counters/ai21_token_counter.py +0 -20
- helm/proxy/token_counters/cohere_token_counter.py +0 -13
- helm/proxy/token_counters/free_token_counter.py +0 -12
- helm/proxy/token_counters/gooseai_token_counter.py +0 -24
- helm/proxy/token_counters/openai_token_counter.py +0 -22
- helm/proxy/token_counters/test_ai21_token_counter.py +0 -86
- helm/proxy/token_counters/test_openai_token_counter.py +0 -79
- helm/proxy/tokenizers/simple_tokenizer.py +0 -32
- helm/proxy/tokenizers/test_huggingface_tokenizer.py +0 -56
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/LICENSE +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.3.0.dist-info → crfm_helm-0.5.0.dist-info}/top_level.txt +0 -0
- /helm/{proxy/clients → benchmark/annotation}/__init__.py +0 -0
- /helm/{proxy/clients/vision_language → benchmark/annotation/image2structure}/__init__.py +0 -0
- /helm/{proxy/tokenizers → benchmark/metrics/image_generation}/__init__.py +0 -0
- /helm/{proxy/tokenizers/yalm_tokenizer_data → benchmark/metrics/image_generation/detectors}/__init__.py +0 -0
- /helm/{proxy/clients → clients}/ai21_utils.py +0 -0
- /helm/{proxy/clients → clients}/cohere_utils.py +0 -0
- /helm/{proxy/clients → clients}/lit_gpt_generate.py +0 -0
- /helm/{proxy/clients → clients}/toxicity_classifier_client.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/lit_gpt_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_ice_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/test_yalm_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tiktoken_tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/tokenizer.py +0 -0
- /helm/{proxy/tokenizers → tokenizers}/yalm_tokenizer_data/test_yalm_tokenizer.py +0 -0
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
# mypy: check_untyped_defs = False
|
|
2
2
|
import threading
|
|
3
3
|
from dataclasses import asdict
|
|
4
|
-
from typing import List, Dict, Optional
|
|
4
|
+
from typing import Any, List, Dict, Optional
|
|
5
5
|
|
|
6
6
|
from dacite import from_dict
|
|
7
7
|
from googleapiclient import discovery
|
|
8
8
|
from googleapiclient.errors import BatchError, HttpError
|
|
9
9
|
from googleapiclient.http import BatchHttpRequest
|
|
10
10
|
from httplib2 import HttpLib2Error
|
|
11
|
-
from helm.
|
|
11
|
+
from helm.clients.toxicity_classifier_client import ToxicityClassifierClient
|
|
12
12
|
from helm.proxy.retry import NonRetriableException
|
|
13
13
|
|
|
14
14
|
from helm.common.cache import Cache, CacheConfig
|
|
@@ -91,14 +91,9 @@ class PerspectiveAPIClient(ToxicityClassifierClient):
|
|
|
91
91
|
Batch several requests into a single API request and get the toxicity attributes and scores.
|
|
92
92
|
For more information, see https://googleapis.github.io/google-api-python-client/docs/batch.html.
|
|
93
93
|
"""
|
|
94
|
-
|
|
95
|
-
with self._client_lock:
|
|
96
|
-
if not self._client:
|
|
97
|
-
self._client = self._create_client()
|
|
98
|
-
|
|
99
94
|
try:
|
|
100
95
|
|
|
101
|
-
def do_it():
|
|
96
|
+
def do_it() -> Dict[str, Any]:
|
|
102
97
|
text_to_response: Dict[str, Dict] = dict()
|
|
103
98
|
|
|
104
99
|
def callback(request_id: str, response: Dict, error: HttpError):
|
|
@@ -106,6 +101,10 @@ class PerspectiveAPIClient(ToxicityClassifierClient):
|
|
|
106
101
|
raise error
|
|
107
102
|
text_to_response[request_id] = response
|
|
108
103
|
|
|
104
|
+
with self._client_lock:
|
|
105
|
+
if not self._client:
|
|
106
|
+
self._client = self._create_client()
|
|
107
|
+
|
|
109
108
|
# Create a batch request. We will add a request to the batch request for each text string
|
|
110
109
|
batch_request: BatchHttpRequest = self._client.new_batch_http_request()
|
|
111
110
|
|
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
from typing import List, TypedDict
|
|
3
|
+
from typing import Dict, Any
|
|
4
|
+
|
|
5
|
+
from helm.common.cache import CacheConfig
|
|
6
|
+
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
7
|
+
from helm.clients.client import CachingClient
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class SimpleClientRequest(TypedDict):
|
|
11
|
+
engine: str
|
|
12
|
+
prompt: str
|
|
13
|
+
num_completions: int
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class SimpleClient(CachingClient):
|
|
17
|
+
"""Simple client for tutorials and for debugging."""
|
|
18
|
+
|
|
19
|
+
def __init__(self, cache_config: CacheConfig):
|
|
20
|
+
super().__init__(cache_config=cache_config)
|
|
21
|
+
|
|
22
|
+
def make_request(self, request: Request) -> RequestResult:
|
|
23
|
+
raw_request: SimpleClientRequest = {
|
|
24
|
+
"engine": request.model_engine,
|
|
25
|
+
"prompt": request.prompt,
|
|
26
|
+
"num_completions": request.num_completions,
|
|
27
|
+
}
|
|
28
|
+
|
|
29
|
+
def do_it() -> Dict[str, Any]:
|
|
30
|
+
return self.invoke_model(raw_request)
|
|
31
|
+
|
|
32
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
33
|
+
response, cached = self.cache.get(cache_key, wrap_request_time(do_it))
|
|
34
|
+
logprob = 0
|
|
35
|
+
completions = [
|
|
36
|
+
GeneratedOutput(
|
|
37
|
+
text=text,
|
|
38
|
+
logprob=logprob,
|
|
39
|
+
tokens=[Token(text=text, logprob=logprob)],
|
|
40
|
+
)
|
|
41
|
+
for text in response["completions"]
|
|
42
|
+
]
|
|
43
|
+
|
|
44
|
+
return RequestResult(
|
|
45
|
+
success=True,
|
|
46
|
+
cached=cached,
|
|
47
|
+
request_time=response["request_time"],
|
|
48
|
+
request_datetime=response.get("request_datetime"),
|
|
49
|
+
completions=completions,
|
|
50
|
+
embedding=[],
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
def invoke_model(self, raw_request: SimpleClientRequest) -> Dict[str, Any]:
|
|
54
|
+
"""
|
|
55
|
+
Example:
|
|
56
|
+
Prompt: 7 2 4 6
|
|
57
|
+
Completions (num_completions = 3):
|
|
58
|
+
- 6
|
|
59
|
+
- 4
|
|
60
|
+
- 2
|
|
61
|
+
"""
|
|
62
|
+
prompt_words: List[str] = raw_request["prompt"].split()
|
|
63
|
+
completions = list(itertools.islice(itertools.cycle(reversed(prompt_words)), raw_request["num_completions"]))
|
|
64
|
+
return {"completions": completions}
|
|
@@ -1,12 +1,13 @@
|
|
|
1
1
|
import dataclasses
|
|
2
2
|
from tempfile import TemporaryDirectory
|
|
3
|
-
from helm.common.
|
|
3
|
+
from helm.common.cache_backend_config import BlackHoleCacheBackendConfig
|
|
4
|
+
from helm.common.request import GeneratedOutput, Token
|
|
4
5
|
|
|
5
6
|
import pytest
|
|
6
7
|
|
|
7
8
|
from helm.common.request import Request, RequestResult
|
|
8
9
|
from helm.common.general import get_credentials
|
|
9
|
-
from helm.
|
|
10
|
+
from helm.clients.auto_client import AutoClient
|
|
10
11
|
|
|
11
12
|
|
|
12
13
|
@pytest.mark.models
|
|
@@ -15,8 +16,8 @@ class TestAutoClient:
|
|
|
15
16
|
credentials = get_credentials()
|
|
16
17
|
if not credentials:
|
|
17
18
|
pytest.skip("Skipping test because no credentials found")
|
|
18
|
-
with TemporaryDirectory() as
|
|
19
|
-
auto_client = AutoClient(credentials,
|
|
19
|
+
with TemporaryDirectory() as temp_dir_path:
|
|
20
|
+
auto_client = AutoClient(credentials, temp_dir_path, BlackHoleCacheBackendConfig())
|
|
20
21
|
actual_result = auto_client.make_request(request)
|
|
21
22
|
assert actual_result.request_time or actual_result.batch_request_time
|
|
22
23
|
actual_result = dataclasses.replace(
|
|
@@ -27,6 +28,7 @@ class TestAutoClient:
|
|
|
27
28
|
def test_make_request_databricks(self):
|
|
28
29
|
request = Request(
|
|
29
30
|
model="databricks/dolly-v2-3b",
|
|
31
|
+
model_deployment="together/dolly-v2-3b",
|
|
30
32
|
prompt="Elephants are one of the most",
|
|
31
33
|
temperature=0.0,
|
|
32
34
|
max_tokens=10,
|
|
@@ -35,32 +37,29 @@ class TestAutoClient:
|
|
|
35
37
|
success=True,
|
|
36
38
|
embedding=[],
|
|
37
39
|
completions=[
|
|
38
|
-
|
|
40
|
+
GeneratedOutput(
|
|
39
41
|
text=" intelligent species on the planet. They are also one",
|
|
40
42
|
logprob=-9.087313510477543,
|
|
41
43
|
tokens=[
|
|
42
44
|
Token(
|
|
43
45
|
text="Ġintelligent",
|
|
44
46
|
logprob=-1.9816237688064575,
|
|
45
|
-
top_logprobs={"Ġintelligent": -1.9816237688064575},
|
|
46
47
|
),
|
|
47
48
|
Token(
|
|
48
49
|
text="Ġspecies",
|
|
49
50
|
logprob=-1.2881066799163818,
|
|
50
|
-
top_logprobs={"Ġspecies": -1.2881066799163818},
|
|
51
51
|
),
|
|
52
|
-
Token(text="Ġon", logprob=-0.16092979907989502
|
|
53
|
-
Token(text="Ġthe", logprob=-0.23620447516441345
|
|
52
|
+
Token(text="Ġon", logprob=-0.16092979907989502),
|
|
53
|
+
Token(text="Ġthe", logprob=-0.23620447516441345),
|
|
54
54
|
Token(
|
|
55
55
|
text="Ġplanet",
|
|
56
56
|
logprob=-0.015416033565998077,
|
|
57
|
-
top_logprobs={"Ġplanet": -0.015416033565998077},
|
|
58
57
|
),
|
|
59
|
-
Token(text=".", logprob=-0.6683081388473511
|
|
60
|
-
Token(text="ĠThey", logprob=-1.9231040477752686
|
|
61
|
-
Token(text="Ġare", logprob=-0.9322243332862854
|
|
62
|
-
Token(text="Ġalso", logprob=-0.7750787138938904
|
|
63
|
-
Token(text="Ġone", logprob=-1.1063175201416016
|
|
58
|
+
Token(text=".", logprob=-0.6683081388473511),
|
|
59
|
+
Token(text="ĠThey", logprob=-1.9231040477752686),
|
|
60
|
+
Token(text="Ġare", logprob=-0.9322243332862854),
|
|
61
|
+
Token(text="Ġalso", logprob=-0.7750787138938904),
|
|
62
|
+
Token(text="Ġone", logprob=-1.1063175201416016),
|
|
64
63
|
],
|
|
65
64
|
finish_reason={"reason": "length"},
|
|
66
65
|
)
|
|
@@ -69,6 +68,7 @@ class TestAutoClient:
|
|
|
69
68
|
)
|
|
70
69
|
request = Request(
|
|
71
70
|
model="databricks/dolly-v2-3b",
|
|
71
|
+
model_deployment="together/dolly-v2-3b",
|
|
72
72
|
prompt="Elephants are one of the most",
|
|
73
73
|
temperature=0.0,
|
|
74
74
|
max_tokens=10,
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from helm.common.cache import BlackHoleCacheConfig
|
|
2
|
+
from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
|
|
3
|
+
from .client import truncate_sequence, truncate_and_tokenize_response_text
|
|
4
|
+
from typing import List
|
|
5
|
+
from helm.common.request import Request, GeneratedOutput, Token
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def truncate_sequence_helper(tokens: List[str], request: Request, expected_tokens: List[str]):
|
|
9
|
+
sequence = GeneratedOutput(
|
|
10
|
+
text="".join(tokens),
|
|
11
|
+
tokens=[Token(text=text, logprob=-1) for text in tokens],
|
|
12
|
+
logprob=-len(tokens),
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
output_sequence = truncate_sequence(sequence, request)
|
|
16
|
+
|
|
17
|
+
assert expected_tokens == [token.text for token in output_sequence.tokens]
|
|
18
|
+
assert "".join(expected_tokens) == output_sequence.text
|
|
19
|
+
assert output_sequence.logprob == sum(token.logprob for token in output_sequence.tokens)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def test_truncate_sequence():
|
|
23
|
+
# echo_prompt = True, nothing gets truncated
|
|
24
|
+
truncate_sequence_helper(
|
|
25
|
+
["a", "b", "c"],
|
|
26
|
+
Request(
|
|
27
|
+
model="openai/text-davinci-002", model_deployment="openai/text-davinci-002", prompt="abc", echo_prompt=True
|
|
28
|
+
),
|
|
29
|
+
["a", "b", "c"],
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
# Nothing gets truncated
|
|
33
|
+
truncate_sequence_helper(
|
|
34
|
+
["hello", " world"],
|
|
35
|
+
Request(model="openai/text-davinci-002", model_deployment="openai/text-davinci-002", stop_sequences=["#"]),
|
|
36
|
+
["hello", " world"],
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
# Truncate using stop sequences
|
|
40
|
+
truncate_sequence_helper(
|
|
41
|
+
["hello", " world", "\n", "what"],
|
|
42
|
+
Request(model="openai/text-davinci-002", model_deployment="openai/text-davinci-002", stop_sequences=["\n"]),
|
|
43
|
+
["hello", " world"],
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Truncate using max tokens
|
|
47
|
+
truncate_sequence_helper(
|
|
48
|
+
["a", "b", "c"],
|
|
49
|
+
Request(model="openai/text-davinci-002", model_deployment="openai/text-davinci-002", max_tokens=2),
|
|
50
|
+
["a", "b"],
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def test_truncate_and_tokenize_response_text():
|
|
55
|
+
tokenizer = HuggingFaceTokenizer(BlackHoleCacheConfig())
|
|
56
|
+
tokenizer_name = "huggingface/gpt2"
|
|
57
|
+
|
|
58
|
+
# No truncation
|
|
59
|
+
response = truncate_and_tokenize_response_text(
|
|
60
|
+
"I am a scientist. I am a scientist.", Request(max_tokens=100, stop_sequences=[]), tokenizer, tokenizer_name
|
|
61
|
+
)
|
|
62
|
+
assert response.finish_reason
|
|
63
|
+
assert response.finish_reason["reason"] == "endoftext"
|
|
64
|
+
assert response.text == "I am a scientist. I am a scientist."
|
|
65
|
+
assert response.tokens == [
|
|
66
|
+
Token("I", 0.0),
|
|
67
|
+
Token(" am", 0.0),
|
|
68
|
+
Token(" a", 0.0),
|
|
69
|
+
Token(" scientist", 0.0),
|
|
70
|
+
Token(".", 0.0),
|
|
71
|
+
Token(" I", 0.0),
|
|
72
|
+
Token(" am", 0.0),
|
|
73
|
+
Token(" a", 0.0),
|
|
74
|
+
Token(" scientist", 0.0),
|
|
75
|
+
Token(".", 0.0),
|
|
76
|
+
]
|
|
77
|
+
|
|
78
|
+
response = truncate_and_tokenize_response_text(
|
|
79
|
+
"I am a scientist. I am a scientist.", Request(max_tokens=7, stop_sequences=["."]), tokenizer, tokenizer_name
|
|
80
|
+
)
|
|
81
|
+
assert response.finish_reason
|
|
82
|
+
assert response.finish_reason["reason"] == "stop"
|
|
83
|
+
assert response.text == "I am a scientist"
|
|
84
|
+
assert response.tokens == [Token("I", 0.0), Token(" am", 0.0), Token(" a", 0.0), Token(" scientist", 0.0)]
|
|
85
|
+
|
|
86
|
+
response = truncate_and_tokenize_response_text(
|
|
87
|
+
"I am a scientist. I am a scientist.", Request(max_tokens=3, stop_sequences=[]), tokenizer, tokenizer_name
|
|
88
|
+
)
|
|
89
|
+
assert response.finish_reason
|
|
90
|
+
assert response.finish_reason["reason"] == "length"
|
|
91
|
+
assert response.text == "I am a"
|
|
92
|
+
assert response.tokens == [Token("I", 0.0), Token(" am", 0.0), Token(" a", 0.0)]
|
|
93
|
+
|
|
94
|
+
response = truncate_and_tokenize_response_text(
|
|
95
|
+
"I am a scientist. I am a scientist.", Request(max_tokens=3, stop_sequences=["."]), tokenizer, tokenizer_name
|
|
96
|
+
)
|
|
97
|
+
assert response.finish_reason
|
|
98
|
+
assert response.finish_reason["reason"] == "length"
|
|
99
|
+
assert response.text == "I am a"
|
|
100
|
+
assert response.tokens == [Token("I", 0.0), Token(" am", 0.0), Token(" a", 0.0)]
|
|
@@ -0,0 +1,70 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
|
|
3
|
+
from helm.common.cache import BlackHoleCacheConfig
|
|
4
|
+
from helm.common.request import Request, RequestResult
|
|
5
|
+
from helm.clients.huggingface_client import HuggingFaceClient
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class TestHuggingFaceClient:
|
|
9
|
+
def test_gpt2(self):
|
|
10
|
+
client = HuggingFaceClient(
|
|
11
|
+
cache_config=BlackHoleCacheConfig(), pretrained_model_name_or_path="openai-community/gpt2"
|
|
12
|
+
)
|
|
13
|
+
prompt: str = "I am a computer scientist."
|
|
14
|
+
result: RequestResult = client.make_request(
|
|
15
|
+
Request(
|
|
16
|
+
model="openai/gpt2",
|
|
17
|
+
model_deployment="huggingface/gpt2",
|
|
18
|
+
prompt=prompt,
|
|
19
|
+
num_completions=3,
|
|
20
|
+
top_k_per_token=5,
|
|
21
|
+
max_tokens=1,
|
|
22
|
+
echo_prompt=True,
|
|
23
|
+
)
|
|
24
|
+
)
|
|
25
|
+
assert len(result.completions) == 3
|
|
26
|
+
assert result.completions[0].text.startswith(
|
|
27
|
+
prompt
|
|
28
|
+
), "echo_prompt was set to true. Expected the prompt at the beginning of each completion"
|
|
29
|
+
|
|
30
|
+
@pytest.mark.skip(reason="GPT-J 6B is 22 GB and extremely slow without a GPU.")
|
|
31
|
+
def test_gptj_6b(self):
|
|
32
|
+
client = HuggingFaceClient(
|
|
33
|
+
cache_config=BlackHoleCacheConfig(), pretrained_model_name_or_path="openai-community/gpt2"
|
|
34
|
+
)
|
|
35
|
+
result: RequestResult = client.make_request(
|
|
36
|
+
Request(
|
|
37
|
+
model="eleutherai/gpt-j-6b",
|
|
38
|
+
model_deployment="huggingface/gpt-j-6b",
|
|
39
|
+
prompt="I am a computer scientist.",
|
|
40
|
+
num_completions=3,
|
|
41
|
+
top_k_per_token=5,
|
|
42
|
+
max_tokens=0,
|
|
43
|
+
)
|
|
44
|
+
)
|
|
45
|
+
assert len(result.completions) == 3
|
|
46
|
+
|
|
47
|
+
def test_logprob(self):
|
|
48
|
+
client = HuggingFaceClient(
|
|
49
|
+
cache_config=BlackHoleCacheConfig(), pretrained_model_name_or_path="openai-community/gpt2"
|
|
50
|
+
)
|
|
51
|
+
prompt: str = "I am a computer scientist."
|
|
52
|
+
result: RequestResult = client.make_request(
|
|
53
|
+
Request(
|
|
54
|
+
model="openai/gpt2",
|
|
55
|
+
model_deployment="huggingface/gpt2",
|
|
56
|
+
prompt=prompt,
|
|
57
|
+
num_completions=1,
|
|
58
|
+
max_tokens=0,
|
|
59
|
+
echo_prompt=True,
|
|
60
|
+
)
|
|
61
|
+
)
|
|
62
|
+
assert result.completions[0].text.startswith(
|
|
63
|
+
prompt
|
|
64
|
+
), "echo_prompt was set to true. Expected the prompt at the beginning of each completion"
|
|
65
|
+
total_logprob: float = 0
|
|
66
|
+
assert len(result.completions[0].tokens) == 6, "Expected 6 tokens in the completion"
|
|
67
|
+
for token in result.completions[0].tokens[1:]:
|
|
68
|
+
assert token.logprob != 0
|
|
69
|
+
total_logprob += token.logprob
|
|
70
|
+
assert result.completions[0].logprob == pytest.approx(total_logprob)
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
from helm.clients.simple_client import SimpleClient
|
|
2
|
+
from helm.common.cache import BlackHoleCacheConfig
|
|
3
|
+
from helm.common.request import GeneratedOutput, Request, Token
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def test_simple_client_make_request():
|
|
7
|
+
client = SimpleClient(BlackHoleCacheConfig())
|
|
8
|
+
request = Request(
|
|
9
|
+
model="simple/model1",
|
|
10
|
+
model_deployment="simple/model1",
|
|
11
|
+
prompt="Elephants are one of the most",
|
|
12
|
+
temperature=0.0,
|
|
13
|
+
max_tokens=10,
|
|
14
|
+
)
|
|
15
|
+
result = client.make_request(request)
|
|
16
|
+
assert result.success
|
|
17
|
+
assert not result.cached
|
|
18
|
+
assert result.embedding == []
|
|
19
|
+
assert result.completions == [GeneratedOutput(text="most", logprob=0, tokens=[Token(text="most", logprob=0)])]
|
|
@@ -4,7 +4,6 @@ import tempfile
|
|
|
4
4
|
|
|
5
5
|
from helm.common.cache import SqliteCacheConfig
|
|
6
6
|
from helm.common.request import Request
|
|
7
|
-
from helm.proxy.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
|
|
8
7
|
|
|
9
8
|
from .together_client import TogetherClient, TogetherClientError
|
|
10
9
|
|
|
@@ -13,25 +12,22 @@ class TestTogetherClient:
|
|
|
13
12
|
def setup_method(self, method):
|
|
14
13
|
cache_file = tempfile.NamedTemporaryFile(delete=False)
|
|
15
14
|
self.cache_path: str = cache_file.name
|
|
16
|
-
self.client = TogetherClient(
|
|
17
|
-
tokenizer=HuggingFaceTokenizer(SqliteCacheConfig(self.cache_path)),
|
|
18
|
-
cache_config=SqliteCacheConfig(self.cache_path),
|
|
19
|
-
)
|
|
20
15
|
|
|
21
16
|
def teardown_method(self, method):
|
|
22
17
|
os.remove(self.cache_path)
|
|
23
18
|
|
|
24
19
|
@pytest.mark.parametrize(
|
|
25
|
-
"test_input,expected",
|
|
20
|
+
"together_model,test_input,expected",
|
|
26
21
|
[
|
|
27
22
|
(
|
|
23
|
+
"togethercomputer/RedPajama-INCITE-Base-3B-v1",
|
|
28
24
|
Request(
|
|
29
25
|
model="together/redpajama-incite-base-3b-v1",
|
|
26
|
+
model_deployment="together/redpajama-incite-base-3b-v1",
|
|
30
27
|
),
|
|
31
28
|
{
|
|
32
29
|
"best_of": 1,
|
|
33
30
|
"echo": False,
|
|
34
|
-
"logprobs": 1,
|
|
35
31
|
"max_tokens": 100,
|
|
36
32
|
"model": "togethercomputer/RedPajama-INCITE-Base-3B-v1",
|
|
37
33
|
"n": 1,
|
|
@@ -43,8 +39,10 @@ class TestTogetherClient:
|
|
|
43
39
|
},
|
|
44
40
|
),
|
|
45
41
|
(
|
|
42
|
+
"huggyllama/llama-7b",
|
|
46
43
|
Request(
|
|
47
44
|
model="meta/llama-7b",
|
|
45
|
+
model_deployment="together/llama-7b",
|
|
48
46
|
prompt="I am a computer scientist.",
|
|
49
47
|
temperature=0,
|
|
50
48
|
num_completions=4,
|
|
@@ -57,7 +55,6 @@ class TestTogetherClient:
|
|
|
57
55
|
{
|
|
58
56
|
"best_of": 3,
|
|
59
57
|
"echo": True,
|
|
60
|
-
"logprobs": 3,
|
|
61
58
|
"max_tokens": 24,
|
|
62
59
|
"model": "huggyllama/llama-7b",
|
|
63
60
|
"n": 4,
|
|
@@ -69,14 +66,15 @@ class TestTogetherClient:
|
|
|
69
66
|
},
|
|
70
67
|
),
|
|
71
68
|
(
|
|
69
|
+
"togethercomputer/alpaca-7b",
|
|
72
70
|
Request(
|
|
73
71
|
model="stanford/alpaca-7b",
|
|
72
|
+
model_deployment="together/alpaca-7b",
|
|
74
73
|
stop_sequences=["\n"],
|
|
75
74
|
),
|
|
76
75
|
{
|
|
77
76
|
"best_of": 1,
|
|
78
77
|
"echo": False,
|
|
79
|
-
"logprobs": 1,
|
|
80
78
|
"max_tokens": 100,
|
|
81
79
|
"model": "togethercomputer/alpaca-7b",
|
|
82
80
|
"n": 1,
|
|
@@ -90,9 +88,22 @@ class TestTogetherClient:
|
|
|
90
88
|
# TODO(#1828): Add test for `SET_DETAILS_TO_TRUE` after Together supports it.
|
|
91
89
|
],
|
|
92
90
|
)
|
|
93
|
-
def test_convert_to_raw_request(self, test_input, expected):
|
|
94
|
-
|
|
91
|
+
def test_convert_to_raw_request(self, together_model, test_input, expected):
|
|
92
|
+
client = TogetherClient(
|
|
93
|
+
cache_config=SqliteCacheConfig(self.cache_path),
|
|
94
|
+
together_model=together_model,
|
|
95
|
+
)
|
|
96
|
+
assert expected == client.convert_to_raw_request(test_input)
|
|
95
97
|
|
|
96
98
|
def test_api_key_error(self):
|
|
99
|
+
client = TogetherClient(
|
|
100
|
+
cache_config=SqliteCacheConfig(self.cache_path),
|
|
101
|
+
together_model="togethercomputer/RedPajama-INCITE-Base-3B-v1",
|
|
102
|
+
)
|
|
97
103
|
with pytest.raises(TogetherClientError):
|
|
98
|
-
|
|
104
|
+
client.make_request(
|
|
105
|
+
Request(
|
|
106
|
+
model="together/redpajama-incite-base-3b-v1",
|
|
107
|
+
model_deployment="together/redpajama-incite-base-3b-v1",
|
|
108
|
+
)
|
|
109
|
+
)
|
|
@@ -5,63 +5,10 @@ import requests
|
|
|
5
5
|
from retrying import retry
|
|
6
6
|
|
|
7
7
|
from helm.common.cache import CacheConfig
|
|
8
|
-
from helm.common.request import wrap_request_time, Request, RequestResult,
|
|
9
|
-
from helm.proxy.tokenizers.tokenizer import Tokenizer
|
|
8
|
+
from helm.common.request import wrap_request_time, Request, RequestResult, GeneratedOutput, Token
|
|
10
9
|
from .client import CachingClient, truncate_sequence, cleanup_str
|
|
11
10
|
|
|
12
11
|
|
|
13
|
-
MODEL_ALIASES: Dict[str, str] = {
|
|
14
|
-
# Legacy models
|
|
15
|
-
"flan-t5-xxl": "flan-t5-xxl-hf",
|
|
16
|
-
"h3-2.7b": "h3-2.7b-h3",
|
|
17
|
-
"opt-1.3b": "opt-1.3b-ft-tp1",
|
|
18
|
-
"opt-6.7b": "opt-6.7b-ft-tp1",
|
|
19
|
-
# Production models
|
|
20
|
-
"redpajama-incite-base-3b-v1": "togethercomputer/RedPajama-INCITE-Base-3B-v1",
|
|
21
|
-
"redpajama-incite-instruct-3b-v1": "togethercomputer/RedPajama-INCITE-Instruct-3B-v1",
|
|
22
|
-
"redpajama-incite-base-7b": "togethercomputer/RedPajama-INCITE-7B-Base",
|
|
23
|
-
"redpajama-incite-instruct-7b": "togethercomputer/RedPajama-INCITE-7B-Instruct",
|
|
24
|
-
"alpaca-7b": "togethercomputer/alpaca-7b",
|
|
25
|
-
"dolly-v2-3b": "databricks/dolly-v2-3b",
|
|
26
|
-
"dolly-v2-7b": "databricks/dolly-v2-7b",
|
|
27
|
-
"dolly-v2-12b": "databricks/dolly-v2-12b",
|
|
28
|
-
"falcon-7b": "togethercomputer/falcon-7b",
|
|
29
|
-
"falcon-7b-instruct": "togethercomputer/falcon-7b-instruct",
|
|
30
|
-
"falcon-40b": "togethercomputer/falcon-40b",
|
|
31
|
-
"falcon-40b-instruct": "togethercomputer/falcon-40b-instruct",
|
|
32
|
-
"llama-7b": "huggyllama/llama-7b",
|
|
33
|
-
"llama-13b": "huggyllama/llama-13b",
|
|
34
|
-
"llama-30b": "huggyllama/llama-30b",
|
|
35
|
-
"llama-65b": "huggyllama/llama-65b",
|
|
36
|
-
"llama-2-7b": "togethercomputer/llama-2-7b",
|
|
37
|
-
"llama-2-13b": "togethercomputer/llama-2-13b",
|
|
38
|
-
"llama-2-70b": "togethercomputer/llama-2-70b",
|
|
39
|
-
"mistral-7b-v0.1": "mistralai/Mistral-7B-v0.1",
|
|
40
|
-
"mpt-7b": "togethercomputer/mpt-7b",
|
|
41
|
-
"mpt-instruct-7b": "togethercomputer/mpt-7b-instruct",
|
|
42
|
-
"mpt-30b": "togethercomputer/mpt-30b",
|
|
43
|
-
"mpt-instruct-30b": "togethercomputer/mpt-30b-instruct",
|
|
44
|
-
"pythia-1b-v0": "EleutherAI/pythia-1b-v0",
|
|
45
|
-
"pythia-2.8b-v0": "EleutherAI/pythia-2.8b-v0",
|
|
46
|
-
"pythia-6.9b": "EleutherAI/pythia-6.9b",
|
|
47
|
-
"pythia-12b-v0": "EleutherAI/pythia-12b-v0",
|
|
48
|
-
"stablelm-base-alpha-3b": "stabilityai/stablelm-base-alpha-3b",
|
|
49
|
-
"stablelm-base-alpha-7b": "stabilityai/stablelm-base-alpha-7b",
|
|
50
|
-
"vicuna-7b-v1.3": "lmsys/vicuna-7b-v1.3",
|
|
51
|
-
"vicuna-13b-v1.3": "lmsys/vicuna-13b-v1.3",
|
|
52
|
-
}
|
|
53
|
-
"""Together model name aliases.
|
|
54
|
-
|
|
55
|
-
HELM users use a shorter model name (e.g. together/flan-t5-xxl)
|
|
56
|
-
whereas the Together client sends and caches requests using
|
|
57
|
-
a longer model name that is suffixed with the implementation framework
|
|
58
|
-
(e.g. flan-t5-xxl-hf). This allows trackcing exactly which
|
|
59
|
-
implementation was used in the cached results, since some results may
|
|
60
|
-
be different depending on the implementation (e.g. efficiency metrics).
|
|
61
|
-
This also allows future migration of results in the case of changes of
|
|
62
|
-
available implementations on Together."""
|
|
63
|
-
|
|
64
|
-
|
|
65
12
|
class _RewriteRequestTags:
|
|
66
13
|
"""Tags that indicate that the request for the model must be rewritten before sending to Together."""
|
|
67
14
|
|
|
@@ -105,6 +52,10 @@ The keys are the model engine of the HELM model name (e.g. "alpaca-7b"), not the
|
|
|
105
52
|
(e.g. "stanford/alpaca-7b") or the Together model name (e.g. "togethercomputer/alpaca-7b")."""
|
|
106
53
|
|
|
107
54
|
|
|
55
|
+
TOGETHER_SUPPORTS_ASYNC_REQUESTS = False
|
|
56
|
+
"""Whether Together AI currently supports asynchronous requests."""
|
|
57
|
+
|
|
58
|
+
|
|
108
59
|
def _rewrite_raw_request_for_model_tags(raw_request: Dict[str, Any], model_engine: str) -> Dict[str, Any]:
|
|
109
60
|
"""Rewrite the raw request given the model."""
|
|
110
61
|
# Make a deepcopy to avoid mutating the input in unexpected ways
|
|
@@ -146,43 +97,41 @@ class TogetherClient(CachingClient):
|
|
|
146
97
|
INFERENCE_ENDPOINT: str = "https://api.together.xyz/api/inference"
|
|
147
98
|
RETRIEVE_JOB_MAX_WAIT_SECONDS: int = 60
|
|
148
99
|
|
|
149
|
-
|
|
150
|
-
def convert_to_raw_request(request: Request) -> Dict:
|
|
100
|
+
def convert_to_raw_request(self, request: Request) -> Dict:
|
|
151
101
|
# Following the examples from https://github.com/togethercomputer/open-models-api
|
|
152
102
|
raw_request = {
|
|
153
103
|
"request_type": "language-model-inference",
|
|
154
|
-
"model":
|
|
104
|
+
"model": self.together_model or request.model,
|
|
155
105
|
"prompt": request.prompt,
|
|
156
106
|
"temperature": request.temperature,
|
|
157
107
|
"n": request.num_completions,
|
|
158
108
|
"max_tokens": request.max_tokens,
|
|
159
109
|
"best_of": request.top_k_per_token,
|
|
160
|
-
"logprobs": request.top_k_per_token,
|
|
161
110
|
"stop": request.stop_sequences or None,
|
|
162
111
|
"echo": request.echo_prompt,
|
|
163
112
|
"top_p": request.top_p,
|
|
164
113
|
}
|
|
165
114
|
return _rewrite_raw_request_for_model_tags(raw_request, request.model_engine)
|
|
166
115
|
|
|
167
|
-
def __init__(self,
|
|
168
|
-
super().__init__(cache_config=cache_config
|
|
116
|
+
def __init__(self, cache_config: CacheConfig, together_model: Optional[str] = None, api_key: Optional[str] = None):
|
|
117
|
+
super().__init__(cache_config=cache_config)
|
|
169
118
|
# TODO: the endpoint currently doesn't require an API key. When an API key is not specified
|
|
170
119
|
# in credentials.conf, we rely on offline evaluation only.
|
|
171
120
|
self.api_key: Optional[str] = api_key
|
|
121
|
+
self.together_model = together_model
|
|
172
122
|
|
|
173
123
|
def _get_job_url(self, job_id: str) -> str:
|
|
174
124
|
return f"https://api.together.xyz/jobs/job/{job_id}"
|
|
175
125
|
|
|
176
126
|
def make_request(self, request: Request) -> RequestResult:
|
|
177
|
-
raw_request =
|
|
178
|
-
cache_key
|
|
127
|
+
raw_request = self.convert_to_raw_request(request)
|
|
128
|
+
cache_key = CachingClient.make_cache_key(raw_request, request)
|
|
179
129
|
|
|
180
130
|
if not self.api_key:
|
|
181
131
|
raise TogetherClientError("togetherApiKey not set in credentials.conf")
|
|
182
132
|
headers: Dict[str, str] = {"Authorization": f"Bearer {self.api_key}"}
|
|
183
133
|
|
|
184
|
-
|
|
185
|
-
if request.model_engine in MODEL_ALIASES:
|
|
134
|
+
if TOGETHER_SUPPORTS_ASYNC_REQUESTS:
|
|
186
135
|
|
|
187
136
|
def submit_job() -> str:
|
|
188
137
|
submit_request = {**raw_request, "async": True}
|
|
@@ -271,7 +220,7 @@ class TogetherClient(CachingClient):
|
|
|
271
220
|
)
|
|
272
221
|
|
|
273
222
|
# Expect the result to be structured the same way as a response from OpenAI API.
|
|
274
|
-
completions: List[
|
|
223
|
+
completions: List[GeneratedOutput] = []
|
|
275
224
|
for raw_completion in response["choices"]:
|
|
276
225
|
sequence_logprob = 0
|
|
277
226
|
tokens: List[Token] = []
|
|
@@ -281,22 +230,20 @@ class TogetherClient(CachingClient):
|
|
|
281
230
|
# Waiting for a fix.
|
|
282
231
|
if "logprobs" in raw_completion:
|
|
283
232
|
raw_data = raw_completion["logprobs"]
|
|
284
|
-
for text, logprob
|
|
285
|
-
raw_data["tokens"], raw_data["token_logprobs"], raw_data["top_logprobs"]
|
|
286
|
-
):
|
|
233
|
+
for text, logprob in zip(raw_data["tokens"], raw_data["token_logprobs"]):
|
|
287
234
|
# TODO #1654: Check if this is still needed
|
|
288
235
|
text = cleanup_str(text, "together")
|
|
289
|
-
tokens.append(Token(text=text, logprob=logprob or 0
|
|
236
|
+
tokens.append(Token(text=text, logprob=logprob or 0))
|
|
290
237
|
sequence_logprob += logprob or 0
|
|
291
238
|
else:
|
|
292
239
|
# hack: just make the entire text one token so that something shows up in the frontend
|
|
293
240
|
text = cleanup_str(raw_completion["text"], "together")
|
|
294
|
-
tokens.append(Token(text=text, logprob=0
|
|
241
|
+
tokens.append(Token(text=text, logprob=0))
|
|
295
242
|
|
|
296
243
|
raw_finish_reason: Optional[str] = raw_completion.get("finish_reason")
|
|
297
244
|
finish_reason: Optional[Dict] = {"reason": raw_finish_reason} if raw_finish_reason else None
|
|
298
245
|
|
|
299
|
-
completion =
|
|
246
|
+
completion = GeneratedOutput(
|
|
300
247
|
text=cleanup_str(raw_completion["text"], "together"),
|
|
301
248
|
logprob=sequence_logprob,
|
|
302
249
|
tokens=tokens,
|