crfm-helm 0.5.4__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- crfm_helm-0.5.5.dist-info/METADATA +413 -0
- crfm_helm-0.5.5.dist-info/RECORD +894 -0
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.5.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapter_spec.py +13 -1
- helm/benchmark/adaptation/adapters/adapter_factory.py +15 -1
- helm/benchmark/adaptation/adapters/binary_ranking_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/chat_adapter.py +49 -0
- helm/benchmark/adaptation/adapters/ehr_instruction_adapter.py +108 -0
- helm/benchmark/adaptation/adapters/generation_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/language_modeling_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multimodal/generation_multimodal_adapter.py +4 -2
- helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multimodal/multiple_choice_joint_multimodal_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multimodal/test_in_context_learning_multimodal_adapter.py +4 -2
- helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +1 -1
- helm/benchmark/adaptation/adapters/multiple_choice_calibrated_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +2 -2
- helm/benchmark/adaptation/adapters/multiple_choice_joint_chain_of_thought_adapter.py +87 -0
- helm/benchmark/adaptation/adapters/multiple_choice_separate_adapter.py +1 -1
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +3 -3
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +2 -2
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +2 -2
- helm/benchmark/adaptation/common_adapter_specs.py +69 -4
- helm/benchmark/adaptation/prompt.py +1 -1
- helm/benchmark/annotation/aci_bench_annotator.py +95 -0
- helm/benchmark/annotation/air_bench_annotator.py +20 -5
- helm/benchmark/annotation/annotator.py +5 -0
- helm/benchmark/annotation/annotator_factory.py +3 -20
- helm/benchmark/annotation/autobencher_capabilities_annotator.py +107 -0
- helm/benchmark/annotation/autobencher_safety_annotator.py +98 -0
- helm/benchmark/annotation/bigcodebench_annotator.py +108 -0
- helm/benchmark/annotation/bird_sql_annotator.py +58 -0
- helm/benchmark/annotation/chw_care_plan_annotator.py +98 -0
- helm/benchmark/annotation/czech_bank_qa_annotator.py +78 -0
- helm/benchmark/annotation/dischargeme_annotator.py +107 -0
- helm/benchmark/annotation/ehr_sql_annotator.py +87 -0
- helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +131 -0
- helm/benchmark/annotation/image2struct/image_compiler_annotator.py +6 -1
- helm/benchmark/annotation/live_qa_annotator.py +1 -1
- helm/benchmark/annotation/med_dialog_annotator.py +99 -0
- helm/benchmark/annotation/medalign_annotator.py +100 -0
- helm/benchmark/annotation/medi_qa_annotator.py +98 -0
- helm/benchmark/annotation/medication_qa_annotator.py +87 -63
- helm/benchmark/annotation/mental_health_annotator.py +98 -0
- helm/benchmark/annotation/mimic_rrs_annotator.py +100 -0
- helm/benchmark/annotation/model_as_judge.py +218 -6
- helm/benchmark/annotation/mtsamples_procedures_annotator.py +98 -0
- helm/benchmark/annotation/mtsamples_replicate_annotator.py +101 -0
- helm/benchmark/annotation/omni_math/gpt_evaluation_template.txt +152 -0
- helm/benchmark/annotation/omni_math/gpt_evaluation_zero_shot_template.txt +36 -0
- helm/benchmark/annotation/omni_math_annotator.py +132 -0
- helm/benchmark/annotation/spider_annotator.py +18 -0
- helm/benchmark/annotation/starr_patient_instructions_annotator.py +98 -0
- helm/benchmark/annotation/wildbench/eval_template.pairwise.v2.md +75 -0
- helm/benchmark/annotation/wildbench/eval_template.score.v2.md +66 -0
- helm/benchmark/annotation/wildbench_annotator.py +119 -0
- helm/benchmark/annotation_executor.py +35 -15
- helm/benchmark/augmentations/cleva_perturbation.py +9 -8
- helm/benchmark/augmentations/contraction_expansion_perturbation.py +2 -2
- helm/benchmark/augmentations/contrast_sets_perturbation.py +2 -2
- helm/benchmark/augmentations/dialect_perturbation.py +4 -5
- helm/benchmark/augmentations/extra_space_perturbation.py +2 -2
- helm/benchmark/augmentations/filler_words_perturbation.py +2 -2
- helm/benchmark/augmentations/gender_perturbation.py +2 -2
- helm/benchmark/augmentations/lowercase_perturbation.py +2 -2
- helm/benchmark/augmentations/mild_mix_perturbation.py +6 -6
- helm/benchmark/augmentations/misspelling_perturbation.py +2 -2
- helm/benchmark/augmentations/person_name_perturbation.py +4 -5
- helm/benchmark/augmentations/perturbation.py +1 -1
- helm/benchmark/augmentations/space_perturbation.py +2 -2
- helm/benchmark/augmentations/suffix_perturbation.py +2 -2
- helm/benchmark/augmentations/synonym_perturbation.py +4 -3
- helm/benchmark/augmentations/test_perturbation.py +16 -13
- helm/benchmark/augmentations/translate_perturbation.py +2 -2
- helm/benchmark/augmentations/typos_perturbation.py +2 -2
- helm/benchmark/data_preprocessor.py +2 -2
- helm/benchmark/huggingface_registration.py +2 -7
- helm/benchmark/metrics/aci_bench_metrics.py +34 -0
- helm/benchmark/metrics/basic_metrics.py +6 -6
- helm/benchmark/metrics/bbq_metrics.py +2 -2
- helm/benchmark/metrics/bias_metrics.py +12 -3
- helm/benchmark/metrics/bigcodebench_metrics.py +25 -0
- helm/benchmark/metrics/bird_sql_metrics.py +28 -0
- helm/benchmark/metrics/chw_care_plan_metrics.py +34 -0
- helm/benchmark/metrics/classification_metrics.py +76 -12
- helm/benchmark/metrics/cleva_harms_metrics.py +8 -7
- helm/benchmark/metrics/code_metrics.py +5 -5
- helm/benchmark/metrics/comet_metric.py +125 -0
- helm/benchmark/metrics/common_metric_specs.py +9 -2
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +72 -0
- helm/benchmark/metrics/copyright_metrics.py +4 -4
- helm/benchmark/metrics/czech_bank_qa_metrics.py +29 -0
- helm/benchmark/metrics/decodingtrust_fairness_metrics.py +2 -2
- helm/benchmark/metrics/decodingtrust_privacy_metrics.py +2 -2
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +2 -2
- helm/benchmark/metrics/dischargeme_metrics.py +34 -0
- helm/benchmark/metrics/disinformation_metrics.py +4 -4
- helm/benchmark/metrics/dry_run_metrics.py +5 -5
- helm/benchmark/metrics/efficiency_metrics.py +3 -3
- helm/benchmark/metrics/ehr_sql_metrics.py +103 -0
- helm/benchmark/metrics/evaluate_instances_metric.py +3 -3
- helm/benchmark/metrics/evaluate_reference_metrics.py +144 -16
- helm/benchmark/metrics/gpqa_chain_of_thought_metric.py +103 -0
- helm/benchmark/metrics/gpt4_audio_critique_metrics.py +167 -0
- helm/benchmark/metrics/helpdesk_call_summarization_metrics.py +36 -0
- helm/benchmark/metrics/ifeval/__init__.py +0 -0
- helm/benchmark/metrics/ifeval/instructions.py +1574 -0
- helm/benchmark/metrics/ifeval/instructions_registry.py +182 -0
- helm/benchmark/metrics/ifeval/instructions_registry.pyi +3 -0
- helm/benchmark/metrics/ifeval/instructions_util.py +153 -0
- helm/benchmark/metrics/ifeval_metrics.py +55 -0
- helm/benchmark/metrics/image_generation/aesthetics_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/detection_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/detectors/vitdet.py +1 -1
- helm/benchmark/metrics/image_generation/fractal_dimension/test_fractal_dimension_util.py +1 -1
- helm/benchmark/metrics/image_generation/fractal_dimension_metric.py +1 -1
- helm/benchmark/metrics/image_generation/nsfw_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/q16/test_q16.py +3 -1
- helm/benchmark/metrics/image_generation/q16_toxicity_metrics.py +1 -1
- helm/benchmark/metrics/image_generation/skin_tone_metrics.py +2 -2
- helm/benchmark/metrics/image_generation/watermark/test_watermark_detector.py +1 -1
- helm/benchmark/metrics/image_generation/watermark_metrics.py +1 -1
- helm/benchmark/metrics/instruction_following_critique_metrics.py +4 -4
- helm/benchmark/metrics/language_modeling_metrics.py +4 -4
- helm/benchmark/metrics/machine_translation_metrics.py +2 -2
- helm/benchmark/metrics/med_dialog_metrics.py +34 -0
- helm/benchmark/metrics/medalign_metrics.py +34 -0
- helm/benchmark/metrics/medcalc_bench_metrics.py +124 -0
- helm/benchmark/metrics/medec_metrics.py +101 -0
- helm/benchmark/metrics/medi_qa_metrics.py +34 -0
- helm/benchmark/metrics/medication_qa_metrics.py +15 -4
- helm/benchmark/metrics/mental_health_metrics.py +34 -0
- helm/benchmark/metrics/metric.py +3 -3
- helm/benchmark/metrics/mimic_rrs_metrics.py +34 -0
- helm/benchmark/metrics/mimiciv_billing_code_metrics.py +96 -0
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +34 -0
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +34 -0
- helm/benchmark/metrics/nltk_helper.py +32 -0
- helm/benchmark/metrics/numeracy_metrics.py +4 -4
- helm/benchmark/metrics/omni_math_metrics.py +32 -0
- helm/benchmark/metrics/output_processing_metric.py +60 -0
- helm/benchmark/metrics/output_processors.py +15 -0
- helm/benchmark/metrics/paraphrase_generation_metrics.py +2 -2
- helm/benchmark/metrics/ranking_metrics.py +3 -3
- helm/benchmark/metrics/reference_metric.py +3 -3
- helm/benchmark/metrics/{bhasa_metrics.py → seahelm_metrics.py} +3 -3
- helm/benchmark/metrics/seahelm_metrics_specs.py +10 -0
- helm/benchmark/metrics/spider_metrics.py +7 -0
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +34 -0
- helm/benchmark/metrics/statistic.py +1 -1
- helm/benchmark/metrics/summac/model_summac.py +1 -1
- helm/benchmark/metrics/summarization_critique_metrics.py +4 -4
- helm/benchmark/metrics/summarization_metrics.py +19 -9
- helm/benchmark/metrics/test_bias_metrics.py +5 -1
- helm/benchmark/metrics/test_classification_metrics.py +140 -68
- helm/benchmark/metrics/test_evaluate_reference_metrics.py +15 -0
- helm/benchmark/metrics/test_metric.py +1 -1
- helm/benchmark/metrics/test_statistic.py +2 -2
- helm/benchmark/metrics/tokens/ai21_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/auto_token_cost_estimator.py +6 -6
- helm/benchmark/metrics/tokens/cohere_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/free_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/gooseai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +1 -1
- helm/benchmark/metrics/toxicity_metrics.py +4 -4
- helm/benchmark/metrics/unitxt_metrics.py +4 -1
- helm/benchmark/metrics/vision_language/image_metrics.py +1 -1
- helm/benchmark/metrics/wildbench_metrics.py +34 -0
- helm/benchmark/model_metadata_registry.py +16 -0
- helm/benchmark/presentation/summarize.py +23 -10
- helm/benchmark/presentation/torr_robustness_summarizer.py +178 -0
- helm/benchmark/reeval_run.py +203 -0
- helm/benchmark/reeval_runner.py +355 -0
- helm/benchmark/run.py +8 -17
- helm/benchmark/run_expander.py +78 -8
- helm/benchmark/run_spec_factory.py +12 -0
- helm/benchmark/run_specs/air_bench_run_specs.py +21 -3
- helm/benchmark/run_specs/audio_run_specs.py +613 -0
- helm/benchmark/run_specs/call_center_run_specs.py +49 -0
- helm/benchmark/run_specs/capabilities_run_specs.py +308 -0
- helm/benchmark/run_specs/classic_run_specs.py +1 -69
- helm/benchmark/run_specs/enem_challenge_specs.py +31 -0
- helm/benchmark/run_specs/enterprise_run_specs.py +260 -0
- helm/benchmark/run_specs/experimental_run_specs.py +112 -3
- helm/benchmark/run_specs/imdb_ptbr_run_specs.py +30 -0
- helm/benchmark/run_specs/lite_run_specs.py +2 -2
- helm/benchmark/run_specs/long_context_run_specs.py +89 -0
- helm/benchmark/run_specs/medhelm_run_specs.py +1155 -0
- helm/benchmark/run_specs/mmlu_clinical_afr_run_specs.py +49 -0
- helm/benchmark/run_specs/oab_exams_specs.py +32 -0
- helm/benchmark/run_specs/safety_run_specs.py +37 -0
- helm/benchmark/run_specs/{bhasa_run_specs.py → seahelm_run_specs.py} +44 -44
- helm/benchmark/run_specs/sql_run_specs.py +54 -0
- helm/benchmark/run_specs/tweetsentbr_run_specs.py +32 -0
- helm/benchmark/run_specs/unitxt_run_specs.py +14 -5
- helm/benchmark/run_specs/vlm_run_specs.py +75 -2
- helm/benchmark/run_specs/winogrande_afr_run_specs.py +47 -0
- helm/benchmark/scenarios/aci_bench_scenario.py +120 -0
- helm/benchmark/scenarios/air_bench_scenario.py +6 -1
- helm/benchmark/scenarios/anthropic_hh_rlhf_scenario.py +5 -3
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/__init__.py +0 -0
- helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +128 -0
- helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +154 -0
- helm/benchmark/scenarios/audio_language/ami_scenario.py +96 -0
- helm/benchmark/scenarios/audio_language/audio_mnist_scenario.py +62 -0
- helm/benchmark/scenarios/audio_language/audio_pairs_scenario.py +62 -0
- helm/benchmark/scenarios/audio_language/audiocaps_scenario.py +59 -0
- helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +152 -0
- helm/benchmark/scenarios/audio_language/common_voice_15_scenario.py +99 -0
- helm/benchmark/scenarios/audio_language/covost2_scenario.py +163 -0
- helm/benchmark/scenarios/audio_language/fleurs_fairness_scenario.py +83 -0
- helm/benchmark/scenarios/audio_language/fleurs_scenario.py +312 -0
- helm/benchmark/scenarios/audio_language/iemocap_audio_scenario.py +83 -0
- helm/benchmark/scenarios/audio_language/librispeech_fairness_scenario.py +96 -0
- helm/benchmark/scenarios/audio_language/librispeech_scenario.py +80 -0
- helm/benchmark/scenarios/audio_language/meld_audio_scenario.py +113 -0
- helm/benchmark/scenarios/audio_language/multilingual_librispeech_scenario.py +80 -0
- helm/benchmark/scenarios/audio_language/mustard_scenario.py +142 -0
- helm/benchmark/scenarios/audio_language/mutox_scenario.py +254 -0
- helm/benchmark/scenarios/audio_language/parade_scenario.py +97 -0
- helm/benchmark/scenarios/audio_language/speech_robust_bench_scenario.py +124 -0
- helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +69 -0
- helm/benchmark/scenarios/audio_language/voice_jailbreak_attacks_scenario.py +87 -0
- helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +106 -0
- helm/benchmark/scenarios/autobencher_capabilities_scenario.py +68 -0
- helm/benchmark/scenarios/autobencher_safety_scenario.py +51 -0
- helm/benchmark/scenarios/babi_qa_scenario.py +1 -1
- helm/benchmark/scenarios/banking77_scenario.py +6 -1
- helm/benchmark/scenarios/bbq_scenario.py +1 -1
- helm/benchmark/scenarios/big_bench_scenario.py +11 -1
- helm/benchmark/scenarios/bigcodebench_scenario.py +58 -0
- helm/benchmark/scenarios/bird_sql_scenario.py +94 -0
- helm/benchmark/scenarios/bird_sql_scenario_helper.py +118 -0
- helm/benchmark/scenarios/blimp_scenario.py +1 -1
- helm/benchmark/scenarios/bold_scenario.py +1 -1
- helm/benchmark/scenarios/boolq_scenario.py +1 -1
- helm/benchmark/scenarios/casehold_scenario.py +79 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +105 -0
- helm/benchmark/scenarios/civil_comments_scenario.py +1 -1
- helm/benchmark/scenarios/clear_scenario.py +153 -0
- helm/benchmark/scenarios/cleva_scenario.py +2 -2
- helm/benchmark/scenarios/code_scenario.py +17 -4
- helm/benchmark/scenarios/commonsense_scenario.py +1 -1
- helm/benchmark/scenarios/conv_fin_qa_calc_scenario.py +97 -0
- helm/benchmark/scenarios/copyright_scenario.py +1 -1
- helm/benchmark/scenarios/covid_dialog_scenario.py +10 -1
- helm/benchmark/scenarios/cti_to_mitre_scenario.py +240 -0
- helm/benchmark/scenarios/custom_mcqa_scenario.py +1 -1
- helm/benchmark/scenarios/czech_bank_qa_scenario.py +130 -0
- helm/benchmark/scenarios/decodingtrust_adv_demonstration_scenario.py +1 -1
- helm/benchmark/scenarios/decodingtrust_privacy_scenario.py +1 -1
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +1 -1
- helm/benchmark/scenarios/decodingtrust_toxicity_prompts_scenario.py +1 -1
- helm/benchmark/scenarios/dialogue_scenarios.py +13 -2
- helm/benchmark/scenarios/dischargeme_scenario.py +157 -0
- helm/benchmark/scenarios/disinformation_scenario.py +10 -1
- helm/benchmark/scenarios/dyck_language_scenario.py +10 -1
- helm/benchmark/scenarios/echr_judgment_classification_scenario.py +113 -0
- helm/benchmark/scenarios/ehr_sql_scenario.py +131 -0
- helm/benchmark/scenarios/ehrshot_scenario.py +1546 -0
- helm/benchmark/scenarios/enem_challenge_scenario.py +58 -0
- helm/benchmark/scenarios/entity_data_imputation_scenario.py +11 -1
- helm/benchmark/scenarios/entity_matching_scenario.py +12 -2
- helm/benchmark/scenarios/financial_phrasebank_scenario.py +94 -0
- helm/benchmark/scenarios/gold_commodity_news_scenario.py +124 -0
- helm/benchmark/scenarios/gpqa_scenario.py +80 -0
- helm/benchmark/scenarios/grammar_scenario.py +2 -2
- helm/benchmark/scenarios/gsm_scenario.py +10 -1
- helm/benchmark/scenarios/harm_bench_gcg_transfer_scenario.py +50 -0
- helm/benchmark/scenarios/harm_bench_scenario.py +1 -1
- helm/benchmark/scenarios/headqa_scenario.py +131 -0
- helm/benchmark/scenarios/helpdesk_call_summarization_scenario.py +37 -0
- helm/benchmark/scenarios/ice_scenario.py +8 -4
- helm/benchmark/scenarios/ifeval_scenario.py +53 -0
- helm/benchmark/scenarios/imdb_ptbr_scenario.py +60 -0
- helm/benchmark/scenarios/imdb_scenario.py +11 -2
- helm/benchmark/scenarios/infinite_bench_sum_scenario.py +82 -0
- helm/benchmark/scenarios/interactive_qa_mmlu_scenario.py +2 -2
- helm/benchmark/scenarios/koala_scenario.py +1 -1
- helm/benchmark/scenarios/legal_contract_summarization_scenario.py +129 -0
- helm/benchmark/scenarios/legal_opinion_sentiment_classification_scenario.py +77 -0
- helm/benchmark/scenarios/legal_summarization_scenario.py +11 -1
- helm/benchmark/scenarios/legal_support_scenario.py +11 -1
- helm/benchmark/scenarios/legalbench_scenario.py +22 -3
- helm/benchmark/scenarios/lex_glue_scenario.py +12 -2
- helm/benchmark/scenarios/lextreme_scenario.py +11 -1
- helm/benchmark/scenarios/live_qa_scenario.py +1 -1
- helm/benchmark/scenarios/lm_entry_scenario.py +1 -1
- helm/benchmark/scenarios/lsat_qa_scenario.py +1 -1
- helm/benchmark/scenarios/math_scenario.py +9 -1
- helm/benchmark/scenarios/me_q_sum_scenario.py +10 -1
- helm/benchmark/scenarios/med_dialog_scenario.py +22 -24
- helm/benchmark/scenarios/med_mcqa_scenario.py +10 -1
- helm/benchmark/scenarios/med_paragraph_simplification_scenario.py +10 -1
- helm/benchmark/scenarios/med_qa_scenario.py +10 -1
- helm/benchmark/scenarios/medalign_scenario.py +88 -0
- helm/benchmark/scenarios/medalign_scenario_helper.py +429 -0
- helm/benchmark/scenarios/medbullets_scenario.py +140 -0
- helm/benchmark/scenarios/medcalc_bench_scenario.py +125 -0
- helm/benchmark/scenarios/medec_scenario.py +120 -0
- helm/benchmark/scenarios/medhallu_scenario.py +66 -0
- helm/benchmark/scenarios/medi_qa_scenario.py +105 -0
- helm/benchmark/scenarios/medication_qa_scenario.py +2 -2
- helm/benchmark/scenarios/mental_health_scenario.py +112 -0
- helm/benchmark/scenarios/mimic_bhc_scenario.py +98 -0
- helm/benchmark/scenarios/mimic_rrs_scenario.py +89 -0
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +71 -0
- helm/benchmark/scenarios/mmlu_clinical_afr_scenario.py +74 -0
- helm/benchmark/scenarios/mmlu_pro_scenario.py +95 -0
- helm/benchmark/scenarios/mmlu_scenario.py +11 -1
- helm/benchmark/scenarios/msmarco_scenario.py +1 -1
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +141 -0
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +141 -0
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +271 -0
- helm/benchmark/scenarios/narrativeqa_scenario.py +1 -1
- helm/benchmark/scenarios/natural_qa_scenario.py +1 -1
- helm/benchmark/scenarios/newsqa_scenario.py +1 -1
- helm/benchmark/scenarios/numeracy_scenario.py +10 -1
- helm/benchmark/scenarios/oab_exams_scenario.py +57 -0
- helm/benchmark/scenarios/omni_math_scenario.py +53 -0
- helm/benchmark/scenarios/open_assistant_scenario.py +11 -2
- helm/benchmark/scenarios/opinions_qa_scenario.py +1 -1
- helm/benchmark/scenarios/pubmed_qa_scenario.py +54 -43
- helm/benchmark/scenarios/quac_scenario.py +10 -1
- helm/benchmark/scenarios/race_based_med_scenario.py +142 -0
- helm/benchmark/scenarios/raft_scenario.py +17 -2
- helm/benchmark/scenarios/real_toxicity_prompts_scenario.py +1 -1
- helm/benchmark/scenarios/ruler_qa_scenario_helper.py +171 -0
- helm/benchmark/scenarios/ruler_qa_scenarios.py +88 -0
- helm/benchmark/scenarios/scenario.py +9 -1
- helm/benchmark/scenarios/{bhasa_scenario.py → seahelm_scenario.py} +7 -2
- helm/benchmark/scenarios/self_instruct_scenario.py +1 -1
- helm/benchmark/scenarios/shc_bmt_scenario.py +69 -0
- helm/benchmark/scenarios/shc_cdi_scenario.py +70 -0
- helm/benchmark/scenarios/shc_conf_scenario.py +70 -0
- helm/benchmark/scenarios/shc_ent_scenario.py +72 -0
- helm/benchmark/scenarios/shc_gip_scenario.py +66 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +76 -0
- helm/benchmark/scenarios/shc_sei_scenario.py +89 -0
- helm/benchmark/scenarios/shc_sequoia_scenario.py +69 -0
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +1 -1
- helm/benchmark/scenarios/spider_scenario.py +91 -0
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +90 -0
- helm/benchmark/scenarios/summarization_scenario.py +11 -1
- helm/benchmark/scenarios/sumosum_scenario.py +157 -0
- helm/benchmark/scenarios/synthetic_efficiency_scenario.py +1 -1
- helm/benchmark/scenarios/synthetic_reasoning_natural_scenario.py +11 -1
- helm/benchmark/scenarios/synthetic_reasoning_scenario.py +11 -1
- helm/benchmark/scenarios/test_bigcodebench_scenario.py +26 -0
- helm/benchmark/scenarios/test_czech_bank_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_enem_challenge_scenario.py +53 -0
- helm/benchmark/scenarios/test_ewok_scenario.py +6 -2
- helm/benchmark/scenarios/test_gold_commodity_news_scenario.py +18 -0
- helm/benchmark/scenarios/test_gpqa_scenario.py +44 -0
- helm/benchmark/scenarios/test_ifeval_scenario.py +36 -0
- helm/benchmark/scenarios/test_imdb_ptbr_scenario.py +27 -0
- helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +46 -0
- helm/benchmark/scenarios/test_math_scenario.py +1 -0
- helm/benchmark/scenarios/test_mmlu_clinical_afr_scenario.py +21 -0
- helm/benchmark/scenarios/test_mmlu_pro_scenario.py +53 -0
- helm/benchmark/scenarios/test_oab_exams_scenario.py +51 -0
- helm/benchmark/scenarios/test_omni_math_scenario.py +27 -0
- helm/benchmark/scenarios/test_tweetsentbr_scenario.py +24 -0
- helm/benchmark/scenarios/test_wildbench_scenario.py +15 -0
- helm/benchmark/scenarios/test_winogrande_afr_scenario.py +19 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +10 -1
- helm/benchmark/scenarios/the_pile_scenario.py +1 -1
- helm/benchmark/scenarios/truthful_qa_scenario.py +10 -1
- helm/benchmark/scenarios/tweetsentbr_scenario.py +66 -0
- helm/benchmark/scenarios/twitter_aae_scenario.py +1 -1
- helm/benchmark/scenarios/unitxt_scenario.py +8 -2
- helm/benchmark/scenarios/verifiability_judgment_scenario.py +1 -1
- helm/benchmark/scenarios/vicuna_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/blink_scenario.py +140 -0
- helm/benchmark/scenarios/vision_language/mm_star_scenario.py +95 -0
- helm/benchmark/scenarios/vision_language/vqa_rad_scenario.py +88 -0
- helm/benchmark/scenarios/wikifact_scenario.py +11 -1
- helm/benchmark/scenarios/wikitext_103_scenario.py +1 -1
- helm/benchmark/scenarios/wildbench_scenario.py +83 -0
- helm/benchmark/scenarios/winogrande_afr_scenario.py +78 -0
- helm/benchmark/scenarios/wmt_14_scenario.py +14 -2
- helm/benchmark/scenarios/xstest_scenario.py +1 -1
- helm/benchmark/server.py +11 -0
- helm/benchmark/slurm_runner.py +1 -1
- helm/benchmark/static/schema_audio.yaml +752 -0
- helm/benchmark/static/schema_autobencher.yaml +150 -0
- helm/benchmark/static/schema_call_center.yaml +97 -60
- helm/benchmark/static/schema_capabilities.yaml +254 -0
- helm/benchmark/static/schema_czech_bank.yaml +148 -0
- helm/benchmark/static/schema_enem_challenge.yaml +146 -0
- helm/benchmark/static/schema_enterprise.yaml +298 -0
- helm/benchmark/static/schema_finance.yaml +14 -12
- helm/benchmark/static/schema_heim.yaml +1389 -0
- helm/benchmark/static/{schema_medical.yaml → schema_long_context.yaml} +67 -82
- helm/benchmark/static/schema_medhelm.yaml +1081 -0
- helm/benchmark/static/schema_mmlu_winogrande_afr.yaml +1045 -0
- helm/benchmark/static/schema_safety.yaml +18 -1
- helm/benchmark/static/{schema_bhasa.yaml → schema_seahelm.yaml} +30 -16
- helm/benchmark/static/schema_social_audio.yaml +224 -0
- helm/benchmark/static/schema_sql.yaml +171 -0
- helm/benchmark/static/{schema_tables.yaml → schema_torr.yaml} +169 -36
- helm/benchmark/static/schema_tweetsentbr.yaml +146 -0
- helm/benchmark/static/schema_vhelm.yaml +109 -36
- helm/benchmark/static_build/assets/helm-safety-2907a7b6.png +0 -0
- helm/benchmark/static_build/assets/index-262903c1.js +10 -0
- helm/benchmark/static_build/assets/index-42060d71.css +1 -0
- helm/benchmark/static_build/assets/medhelm-overview-3ddfcd65.png +0 -0
- helm/benchmark/static_build/assets/{react-d4a0b69b.js → react-f82877fd.js} +1 -1
- helm/benchmark/static_build/assets/{recharts-6d337683.js → recharts-4037aff0.js} +1 -1
- helm/benchmark/static_build/assets/{tremor-54a99cc4.js → tremor-9cefc3c5.js} +1 -1
- helm/benchmark/static_build/config.js +1 -1
- helm/benchmark/static_build/index.html +5 -5
- helm/benchmark/window_services/default_window_service.py +1 -1
- helm/benchmark/window_services/encoder_decoder_window_service.py +1 -1
- helm/benchmark/window_services/ice_window_service.py +1 -1
- helm/benchmark/window_services/image_generation/lexica_search_window_service.py +1 -1
- helm/benchmark/window_services/image_generation/openai_dalle_window_service.py +1 -1
- helm/benchmark/window_services/local_window_service.py +2 -2
- helm/benchmark/window_services/test_anthropic_window_service.py +3 -3
- helm/benchmark/window_services/test_bloom_window_service.py +3 -3
- helm/benchmark/window_services/test_gpt2_window_service.py +7 -2
- helm/benchmark/window_services/test_gpt4_window_service.py +8 -3
- helm/benchmark/window_services/test_gptj_window_service.py +8 -3
- helm/benchmark/window_services/test_gptneox_window_service.py +3 -3
- helm/benchmark/window_services/test_openai_window_service.py +8 -3
- helm/benchmark/window_services/test_opt_window_service.py +3 -3
- helm/benchmark/window_services/test_palmyra_window_service.py +3 -3
- helm/benchmark/window_services/test_t0pp_window_service.py +3 -3
- helm/benchmark/window_services/test_t511b_window_service.py +3 -3
- helm/benchmark/window_services/test_ul2_window_service.py +3 -3
- helm/benchmark/window_services/test_utils.py +1 -1
- helm/benchmark/window_services/test_yalm_window_service.py +3 -3
- helm/benchmark/window_services/yalm_window_service.py +1 -1
- helm/clients/ai21_client.py +3 -3
- helm/clients/aleph_alpha_client.py +1 -1
- helm/clients/audio_language/__init__.py +0 -0
- helm/clients/audio_language/diva_llama_client.py +118 -0
- helm/clients/audio_language/llama_omni_client.py +198 -0
- helm/clients/audio_language/qwen2_audiolm_client.py +188 -0
- helm/clients/audio_language/qwen_audiolm_client.py +150 -0
- helm/clients/auto_client.py +4 -2
- helm/clients/azure_openai_client.py +55 -0
- helm/clients/bedrock_client.py +201 -7
- helm/clients/bedrock_utils.py +33 -0
- helm/clients/clip_scorers/clip_scorer.py +1 -1
- helm/clients/clip_scorers/multilingual_clip_scorer.py +1 -1
- helm/clients/cohere_client.py +3 -3
- helm/clients/google_client.py +1 -1
- helm/clients/http_model_client.py +1 -1
- helm/clients/huggingface_client.py +10 -18
- helm/clients/ibm_client.py +267 -0
- helm/clients/image_generation/adobe_vision_client.py +1 -1
- helm/clients/image_generation/aleph_alpha_image_generation_client.py +1 -1
- helm/clients/image_generation/cogview2/sr_pipeline/__init__.py +3 -3
- helm/clients/image_generation/cogview2/sr_pipeline/direct_sr.py +5 -2
- helm/clients/image_generation/cogview2/sr_pipeline/iterative_sr.py +5 -2
- helm/clients/image_generation/cogview2/sr_pipeline/sr_group.py +2 -2
- helm/clients/image_generation/cogview2_client.py +1 -1
- helm/clients/image_generation/dalle2_client.py +1 -1
- helm/clients/image_generation/dalle3_client.py +2 -2
- helm/clients/image_generation/dalle_mini/__init__.py +1 -1
- helm/clients/image_generation/dalle_mini/data.py +1 -1
- helm/clients/image_generation/dalle_mini/model/__init__.py +5 -5
- helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
- helm/clients/image_generation/dalle_mini/model/modeling.py +2 -2
- helm/clients/image_generation/dalle_mini/model/processor.py +4 -4
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
- helm/clients/image_generation/dalle_mini/vqgan_jax/__init__.py +1 -1
- helm/clients/image_generation/dalle_mini/vqgan_jax/convert_pt_model_to_jax.py +2 -2
- helm/clients/image_generation/dalle_mini/vqgan_jax/modeling_flax_vqgan.py +1 -1
- helm/clients/image_generation/dalle_mini_client.py +1 -1
- helm/clients/image_generation/deep_floyd_client.py +1 -1
- helm/clients/image_generation/huggingface_diffusers_client.py +1 -1
- helm/clients/image_generation/lexica_client.py +1 -1
- helm/clients/image_generation/mindalle/models/__init__.py +6 -6
- helm/clients/image_generation/mindalle/models/stage1/vqgan.py +1 -1
- helm/clients/image_generation/mindalle/models/stage2/transformer.py +1 -1
- helm/clients/image_generation/mindalle/utils/__init__.py +3 -3
- helm/clients/image_generation/mindalle_client.py +1 -1
- helm/clients/image_generation/together_image_generation_client.py +1 -1
- helm/clients/lit_gpt_client.py +2 -2
- helm/clients/mistral_client.py +62 -18
- helm/clients/nvidia_nim_client.py +0 -3
- helm/clients/openai_client.py +241 -22
- helm/clients/palmyra_client.py +1 -4
- helm/clients/reka_client.py +1 -1
- helm/clients/stanfordhealthcare_azure_openai_client.py +58 -0
- helm/clients/stanfordhealthcare_claude_client.py +31 -0
- helm/clients/stanfordhealthcare_google_client.py +43 -0
- helm/clients/stanfordhealthcare_http_model_client.py +93 -0
- helm/clients/stanfordhealthcare_openai_client.py +62 -0
- helm/clients/stanfordhealthcare_shc_openai_client.py +42 -0
- helm/clients/test_client.py +1 -1
- helm/clients/test_together_client.py +6 -1
- helm/clients/together_client.py +47 -7
- helm/clients/upstage_client.py +23 -0
- helm/clients/vertexai_client.py +39 -13
- helm/clients/vision_language/open_flamingo/__init__.py +2 -2
- helm/clients/vision_language/open_flamingo/src/factory.py +3 -3
- helm/clients/vision_language/open_flamingo/src/flamingo.py +2 -2
- helm/clients/vision_language/open_flamingo/src/flamingo_lm.py +2 -2
- helm/clients/vision_language/qwen2_vlm_client.py +175 -0
- helm/clients/vllm_client.py +4 -6
- helm/clients/yi_client.py +0 -3
- helm/common/audio_utils.py +111 -0
- helm/common/file_caches/local_file_cache.py +1 -1
- helm/common/file_caches/test_local_file_cache.py +1 -1
- helm/common/images_utils.py +2 -2
- helm/common/media_object.py +2 -2
- helm/common/multimodal_request_utils.py +26 -0
- helm/common/reeval_parameters.py +12 -0
- helm/common/request.py +6 -2
- helm/common/response_format.py +18 -0
- helm/common/test_media_object.py +1 -1
- helm/config/model_deployments.yaml +1112 -19
- helm/config/model_metadata.yaml +985 -44
- helm/config/tokenizer_configs.yaml +379 -3
- helm/proxy/cli.py +2 -2
- helm/proxy/example_queries.py +1 -1
- helm/proxy/server.py +11 -4
- helm/proxy/services/remote_service.py +1 -1
- helm/proxy/services/server_service.py +1 -1
- helm/proxy/services/test_remote_service.py +2 -2
- helm/proxy/services/test_service.py +1 -1
- helm/proxy/static/general.js +122 -0
- helm/proxy/static/help.html +99 -0
- helm/proxy/static/index.css +57 -0
- helm/proxy/static/index.html +40 -0
- helm/proxy/static/index.js +456 -0
- helm/proxy/static/info-icon.png +0 -0
- helm/proxy/test_retry.py +1 -1
- helm/proxy/token_counters/auto_token_counter.py +1 -1
- helm/tokenizers/aleph_alpha_tokenizer.py +1 -1
- helm/tokenizers/caching_tokenizer.py +2 -30
- helm/tokenizers/http_model_tokenizer.py +1 -1
- helm/tokenizers/huggingface_tokenizer.py +2 -2
- helm/tokenizers/lit_gpt_tokenizer.py +1 -1
- helm/tokenizers/test_anthropic_tokenizer.py +6 -2
- helm/tokenizers/test_huggingface_tokenizer.py +1 -1
- helm/tokenizers/test_yalm_tokenizer.py +1 -1
- helm/tokenizers/tiktoken_tokenizer.py +1 -1
- helm/tokenizers/tokenizer.py +3 -1
- helm/tokenizers/yalm_tokenizer.py +3 -3
- helm/tokenizers/yalm_tokenizer_data/test_yalm_tokenizer.py +1 -1
- crfm_helm-0.5.4.dist-info/METADATA +0 -350
- crfm_helm-0.5.4.dist-info/RECORD +0 -697
- helm/benchmark/metrics/bhasa_metrics_specs.py +0 -10
- helm/benchmark/static_build/assets/01-694cb9b7.png +0 -0
- helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
- helm/benchmark/static_build/assets/ai21-0eb91ec3.png +0 -0
- helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
- helm/benchmark/static_build/assets/aleph-alpha-7ce10034.png +0 -0
- helm/benchmark/static_build/assets/anthropic-70d8bc39.png +0 -0
- helm/benchmark/static_build/assets/bigscience-7f0400c0.png +0 -0
- helm/benchmark/static_build/assets/cohere-3550c6cb.png +0 -0
- helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
- helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
- helm/benchmark/static_build/assets/eleutherai-b9451114.png +0 -0
- helm/benchmark/static_build/assets/google-06d997ad.png +0 -0
- helm/benchmark/static_build/assets/index-05c76bb1.css +0 -1
- helm/benchmark/static_build/assets/index-3ee38b3d.js +0 -10
- helm/benchmark/static_build/assets/meta-5580e9f1.png +0 -0
- helm/benchmark/static_build/assets/microsoft-f5ee5016.png +0 -0
- helm/benchmark/static_build/assets/mistral-18e1be23.png +0 -0
- helm/benchmark/static_build/assets/nvidia-86fa75c1.png +0 -0
- helm/benchmark/static_build/assets/openai-3f8653e4.png +0 -0
- helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
- helm/benchmark/static_build/assets/tii-24de195c.png +0 -0
- helm/benchmark/static_build/assets/together-a665a35b.png +0 -0
- helm/benchmark/static_build/assets/tsinghua-keg-97d4b395.png +0 -0
- helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
- helm/benchmark/static_build/assets/yandex-38e09d70.png +0 -0
- helm/tokenizers/anthropic_tokenizer.py +0 -52
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.5.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.5.dist-info/licenses}/LICENSE +0 -0
- {crfm_helm-0.5.4.dist-info → crfm_helm-0.5.5.dist-info}/top_level.txt +0 -0
|
@@ -3,7 +3,7 @@ import os
|
|
|
3
3
|
from typing import List
|
|
4
4
|
|
|
5
5
|
from helm.common.general import ensure_directory_exists, ensure_file_downloaded
|
|
6
|
-
from .scenario import Scenario, Instance, Reference, ALL_SPLITS, CORRECT_TAG, Input, Output
|
|
6
|
+
from helm.benchmark.scenarios.scenario import Scenario, Instance, Reference, ALL_SPLITS, CORRECT_TAG, Input, Output
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class MedDialogScenario(Scenario):
|
|
@@ -90,11 +90,7 @@ class MedDialogScenario(Scenario):
|
|
|
90
90
|
"""
|
|
91
91
|
|
|
92
92
|
name = "med_dialog"
|
|
93
|
-
description =
|
|
94
|
-
"The MedDialog dataset (English) contains conversations between doctors and patients. "
|
|
95
|
-
"It has 0.26 million dialogues. The data is continuously growing and more dialogues will be added. "
|
|
96
|
-
"The raw dialogues are from healthcaremagic.com and icliniq.com."
|
|
97
|
-
)
|
|
93
|
+
description = "A collection of doctor-patient conversations with corresponding summaries."
|
|
98
94
|
tags = ["dialogue", "biomedical"]
|
|
99
95
|
|
|
100
96
|
def __init__(self, subset: str):
|
|
@@ -109,24 +105,26 @@ class MedDialogScenario(Scenario):
|
|
|
109
105
|
instances: List[Instance] = []
|
|
110
106
|
|
|
111
107
|
for split in ALL_SPLITS:
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
108
|
+
# Limit to zero shot setting
|
|
109
|
+
if split == "test":
|
|
110
|
+
split_file_name: str = f"{split}.json"
|
|
111
|
+
split_path: str = os.path.join(data_path, split_file_name)
|
|
112
|
+
ensure_file_downloaded(
|
|
113
|
+
source_url="https://worksheets.codalab.org/rest/bundles/0x82f0c47f6d3e4462ae9ef8ea39eebe64/"
|
|
114
|
+
f"contents/blob/{self.subset}/{split_file_name}",
|
|
115
|
+
target_path=split_path,
|
|
116
|
+
unpack=False,
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
with open(split_path, "r") as f:
|
|
120
|
+
examples: List = json.load(f)["data"]
|
|
121
|
+
for example in examples:
|
|
122
|
+
instances.append(
|
|
123
|
+
Instance(
|
|
124
|
+
input=Input(text=example["src"]),
|
|
125
|
+
references=[Reference(Output(text=example["tgt"]), tags=[CORRECT_TAG])],
|
|
126
|
+
split=split,
|
|
127
|
+
)
|
|
129
128
|
)
|
|
130
|
-
)
|
|
131
129
|
|
|
132
130
|
return instances
|
|
@@ -3,7 +3,16 @@ import os
|
|
|
3
3
|
from typing import Dict, List
|
|
4
4
|
|
|
5
5
|
from helm.common.general import ensure_file_downloaded
|
|
6
|
-
from .scenario import
|
|
6
|
+
from helm.benchmark.scenarios.scenario import (
|
|
7
|
+
Scenario,
|
|
8
|
+
Instance,
|
|
9
|
+
Reference,
|
|
10
|
+
CORRECT_TAG,
|
|
11
|
+
TRAIN_SPLIT,
|
|
12
|
+
VALID_SPLIT,
|
|
13
|
+
Input,
|
|
14
|
+
Output,
|
|
15
|
+
)
|
|
7
16
|
|
|
8
17
|
|
|
9
18
|
class MedMCQAScenario(Scenario):
|
|
@@ -2,7 +2,16 @@ import os
|
|
|
2
2
|
from typing import List
|
|
3
3
|
|
|
4
4
|
from helm.common.general import ensure_directory_exists, ensure_file_downloaded
|
|
5
|
-
from .scenario import
|
|
5
|
+
from helm.benchmark.scenarios.scenario import (
|
|
6
|
+
Scenario,
|
|
7
|
+
Instance,
|
|
8
|
+
Reference,
|
|
9
|
+
ALL_SPLITS,
|
|
10
|
+
CORRECT_TAG,
|
|
11
|
+
VALID_SPLIT,
|
|
12
|
+
Input,
|
|
13
|
+
Output,
|
|
14
|
+
)
|
|
6
15
|
|
|
7
16
|
|
|
8
17
|
class MedParagraphSimplificationScenario(Scenario):
|
|
@@ -3,7 +3,16 @@ import os
|
|
|
3
3
|
from typing import Dict, List
|
|
4
4
|
|
|
5
5
|
from helm.common.general import ensure_file_downloaded
|
|
6
|
-
from .scenario import
|
|
6
|
+
from helm.benchmark.scenarios.scenario import (
|
|
7
|
+
Scenario,
|
|
8
|
+
Instance,
|
|
9
|
+
Reference,
|
|
10
|
+
ALL_SPLITS,
|
|
11
|
+
CORRECT_TAG,
|
|
12
|
+
VALID_SPLIT,
|
|
13
|
+
Input,
|
|
14
|
+
Output,
|
|
15
|
+
)
|
|
7
16
|
|
|
8
17
|
|
|
9
18
|
class MedQAScenario(Scenario):
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from helm.benchmark.scenarios.scenario import (
|
|
4
|
+
Scenario,
|
|
5
|
+
Instance,
|
|
6
|
+
Reference,
|
|
7
|
+
TEST_SPLIT,
|
|
8
|
+
CORRECT_TAG,
|
|
9
|
+
PassageQuestionInput,
|
|
10
|
+
Output,
|
|
11
|
+
)
|
|
12
|
+
from helm.benchmark.scenarios.medalign_scenario_helper import return_dataset_dataframe # type: ignore
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class MedalignScenario(Scenario):
|
|
16
|
+
"""
|
|
17
|
+
Scenario defining the MedAlign task as defined in the following work by Fleming et al:
|
|
18
|
+
@article{fleming2023medalign,
|
|
19
|
+
title={MedAlign: A Clinician-Generated Dataset for Instruction Following with Electronic Medical Records},
|
|
20
|
+
author={Scott L. Fleming
|
|
21
|
+
and Alejandro Lozano
|
|
22
|
+
and William J. Haberkorn
|
|
23
|
+
and Jenelle A. Jindal
|
|
24
|
+
and Eduardo P. Reis
|
|
25
|
+
and Rahul Thapa
|
|
26
|
+
and Louis Blankemeier
|
|
27
|
+
and Julian Z. Genkins
|
|
28
|
+
and Ethan Steinberg
|
|
29
|
+
and Ashwin Nayak
|
|
30
|
+
and Birju S. Patel
|
|
31
|
+
and Chia-Chun Chiang
|
|
32
|
+
and Alison Callahan
|
|
33
|
+
and Zepeng Huo
|
|
34
|
+
and Sergios Gatidis
|
|
35
|
+
and Scott J. Adams
|
|
36
|
+
and Oluseyi Fayanju
|
|
37
|
+
and Shreya J. Shah
|
|
38
|
+
and Thomas Savage
|
|
39
|
+
and Ethan Goh
|
|
40
|
+
and Akshay S. Chaudhari
|
|
41
|
+
and Nima Aghaeepour
|
|
42
|
+
and Christopher Sharp
|
|
43
|
+
and Michael A. Pfeffer
|
|
44
|
+
and Percy Liang
|
|
45
|
+
and Jonathan H. Chen
|
|
46
|
+
and Keith E. Morse
|
|
47
|
+
and Emma P. Brunskill
|
|
48
|
+
and Jason A. Fries
|
|
49
|
+
and Nigam H. Shah},
|
|
50
|
+
journal={arXiv preprint arXiv:2308.14089},
|
|
51
|
+
year={2023}
|
|
52
|
+
}
|
|
53
|
+
Each instance includes:
|
|
54
|
+
- input: the instruction and patient record
|
|
55
|
+
- reference: the clinical 'gold standard' completion for the instruction for the given patient record
|
|
56
|
+
This is a clinical instruction-following task, wherein a generative language model must follow
|
|
57
|
+
the instructions using the provided patient record. As explained in the MedAlign work, each example
|
|
58
|
+
is guaranteed to be completable for the given patient record.
|
|
59
|
+
This task is evaluated using COMET and BERTScore metrics.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
name = "medalign"
|
|
63
|
+
description = "A dataset that asks models to answer questions/follow instructions over longitudinal EHR."
|
|
64
|
+
tags = ["knowledge", "reasoning", "biomedical"]
|
|
65
|
+
|
|
66
|
+
def __init__(self, max_length: int):
|
|
67
|
+
super().__init__()
|
|
68
|
+
self.max_length = max_length
|
|
69
|
+
|
|
70
|
+
def process_tsv(self, data) -> List[Instance]:
|
|
71
|
+
instances: List[Instance] = []
|
|
72
|
+
for index, row in data.iterrows():
|
|
73
|
+
question = row["prompt"]
|
|
74
|
+
ground_truth_answer = row["clinician_response"]
|
|
75
|
+
|
|
76
|
+
prompt = PassageQuestionInput(passage="", question=question)
|
|
77
|
+
|
|
78
|
+
instance = Instance(
|
|
79
|
+
input=prompt,
|
|
80
|
+
references=[Reference(Output(text=ground_truth_answer), tags=[CORRECT_TAG])],
|
|
81
|
+
split=TEST_SPLIT,
|
|
82
|
+
)
|
|
83
|
+
instances.append(instance)
|
|
84
|
+
return instances
|
|
85
|
+
|
|
86
|
+
def get_instances(self, output_path: str) -> List[Instance]:
|
|
87
|
+
dataset = return_dataset_dataframe(self.max_length)
|
|
88
|
+
return self.process_tsv(dataset)
|
|
@@ -0,0 +1,429 @@
|
|
|
1
|
+
# flake8: noqa
|
|
2
|
+
# type: ignore
|
|
3
|
+
# fmt: off
|
|
4
|
+
|
|
5
|
+
import ast
|
|
6
|
+
import datetime
|
|
7
|
+
import transformers
|
|
8
|
+
import langchain
|
|
9
|
+
import langchain.prompts
|
|
10
|
+
import lxml.etree
|
|
11
|
+
import os
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import re
|
|
14
|
+
import tiktoken
|
|
15
|
+
|
|
16
|
+
from langchain_community.retrievers import BM25Retriever
|
|
17
|
+
from tqdm import tqdm
|
|
18
|
+
from typing import Any, Dict, Optional, Union, Callable
|
|
19
|
+
from langchain.schema import Document
|
|
20
|
+
import langchain_community
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def get_instructions(path_to_instructions: str) -> Dict[int, Dict[str, Any]]:
|
|
25
|
+
"""
|
|
26
|
+
Builds map from Instruction ID to instruction details
|
|
27
|
+
|
|
28
|
+
The needed information for creating the map is accomplished by reading
|
|
29
|
+
a CSV file from the user-specified path.
|
|
30
|
+
|
|
31
|
+
The CSV file is expected to contain at least the following columns:
|
|
32
|
+
- instruction_id: The ID of the instruction.
|
|
33
|
+
- question: The text of the instruction.
|
|
34
|
+
- person_id: The ID of the associated patient.
|
|
35
|
+
- is_selected_ehr: A flag indicating whether the instruction is selected.
|
|
36
|
+
|
|
37
|
+
See https://stanfordmedicine.box.com/s/0om9qav2sklb9vaitn0ibye65vgbfx0e
|
|
38
|
+
|
|
39
|
+
Parameters:
|
|
40
|
+
path_to_instructions (str): Path to CSV file containing instructions.
|
|
41
|
+
|
|
42
|
+
Returns:
|
|
43
|
+
Dict[int, Dict[str, Any]]: A dictionary mapping instruction IDs to a
|
|
44
|
+
dictionary containing instruction text and associated patient ID.
|
|
45
|
+
|
|
46
|
+
Raises:
|
|
47
|
+
FileNotFoundError: If the specified file does not exist.
|
|
48
|
+
ValueError: If the CSV file does not contain the expected columns.
|
|
49
|
+
"""
|
|
50
|
+
if not os.path.exists(path_to_instructions):
|
|
51
|
+
raise FileNotFoundError(
|
|
52
|
+
f"The specified file {path_to_instructions} does not exist."
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
instructions_df = pd.read_csv(path_to_instructions, sep='\t')
|
|
56
|
+
required_columns = {
|
|
57
|
+
"instruction_id",
|
|
58
|
+
"question",
|
|
59
|
+
"person_id",
|
|
60
|
+
}
|
|
61
|
+
if not required_columns.issubset(instructions_df.columns):
|
|
62
|
+
raise ValueError(
|
|
63
|
+
f"The CSV file is missing one or more of the required columns: {required_columns}"
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
selected_instructions_df = instructions_df #.query("is_selected_ehr == 'yes'")
|
|
67
|
+
instructions_map = {
|
|
68
|
+
row["instruction_id"]: {
|
|
69
|
+
"instruction": row["question"],
|
|
70
|
+
"patient_id": row["person_id"],
|
|
71
|
+
}
|
|
72
|
+
for _, row in selected_instructions_df.iterrows()
|
|
73
|
+
}
|
|
74
|
+
return instructions_map
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def extract_patient_id_from_fname(fname: str) -> Optional[int]:
|
|
78
|
+
"""
|
|
79
|
+
Extracts and returns the patient ID from a given filename.
|
|
80
|
+
|
|
81
|
+
The function expects filenames in the format 'EHR_<patient_id>.xml',
|
|
82
|
+
where <patient_id> is a sequence of digits.
|
|
83
|
+
|
|
84
|
+
Parameters:
|
|
85
|
+
fname (str): The filename from which to extract the patient ID.
|
|
86
|
+
|
|
87
|
+
Returns:
|
|
88
|
+
Optional[int]: The extracted patient ID as an integer, or None if
|
|
89
|
+
the filename doesn't match the expected format.
|
|
90
|
+
"""
|
|
91
|
+
name=fname.split('.')[0]
|
|
92
|
+
return int(name)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_ehrs(path_to_ehrs: str) -> Dict[int, str]:
|
|
96
|
+
"""
|
|
97
|
+
Builds a map from Instruction ID to EHR (Electronic Health Record) timeline.
|
|
98
|
+
|
|
99
|
+
EHR timelines are in string format and EHR files are read in from the
|
|
100
|
+
user-specified directory. Each file in the directory should be named
|
|
101
|
+
'EHR_<patient_id>.xml', where <patient_id> is a sequence of digits.
|
|
102
|
+
|
|
103
|
+
See https://stanfordmedicine.box.com/s/r28wfwwude9rpjtu0szhzegmku8qv2pe
|
|
104
|
+
|
|
105
|
+
Parameters:
|
|
106
|
+
path_to_ehrs (str): The path to the directory containing the EHR files.
|
|
107
|
+
|
|
108
|
+
Returns:
|
|
109
|
+
Dict[int, str]: A dictionary mapping patient IDs to EHR timelines.
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
FileNotFoundError: If the specified directory does not exist.
|
|
113
|
+
"""
|
|
114
|
+
if not os.path.isdir(path_to_ehrs):
|
|
115
|
+
raise FileNotFoundError(
|
|
116
|
+
f"The specified directory {path_to_ehrs} does not exist."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
ehr_map = {}
|
|
120
|
+
for fname in os.listdir(path_to_ehrs):
|
|
121
|
+
pt_id = extract_patient_id_from_fname(fname)
|
|
122
|
+
if pt_id is None:
|
|
123
|
+
print(
|
|
124
|
+
f"Warning: File '{fname}' does not match the expected format "
|
|
125
|
+
"and will be skipped."
|
|
126
|
+
)
|
|
127
|
+
continue
|
|
128
|
+
|
|
129
|
+
file_path = os.path.join(path_to_ehrs, fname)
|
|
130
|
+
with open(file_path, encoding="utf-8", mode="r") as f:
|
|
131
|
+
ehr = f.read()
|
|
132
|
+
|
|
133
|
+
ehr_map[pt_id] = ehr
|
|
134
|
+
return ehr_map
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def get_tokenizer(tokenizer_name: str) -> Callable:
|
|
138
|
+
"""
|
|
139
|
+
Returns a tokenizer based on the given tokenizer name.
|
|
140
|
+
|
|
141
|
+
Parameters:
|
|
142
|
+
tokenizer_name (str): The name of the tokenizer. Acceptable values are:
|
|
143
|
+
- "tiktoken"
|
|
144
|
+
- "chatgpt"
|
|
145
|
+
- "gpt-3.5-turbo"
|
|
146
|
+
- "gpt-4"
|
|
147
|
+
- "gpt-4-turbo"
|
|
148
|
+
- "gpt-4o"
|
|
149
|
+
- "cl100k_base"
|
|
150
|
+
- Any valid tokenizer name recognized by the transformers library.
|
|
151
|
+
|
|
152
|
+
Returns:
|
|
153
|
+
Callable: The tokenizer instance.
|
|
154
|
+
"""
|
|
155
|
+
if tokenizer_name.lower() in [
|
|
156
|
+
"tiktoken",
|
|
157
|
+
"chatgpt",
|
|
158
|
+
"gpt-3.5-turbo",
|
|
159
|
+
"gpt-4",
|
|
160
|
+
"gpt-4-turbo",
|
|
161
|
+
"gpt-4o",
|
|
162
|
+
"cl100k_base",
|
|
163
|
+
]:
|
|
164
|
+
return tiktoken.get_encoding("cl100k_base")
|
|
165
|
+
print(tokenizer_name)
|
|
166
|
+
return transformers.AutoTokenizer.from_pretrained(tokenizer_name, legacy=False)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def retrieve_most_relevant_visits(ehr_visit_strs, query, target_length, tokenizer):
|
|
170
|
+
"""
|
|
171
|
+
Retrieve and filter relevant EHR visits based on a query and target length.
|
|
172
|
+
|
|
173
|
+
This function retrieves electronic health record (EHR) visit strings, sorts them
|
|
174
|
+
by relevance using the BM25Retriever, and constructs a list of final documents
|
|
175
|
+
that fit within a specified character length. The final list ensures that the
|
|
176
|
+
most important visit isn't cut off and is sorted chronologically.
|
|
177
|
+
|
|
178
|
+
Parameters:
|
|
179
|
+
ehr_visit_strs (list of str): List of EHR visit strings.
|
|
180
|
+
query (str): Query string to retrieve relevant visits.
|
|
181
|
+
target_length (int): Maximum total token count for the final list of documents.
|
|
182
|
+
tokenizer (Callable): Tokenizer that converts text to tokens (used for tracking context length)
|
|
183
|
+
|
|
184
|
+
Returns:
|
|
185
|
+
list[str]: List of EHR visit strings sorted chronologically and constrained by the target length.
|
|
186
|
+
"""
|
|
187
|
+
ehr_visits=re.split(r'(?=</encounter>\n)',ehr_visit_strs)
|
|
188
|
+
langchain_docs = [
|
|
189
|
+
langchain.schema.Document(page_content=doc) for doc in ehr_visits #broken since ehr_visit_strs is one string of all visits
|
|
190
|
+
]
|
|
191
|
+
# `k` is the number of documents to retrieve
|
|
192
|
+
# We retrieve everything and just use the BM25Retriever to sort the documents
|
|
193
|
+
retriever = langchain_community.retrievers.BM25Retriever.from_documents(
|
|
194
|
+
langchain_docs, k=len(langchain_docs)
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
# Invoking the retriever means the most relevant documents are sorted first
|
|
198
|
+
sorted_docs = retriever.invoke(query)
|
|
199
|
+
|
|
200
|
+
# Define the regex pattern to find the start time
|
|
201
|
+
# pattern = r'start="([\d/]+ [\d:]+)"'
|
|
202
|
+
pattern = r'start="([\d/]+ [\d:]+ ?[APM]{0,2})"'
|
|
203
|
+
|
|
204
|
+
docs = []
|
|
205
|
+
dts = []
|
|
206
|
+
|
|
207
|
+
# Find the startime of the document
|
|
208
|
+
for doc in sorted_docs:
|
|
209
|
+
doc_content = doc.page_content
|
|
210
|
+
start_dt_match = re.search(pattern, doc_content)
|
|
211
|
+
if start_dt_match:
|
|
212
|
+
start_dt = start_dt_match.group(1)
|
|
213
|
+
parsed = False
|
|
214
|
+
# Try different date formats
|
|
215
|
+
for fmt in (
|
|
216
|
+
"%m/%d/%y %I:%M %p",
|
|
217
|
+
"%m/%d/%Y %I:%M %p",
|
|
218
|
+
"%m/%d/%y %H:%M",
|
|
219
|
+
"%m/%d/%Y %H:%M",
|
|
220
|
+
):
|
|
221
|
+
try:
|
|
222
|
+
dts.append(datetime.datetime.strptime(start_dt, fmt))
|
|
223
|
+
parsed = True
|
|
224
|
+
break
|
|
225
|
+
except ValueError:
|
|
226
|
+
continue
|
|
227
|
+
if not parsed:
|
|
228
|
+
print(f"Error parsing date: {start_dt}")
|
|
229
|
+
continue
|
|
230
|
+
else:
|
|
231
|
+
print(f"Start time not found., {doc_content}")
|
|
232
|
+
dts.append(datetime.datetime.min)
|
|
233
|
+
docs.append(doc_content)
|
|
234
|
+
|
|
235
|
+
final_docs = []
|
|
236
|
+
current_length = 0
|
|
237
|
+
|
|
238
|
+
# Add documents until we exceed the allocated context length
|
|
239
|
+
for i in range(len(docs)):
|
|
240
|
+
doc_content = docs[i]
|
|
241
|
+
doc_length = len(tokenizer.encode(doc_content))
|
|
242
|
+
final_docs.append((dts[i], doc_content))
|
|
243
|
+
current_length += doc_length
|
|
244
|
+
if current_length > target_length:
|
|
245
|
+
break
|
|
246
|
+
|
|
247
|
+
# Sort final_docs chronologically
|
|
248
|
+
final_docs.sort(key=lambda x: x[0])
|
|
249
|
+
|
|
250
|
+
# Extract only the document content for the final output
|
|
251
|
+
final_docs_content = [doc_content for _, doc_content in final_docs]
|
|
252
|
+
|
|
253
|
+
return final_docs_content
|
|
254
|
+
|
|
255
|
+
|
|
256
|
+
|
|
257
|
+
def pack_and_trim_prompts(
|
|
258
|
+
instructions: Dict[int, Dict[str, str]],
|
|
259
|
+
ehrs: Dict[int, str],
|
|
260
|
+
prompt_template: langchain.prompts.PromptTemplate,
|
|
261
|
+
context_length: int,
|
|
262
|
+
generation_length: int,
|
|
263
|
+
tokenizer: Any,
|
|
264
|
+
use_RAG: bool = True,
|
|
265
|
+
verbose: bool = False,
|
|
266
|
+
include_ehr: bool = True,
|
|
267
|
+
) -> Dict[int, str]:
|
|
268
|
+
"""
|
|
269
|
+
Returns:
|
|
270
|
+
A map from Instruction ID to prompt
|
|
271
|
+
"""
|
|
272
|
+
prompts_map = {}
|
|
273
|
+
for instruction_id in tqdm(instructions.keys()):
|
|
274
|
+
instruction = instructions[instruction_id]["instruction"]
|
|
275
|
+
patient_id = int(instructions[instruction_id]["patient_id"])
|
|
276
|
+
relevant_ehr = ehrs[patient_id]
|
|
277
|
+
|
|
278
|
+
# Calculate how many tokens of EHR we can include in the prompt
|
|
279
|
+
num_tokens_instruction = len(tokenizer.encode(instruction))
|
|
280
|
+
num_tokens_prompt_template = len(tokenizer.encode(prompt_template.template))
|
|
281
|
+
if include_ehr:
|
|
282
|
+
target_ehr_length = context_length - generation_length - num_tokens_prompt_template - num_tokens_instruction
|
|
283
|
+
else:
|
|
284
|
+
target_ehr_length = 0
|
|
285
|
+
if target_ehr_length <= 0:
|
|
286
|
+
prompt_with_truncated_ehr = prompt_template.format(question=instruction, ehr="")
|
|
287
|
+
else:
|
|
288
|
+
if use_RAG:
|
|
289
|
+
# Return a list of the most relevant visit strings
|
|
290
|
+
most_relevant_visits = retrieve_most_relevant_visits(
|
|
291
|
+
ehr_visit_strs=relevant_ehr,
|
|
292
|
+
query=instruction,
|
|
293
|
+
target_length=target_ehr_length,
|
|
294
|
+
tokenizer=tokenizer,
|
|
295
|
+
)
|
|
296
|
+
relevant_ehr = "\n".join(most_relevant_visits)
|
|
297
|
+
|
|
298
|
+
# Do a first pass with a fast tokenizer
|
|
299
|
+
fast_tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
300
|
+
fast_encoded = fast_tokenizer.encode(relevant_ehr)
|
|
301
|
+
if len(fast_encoded) <= target_ehr_length:
|
|
302
|
+
fast_encoded_truncated = fast_encoded[-(2 * target_ehr_length) :]
|
|
303
|
+
fast_truncated_ehr = fast_tokenizer.decode(fast_encoded_truncated)
|
|
304
|
+
|
|
305
|
+
# Then do a second pass with the actual tokenizer
|
|
306
|
+
encoded_ehr = tokenizer.encode(fast_truncated_ehr)
|
|
307
|
+
truncated_encoded_ehr = encoded_ehr[-target_ehr_length:]
|
|
308
|
+
truncated_ehr = tokenizer.decode(truncated_encoded_ehr)
|
|
309
|
+
prompt_with_truncated_ehr = prompt_template.format(question=instruction, ehr=truncated_ehr)
|
|
310
|
+
|
|
311
|
+
prompts_map[instruction_id] = prompt_with_truncated_ehr
|
|
312
|
+
|
|
313
|
+
if verbose:
|
|
314
|
+
print(prompt_with_truncated_ehr)
|
|
315
|
+
print("~" * 20)
|
|
316
|
+
return prompts_map
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def preprocess_prompts(
|
|
320
|
+
target_context_length,
|
|
321
|
+
generation_length,
|
|
322
|
+
path_to_instructions,
|
|
323
|
+
path_to_ehrs,
|
|
324
|
+
use_RAG,
|
|
325
|
+
include_ehr,
|
|
326
|
+
tokenizer,
|
|
327
|
+
codes_only=False,
|
|
328
|
+
notes_only=False,
|
|
329
|
+
):
|
|
330
|
+
print(
|
|
331
|
+
f"\n\twith target context length = {target_context_length} "
|
|
332
|
+
f"\n\twith target generation length = {generation_length} "
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
# FETCH INSTRUCTIONS
|
|
336
|
+
print("Fetching instructions...")
|
|
337
|
+
instructions = get_instructions(path_to_instructions)
|
|
338
|
+
|
|
339
|
+
# FETCH RELEVANT EHRs #
|
|
340
|
+
print("Fetching patient EHR timelines...")
|
|
341
|
+
ehrs = get_ehrs(path_to_ehrs)
|
|
342
|
+
|
|
343
|
+
# LOAD TOKENIZER #
|
|
344
|
+
print("Loading tokenizer...")
|
|
345
|
+
tokenizer = get_tokenizer(tokenizer)
|
|
346
|
+
|
|
347
|
+
# CONSTRUCT & TRUNCATE PROMPTS #
|
|
348
|
+
print("Constructing prompts using instructions and EHRs...")
|
|
349
|
+
prompt_string="Instruction: Answer the following question based on the EHR:\n\nEHR: {ehr}\n\nQuestion: {question}\n\nAnswer:"
|
|
350
|
+
prompt_template = langchain.prompts.PromptTemplate.from_template(prompt_string)
|
|
351
|
+
filled_prompts = pack_and_trim_prompts(
|
|
352
|
+
instructions=instructions,
|
|
353
|
+
ehrs=ehrs,
|
|
354
|
+
prompt_template=prompt_template,
|
|
355
|
+
context_length=target_context_length,
|
|
356
|
+
generation_length=generation_length,
|
|
357
|
+
tokenizer=tokenizer,
|
|
358
|
+
use_RAG=use_RAG,
|
|
359
|
+
verbose=False,
|
|
360
|
+
include_ehr=include_ehr,
|
|
361
|
+
)
|
|
362
|
+
assert filled_prompts, f"No prompts were found for length: {target_context_length}. Try again with a larger length."
|
|
363
|
+
# SAVE CONSTRUCTED PROMPTS TO DISK
|
|
364
|
+
df_rows = []
|
|
365
|
+
for instruction_id in tqdm(filled_prompts.keys()):
|
|
366
|
+
row = {}
|
|
367
|
+
row["instruction_id"] = instruction_id
|
|
368
|
+
patient_id = instructions[instruction_id]["patient_id"]
|
|
369
|
+
row["patient_id"] = patient_id
|
|
370
|
+
row["instruction"] = instructions[instruction_id]["instruction"]
|
|
371
|
+
row["ehr"] = "".join(ehrs[patient_id])
|
|
372
|
+
row["prompt"] = filled_prompts[instruction_id]
|
|
373
|
+
row["context_length"] = target_context_length
|
|
374
|
+
row["generation_length"] = generation_length
|
|
375
|
+
df_rows.append(row)
|
|
376
|
+
|
|
377
|
+
prompts_df = pd.DataFrame(df_rows)
|
|
378
|
+
instructionid_to_prompt_map = (
|
|
379
|
+
prompts_df[["instruction_id", "prompt"]].set_index("instruction_id").to_dict().get("prompt")
|
|
380
|
+
)
|
|
381
|
+
instructionid_to_prompt_df = (
|
|
382
|
+
pd.DataFrame.from_dict(instructionid_to_prompt_map, orient="index", columns=["prompt"])
|
|
383
|
+
.reset_index()
|
|
384
|
+
.rename(columns={"index": "instruction_id"})
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
print("...Prompt construction complete")
|
|
388
|
+
return instructionid_to_prompt_df
|
|
389
|
+
|
|
390
|
+
|
|
391
|
+
def add_reference_responses(prompts_df, path_to_reference_responses) -> pd.DataFrame:
|
|
392
|
+
"""
|
|
393
|
+
Processes a single file for evaluation.
|
|
394
|
+
|
|
395
|
+
Parameters:
|
|
396
|
+
file_path (str): Path to the file to be processed.
|
|
397
|
+
args (argparse.Namespace): Command line arguments passed to the script.
|
|
398
|
+
|
|
399
|
+
Returns:
|
|
400
|
+
pd.DataFrame: DataFrame containing the processed data.
|
|
401
|
+
"""
|
|
402
|
+
gold_df = pd.read_csv(path_to_reference_responses)
|
|
403
|
+
gold_df = gold_df.query("annotator_num == 'Annotator_1'")
|
|
404
|
+
gold_df = gold_df[["instruction_id", "clinician_response"]]
|
|
405
|
+
merged_df = gold_df.merge(prompts_df, on="instruction_id", how="inner")
|
|
406
|
+
return merged_df
|
|
407
|
+
|
|
408
|
+
|
|
409
|
+
def return_dataset_dataframe(max_length: int) -> pd.DataFrame:
|
|
410
|
+
target_context_length = max_length
|
|
411
|
+
generation_length = 256
|
|
412
|
+
path_to_instructions = "/share/pi/nigam/datasets/medalign_release_fixes/clinician-reviewed-model-responses.tsv"
|
|
413
|
+
path_to_ehrs = "/share/pi/nigam/datasets/medalign_release_fixes/medalign_ehr_xml"
|
|
414
|
+
path_to_reference_responses = "/share/pi/nigam/scottyf/clinician-instruction-responses.csv"
|
|
415
|
+
use_RAG = False
|
|
416
|
+
include_ehr = True
|
|
417
|
+
tokenizer = "tiktoken"
|
|
418
|
+
|
|
419
|
+
instructionid_to_prompt_df = preprocess_prompts(
|
|
420
|
+
target_context_length=target_context_length,
|
|
421
|
+
generation_length=generation_length,
|
|
422
|
+
path_to_instructions=path_to_instructions,
|
|
423
|
+
path_to_ehrs=path_to_ehrs,
|
|
424
|
+
use_RAG=use_RAG,
|
|
425
|
+
include_ehr=include_ehr,
|
|
426
|
+
tokenizer=tokenizer,
|
|
427
|
+
)
|
|
428
|
+
medalign_dataframe = add_reference_responses(instructionid_to_prompt_df, path_to_reference_responses)
|
|
429
|
+
return medalign_dataframe
|