crfm-helm 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (268) hide show
  1. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/METADATA +74 -53
  2. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/RECORD +262 -182
  3. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +3 -3
  5. helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
  6. helm/benchmark/annotation/air_bench_annotator.py +2 -2
  7. helm/benchmark/annotation/bigcodebench_annotator.py +3 -3
  8. helm/benchmark/annotation/bird_sql_annotator.py +2 -2
  9. helm/benchmark/annotation/chw_care_plan_annotator.py +7 -12
  10. helm/benchmark/annotation/ehr_sql_annotator.py +2 -2
  11. helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +7 -7
  12. helm/benchmark/annotation/live_qa_annotator.py +1 -1
  13. helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
  14. helm/benchmark/annotation/model_as_judge.py +12 -16
  15. helm/benchmark/annotation/omni_math_annotator.py +13 -14
  16. helm/benchmark/annotation/wildbench_annotator.py +9 -9
  17. helm/benchmark/executor.py +11 -12
  18. helm/benchmark/metrics/aci_bench_metrics.py +9 -29
  19. helm/benchmark/metrics/bias_word_lists.py +1 -1
  20. helm/benchmark/metrics/chw_care_plan_metrics.py +10 -30
  21. helm/benchmark/metrics/classification_metrics.py +3 -3
  22. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  23. helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
  24. helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
  25. helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
  26. helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
  27. helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
  28. helm/benchmark/metrics/comet_metric.py +1 -1
  29. helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +2 -2
  30. helm/benchmark/metrics/copyright_metrics.py +1 -1
  31. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
  32. helm/benchmark/metrics/dischargeme_metrics.py +9 -29
  33. helm/benchmark/metrics/efficiency_metrics.py +3 -3
  34. helm/benchmark/metrics/evaluate_reference_metrics.py +1 -1
  35. helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
  36. helm/benchmark/metrics/ifeval_metrics.py +2 -2
  37. helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
  38. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
  39. helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
  40. helm/benchmark/metrics/llm_jury_metrics.py +46 -0
  41. helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
  42. helm/benchmark/metrics/lmkt_metrics.py +47 -0
  43. helm/benchmark/metrics/med_dialog_metrics.py +9 -29
  44. helm/benchmark/metrics/medalign_metrics.py +9 -29
  45. helm/benchmark/metrics/medi_qa_metrics.py +9 -29
  46. helm/benchmark/metrics/medication_qa_metrics.py +10 -30
  47. helm/benchmark/metrics/melt_bias_metric.py +234 -0
  48. helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
  49. helm/benchmark/metrics/melt_metric_specs.py +43 -0
  50. helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
  51. helm/benchmark/metrics/mental_health_metrics.py +9 -29
  52. helm/benchmark/metrics/metric_service.py +11 -11
  53. helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
  54. helm/benchmark/metrics/mimic_rrs_metrics.py +9 -29
  55. helm/benchmark/metrics/mtsamples_procedures_metrics.py +9 -29
  56. helm/benchmark/metrics/mtsamples_replicate_metrics.py +9 -29
  57. helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
  58. helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
  59. helm/benchmark/metrics/starr_patient_instructions_metrics.py +9 -29
  60. helm/benchmark/metrics/summac/model_summac.py +2 -3
  61. helm/benchmark/metrics/summarization_metrics.py +2 -1
  62. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +2 -2
  63. helm/benchmark/metrics/toxicity_metrics.py +2 -2
  64. helm/benchmark/metrics/unitxt_metrics.py +3 -4
  65. helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
  66. helm/benchmark/metrics/vision_language/image_utils.py +2 -2
  67. helm/benchmark/model_deployment_registry.py +16 -26
  68. helm/benchmark/presentation/contamination.py +3 -3
  69. helm/benchmark/presentation/create_plots.py +43 -13
  70. helm/benchmark/presentation/run_display.py +13 -0
  71. helm/benchmark/presentation/schema.py +7 -1
  72. helm/benchmark/presentation/summarize.py +84 -61
  73. helm/benchmark/presentation/test_create_plots.py +4 -1
  74. helm/benchmark/reeval_run.py +3 -4
  75. helm/benchmark/reeval_runner.py +3 -3
  76. helm/benchmark/run.py +84 -73
  77. helm/benchmark/run_expander.py +12 -1
  78. helm/benchmark/run_spec_factory.py +7 -6
  79. helm/benchmark/run_specs/arabic_run_specs.py +73 -0
  80. helm/benchmark/run_specs/audio_run_specs.py +52 -8
  81. helm/benchmark/run_specs/bluex_run_specs.py +40 -0
  82. helm/benchmark/run_specs/classic_run_specs.py +0 -53
  83. helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
  84. helm/benchmark/run_specs/enterprise_run_specs.py +20 -0
  85. helm/benchmark/run_specs/experimental_run_specs.py +31 -1
  86. helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
  87. helm/benchmark/run_specs/heim_run_specs.py +3 -1
  88. helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
  89. helm/benchmark/run_specs/long_context_run_specs.py +114 -15
  90. helm/benchmark/run_specs/medhelm_run_specs.py +146 -41
  91. helm/benchmark/run_specs/melt_run_specs.py +783 -0
  92. helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
  93. helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +163 -0
  94. helm/benchmark/run_specs/vlm_run_specs.py +28 -0
  95. helm/benchmark/runner.py +5 -5
  96. helm/benchmark/scenarios/aci_bench_scenario.py +7 -1
  97. helm/benchmark/scenarios/alghafa_scenario.py +126 -0
  98. helm/benchmark/scenarios/arabic_mmlu_scenario.py +78 -0
  99. helm/benchmark/scenarios/aratrust_scenario.py +76 -0
  100. helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +3 -1
  101. helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +5 -5
  102. helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
  103. helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
  104. helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
  105. helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +104 -0
  106. helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +99 -0
  107. helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +118 -0
  108. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +86 -0
  109. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +117 -0
  110. helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +15 -1
  111. helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +1 -2
  112. helm/benchmark/scenarios/autobencher_capabilities_scenario.py +2 -2
  113. helm/benchmark/scenarios/bluex_scenario.py +66 -0
  114. helm/benchmark/scenarios/chw_care_plan_scenario.py +14 -13
  115. helm/benchmark/scenarios/clear_scenario.py +11 -7
  116. helm/benchmark/scenarios/cleva_scenario.py +1 -1
  117. helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
  118. helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
  119. helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
  120. helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
  121. helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
  122. helm/benchmark/scenarios/dischargeme_scenario.py +36 -21
  123. helm/benchmark/scenarios/ehr_sql_scenario.py +7 -1
  124. helm/benchmark/scenarios/ehrshot_scenario.py +28 -55
  125. helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
  126. helm/benchmark/scenarios/grammar.py +2 -2
  127. helm/benchmark/scenarios/headqa_scenario.py +6 -1
  128. helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
  129. helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +90 -0
  130. helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
  131. helm/benchmark/scenarios/{infinite_bench_sum_scenario.py → infinite_bench_en_sum_scenario.py} +10 -13
  132. helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
  133. helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
  134. helm/benchmark/scenarios/math_scenario.py +21 -20
  135. helm/benchmark/scenarios/med_dialog_scenario.py +6 -1
  136. helm/benchmark/scenarios/medalign_scenario.py +9 -3
  137. helm/benchmark/scenarios/medalign_scenario_helper.py +27 -130
  138. helm/benchmark/scenarios/medbullets_scenario.py +7 -2
  139. helm/benchmark/scenarios/medcalc_bench_scenario.py +4 -2
  140. helm/benchmark/scenarios/medec_scenario.py +6 -1
  141. helm/benchmark/scenarios/medhallu_scenario.py +7 -1
  142. helm/benchmark/scenarios/medi_qa_scenario.py +10 -4
  143. helm/benchmark/scenarios/medication_qa_scenario.py +7 -1
  144. helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
  145. helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
  146. helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
  147. helm/benchmark/scenarios/melt_scenarios.py +793 -0
  148. helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
  149. helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
  150. helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
  151. helm/benchmark/scenarios/mental_health_scenario.py +16 -5
  152. helm/benchmark/scenarios/mimic_bhc_scenario.py +13 -8
  153. helm/benchmark/scenarios/mimic_rrs_scenario.py +17 -8
  154. helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +14 -8
  155. helm/benchmark/scenarios/mmlu_pro_scenario.py +1 -1
  156. helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
  157. helm/benchmark/scenarios/mtsamples_procedures_scenario.py +5 -2
  158. helm/benchmark/scenarios/mtsamples_replicate_scenario.py +3 -2
  159. helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +11 -5
  160. helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
  161. helm/benchmark/scenarios/pubmed_qa_scenario.py +6 -1
  162. helm/benchmark/scenarios/race_based_med_scenario.py +18 -8
  163. helm/benchmark/scenarios/ruler_qa_scenario_helper.py +2 -2
  164. helm/benchmark/scenarios/ruler_qa_scenarios.py +2 -2
  165. helm/benchmark/scenarios/seahelm_scenario.py +2 -2
  166. helm/benchmark/scenarios/shc_bmt_scenario.py +12 -6
  167. helm/benchmark/scenarios/shc_cdi_scenario.py +11 -6
  168. helm/benchmark/scenarios/shc_conf_scenario.py +12 -6
  169. helm/benchmark/scenarios/shc_ent_scenario.py +11 -6
  170. helm/benchmark/scenarios/shc_gip_scenario.py +13 -5
  171. helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
  172. helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
  173. helm/benchmark/scenarios/shc_ptbm_scenario.py +12 -7
  174. helm/benchmark/scenarios/shc_sei_scenario.py +12 -7
  175. helm/benchmark/scenarios/shc_sequoia_scenario.py +13 -5
  176. helm/benchmark/scenarios/starr_patient_instructions_scenario.py +15 -8
  177. helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
  178. helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
  179. helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
  180. helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
  181. helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
  182. helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
  183. helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
  184. helm/benchmark/scenarios/truthful_qa_scenario.py +2 -1
  185. helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
  186. helm/benchmark/server.py +2 -1
  187. helm/benchmark/slurm_jobs.py +1 -2
  188. helm/benchmark/slurm_runner.py +8 -1
  189. helm/benchmark/static/schema_arabic.yaml +228 -0
  190. helm/benchmark/static/schema_audio.yaml +60 -49
  191. helm/benchmark/static/schema_classic.yaml +0 -17
  192. helm/benchmark/static/schema_enterprise.yaml +21 -0
  193. helm/benchmark/static/schema_long_context.yaml +81 -20
  194. helm/benchmark/static/schema_medhelm.yaml +272 -213
  195. helm/benchmark/static/schema_melt.yaml +1257 -0
  196. helm/benchmark/static/schema_slphelm.yaml +162 -0
  197. helm/benchmark/static/schema_vhelm.yaml +26 -26
  198. helm/benchmark/static/schema_video.yaml +219 -0
  199. helm/benchmark/static_build/assets/index-b9779128.css +1 -0
  200. helm/benchmark/static_build/assets/index-e439d5e1.js +10 -0
  201. helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
  202. helm/benchmark/static_build/assets/{tremor-9cefc3c5.js → tremor-38a10867.js} +1 -1
  203. helm/benchmark/static_build/index.html +4 -4
  204. helm/benchmark/window_services/encoder_decoder_window_service.py +3 -3
  205. helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
  206. helm/benchmark/window_services/test_utils.py +3 -4
  207. helm/benchmark/window_services/tokenizer_service.py +7 -8
  208. helm/clients/anthropic_client.py +69 -29
  209. helm/clients/audio_language/diva_llama_client.py +4 -2
  210. helm/clients/audio_language/qwen2_5_omni_client.py +209 -0
  211. helm/clients/audio_language/qwen2_audiolm_client.py +8 -6
  212. helm/clients/audio_language/qwen_audiolm_client.py +4 -2
  213. helm/clients/audio_language/test.py +62 -0
  214. helm/clients/bedrock_client.py +3 -1
  215. helm/clients/client.py +7 -7
  216. helm/clients/grok_client.py +36 -0
  217. helm/clients/huggingface_client.py +42 -3
  218. helm/clients/huggingface_pipeline_client.py +138 -0
  219. helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
  220. helm/clients/image_generation/dalle_mini/model/modeling.py +1 -1
  221. helm/clients/image_generation/dalle_mini/model/processor.py +1 -1
  222. helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
  223. helm/clients/openai_client.py +102 -55
  224. helm/clients/openai_responses_client.py +176 -0
  225. helm/clients/palmyra_client.py +2 -5
  226. helm/clients/reka_client.py +2 -2
  227. helm/clients/test_huggingface_client.py +3 -3
  228. helm/clients/together_client.py +31 -6
  229. helm/clients/vertexai_client.py +17 -9
  230. helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
  231. helm/clients/vision_language/huggingface_vlm_client.py +2 -2
  232. helm/clients/vision_language/idefics_client.py +6 -2
  233. helm/clients/vision_language/paligemma_client.py +2 -2
  234. helm/clients/vision_language/qwen2_vlm_client.py +66 -53
  235. helm/clients/vision_language/qwen_vlm_client.py +7 -5
  236. helm/clients/vllm_client.py +43 -7
  237. helm/clients/vllm_granite_thinking_client.py +56 -0
  238. helm/clients/writer_client.py +102 -0
  239. helm/common/context.py +80 -0
  240. helm/common/credentials_utils.py +5 -5
  241. helm/common/critique_request.py +0 -1
  242. helm/common/general.py +9 -2
  243. helm/common/hierarchical_logger.py +104 -12
  244. helm/common/local_context.py +140 -0
  245. helm/common/object_spec.py +23 -8
  246. helm/common/remote_context.py +61 -0
  247. helm/common/request.py +8 -0
  248. helm/common/test_logging.py +94 -0
  249. helm/config/model_deployments.yaml +995 -45
  250. helm/config/model_metadata.yaml +780 -59
  251. helm/config/tokenizer_configs.yaml +224 -3
  252. helm/proxy/cli.py +4 -2
  253. helm/proxy/critique/mechanical_turk_utils.py +1 -1
  254. helm/proxy/retry.py +5 -0
  255. helm/proxy/services/server_service.py +21 -85
  256. helm/tokenizers/grok_tokenizer.py +55 -0
  257. helm/tokenizers/huggingface_tokenizer.py +1 -1
  258. helm/tokenizers/test_grok_tokenizer.py +33 -0
  259. helm/benchmark/metrics/numeracy_metrics.py +0 -72
  260. helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
  261. helm/benchmark/scenarios/numeracy_scenario.py +0 -793
  262. helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +0 -46
  263. helm/benchmark/static_build/assets/index-262903c1.js +0 -10
  264. helm/benchmark/static_build/assets/index-42060d71.css +0 -1
  265. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/entry_points.txt +0 -0
  266. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/licenses/LICENSE +0 -0
  267. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/top_level.txt +0 -0
  268. /helm/benchmark/static_build/assets/{medhelm-overview-3ddfcd65.png → medhelm-v1-overview-3ddfcd65.png} +0 -0
