crfm-helm 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/METADATA +74 -53
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/RECORD +262 -182
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +3 -3
- helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
- helm/benchmark/annotation/air_bench_annotator.py +2 -2
- helm/benchmark/annotation/bigcodebench_annotator.py +3 -3
- helm/benchmark/annotation/bird_sql_annotator.py +2 -2
- helm/benchmark/annotation/chw_care_plan_annotator.py +7 -12
- helm/benchmark/annotation/ehr_sql_annotator.py +2 -2
- helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +7 -7
- helm/benchmark/annotation/live_qa_annotator.py +1 -1
- helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
- helm/benchmark/annotation/model_as_judge.py +12 -16
- helm/benchmark/annotation/omni_math_annotator.py +13 -14
- helm/benchmark/annotation/wildbench_annotator.py +9 -9
- helm/benchmark/executor.py +11 -12
- helm/benchmark/metrics/aci_bench_metrics.py +9 -29
- helm/benchmark/metrics/bias_word_lists.py +1 -1
- helm/benchmark/metrics/chw_care_plan_metrics.py +10 -30
- helm/benchmark/metrics/classification_metrics.py +3 -3
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
- helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
- helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
- helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
- helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
- helm/benchmark/metrics/comet_metric.py +1 -1
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +2 -2
- helm/benchmark/metrics/copyright_metrics.py +1 -1
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
- helm/benchmark/metrics/dischargeme_metrics.py +9 -29
- helm/benchmark/metrics/efficiency_metrics.py +3 -3
- helm/benchmark/metrics/evaluate_reference_metrics.py +1 -1
- helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
- helm/benchmark/metrics/ifeval_metrics.py +2 -2
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
- helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
- helm/benchmark/metrics/llm_jury_metrics.py +46 -0
- helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
- helm/benchmark/metrics/lmkt_metrics.py +47 -0
- helm/benchmark/metrics/med_dialog_metrics.py +9 -29
- helm/benchmark/metrics/medalign_metrics.py +9 -29
- helm/benchmark/metrics/medi_qa_metrics.py +9 -29
- helm/benchmark/metrics/medication_qa_metrics.py +10 -30
- helm/benchmark/metrics/melt_bias_metric.py +234 -0
- helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
- helm/benchmark/metrics/melt_metric_specs.py +43 -0
- helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
- helm/benchmark/metrics/mental_health_metrics.py +9 -29
- helm/benchmark/metrics/metric_service.py +11 -11
- helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
- helm/benchmark/metrics/mimic_rrs_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +9 -29
- helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
- helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +9 -29
- helm/benchmark/metrics/summac/model_summac.py +2 -3
- helm/benchmark/metrics/summarization_metrics.py +2 -1
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/toxicity_metrics.py +2 -2
- helm/benchmark/metrics/unitxt_metrics.py +3 -4
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
- helm/benchmark/metrics/vision_language/image_utils.py +2 -2
- helm/benchmark/model_deployment_registry.py +16 -26
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +43 -13
- helm/benchmark/presentation/run_display.py +13 -0
- helm/benchmark/presentation/schema.py +7 -1
- helm/benchmark/presentation/summarize.py +84 -61
- helm/benchmark/presentation/test_create_plots.py +4 -1
- helm/benchmark/reeval_run.py +3 -4
- helm/benchmark/reeval_runner.py +3 -3
- helm/benchmark/run.py +84 -73
- helm/benchmark/run_expander.py +12 -1
- helm/benchmark/run_spec_factory.py +7 -6
- helm/benchmark/run_specs/arabic_run_specs.py +73 -0
- helm/benchmark/run_specs/audio_run_specs.py +52 -8
- helm/benchmark/run_specs/bluex_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +0 -53
- helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
- helm/benchmark/run_specs/enterprise_run_specs.py +20 -0
- helm/benchmark/run_specs/experimental_run_specs.py +31 -1
- helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
- helm/benchmark/run_specs/heim_run_specs.py +3 -1
- helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
- helm/benchmark/run_specs/long_context_run_specs.py +114 -15
- helm/benchmark/run_specs/medhelm_run_specs.py +146 -41
- helm/benchmark/run_specs/melt_run_specs.py +783 -0
- helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +163 -0
- helm/benchmark/run_specs/vlm_run_specs.py +28 -0
- helm/benchmark/runner.py +5 -5
- helm/benchmark/scenarios/aci_bench_scenario.py +7 -1
- helm/benchmark/scenarios/alghafa_scenario.py +126 -0
- helm/benchmark/scenarios/arabic_mmlu_scenario.py +78 -0
- helm/benchmark/scenarios/aratrust_scenario.py +76 -0
- helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +3 -1
- helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +5 -5
- helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
- helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +104 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +99 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +118 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +86 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +117 -0
- helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +15 -1
- helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +1 -2
- helm/benchmark/scenarios/autobencher_capabilities_scenario.py +2 -2
- helm/benchmark/scenarios/bluex_scenario.py +66 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +14 -13
- helm/benchmark/scenarios/clear_scenario.py +11 -7
- helm/benchmark/scenarios/cleva_scenario.py +1 -1
- helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
- helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
- helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
- helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
- helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
- helm/benchmark/scenarios/dischargeme_scenario.py +36 -21
- helm/benchmark/scenarios/ehr_sql_scenario.py +7 -1
- helm/benchmark/scenarios/ehrshot_scenario.py +28 -55
- helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
- helm/benchmark/scenarios/grammar.py +2 -2
- helm/benchmark/scenarios/headqa_scenario.py +6 -1
- helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
- helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +90 -0
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
- helm/benchmark/scenarios/{infinite_bench_sum_scenario.py → infinite_bench_en_sum_scenario.py} +10 -13
- helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
- helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
- helm/benchmark/scenarios/math_scenario.py +21 -20
- helm/benchmark/scenarios/med_dialog_scenario.py +6 -1
- helm/benchmark/scenarios/medalign_scenario.py +9 -3
- helm/benchmark/scenarios/medalign_scenario_helper.py +27 -130
- helm/benchmark/scenarios/medbullets_scenario.py +7 -2
- helm/benchmark/scenarios/medcalc_bench_scenario.py +4 -2
- helm/benchmark/scenarios/medec_scenario.py +6 -1
- helm/benchmark/scenarios/medhallu_scenario.py +7 -1
- helm/benchmark/scenarios/medi_qa_scenario.py +10 -4
- helm/benchmark/scenarios/medication_qa_scenario.py +7 -1
- helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
- helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
- helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
- helm/benchmark/scenarios/melt_scenarios.py +793 -0
- helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
- helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
- helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
- helm/benchmark/scenarios/mental_health_scenario.py +16 -5
- helm/benchmark/scenarios/mimic_bhc_scenario.py +13 -8
- helm/benchmark/scenarios/mimic_rrs_scenario.py +17 -8
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +14 -8
- helm/benchmark/scenarios/mmlu_pro_scenario.py +1 -1
- helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +5 -2
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +3 -2
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +11 -5
- helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +6 -1
- helm/benchmark/scenarios/race_based_med_scenario.py +18 -8
- helm/benchmark/scenarios/ruler_qa_scenario_helper.py +2 -2
- helm/benchmark/scenarios/ruler_qa_scenarios.py +2 -2
- helm/benchmark/scenarios/seahelm_scenario.py +2 -2
- helm/benchmark/scenarios/shc_bmt_scenario.py +12 -6
- helm/benchmark/scenarios/shc_cdi_scenario.py +11 -6
- helm/benchmark/scenarios/shc_conf_scenario.py +12 -6
- helm/benchmark/scenarios/shc_ent_scenario.py +11 -6
- helm/benchmark/scenarios/shc_gip_scenario.py +13 -5
- helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sei_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sequoia_scenario.py +13 -5
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +15 -8
- helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
- helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
- helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
- helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
- helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
- helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
- helm/benchmark/scenarios/truthful_qa_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
- helm/benchmark/server.py +2 -1
- helm/benchmark/slurm_jobs.py +1 -2
- helm/benchmark/slurm_runner.py +8 -1
- helm/benchmark/static/schema_arabic.yaml +228 -0
- helm/benchmark/static/schema_audio.yaml +60 -49
- helm/benchmark/static/schema_classic.yaml +0 -17
- helm/benchmark/static/schema_enterprise.yaml +21 -0
- helm/benchmark/static/schema_long_context.yaml +81 -20
- helm/benchmark/static/schema_medhelm.yaml +272 -213
- helm/benchmark/static/schema_melt.yaml +1257 -0
- helm/benchmark/static/schema_slphelm.yaml +162 -0
- helm/benchmark/static/schema_vhelm.yaml +26 -26
- helm/benchmark/static/schema_video.yaml +219 -0
- helm/benchmark/static_build/assets/index-b9779128.css +1 -0
- helm/benchmark/static_build/assets/index-e439d5e1.js +10 -0
- helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
- helm/benchmark/static_build/assets/{tremor-9cefc3c5.js → tremor-38a10867.js} +1 -1
- helm/benchmark/static_build/index.html +4 -4
- helm/benchmark/window_services/encoder_decoder_window_service.py +3 -3
- helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
- helm/benchmark/window_services/test_utils.py +3 -4
- helm/benchmark/window_services/tokenizer_service.py +7 -8
- helm/clients/anthropic_client.py +69 -29
- helm/clients/audio_language/diva_llama_client.py +4 -2
- helm/clients/audio_language/qwen2_5_omni_client.py +209 -0
- helm/clients/audio_language/qwen2_audiolm_client.py +8 -6
- helm/clients/audio_language/qwen_audiolm_client.py +4 -2
- helm/clients/audio_language/test.py +62 -0
- helm/clients/bedrock_client.py +3 -1
- helm/clients/client.py +7 -7
- helm/clients/grok_client.py +36 -0
- helm/clients/huggingface_client.py +42 -3
- helm/clients/huggingface_pipeline_client.py +138 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
- helm/clients/image_generation/dalle_mini/model/modeling.py +1 -1
- helm/clients/image_generation/dalle_mini/model/processor.py +1 -1
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
- helm/clients/openai_client.py +102 -55
- helm/clients/openai_responses_client.py +176 -0
- helm/clients/palmyra_client.py +2 -5
- helm/clients/reka_client.py +2 -2
- helm/clients/test_huggingface_client.py +3 -3
- helm/clients/together_client.py +31 -6
- helm/clients/vertexai_client.py +17 -9
- helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
- helm/clients/vision_language/huggingface_vlm_client.py +2 -2
- helm/clients/vision_language/idefics_client.py +6 -2
- helm/clients/vision_language/paligemma_client.py +2 -2
- helm/clients/vision_language/qwen2_vlm_client.py +66 -53
- helm/clients/vision_language/qwen_vlm_client.py +7 -5
- helm/clients/vllm_client.py +43 -7
- helm/clients/vllm_granite_thinking_client.py +56 -0
- helm/clients/writer_client.py +102 -0
- helm/common/context.py +80 -0
- helm/common/credentials_utils.py +5 -5
- helm/common/critique_request.py +0 -1
- helm/common/general.py +9 -2
- helm/common/hierarchical_logger.py +104 -12
- helm/common/local_context.py +140 -0
- helm/common/object_spec.py +23 -8
- helm/common/remote_context.py +61 -0
- helm/common/request.py +8 -0
- helm/common/test_logging.py +94 -0
- helm/config/model_deployments.yaml +995 -45
- helm/config/model_metadata.yaml +780 -59
- helm/config/tokenizer_configs.yaml +224 -3
- helm/proxy/cli.py +4 -2
- helm/proxy/critique/mechanical_turk_utils.py +1 -1
- helm/proxy/retry.py +5 -0
- helm/proxy/services/server_service.py +21 -85
- helm/tokenizers/grok_tokenizer.py +55 -0
- helm/tokenizers/huggingface_tokenizer.py +1 -1
- helm/tokenizers/test_grok_tokenizer.py +33 -0
- helm/benchmark/metrics/numeracy_metrics.py +0 -72
- helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
- helm/benchmark/scenarios/numeracy_scenario.py +0 -793
- helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +0 -46
- helm/benchmark/static_build/assets/index-262903c1.js +0 -10
- helm/benchmark/static_build/assets/index-42060d71.css +0 -1
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/top_level.txt +0 -0
- /helm/benchmark/static_build/assets/{medhelm-overview-3ddfcd65.png → medhelm-v1-overview-3ddfcd65.png} +0 -0
|
@@ -0,0 +1,162 @@
|
|
|
1
|
+
from helm.benchmark.scenarios.scenario import Scenario, Instance, Input, Output, Reference, VALID_SPLIT, CORRECT_TAG
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CodeInsightsStudentCodingScenario(Scenario):
|
|
7
|
+
name = "codeinsights_student_coding"
|
|
8
|
+
description = "Mimic student C++ style on foundational questions"
|
|
9
|
+
tags = ["codeinsights", "c++", "student_coding"]
|
|
10
|
+
|
|
11
|
+
def __init__(self, num_testcases: int = 1):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.num_testcases = num_testcases
|
|
14
|
+
|
|
15
|
+
def get_instances(self, output_path: str):
|
|
16
|
+
df = pd.read_csv("https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/Scenario1_2_data.csv")
|
|
17
|
+
student_topic = pd.read_csv(
|
|
18
|
+
"https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/student_performace_by_topic.csv"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
instances = []
|
|
22
|
+
for student_id, student_df in df.groupby("student_id"):
|
|
23
|
+
student_df = student_df.sort_values("timestamp")
|
|
24
|
+
if len(student_df) < 4:
|
|
25
|
+
continue
|
|
26
|
+
first = student_df.iloc[0]
|
|
27
|
+
second = student_df.iloc[1]
|
|
28
|
+
third = student_df.iloc[2]
|
|
29
|
+
target = student_df.iloc[3]
|
|
30
|
+
|
|
31
|
+
# Get test cases for this question
|
|
32
|
+
question_id = target.get("question_unittest_id", None)
|
|
33
|
+
question_test_cases = []
|
|
34
|
+
tc_parsing_success = True
|
|
35
|
+
|
|
36
|
+
for testcase_str in target["question_unittests"].split("Unittest")[1:]:
|
|
37
|
+
testcase_str = testcase_str[testcase_str.find(":") + 1 :]
|
|
38
|
+
input_idx = testcase_str.find("Input:")
|
|
39
|
+
std_in_idx = testcase_str.find("STD input:")
|
|
40
|
+
output_idx = testcase_str.find("Output:")
|
|
41
|
+
if input_idx == -1 or std_in_idx == -1 or output_idx == -1:
|
|
42
|
+
tc_parsing_success = False
|
|
43
|
+
break
|
|
44
|
+
|
|
45
|
+
testcase = {
|
|
46
|
+
"input": testcase_str[input_idx + 6 : std_in_idx].strip(),
|
|
47
|
+
"std_in": testcase_str[std_in_idx + 10 : output_idx].strip(),
|
|
48
|
+
"output": testcase_str[output_idx + 7 :].strip(),
|
|
49
|
+
}
|
|
50
|
+
question_test_cases.append(testcase)
|
|
51
|
+
|
|
52
|
+
if not tc_parsing_success:
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
if len(question_test_cases) < self.num_testcases:
|
|
56
|
+
# If not enough test cases, skip this question
|
|
57
|
+
continue
|
|
58
|
+
if self.num_testcases >= 0:
|
|
59
|
+
# If more than one test case is requested, only take the first ones
|
|
60
|
+
question_test_cases = question_test_cases[: self.num_testcases]
|
|
61
|
+
|
|
62
|
+
# Get student pass (0 or 1) for the target question
|
|
63
|
+
student_correctness_pattern = target.get("pass", None)
|
|
64
|
+
main_part = int(student_correctness_pattern) # "1111111111"
|
|
65
|
+
# Convert each character to an int
|
|
66
|
+
student_correctness_list = [int(ch) for ch in str(main_part)] # [1,1,1,1,1,1,1,1,1,1]
|
|
67
|
+
|
|
68
|
+
# Student specific topic performance in previous attempts
|
|
69
|
+
student_level_prompt = f"Student {student_id} has the following performance across topics:\n"
|
|
70
|
+
topic_performance = student_topic[student_topic["student_id"] == student_id]
|
|
71
|
+
for _, row in topic_performance.iterrows():
|
|
72
|
+
topic = row["topic"]
|
|
73
|
+
pass_rate = round(row["pass_rate"], 2)
|
|
74
|
+
perfect = round(row["perfect"], 2)
|
|
75
|
+
|
|
76
|
+
student_level_prompt += (
|
|
77
|
+
f"- For topic '{topic}', the unit test pass rate is {pass_rate}, "
|
|
78
|
+
f"and the rate of passing all unit tests is {perfect}.\n"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
prompt = (
|
|
82
|
+
"=== Student Profile ===\n"
|
|
83
|
+
f"{student_level_prompt}\n"
|
|
84
|
+
f"Week: {target['week']}\n"
|
|
85
|
+
f"Topic: {target['topic']}\n\n"
|
|
86
|
+
"Example 1:\n"
|
|
87
|
+
f"Question: {first['question_name']} — {first['question_text']}\n"
|
|
88
|
+
"Template:\n"
|
|
89
|
+
f"{first['question_template']}\n"
|
|
90
|
+
"Your Code:\n"
|
|
91
|
+
f"{first['response']}\n\n"
|
|
92
|
+
"Example 2:\n"
|
|
93
|
+
f"Question: {second['question_name']} — {second['question_text']}\n"
|
|
94
|
+
"Template:\n"
|
|
95
|
+
f"{second['question_template']}\n"
|
|
96
|
+
"Your Code:\n"
|
|
97
|
+
f"{second['response']}\n\n"
|
|
98
|
+
"Example 3:\n"
|
|
99
|
+
f"Question: {third['question_name']} — {third['question_text']}\n"
|
|
100
|
+
"Template:\n"
|
|
101
|
+
f"{third['question_template']}\n"
|
|
102
|
+
"Your Code:\n"
|
|
103
|
+
f"{third['response']}\n\n"
|
|
104
|
+
"Now, using that same student style, attempt this:\n"
|
|
105
|
+
f"Question: {target['question_name']} — {target['question_text']}\n"
|
|
106
|
+
f"Unit Test Input: {question_test_cases}\n\n"
|
|
107
|
+
if question_test_cases
|
|
108
|
+
else ""
|
|
109
|
+
"Template:\n"
|
|
110
|
+
f"{target['question_template']}\n\n"
|
|
111
|
+
"Provide ONLY your C++ implementation following the given template, where the answer will replace the {{ STUDENT_ANSWER }} block in the template. "
|
|
112
|
+
"DO NOT reproduce the template part as the generated code would be inserted to the template, "
|
|
113
|
+
"and make sure the code is compatible with the Unit Test Input. "
|
|
114
|
+
"int main() is always declared already so DO NOT produce that initialization on the code. "
|
|
115
|
+
"Ensure your code includes any class definition when needed. "
|
|
116
|
+
"Return the code in C++ code block format, and nothing else."
|
|
117
|
+
)
|
|
118
|
+
instances.append(
|
|
119
|
+
Instance(
|
|
120
|
+
id=f"{student_id}_{target['question_unittest_id']}",
|
|
121
|
+
input=Input(text=prompt),
|
|
122
|
+
references=[Reference(output=Output(text=target["response"]), tags=[CORRECT_TAG])],
|
|
123
|
+
extra_data={
|
|
124
|
+
"question_template": target["question_template"],
|
|
125
|
+
"test_cases": question_test_cases,
|
|
126
|
+
"question_id": str(question_id) if question_id else None,
|
|
127
|
+
"question_name": target.get("question_name", ""),
|
|
128
|
+
"student_id": str(student_id),
|
|
129
|
+
"student_correctness_pattern": student_correctness_list,
|
|
130
|
+
},
|
|
131
|
+
split=VALID_SPLIT,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
return instances
|
|
135
|
+
|
|
136
|
+
def _load_test_cases(self):
|
|
137
|
+
"""
|
|
138
|
+
Load test cases from external source or return None if not available.
|
|
139
|
+
This method should be implemented based on where your test cases are stored.
|
|
140
|
+
|
|
141
|
+
Expected format:
|
|
142
|
+
{
|
|
143
|
+
"question_id": [
|
|
144
|
+
{
|
|
145
|
+
"unittest": "test_id",
|
|
146
|
+
"input": "test input code",
|
|
147
|
+
"output": "expected output"
|
|
148
|
+
},
|
|
149
|
+
...
|
|
150
|
+
],
|
|
151
|
+
...
|
|
152
|
+
}
|
|
153
|
+
"""
|
|
154
|
+
try:
|
|
155
|
+
response = requests.get(
|
|
156
|
+
"https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/test_cases_by_qid.json"
|
|
157
|
+
)
|
|
158
|
+
if response.status_code == 200:
|
|
159
|
+
return response.json()
|
|
160
|
+
except Exception as e:
|
|
161
|
+
print(f"Failed to load test cases from URL: {e}")
|
|
162
|
+
return {}
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
from helm.benchmark.scenarios.scenario import Scenario, Instance, Input, Output, Reference, VALID_SPLIT, CORRECT_TAG
|
|
2
|
+
import pandas as pd
|
|
3
|
+
import requests
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class CodeInsightsStudentMistakeScenario(Scenario):
|
|
7
|
+
name = "codeinsights_student_mistake"
|
|
8
|
+
description = "Mimic how students mistake their C++ codes on foundational questions"
|
|
9
|
+
tags = ["codeinsights", "c++", "student_mistake"]
|
|
10
|
+
|
|
11
|
+
def __init__(self, num_testcases: int = 1):
|
|
12
|
+
super().__init__()
|
|
13
|
+
self.num_testcases = num_testcases
|
|
14
|
+
|
|
15
|
+
def get_instances(self, output_path: str):
|
|
16
|
+
df = pd.read_csv("https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/Scenario3_data.csv")
|
|
17
|
+
student_topic = pd.read_csv(
|
|
18
|
+
"https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/student_performace_by_topic.csv"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
instances = []
|
|
22
|
+
for student_id, student_df in df.groupby("student_id"):
|
|
23
|
+
student_df = student_df.sort_values(by=["student_id", "question_unittest_id", "timestamp"])
|
|
24
|
+
if len(student_df) < 4:
|
|
25
|
+
continue
|
|
26
|
+
first = student_df.iloc[0]
|
|
27
|
+
second = student_df.iloc[1]
|
|
28
|
+
third = student_df.iloc[2]
|
|
29
|
+
target = student_df.iloc[3]
|
|
30
|
+
|
|
31
|
+
# Get test cases for this question
|
|
32
|
+
question_id = target.get("question_unittest_id", None)
|
|
33
|
+
question_test_cases = []
|
|
34
|
+
tc_parsing_success = True
|
|
35
|
+
|
|
36
|
+
for testcase_str in target["question_unittests"].split("Unittest")[1:]:
|
|
37
|
+
testcase_str = testcase_str[testcase_str.find(":") + 1 :]
|
|
38
|
+
input_idx = testcase_str.find("Input:")
|
|
39
|
+
std_in_idx = testcase_str.find("STD input:")
|
|
40
|
+
output_idx = testcase_str.find("Output:")
|
|
41
|
+
if input_idx == -1 or std_in_idx == -1 or output_idx == -1:
|
|
42
|
+
tc_parsing_success = False
|
|
43
|
+
break
|
|
44
|
+
|
|
45
|
+
testcase = {
|
|
46
|
+
"input": testcase_str[input_idx + 6 : std_in_idx].strip(),
|
|
47
|
+
"std_in": testcase_str[std_in_idx + 10 : output_idx].strip(),
|
|
48
|
+
"output": testcase_str[output_idx + 7 :].strip(),
|
|
49
|
+
}
|
|
50
|
+
question_test_cases.append(testcase)
|
|
51
|
+
|
|
52
|
+
if not tc_parsing_success:
|
|
53
|
+
continue
|
|
54
|
+
|
|
55
|
+
if len(question_test_cases) < self.num_testcases:
|
|
56
|
+
# If not enough test cases, skip this question
|
|
57
|
+
continue
|
|
58
|
+
if self.num_testcases >= 0:
|
|
59
|
+
# If more than one test case is requested, only take the first ones
|
|
60
|
+
question_test_cases = question_test_cases[: self.num_testcases]
|
|
61
|
+
|
|
62
|
+
# Get student pass (0 or 1) for the target question
|
|
63
|
+
student_correctness_pattern = target.get("pass", None)
|
|
64
|
+
main_part = int(student_correctness_pattern) # "1111111111"
|
|
65
|
+
# Convert each character to an int
|
|
66
|
+
student_correctness_list = [int(ch) for ch in str(main_part)] # [1,1,1,1,1,1,1,1,1,1]
|
|
67
|
+
|
|
68
|
+
# Student specific topic performance in previous attempts
|
|
69
|
+
student_level_prompt = f"Student {student_id} has the following performance across topics:\n"
|
|
70
|
+
topic_performance = student_topic[student_topic["student_id"] == student_id]
|
|
71
|
+
for _, row in topic_performance.iterrows():
|
|
72
|
+
topic = row["topic"]
|
|
73
|
+
pass_rate = round(row["pass_rate"], 2)
|
|
74
|
+
perfect = round(row["perfect"], 2)
|
|
75
|
+
|
|
76
|
+
student_level_prompt += (
|
|
77
|
+
f"- For topic '{topic}', the unit test pass rate is {pass_rate}, "
|
|
78
|
+
f"and the rate of passing all unit tests is {perfect}.\n"
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
prompt = (
|
|
82
|
+
"=== Student Profile ===\n"
|
|
83
|
+
f"{student_level_prompt}\n"
|
|
84
|
+
"When students submit a code to the platform, it will be tested by number of unit tests, where"
|
|
85
|
+
"- Unit test pass rate = proportion of unit tests passed with the code \n"
|
|
86
|
+
"- Full pass rate = proportion of code passing all unit tests\n\n"
|
|
87
|
+
"=== Past Mistake Examples ===\n"
|
|
88
|
+
"Example 1 (Week {first['week']}, Topic: {first['topic']}):\n"
|
|
89
|
+
f"Question: {first['question_name']} — {first['question_text']}\n"
|
|
90
|
+
"Template:\n"
|
|
91
|
+
f"{first['question_template']}\n"
|
|
92
|
+
"Student's Response Code with Error:\n"
|
|
93
|
+
f"{first['response_mistake']}\n\n"
|
|
94
|
+
"Example 2 (Week {second['week']}, Topic: {second['topic']}):\n"
|
|
95
|
+
f"Question: {second['question_name']} — {second['question_text']}\n"
|
|
96
|
+
"Template:\n"
|
|
97
|
+
f"{second['question_template']}\n"
|
|
98
|
+
"Student's Response Code with Error:\n"
|
|
99
|
+
f"{second['response_mistake']}\n\n"
|
|
100
|
+
"Example 3 (Week {third['week']}, Topic: {third['topic']}):\n"
|
|
101
|
+
f"Question: {third['question_name']} — {third['question_text']}\n"
|
|
102
|
+
"Template:\n"
|
|
103
|
+
f"{third['question_template']}\n"
|
|
104
|
+
"Student's Response Code with Error:\n"
|
|
105
|
+
f"{third['response_mistake']}\n\n"
|
|
106
|
+
"=== New Target Problem ===\n"
|
|
107
|
+
f"Week: {target['week']}, Topic: {target['topic']}\n"
|
|
108
|
+
f"Question: {target['question_name']} — {target['question_text']}\n"
|
|
109
|
+
f"Unit Test Input: {question_test_cases}\n\n"
|
|
110
|
+
if question_test_cases
|
|
111
|
+
else ""
|
|
112
|
+
"Template:\n"
|
|
113
|
+
f"{target['question_template']}\n\n"
|
|
114
|
+
"⚠**Instructions:**\n"
|
|
115
|
+
"1. Mimic your own coding style, naming conventions, indentation, and typical error patterns.\n"
|
|
116
|
+
"2. Introduce mistake you are likely to make (e.g., off‐by‐one index, wrong initialization, "
|
|
117
|
+
"missing edge case).\n"
|
|
118
|
+
"3. Do **not** produce a fully correct solution or add unfamiliar optimizations.\n\n"
|
|
119
|
+
"Provide ONLY your C++ implementation following the given template, where the answer will replace the {{ STUDENT_ANSWER }} block in the template. "
|
|
120
|
+
"DO NOT reproduce the template part as the generated code would be inserted to the template, "
|
|
121
|
+
"and make sure the code is compatible with the Unit Test Input. "
|
|
122
|
+
"int main() is always declared already so DO NOT produce that initialization on the code. "
|
|
123
|
+
"Ensure your code is includes any class definition when needed. "
|
|
124
|
+
"Return the code in C++ code block format, and nothing else."
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
print(f"\n=== DEBUG INFO FOR STUDENT {student_id}, QUESTION {question_id} ===")
|
|
128
|
+
print(f"Test cases loaded: {len(question_test_cases)}")
|
|
129
|
+
print(f"Student correctness pattern: {student_correctness_list}")
|
|
130
|
+
print(f"Original pass field: {target.get('pass', 'MISSING')}")
|
|
131
|
+
print(f"Question template exists: {'question_template' in target}")
|
|
132
|
+
print(f"Question name: {target.get('question_name', 'MISSING')}")
|
|
133
|
+
|
|
134
|
+
# Also add this validation in your UnitTestAlignmentMetric evaluate_generation method:
|
|
135
|
+
def evaluate_generation(self, adapter_spec, request_state, metric_service, eval_cache_path):
|
|
136
|
+
print("\n=== UNIT TEST METRIC DEBUG ===")
|
|
137
|
+
print(f"Has extra_data: {hasattr(request_state.instance, 'extra_data')}")
|
|
138
|
+
if hasattr(request_state.instance, "extra_data"):
|
|
139
|
+
extra_data = request_state.instance.extra_data
|
|
140
|
+
print(f"Extra data keys: {list(extra_data.keys())}")
|
|
141
|
+
print(f"Test cases: {len(extra_data.get('test_cases', []))}")
|
|
142
|
+
print(f"Student pattern: {extra_data.get('student_correctness_pattern', 'MISSING')}")
|
|
143
|
+
|
|
144
|
+
instances.append(
|
|
145
|
+
Instance(
|
|
146
|
+
id=f"{student_id}_{target['question_unittest_id']}",
|
|
147
|
+
input=Input(text=prompt),
|
|
148
|
+
references=[Reference(output=Output(text=target["response_mistake"]), tags=[CORRECT_TAG])],
|
|
149
|
+
extra_data={
|
|
150
|
+
"question_template": target["question_template"],
|
|
151
|
+
"test_cases": question_test_cases,
|
|
152
|
+
"question_id": str(question_id) if question_id else None,
|
|
153
|
+
"question_name": target.get("question_name", ""),
|
|
154
|
+
"student_id": str(student_id),
|
|
155
|
+
"student_correctness_pattern": student_correctness_list,
|
|
156
|
+
},
|
|
157
|
+
split=VALID_SPLIT,
|
|
158
|
+
)
|
|
159
|
+
)
|
|
160
|
+
return instances
|
|
161
|
+
|
|
162
|
+
def _load_test_cases(self):
|
|
163
|
+
"""
|
|
164
|
+
Load test cases from external source or return None if not available.
|
|
165
|
+
This method should be implemented based on where your test cases are stored.
|
|
166
|
+
|
|
167
|
+
Expected format:
|
|
168
|
+
{
|
|
169
|
+
"question_id": [
|
|
170
|
+
{
|
|
171
|
+
"unittest": "test_id",
|
|
172
|
+
"input": "test input code",
|
|
173
|
+
"output": "expected output"
|
|
174
|
+
},
|
|
175
|
+
...
|
|
176
|
+
],
|
|
177
|
+
...
|
|
178
|
+
}
|
|
179
|
+
"""
|
|
180
|
+
try:
|
|
181
|
+
response = requests.get(
|
|
182
|
+
"https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/test_cases_by_qid.json"
|
|
183
|
+
)
|
|
184
|
+
if response.status_code == 200:
|
|
185
|
+
return response.json()
|
|
186
|
+
except Exception as e:
|
|
187
|
+
print(f"Failed to load test cases from URL: {e}")
|
|
188
|
+
return {}
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
from typing import List
|
|
2
|
-
from helm.common.general import
|
|
2
|
+
from helm.common.general import check_file_exists
|
|
3
3
|
from helm.benchmark.scenarios.scenario import (
|
|
4
4
|
Input,
|
|
5
5
|
Scenario,
|
|
@@ -21,26 +21,34 @@ def file_preprocessing(data_path: str, task_objective: str) -> pd.DataFrame:
|
|
|
21
21
|
data_path is directory that contains the downloaded files: '{base_dir}/physionet.org/'
|
|
22
22
|
"""
|
|
23
23
|
# Load the first CSV file
|
|
24
|
-
|
|
25
|
-
|
|
24
|
+
diagnosis_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/diagnosis.csv.gz"
|
|
25
|
+
check_file_exists(
|
|
26
|
+
diagnosis_path, msg=f"[DischargeMeScenario] Required diagnosis file not found: '{diagnosis_path}'"
|
|
26
27
|
)
|
|
27
|
-
|
|
28
|
-
|
|
28
|
+
discharge_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/discharge.csv.gz"
|
|
29
|
+
check_file_exists(
|
|
30
|
+
discharge_path, msg=f"[DischargeMeScenario] Required discharge file not found: '{discharge_path}'"
|
|
29
31
|
)
|
|
32
|
+
target_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/discharge_target.csv.gz"
|
|
33
|
+
check_file_exists(target_path, msg=f"[DischargeMeScenario] Required target file not found: '{target_path}'")
|
|
34
|
+
radiology_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/radiology.csv.gz"
|
|
35
|
+
check_file_exists(
|
|
36
|
+
radiology_path, msg=f"[DischargeMeScenario] Required radiology file not found: '{radiology_path}'"
|
|
37
|
+
)
|
|
38
|
+
ed_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/edstays.csv.gz"
|
|
39
|
+
check_file_exists(ed_path, msg=f"[DischargeMeScenario] Required ed file not found: '{ed_path}'")
|
|
40
|
+
triage_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/triage.csv.gz"
|
|
41
|
+
check_file_exists(triage_path, msg=f"[DischargeMeScenario] Required triage file not found: '{triage_path}'")
|
|
42
|
+
df_diagnosis = pd.read_csv(diagnosis_path, compression="gzip", keep_default_na=False)
|
|
43
|
+
df_discharge = pd.read_csv(discharge_path, compression="gzip", keep_default_na=False)
|
|
30
44
|
df_target = pd.read_csv(
|
|
31
|
-
|
|
45
|
+
target_path,
|
|
32
46
|
compression="gzip",
|
|
33
47
|
keep_default_na=False,
|
|
34
48
|
)
|
|
35
|
-
df_radiology = pd.read_csv(
|
|
36
|
-
|
|
37
|
-
)
|
|
38
|
-
df_ed = pd.read_csv(
|
|
39
|
-
f"{data_path}/files/discharge-me/1.3/test_phase_1/edstays.csv.gz", compression="gzip", keep_default_na=False
|
|
40
|
-
)
|
|
41
|
-
df_triage = pd.read_csv(
|
|
42
|
-
f"{data_path}/files/discharge-me/1.3/test_phase_1/triage.csv.gz", compression="gzip", keep_default_na=False
|
|
43
|
-
)
|
|
49
|
+
df_radiology = pd.read_csv(radiology_path, compression="gzip", keep_default_na=False)
|
|
50
|
+
df_ed = pd.read_csv(ed_path, compression="gzip", keep_default_na=False)
|
|
51
|
+
df_triage = pd.read_csv(triage_path, compression="gzip", keep_default_na=False)
|
|
44
52
|
df_diagnosis_triage = pd.merge(
|
|
45
53
|
df_diagnosis, df_triage, on="subject_id", how="inner", suffixes=("_df_diagnosis", "_df_triage")
|
|
46
54
|
)
|
|
@@ -113,16 +121,23 @@ class DischargeMeScenario(Scenario):
|
|
|
113
121
|
"""
|
|
114
122
|
|
|
115
123
|
name = "dischargeme"
|
|
116
|
-
description =
|
|
117
|
-
|
|
124
|
+
description = (
|
|
125
|
+
"DischargeMe is a benchmark designed to evaluate clinical text generation. It pairs"
|
|
126
|
+
"discharge summaries and radiology reports from MIMIC-IV with generation tasks"
|
|
127
|
+
"such as writing discharge instructions or summarizing the brief hospital course. The"
|
|
128
|
+
"benchmark assesses a model's ability to generate patient-facing documentation that is"
|
|
129
|
+
"complete, empathetic, and clinically accurate."
|
|
130
|
+
)
|
|
118
131
|
tags = ["biomedical"]
|
|
119
132
|
|
|
133
|
+
def __init__(self, data_path: str):
|
|
134
|
+
super().__init__()
|
|
135
|
+
self.data_path = data_path
|
|
136
|
+
|
|
120
137
|
def get_instances(self, output_path: str) -> List[Instance]:
|
|
121
|
-
data_path = "/share/pi/nigam/data/physionet.org"
|
|
122
|
-
ensure_directory_exists(data_path)
|
|
123
138
|
instances: List[Instance] = []
|
|
124
|
-
df_bhc = file_preprocessing(data_path, "brief_hospital_course")
|
|
125
|
-
df_di = file_preprocessing(data_path, "discharge_instructions")
|
|
139
|
+
df_bhc = file_preprocessing(self.data_path, "brief_hospital_course")
|
|
140
|
+
df_di = file_preprocessing(self.data_path, "discharge_instructions")
|
|
126
141
|
|
|
127
142
|
for i in range(df_bhc.shape[0]):
|
|
128
143
|
prompt_bhc = create_prompt(
|
|
@@ -36,7 +36,13 @@ class EhrSqlScenario(Scenario):
|
|
|
36
36
|
)
|
|
37
37
|
|
|
38
38
|
name = "ehr_sql"
|
|
39
|
-
description =
|
|
39
|
+
description = (
|
|
40
|
+
"EHRSQL is a benchmark designed to evaluate models on generating structured queries"
|
|
41
|
+
"for clinical research. Each example includes a natural language question and a database"
|
|
42
|
+
"schema, and the task is to produce an SQL query that would return the correct result"
|
|
43
|
+
"for a biomedical research objective. This benchmark assesses a model's understanding"
|
|
44
|
+
"of medical terminology, data structures, and query construction."
|
|
45
|
+
)
|
|
40
46
|
tags = ["sql", "medical", "reasoning"]
|
|
41
47
|
|
|
42
48
|
def setup_database(self, output_path: str) -> str:
|
|
@@ -3,12 +3,11 @@ import os
|
|
|
3
3
|
import pandas as pd
|
|
4
4
|
import tiktoken
|
|
5
5
|
|
|
6
|
-
from filelock import FileLock
|
|
7
6
|
from functools import partial
|
|
8
7
|
from tqdm import tqdm
|
|
9
8
|
from typing import Any, Dict, List, Optional, Mapping
|
|
10
9
|
|
|
11
|
-
from helm.common.general import ensure_directory_exists
|
|
10
|
+
from helm.common.general import check_file_exists, ensure_directory_exists
|
|
12
11
|
from helm.benchmark.scenarios.scenario import (
|
|
13
12
|
TEST_SPLIT,
|
|
14
13
|
Input,
|
|
@@ -1411,7 +1410,10 @@ class EHRSHOTScenario(Scenario):
|
|
|
1411
1410
|
|
|
1412
1411
|
name = "ehrshot"
|
|
1413
1412
|
description = (
|
|
1414
|
-
"
|
|
1413
|
+
"EHRSHOT is a benchmark designed to evaluate a model's ability to predict future"
|
|
1414
|
+
"clinical events using structured EHR data. Each instance contains a patient's"
|
|
1415
|
+
"historical EHR data and a forward-looking clinical question about whether a particular"
|
|
1416
|
+
"diagnosis, lab result, or hospital event will occur."
|
|
1415
1417
|
)
|
|
1416
1418
|
tags = [] # TODO
|
|
1417
1419
|
|
|
@@ -1420,24 +1422,32 @@ class EHRSHOTScenario(Scenario):
|
|
|
1420
1422
|
"no",
|
|
1421
1423
|
]
|
|
1422
1424
|
|
|
1423
|
-
def __init__(self, subject: str, max_length: Optional[int] = None):
|
|
1425
|
+
def __init__(self, subject: str, data_path: str, max_length: Optional[int] = None):
|
|
1424
1426
|
super().__init__()
|
|
1425
1427
|
self.subject: str = subject # same as "task" or "labeling_function"
|
|
1426
|
-
self.path_to_meds_dir: str = "/share/pi/nigam/data/medhelm/ehrshot/meds/"
|
|
1427
|
-
self.path_to_tmp_dir: str = "/share/pi/nigam/data/medhelm/ehrshot/prompts/"
|
|
1428
1428
|
self.max_length = max_length
|
|
1429
|
+
self.data_path = data_path
|
|
1429
1430
|
|
|
1430
|
-
def create_benchmark(self, n_procs: int = 4) -> Dict[str, str]:
|
|
1431
|
+
def create_benchmark(self, output_path: str, n_procs: int = 4) -> Dict[str, str]:
|
|
1431
1432
|
"""Loads the MEDS dataset and converts it to prompts"""
|
|
1432
|
-
|
|
1433
1433
|
# Load MEDS EHRSHOT patient timelines
|
|
1434
|
-
|
|
1435
|
-
|
|
1436
|
-
|
|
1434
|
+
data_parquet_path = os.path.join(self.data_path, "data/data.parquet")
|
|
1435
|
+
check_file_exists(
|
|
1436
|
+
data_parquet_path, msg=f"[EHRSHOTScenario] Required parquet data file not found: '{data_parquet_path}'"
|
|
1437
|
+
)
|
|
1438
|
+
splits_parquet_path = os.path.join(self.data_path, "metadata/subject_splits.parquet")
|
|
1439
|
+
check_file_exists(
|
|
1440
|
+
splits_parquet_path, msg=f"[EHRSHOTScenario] Required splits file not found: '{splits_parquet_path}'"
|
|
1441
|
+
)
|
|
1442
|
+
df_data = pd.read_parquet(data_parquet_path)
|
|
1443
|
+
df_splits = pd.read_parquet(splits_parquet_path)
|
|
1437
1444
|
# Load MEDS EHRSHOT labels
|
|
1438
|
-
tasks = sorted(os.listdir(os.path.join(self.
|
|
1445
|
+
tasks = sorted(os.listdir(os.path.join(self.data_path, "labels")))
|
|
1439
1446
|
for t in tasks:
|
|
1440
|
-
path_to_labels: str = os.path.join(self.
|
|
1447
|
+
path_to_labels: str = os.path.join(self.data_path, "labels", t, "labels.parquet")
|
|
1448
|
+
check_file_exists(
|
|
1449
|
+
path_to_labels, msg=f"[EHRSHOTScenario] Required labels file not found: '{path_to_labels}'"
|
|
1450
|
+
)
|
|
1441
1451
|
if t != self.subject or not os.path.exists(path_to_labels):
|
|
1442
1452
|
continue
|
|
1443
1453
|
df_labels = pd.read_parquet(path_to_labels)
|
|
@@ -1470,18 +1480,16 @@ class EHRSHOTScenario(Scenario):
|
|
|
1470
1480
|
df_labels["prompt"] = prompts
|
|
1471
1481
|
|
|
1472
1482
|
# Save to parquet
|
|
1473
|
-
path_to_output_dir: str = os.path.join(
|
|
1483
|
+
path_to_output_dir: str = os.path.join(output_path, self.subject)
|
|
1474
1484
|
ensure_directory_exists(path_to_output_dir)
|
|
1475
1485
|
df_labels.to_parquet(os.path.join(path_to_output_dir, "medhelm_prompts.parquet"))
|
|
1476
1486
|
return {"status": "success"}
|
|
1477
1487
|
|
|
1478
1488
|
def get_instances(self, output_path: str) -> List[Instance]:
|
|
1479
|
-
path_to_input_csv: str = os.path.join(
|
|
1480
|
-
|
|
1481
|
-
|
|
1482
|
-
|
|
1483
|
-
print(f"Creating benchmark from SCRATCH for {self.subject}...")
|
|
1484
|
-
self.create_benchmark() # Create benchmark from scratch
|
|
1489
|
+
path_to_input_csv: str = os.path.join(output_path, self.subject, "medhelm_prompts.parquet")
|
|
1490
|
+
if not os.path.exists(path_to_input_csv):
|
|
1491
|
+
print(f"Creating benchmark from SCRATCH for {self.subject}...")
|
|
1492
|
+
self.create_benchmark(output_path=output_path) # Create benchmark from scratch
|
|
1485
1493
|
|
|
1486
1494
|
# Load data for this task
|
|
1487
1495
|
df = pd.read_parquet(path_to_input_csv)
|
|
@@ -1509,38 +1517,3 @@ class EHRSHOTScenario(Scenario):
|
|
|
1509
1517
|
)
|
|
1510
1518
|
|
|
1511
1519
|
return instances
|
|
1512
|
-
|
|
1513
|
-
|
|
1514
|
-
if __name__ == "__main__":
|
|
1515
|
-
# Generate statistics on prompts
|
|
1516
|
-
from transformers import AutoTokenizer
|
|
1517
|
-
|
|
1518
|
-
tokenizer = AutoTokenizer.from_pretrained("gpt2")
|
|
1519
|
-
tqdm.pandas()
|
|
1520
|
-
n_procs: int = 10
|
|
1521
|
-
|
|
1522
|
-
os.makedirs("./ehrshot_stats", exist_ok=True)
|
|
1523
|
-
for t in TASK_FULL_NAMES.keys():
|
|
1524
|
-
# Skip if already exists
|
|
1525
|
-
if os.path.exists(f"./ehrshot_stats/{t}.txt"):
|
|
1526
|
-
print(f"Skipping {t} because it already exists")
|
|
1527
|
-
continue
|
|
1528
|
-
|
|
1529
|
-
# Create benchmark
|
|
1530
|
-
scenario = EHRSHOTScenario(subject=t)
|
|
1531
|
-
scenario.create_benchmark(n_procs=n_procs)
|
|
1532
|
-
instances = scenario.get_instances("test.csv")
|
|
1533
|
-
|
|
1534
|
-
# Calculate prompt token stats
|
|
1535
|
-
path_to_input_csv = os.path.join(scenario.path_to_tmp_dir, scenario.subject, "medhelm_prompts.parquet")
|
|
1536
|
-
df = pd.read_parquet(path_to_input_csv)
|
|
1537
|
-
df["prompt_n_tokens"] = df["prompt"].progress_apply(lambda x: len(tokenizer.encode(x)))
|
|
1538
|
-
with open(f"./ehrshot_stats/{t}.txt", "w") as f:
|
|
1539
|
-
f.write("-" * 100 + "\n")
|
|
1540
|
-
f.write(f"Task: {t}\n")
|
|
1541
|
-
f.write(f"# of instances: {len(instances)}\n")
|
|
1542
|
-
f.write(f"# of positives: {df['boolean_value'].sum()}\n")
|
|
1543
|
-
f.write(f"Size of splits:\n{df['split'].value_counts()}\n")
|
|
1544
|
-
f.write(f"# tokens per prompt:\n{df['prompt_n_tokens'].describe()}\n")
|
|
1545
|
-
f.write("-" * 100 + "\n")
|
|
1546
|
-
df.to_parquet(os.path.join(scenario.path_to_tmp_dir, scenario.subject, "medhelm_prompts.parquet"))
|