@@ -0,0 +1,162 @@
1
+ from helm.benchmark.scenarios.scenario import Scenario, Instance, Input, Output, Reference, VALID_SPLIT, CORRECT_TAG
2
+ import pandas as pd
3
+ import requests
4
+
5
+
6
+ class CodeInsightsStudentCodingScenario(Scenario):
7
+ name = "codeinsights_student_coding"
8
+ description = "Mimic student C++ style on foundational questions"
9
+ tags = ["codeinsights", "c++", "student_coding"]
10
+
11
+ def __init__(self, num_testcases: int = 1):
12
+ super().__init__()
13
+ self.num_testcases = num_testcases
14
+
15
+ def get_instances(self, output_path: str):
16
+ df = pd.read_csv("https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/Scenario1_2_data.csv")
17
+ student_topic = pd.read_csv(
18
+ "https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/student_performace_by_topic.csv"
19
+ )
20
+
21
+ instances = []
22
+ for student_id, student_df in df.groupby("student_id"):
23
+ student_df = student_df.sort_values("timestamp")
24
+ if len(student_df) < 4:
25
+ continue
26
+ first = student_df.iloc[0]
27
+ second = student_df.iloc[1]
28
+ third = student_df.iloc[2]
29
+ target = student_df.iloc[3]
30
+
31
+ # Get test cases for this question
32
+ question_id = target.get("question_unittest_id", None)
33
+ question_test_cases = []
34
+ tc_parsing_success = True
35
+
36
+ for testcase_str in target["question_unittests"].split("Unittest")[1:]:
37
+ testcase_str = testcase_str[testcase_str.find(":") + 1 :]
38
+ input_idx = testcase_str.find("Input:")
39
+ std_in_idx = testcase_str.find("STD input:")
40
+ output_idx = testcase_str.find("Output:")
41
+ if input_idx == -1 or std_in_idx == -1 or output_idx == -1:
42
+ tc_parsing_success = False
43
+ break
44
+
45
+ testcase = {
46
+ "input": testcase_str[input_idx + 6 : std_in_idx].strip(),
47
+ "std_in": testcase_str[std_in_idx + 10 : output_idx].strip(),
48
+ "output": testcase_str[output_idx + 7 :].strip(),
49
+ }
50
+ question_test_cases.append(testcase)
51
+
52
+ if not tc_parsing_success:
53
+ continue
54
+
55
+ if len(question_test_cases) < self.num_testcases:
56
+ # If not enough test cases, skip this question
57
+ continue
58
+ if self.num_testcases >= 0:
59
+ # If more than one test case is requested, only take the first ones
60
+ question_test_cases = question_test_cases[: self.num_testcases]
61
+
62
+ # Get student pass (0 or 1) for the target question
63
+ student_correctness_pattern = target.get("pass", None)
64
+ main_part = int(student_correctness_pattern) # "1111111111"
65
+ # Convert each character to an int
66
+ student_correctness_list = [int(ch) for ch in str(main_part)] # [1,1,1,1,1,1,1,1,1,1]
67
+
68
+ # Student specific topic performance in previous attempts
69
+ student_level_prompt = f"Student {student_id} has the following performance across topics:\n"
70
+ topic_performance = student_topic[student_topic["student_id"] == student_id]
71
+ for _, row in topic_performance.iterrows():
72
+ topic = row["topic"]
73
+ pass_rate = round(row["pass_rate"], 2)
74
+ perfect = round(row["perfect"], 2)
75
+
76
+ student_level_prompt += (
77
+ f"- For topic '{topic}', the unit test pass rate is {pass_rate}, "
78
+ f"and the rate of passing all unit tests is {perfect}.\n"
79
+ )
80
+
81
+ prompt = (
82
+ "=== Student Profile ===\n"
83
+ f"{student_level_prompt}\n"
84
+ f"Week: {target['week']}\n"
85
+ f"Topic: {target['topic']}\n\n"
86
+ "Example 1:\n"
87
+ f"Question: {first['question_name']} — {first['question_text']}\n"
88
+ "Template:\n"
89
+ f"{first['question_template']}\n"
90
+ "Your Code:\n"
91
+ f"{first['response']}\n\n"
92
+ "Example 2:\n"
93
+ f"Question: {second['question_name']} — {second['question_text']}\n"
94
+ "Template:\n"
95
+ f"{second['question_template']}\n"
96
+ "Your Code:\n"
97
+ f"{second['response']}\n\n"
98
+ "Example 3:\n"
99
+ f"Question: {third['question_name']} — {third['question_text']}\n"
100
+ "Template:\n"
101
+ f"{third['question_template']}\n"
102
+ "Your Code:\n"
103
+ f"{third['response']}\n\n"
104
+ "Now, using that same student style, attempt this:\n"
105
+ f"Question: {target['question_name']} — {target['question_text']}\n"
106
+ f"Unit Test Input: {question_test_cases}\n\n"
107
+ if question_test_cases
108
+ else ""
109
+ "Template:\n"
110
+ f"{target['question_template']}\n\n"
111
+ "Provide ONLY your C++ implementation following the given template, where the answer will replace the {{ STUDENT_ANSWER }} block in the template. "
112
+ "DO NOT reproduce the template part as the generated code would be inserted to the template, "
113
+ "and make sure the code is compatible with the Unit Test Input. "
114
+ "int main() is always declared already so DO NOT produce that initialization on the code. "
115
+ "Ensure your code includes any class definition when needed. "
116
+ "Return the code in C++ code block format, and nothing else."
117
+ )
118
+ instances.append(
119
+ Instance(
120
+ id=f"{student_id}_{target['question_unittest_id']}",
121
+ input=Input(text=prompt),
122
+ references=[Reference(output=Output(text=target["response"]), tags=[CORRECT_TAG])],
123
+ extra_data={
124
+ "question_template": target["question_template"],
125
+ "test_cases": question_test_cases,
126
+ "question_id": str(question_id) if question_id else None,
127
+ "question_name": target.get("question_name", ""),
128
+ "student_id": str(student_id),
129
+ "student_correctness_pattern": student_correctness_list,
130
+ },
131
+ split=VALID_SPLIT,
132
+ )
133
+ )
134
+ return instances
135
+
136
+ def _load_test_cases(self):
137
+ """
138
+ Load test cases from external source or return None if not available.
139
+ This method should be implemented based on where your test cases are stored.
140
+
141
+ Expected format:
142
+ {
143
+ "question_id": [
144
+ {
145
+ "unittest": "test_id",
146
+ "input": "test input code",
147
+ "output": "expected output"
148
+ },
149
+ ...
150
+ ],
151
+ ...
152
+ }
153
+ """
154
+ try:
155
+ response = requests.get(
156
+ "https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/test_cases_by_qid.json"
157
+ )
158
+ if response.status_code == 200:
159
+ return response.json()
160
+ except Exception as e:
161
+ print(f"Failed to load test cases from URL: {e}")
162
+ return {}
@@ -0,0 +1,188 @@
1
+ from helm.benchmark.scenarios.scenario import Scenario, Instance, Input, Output, Reference, VALID_SPLIT, CORRECT_TAG
2
+ import pandas as pd
3
+ import requests
4
+
5
+
6
+ class CodeInsightsStudentMistakeScenario(Scenario):
7
+ name = "codeinsights_student_mistake"
8
+ description = "Mimic how students mistake their C++ codes on foundational questions"
9
+ tags = ["codeinsights", "c++", "student_mistake"]
10
+
11
+ def __init__(self, num_testcases: int = 1):
12
+ super().__init__()
13
+ self.num_testcases = num_testcases
14
+
15
+ def get_instances(self, output_path: str):
16
+ df = pd.read_csv("https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/Scenario3_data.csv")
17
+ student_topic = pd.read_csv(
18
+ "https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/student_performace_by_topic.csv"
19
+ )
20
+
21
+ instances = []
22
+ for student_id, student_df in df.groupby("student_id"):
23
+ student_df = student_df.sort_values(by=["student_id", "question_unittest_id", "timestamp"])
24
+ if len(student_df) < 4:
25
+ continue
26
+ first = student_df.iloc[0]
27
+ second = student_df.iloc[1]
28
+ third = student_df.iloc[2]
29
+ target = student_df.iloc[3]
30
+
31
+ # Get test cases for this question
32
+ question_id = target.get("question_unittest_id", None)
33
+ question_test_cases = []
34
+ tc_parsing_success = True
35
+
36
+ for testcase_str in target["question_unittests"].split("Unittest")[1:]:
37
+ testcase_str = testcase_str[testcase_str.find(":") + 1 :]
38
+ input_idx = testcase_str.find("Input:")
39
+ std_in_idx = testcase_str.find("STD input:")
40
+ output_idx = testcase_str.find("Output:")
41
+ if input_idx == -1 or std_in_idx == -1 or output_idx == -1:
42
+ tc_parsing_success = False
43
+ break
44
+
45
+ testcase = {
46
+ "input": testcase_str[input_idx + 6 : std_in_idx].strip(),
47
+ "std_in": testcase_str[std_in_idx + 10 : output_idx].strip(),
48
+ "output": testcase_str[output_idx + 7 :].strip(),
49
+ }
50
+ question_test_cases.append(testcase)
51
+
52
+ if not tc_parsing_success:
53
+ continue
54
+
55
+ if len(question_test_cases) < self.num_testcases:
56
+ # If not enough test cases, skip this question
57
+ continue
58
+ if self.num_testcases >= 0:
59
+ # If more than one test case is requested, only take the first ones
60
+ question_test_cases = question_test_cases[: self.num_testcases]
61
+
62
+ # Get student pass (0 or 1) for the target question
63
+ student_correctness_pattern = target.get("pass", None)
64
+ main_part = int(student_correctness_pattern) # "1111111111"
65
+ # Convert each character to an int
66
+ student_correctness_list = [int(ch) for ch in str(main_part)] # [1,1,1,1,1,1,1,1,1,1]
67
+
68
+ # Student specific topic performance in previous attempts
69
+ student_level_prompt = f"Student {student_id} has the following performance across topics:\n"
70
+ topic_performance = student_topic[student_topic["student_id"] == student_id]
71
+ for _, row in topic_performance.iterrows():
72
+ topic = row["topic"]
73
+ pass_rate = round(row["pass_rate"], 2)
74
+ perfect = round(row["perfect"], 2)
75
+
76
+ student_level_prompt += (
77
+ f"- For topic '{topic}', the unit test pass rate is {pass_rate}, "
78
+ f"and the rate of passing all unit tests is {perfect}.\n"
79
+ )
80
+
81
+ prompt = (
82
+ "=== Student Profile ===\n"
83
+ f"{student_level_prompt}\n"
84
+ "When students submit a code to the platform, it will be tested by number of unit tests, where"
85
+ "- Unit test pass rate = proportion of unit tests passed with the code \n"
86
+ "- Full pass rate = proportion of code passing all unit tests\n\n"
87
+ "=== Past Mistake Examples ===\n"
88
+ "Example 1 (Week {first['week']}, Topic: {first['topic']}):\n"
89
+ f"Question: {first['question_name']} — {first['question_text']}\n"
90
+ "Template:\n"
91
+ f"{first['question_template']}\n"
92
+ "Student's Response Code with Error:\n"
93
+ f"{first['response_mistake']}\n\n"
94
+ "Example 2 (Week {second['week']}, Topic: {second['topic']}):\n"
95
+ f"Question: {second['question_name']} — {second['question_text']}\n"
96
+ "Template:\n"
97
+ f"{second['question_template']}\n"
98
+ "Student's Response Code with Error:\n"
99
+ f"{second['response_mistake']}\n\n"
100
+ "Example 3 (Week {third['week']}, Topic: {third['topic']}):\n"
101
+ f"Question: {third['question_name']} — {third['question_text']}\n"
102
+ "Template:\n"
103
+ f"{third['question_template']}\n"
104
+ "Student's Response Code with Error:\n"
105
+ f"{third['response_mistake']}\n\n"
106
+ "=== New Target Problem ===\n"
107
+ f"Week: {target['week']}, Topic: {target['topic']}\n"
108
+ f"Question: {target['question_name']} — {target['question_text']}\n"
109
+ f"Unit Test Input: {question_test_cases}\n\n"
110
+ if question_test_cases
111
+ else ""
112
+ "Template:\n"
113
+ f"{target['question_template']}\n\n"
114
+ "⚠**Instructions:**\n"
115
+ "1. Mimic your own coding style, naming conventions, indentation, and typical error patterns.\n"
116
+ "2. Introduce mistake you are likely to make (e.g., off‐by‐one index, wrong initialization, "
117
+ "missing edge case).\n"
118
+ "3. Do **not** produce a fully correct solution or add unfamiliar optimizations.\n\n"
119
+ "Provide ONLY your C++ implementation following the given template, where the answer will replace the {{ STUDENT_ANSWER }} block in the template. "
120
+ "DO NOT reproduce the template part as the generated code would be inserted to the template, "
121
+ "and make sure the code is compatible with the Unit Test Input. "
122
+ "int main() is always declared already so DO NOT produce that initialization on the code. "
123
+ "Ensure your code is includes any class definition when needed. "
124
+ "Return the code in C++ code block format, and nothing else."
125
+ )
126
+
127
+ print(f"\n=== DEBUG INFO FOR STUDENT {student_id}, QUESTION {question_id} ===")
128
+ print(f"Test cases loaded: {len(question_test_cases)}")
129
+ print(f"Student correctness pattern: {student_correctness_list}")
130
+ print(f"Original pass field: {target.get('pass', 'MISSING')}")
131
+ print(f"Question template exists: {'question_template' in target}")
132
+ print(f"Question name: {target.get('question_name', 'MISSING')}")
133
+
134
+ # Also add this validation in your UnitTestAlignmentMetric evaluate_generation method:
135
+ def evaluate_generation(self, adapter_spec, request_state, metric_service, eval_cache_path):
136
+ print("\n=== UNIT TEST METRIC DEBUG ===")
137
+ print(f"Has extra_data: {hasattr(request_state.instance, 'extra_data')}")
138
+ if hasattr(request_state.instance, "extra_data"):
139
+ extra_data = request_state.instance.extra_data
140
+ print(f"Extra data keys: {list(extra_data.keys())}")
141
+ print(f"Test cases: {len(extra_data.get('test_cases', []))}")
142
+ print(f"Student pattern: {extra_data.get('student_correctness_pattern', 'MISSING')}")
143
+
144
+ instances.append(
145
+ Instance(
146
+ id=f"{student_id}_{target['question_unittest_id']}",
147
+ input=Input(text=prompt),
148
+ references=[Reference(output=Output(text=target["response_mistake"]), tags=[CORRECT_TAG])],
149
+ extra_data={
150
+ "question_template": target["question_template"],
151
+ "test_cases": question_test_cases,
152
+ "question_id": str(question_id) if question_id else None,
153
+ "question_name": target.get("question_name", ""),
154
+ "student_id": str(student_id),
155
+ "student_correctness_pattern": student_correctness_list,
156
+ },
157
+ split=VALID_SPLIT,
158
+ )
159
+ )
160
+ return instances
161
+
162
+ def _load_test_cases(self):
163
+ """
164
+ Load test cases from external source or return None if not available.
165
+ This method should be implemented based on where your test cases are stored.
166
+
167
+ Expected format:
168
+ {
169
+ "question_id": [
170
+ {
171
+ "unittest": "test_id",
172
+ "input": "test input code",
173
+ "output": "expected output"
174
+ },
175
+ ...
176
+ ],
177
+ ...
178
+ }
179
+ """
180
+ try:
181
+ response = requests.get(
182
+ "https://huggingface.co/datasets/Kazchoko/my_dataset/resolve/main/test_cases_by_qid.json"
183
+ )
184
+ if response.status_code == 200:
185
+ return response.json()
186
+ except Exception as e:
187
+ print(f"Failed to load test cases from URL: {e}")
188
+ return {}
@@ -1,5 +1,5 @@
1
1
  from typing import List
2
- from helm.common.general import ensure_directory_exists
2
+ from helm.common.general import check_file_exists
3
3
  from helm.benchmark.scenarios.scenario import (
4
4
  Input,
5
5
  Scenario,
@@ -21,26 +21,34 @@ def file_preprocessing(data_path: str, task_objective: str) -> pd.DataFrame:
21
21
  data_path is directory that contains the downloaded files: '{base_dir}/physionet.org/'
22
22
  """
23
23
  # Load the first CSV file
24
- df_diagnosis = pd.read_csv(
25
- f"{data_path}/files/discharge-me/1.3/test_phase_1/diagnosis.csv.gz", compression="gzip", keep_default_na=False
24
+ diagnosis_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/diagnosis.csv.gz"
25
+ check_file_exists(
26
+ diagnosis_path, msg=f"[DischargeMeScenario] Required diagnosis file not found: '{diagnosis_path}'"
26
27
  )
27
- df_discharge = pd.read_csv(
28
- f"{data_path}/files/discharge-me/1.3/test_phase_1/discharge.csv.gz", compression="gzip", keep_default_na=False
28
+ discharge_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/discharge.csv.gz"
29
+ check_file_exists(
30
+ discharge_path, msg=f"[DischargeMeScenario] Required discharge file not found: '{discharge_path}'"
29
31
  )
32
+ target_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/discharge_target.csv.gz"
33
+ check_file_exists(target_path, msg=f"[DischargeMeScenario] Required target file not found: '{target_path}'")
34
+ radiology_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/radiology.csv.gz"
35
+ check_file_exists(
36
+ radiology_path, msg=f"[DischargeMeScenario] Required radiology file not found: '{radiology_path}'"
37
+ )
38
+ ed_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/edstays.csv.gz"
39
+ check_file_exists(ed_path, msg=f"[DischargeMeScenario] Required ed file not found: '{ed_path}'")
40
+ triage_path = f"{data_path}/files/discharge-me/1.3/test_phase_1/triage.csv.gz"
41
+ check_file_exists(triage_path, msg=f"[DischargeMeScenario] Required triage file not found: '{triage_path}'")
42
+ df_diagnosis = pd.read_csv(diagnosis_path, compression="gzip", keep_default_na=False)
43
+ df_discharge = pd.read_csv(discharge_path, compression="gzip", keep_default_na=False)
30
44
  df_target = pd.read_csv(
31
- f"{data_path}/files/discharge-me/1.3/test_phase_1/discharge_target.csv.gz",
45
+ target_path,
32
46
  compression="gzip",
33
47
  keep_default_na=False,
34
48
  )
35
- df_radiology = pd.read_csv(
36
- f"{data_path}/files/discharge-me/1.3/test_phase_1/radiology.csv.gz", compression="gzip", keep_default_na=False
37
- )
38
- df_ed = pd.read_csv(
39
- f"{data_path}/files/discharge-me/1.3/test_phase_1/edstays.csv.gz", compression="gzip", keep_default_na=False
40
- )
41
- df_triage = pd.read_csv(
42
- f"{data_path}/files/discharge-me/1.3/test_phase_1/triage.csv.gz", compression="gzip", keep_default_na=False
43
- )
49
+ df_radiology = pd.read_csv(radiology_path, compression="gzip", keep_default_na=False)
50
+ df_ed = pd.read_csv(ed_path, compression="gzip", keep_default_na=False)
51
+ df_triage = pd.read_csv(triage_path, compression="gzip", keep_default_na=False)
44
52
  df_diagnosis_triage = pd.merge(
45
53
  df_diagnosis, df_triage, on="subject_id", how="inner", suffixes=("_df_diagnosis", "_df_triage")
46
54
  )
@@ -113,16 +121,23 @@ class DischargeMeScenario(Scenario):
113
121
  """
114
122
 
115
123
  name = "dischargeme"
116
- description = "DischargeMe is a discharge instruction generation dataset and brief hospital course generation \
117
- dataset collected from MIMIC-IV data, consindering only the discharge text as well as the radiology report text."
124
+ description = (
125
+ "DischargeMe is a benchmark designed to evaluate clinical text generation. It pairs"
126
+ "discharge summaries and radiology reports from MIMIC-IV with generation tasks"
127
+ "such as writing discharge instructions or summarizing the brief hospital course. The"
128
+ "benchmark assesses a model's ability to generate patient-facing documentation that is"
129
+ "complete, empathetic, and clinically accurate."
130
+ )
118
131
  tags = ["biomedical"]
119
132
 
133
+ def __init__(self, data_path: str):
134
+ super().__init__()
135
+ self.data_path = data_path
136
+
120
137
  def get_instances(self, output_path: str) -> List[Instance]:
121
- data_path = "/share/pi/nigam/data/physionet.org"
122
- ensure_directory_exists(data_path)
123
138
  instances: List[Instance] = []
124
- df_bhc = file_preprocessing(data_path, "brief_hospital_course")
125
- df_di = file_preprocessing(data_path, "discharge_instructions")
139
+ df_bhc = file_preprocessing(self.data_path, "brief_hospital_course")
140
+ df_di = file_preprocessing(self.data_path, "discharge_instructions")
126
141
 
127
142
  for i in range(df_bhc.shape[0]):
128
143
  prompt_bhc = create_prompt(
@@ -36,7 +36,13 @@ class EhrSqlScenario(Scenario):
36
36
  )
37
37
 
38
38
  name = "ehr_sql"
39
- description = "Given a natural language instruction, generate an SQL query that would be used in clinical research."
39
+ description = (
40
+ "EHRSQL is a benchmark designed to evaluate models on generating structured queries"
41
+ "for clinical research. Each example includes a natural language question and a database"
42
+ "schema, and the task is to produce an SQL query that would return the correct result"
43
+ "for a biomedical research objective. This benchmark assesses a model's understanding"
44
+ "of medical terminology, data structures, and query construction."
45
+ )
40
46
  tags = ["sql", "medical", "reasoning"]
41
47
 
42
48
  def setup_database(self, output_path: str) -> str:
@@ -3,12 +3,11 @@ import os
3
3
  import pandas as pd
4
4
  import tiktoken
5
5
 
6
- from filelock import FileLock
7
6
  from functools import partial
8
7
  from tqdm import tqdm
9
8
  from typing import Any, Dict, List, Optional, Mapping
10
9
 
11
- from helm.common.general import ensure_directory_exists
10
+ from helm.common.general import check_file_exists, ensure_directory_exists
12
11
  from helm.benchmark.scenarios.scenario import (
13
12
  TEST_SPLIT,
14
13
  Input,
@@ -1411,7 +1410,10 @@ class EHRSHOTScenario(Scenario):
1411
1410
 
1412
1411
  name = "ehrshot"
1413
1412
  description = (
1414
- "A dataset given a patient record of EHR codes, classifying if an event will occur at a future date or not."
1413
+ "EHRSHOT is a benchmark designed to evaluate a model's ability to predict future"
1414
+ "clinical events using structured EHR data. Each instance contains a patient's"
1415
+ "historical EHR data and a forward-looking clinical question about whether a particular"
1416
+ "diagnosis, lab result, or hospital event will occur."
1415
1417
  )
1416
1418
  tags = [] # TODO
1417
1419
 
@@ -1420,24 +1422,32 @@ class EHRSHOTScenario(Scenario):
1420
1422
  "no",
1421
1423
  ]
1422
1424
 
1423
- def __init__(self, subject: str, max_length: Optional[int] = None):
1425
+ def __init__(self, subject: str, data_path: str, max_length: Optional[int] = None):
1424
1426
  super().__init__()
1425
1427
  self.subject: str = subject # same as "task" or "labeling_function"
1426
- self.path_to_meds_dir: str = "/share/pi/nigam/data/medhelm/ehrshot/meds/"
1427
- self.path_to_tmp_dir: str = "/share/pi/nigam/data/medhelm/ehrshot/prompts/"
1428
1428
  self.max_length = max_length
1429
+ self.data_path = data_path
1429
1430
 
1430
- def create_benchmark(self, n_procs: int = 4) -> Dict[str, str]:
1431
+ def create_benchmark(self, output_path: str, n_procs: int = 4) -> Dict[str, str]:
1431
1432
  """Loads the MEDS dataset and converts it to prompts"""
1432
-
1433
1433
  # Load MEDS EHRSHOT patient timelines
1434
- df_data = pd.read_parquet(os.path.join(self.path_to_meds_dir, "data/data.parquet"))
1435
- df_splits = pd.read_parquet(os.path.join(self.path_to_meds_dir, "metadata/subject_splits.parquet"))
1436
-
1434
+ data_parquet_path = os.path.join(self.data_path, "data/data.parquet")
1435
+ check_file_exists(
1436
+ data_parquet_path, msg=f"[EHRSHOTScenario] Required parquet data file not found: '{data_parquet_path}'"
1437
+ )
1438
+ splits_parquet_path = os.path.join(self.data_path, "metadata/subject_splits.parquet")
1439
+ check_file_exists(
1440
+ splits_parquet_path, msg=f"[EHRSHOTScenario] Required splits file not found: '{splits_parquet_path}'"
1441
+ )
1442
+ df_data = pd.read_parquet(data_parquet_path)
1443
+ df_splits = pd.read_parquet(splits_parquet_path)
1437
1444
  # Load MEDS EHRSHOT labels
1438
- tasks = sorted(os.listdir(os.path.join(self.path_to_meds_dir, "labels")))
1445
+ tasks = sorted(os.listdir(os.path.join(self.data_path, "labels")))
1439
1446
  for t in tasks:
1440
- path_to_labels: str = os.path.join(self.path_to_meds_dir, "labels", t, "labels.parquet")
1447
+ path_to_labels: str = os.path.join(self.data_path, "labels", t, "labels.parquet")
1448
+ check_file_exists(
1449
+ path_to_labels, msg=f"[EHRSHOTScenario] Required labels file not found: '{path_to_labels}'"
1450
+ )
1441
1451
  if t != self.subject or not os.path.exists(path_to_labels):
1442
1452
  continue
1443
1453
  df_labels = pd.read_parquet(path_to_labels)
@@ -1470,18 +1480,16 @@ class EHRSHOTScenario(Scenario):
1470
1480
  df_labels["prompt"] = prompts
1471
1481
 
1472
1482
  # Save to parquet
1473
- path_to_output_dir: str = os.path.join(self.path_to_tmp_dir, self.subject)
1483
+ path_to_output_dir: str = os.path.join(output_path, self.subject)
1474
1484
  ensure_directory_exists(path_to_output_dir)
1475
1485
  df_labels.to_parquet(os.path.join(path_to_output_dir, "medhelm_prompts.parquet"))
1476
1486
  return {"status": "success"}
1477
1487
 
1478
1488
  def get_instances(self, output_path: str) -> List[Instance]:
1479
- path_to_input_csv: str = os.path.join(self.path_to_tmp_dir, self.subject, "medhelm_prompts.parquet")
1480
- lock_path = path_to_input_csv + ".lock"
1481
- with FileLock(lock_path):
1482
- if not os.path.exists(path_to_input_csv):
1483
- print(f"Creating benchmark from SCRATCH for {self.subject}...")
1484
- self.create_benchmark() # Create benchmark from scratch
1489
+ path_to_input_csv: str = os.path.join(output_path, self.subject, "medhelm_prompts.parquet")
1490
+ if not os.path.exists(path_to_input_csv):
1491
+ print(f"Creating benchmark from SCRATCH for {self.subject}...")
1492
+ self.create_benchmark(output_path=output_path) # Create benchmark from scratch
1485
1493
 
1486
1494
  # Load data for this task
1487
1495
  df = pd.read_parquet(path_to_input_csv)
@@ -1509,38 +1517,3 @@ class EHRSHOTScenario(Scenario):
1509
1517
  )
1510
1518
 
1511
1519
  return instances
1512
-
1513
-
1514
- if __name__ == "__main__":
1515
- # Generate statistics on prompts
1516
- from transformers import AutoTokenizer
1517
-
1518
- tokenizer = AutoTokenizer.from_pretrained("gpt2")
1519
- tqdm.pandas()
1520
- n_procs: int = 10
1521
-
1522
- os.makedirs("./ehrshot_stats", exist_ok=True)
1523
- for t in TASK_FULL_NAMES.keys():
1524
- # Skip if already exists
1525
- if os.path.exists(f"./ehrshot_stats/{t}.txt"):
1526
- print(f"Skipping {t} because it already exists")
1527
- continue
1528
-
1529
- # Create benchmark
1530
- scenario = EHRSHOTScenario(subject=t)
1531
- scenario.create_benchmark(n_procs=n_procs)
1532
- instances = scenario.get_instances("test.csv")
1533
-
1534
- # Calculate prompt token stats
1535
- path_to_input_csv = os.path.join(scenario.path_to_tmp_dir, scenario.subject, "medhelm_prompts.parquet")
1536
- df = pd.read_parquet(path_to_input_csv)
1537
- df["prompt_n_tokens"] = df["prompt"].progress_apply(lambda x: len(tokenizer.encode(x)))
1538
- with open(f"./ehrshot_stats/{t}.txt", "w") as f:
1539
- f.write("-" * 100 + "\n")
1540
- f.write(f"Task: {t}\n")
1541
- f.write(f"# of instances: {len(instances)}\n")
1542
- f.write(f"# of positives: {df['boolean_value'].sum()}\n")
1543
- f.write(f"Size of splits:\n{df['split'].value_counts()}\n")
1544
- f.write(f"# tokens per prompt:\n{df['prompt_n_tokens'].describe()}\n")
1545
- f.write("-" * 100 + "\n")
1546
- df.to_parquet(os.path.join(scenario.path_to_tmp_dir, scenario.subject, "medhelm_prompts.parquet"))