crfm-helm 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (268) hide show
  1. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/METADATA +74 -53
  2. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/RECORD +262 -182
  3. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/WHEEL +1 -1
  4. helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +3 -3
  5. helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
  6. helm/benchmark/annotation/air_bench_annotator.py +2 -2
  7. helm/benchmark/annotation/bigcodebench_annotator.py +3 -3
  8. helm/benchmark/annotation/bird_sql_annotator.py +2 -2
  9. helm/benchmark/annotation/chw_care_plan_annotator.py +7 -12
  10. helm/benchmark/annotation/ehr_sql_annotator.py +2 -2
  11. helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +7 -7
  12. helm/benchmark/annotation/live_qa_annotator.py +1 -1
  13. helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
  14. helm/benchmark/annotation/model_as_judge.py +12 -16
  15. helm/benchmark/annotation/omni_math_annotator.py +13 -14
  16. helm/benchmark/annotation/wildbench_annotator.py +9 -9
  17. helm/benchmark/executor.py +11 -12
  18. helm/benchmark/metrics/aci_bench_metrics.py +9 -29
  19. helm/benchmark/metrics/bias_word_lists.py +1 -1
  20. helm/benchmark/metrics/chw_care_plan_metrics.py +10 -30
  21. helm/benchmark/metrics/classification_metrics.py +3 -3
  22. helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
  23. helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
  24. helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
  25. helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
  26. helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
  27. helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
  28. helm/benchmark/metrics/comet_metric.py +1 -1
  29. helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +2 -2
  30. helm/benchmark/metrics/copyright_metrics.py +1 -1
  31. helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
  32. helm/benchmark/metrics/dischargeme_metrics.py +9 -29
  33. helm/benchmark/metrics/efficiency_metrics.py +3 -3
  34. helm/benchmark/metrics/evaluate_reference_metrics.py +1 -1
  35. helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
  36. helm/benchmark/metrics/ifeval_metrics.py +2 -2
  37. helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
  38. helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
  39. helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
  40. helm/benchmark/metrics/llm_jury_metrics.py +46 -0
  41. helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
  42. helm/benchmark/metrics/lmkt_metrics.py +47 -0
  43. helm/benchmark/metrics/med_dialog_metrics.py +9 -29
  44. helm/benchmark/metrics/medalign_metrics.py +9 -29
  45. helm/benchmark/metrics/medi_qa_metrics.py +9 -29
  46. helm/benchmark/metrics/medication_qa_metrics.py +10 -30
  47. helm/benchmark/metrics/melt_bias_metric.py +234 -0
  48. helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
  49. helm/benchmark/metrics/melt_metric_specs.py +43 -0
  50. helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
  51. helm/benchmark/metrics/mental_health_metrics.py +9 -29
  52. helm/benchmark/metrics/metric_service.py +11 -11
  53. helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
  54. helm/benchmark/metrics/mimic_rrs_metrics.py +9 -29
  55. helm/benchmark/metrics/mtsamples_procedures_metrics.py +9 -29
  56. helm/benchmark/metrics/mtsamples_replicate_metrics.py +9 -29
  57. helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
  58. helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
  59. helm/benchmark/metrics/starr_patient_instructions_metrics.py +9 -29
  60. helm/benchmark/metrics/summac/model_summac.py +2 -3
  61. helm/benchmark/metrics/summarization_metrics.py +2 -1
  62. helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +2 -2
  63. helm/benchmark/metrics/toxicity_metrics.py +2 -2
  64. helm/benchmark/metrics/unitxt_metrics.py +3 -4
  65. helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
  66. helm/benchmark/metrics/vision_language/image_utils.py +2 -2
  67. helm/benchmark/model_deployment_registry.py +16 -26
  68. helm/benchmark/presentation/contamination.py +3 -3
  69. helm/benchmark/presentation/create_plots.py +43 -13
  70. helm/benchmark/presentation/run_display.py +13 -0
  71. helm/benchmark/presentation/schema.py +7 -1
  72. helm/benchmark/presentation/summarize.py +84 -61
  73. helm/benchmark/presentation/test_create_plots.py +4 -1
  74. helm/benchmark/reeval_run.py +3 -4
  75. helm/benchmark/reeval_runner.py +3 -3
  76. helm/benchmark/run.py +84 -73
  77. helm/benchmark/run_expander.py +12 -1
  78. helm/benchmark/run_spec_factory.py +7 -6
  79. helm/benchmark/run_specs/arabic_run_specs.py +73 -0
  80. helm/benchmark/run_specs/audio_run_specs.py +52 -8
  81. helm/benchmark/run_specs/bluex_run_specs.py +40 -0
  82. helm/benchmark/run_specs/classic_run_specs.py +0 -53
  83. helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
  84. helm/benchmark/run_specs/enterprise_run_specs.py +20 -0
  85. helm/benchmark/run_specs/experimental_run_specs.py +31 -1
  86. helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
  87. helm/benchmark/run_specs/heim_run_specs.py +3 -1
  88. helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
  89. helm/benchmark/run_specs/long_context_run_specs.py +114 -15
  90. helm/benchmark/run_specs/medhelm_run_specs.py +146 -41
  91. helm/benchmark/run_specs/melt_run_specs.py +783 -0
  92. helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
  93. helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +163 -0
  94. helm/benchmark/run_specs/vlm_run_specs.py +28 -0
  95. helm/benchmark/runner.py +5 -5
  96. helm/benchmark/scenarios/aci_bench_scenario.py +7 -1
  97. helm/benchmark/scenarios/alghafa_scenario.py +126 -0
  98. helm/benchmark/scenarios/arabic_mmlu_scenario.py +78 -0
  99. helm/benchmark/scenarios/aratrust_scenario.py +76 -0
  100. helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +3 -1
  101. helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +5 -5
  102. helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
  103. helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
  104. helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
  105. helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +104 -0
  106. helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +99 -0
  107. helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +118 -0
  108. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +86 -0
  109. helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +117 -0
  110. helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +15 -1
  111. helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +1 -2
  112. helm/benchmark/scenarios/autobencher_capabilities_scenario.py +2 -2
  113. helm/benchmark/scenarios/bluex_scenario.py +66 -0
  114. helm/benchmark/scenarios/chw_care_plan_scenario.py +14 -13
  115. helm/benchmark/scenarios/clear_scenario.py +11 -7
  116. helm/benchmark/scenarios/cleva_scenario.py +1 -1
  117. helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
  118. helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
  119. helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
  120. helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
  121. helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
  122. helm/benchmark/scenarios/dischargeme_scenario.py +36 -21
  123. helm/benchmark/scenarios/ehr_sql_scenario.py +7 -1
  124. helm/benchmark/scenarios/ehrshot_scenario.py +28 -55
  125. helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
  126. helm/benchmark/scenarios/grammar.py +2 -2
  127. helm/benchmark/scenarios/headqa_scenario.py +6 -1
  128. helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
  129. helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +90 -0
  130. helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
  131. helm/benchmark/scenarios/{infinite_bench_sum_scenario.py → infinite_bench_en_sum_scenario.py} +10 -13
  132. helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
  133. helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
  134. helm/benchmark/scenarios/math_scenario.py +21 -20
  135. helm/benchmark/scenarios/med_dialog_scenario.py +6 -1
  136. helm/benchmark/scenarios/medalign_scenario.py +9 -3
  137. helm/benchmark/scenarios/medalign_scenario_helper.py +27 -130
  138. helm/benchmark/scenarios/medbullets_scenario.py +7 -2
  139. helm/benchmark/scenarios/medcalc_bench_scenario.py +4 -2
  140. helm/benchmark/scenarios/medec_scenario.py +6 -1
  141. helm/benchmark/scenarios/medhallu_scenario.py +7 -1
  142. helm/benchmark/scenarios/medi_qa_scenario.py +10 -4
  143. helm/benchmark/scenarios/medication_qa_scenario.py +7 -1
  144. helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
  145. helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
  146. helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
  147. helm/benchmark/scenarios/melt_scenarios.py +793 -0
  148. helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
  149. helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
  150. helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
  151. helm/benchmark/scenarios/mental_health_scenario.py +16 -5
  152. helm/benchmark/scenarios/mimic_bhc_scenario.py +13 -8
  153. helm/benchmark/scenarios/mimic_rrs_scenario.py +17 -8
  154. helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +14 -8
  155. helm/benchmark/scenarios/mmlu_pro_scenario.py +1 -1
  156. helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
  157. helm/benchmark/scenarios/mtsamples_procedures_scenario.py +5 -2
  158. helm/benchmark/scenarios/mtsamples_replicate_scenario.py +3 -2
  159. helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +11 -5
  160. helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
  161. helm/benchmark/scenarios/pubmed_qa_scenario.py +6 -1
  162. helm/benchmark/scenarios/race_based_med_scenario.py +18 -8
  163. helm/benchmark/scenarios/ruler_qa_scenario_helper.py +2 -2
  164. helm/benchmark/scenarios/ruler_qa_scenarios.py +2 -2
  165. helm/benchmark/scenarios/seahelm_scenario.py +2 -2
  166. helm/benchmark/scenarios/shc_bmt_scenario.py +12 -6
  167. helm/benchmark/scenarios/shc_cdi_scenario.py +11 -6
  168. helm/benchmark/scenarios/shc_conf_scenario.py +12 -6
  169. helm/benchmark/scenarios/shc_ent_scenario.py +11 -6
  170. helm/benchmark/scenarios/shc_gip_scenario.py +13 -5
  171. helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
  172. helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
  173. helm/benchmark/scenarios/shc_ptbm_scenario.py +12 -7
  174. helm/benchmark/scenarios/shc_sei_scenario.py +12 -7
  175. helm/benchmark/scenarios/shc_sequoia_scenario.py +13 -5
  176. helm/benchmark/scenarios/starr_patient_instructions_scenario.py +15 -8
  177. helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
  178. helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
  179. helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
  180. helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
  181. helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
  182. helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
  183. helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
  184. helm/benchmark/scenarios/truthful_qa_scenario.py +2 -1
  185. helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
  186. helm/benchmark/server.py +2 -1
  187. helm/benchmark/slurm_jobs.py +1 -2
  188. helm/benchmark/slurm_runner.py +8 -1
  189. helm/benchmark/static/schema_arabic.yaml +228 -0
  190. helm/benchmark/static/schema_audio.yaml +60 -49
  191. helm/benchmark/static/schema_classic.yaml +0 -17
  192. helm/benchmark/static/schema_enterprise.yaml +21 -0
  193. helm/benchmark/static/schema_long_context.yaml +81 -20
  194. helm/benchmark/static/schema_medhelm.yaml +272 -213
  195. helm/benchmark/static/schema_melt.yaml +1257 -0
  196. helm/benchmark/static/schema_slphelm.yaml +162 -0
  197. helm/benchmark/static/schema_vhelm.yaml +26 -26
  198. helm/benchmark/static/schema_video.yaml +219 -0
  199. helm/benchmark/static_build/assets/index-b9779128.css +1 -0
  200. helm/benchmark/static_build/assets/index-e439d5e1.js +10 -0
  201. helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
  202. helm/benchmark/static_build/assets/{tremor-9cefc3c5.js → tremor-38a10867.js} +1 -1
  203. helm/benchmark/static_build/index.html +4 -4
  204. helm/benchmark/window_services/encoder_decoder_window_service.py +3 -3
  205. helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
  206. helm/benchmark/window_services/test_utils.py +3 -4
  207. helm/benchmark/window_services/tokenizer_service.py +7 -8
  208. helm/clients/anthropic_client.py +69 -29
  209. helm/clients/audio_language/diva_llama_client.py +4 -2
  210. helm/clients/audio_language/qwen2_5_omni_client.py +209 -0
  211. helm/clients/audio_language/qwen2_audiolm_client.py +8 -6
  212. helm/clients/audio_language/qwen_audiolm_client.py +4 -2
  213. helm/clients/audio_language/test.py +62 -0
  214. helm/clients/bedrock_client.py +3 -1
  215. helm/clients/client.py +7 -7
  216. helm/clients/grok_client.py +36 -0
  217. helm/clients/huggingface_client.py +42 -3
  218. helm/clients/huggingface_pipeline_client.py +138 -0
  219. helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
  220. helm/clients/image_generation/dalle_mini/model/modeling.py +1 -1
  221. helm/clients/image_generation/dalle_mini/model/processor.py +1 -1
  222. helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
  223. helm/clients/openai_client.py +102 -55
  224. helm/clients/openai_responses_client.py +176 -0
  225. helm/clients/palmyra_client.py +2 -5
  226. helm/clients/reka_client.py +2 -2
  227. helm/clients/test_huggingface_client.py +3 -3
  228. helm/clients/together_client.py +31 -6
  229. helm/clients/vertexai_client.py +17 -9
  230. helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
  231. helm/clients/vision_language/huggingface_vlm_client.py +2 -2
  232. helm/clients/vision_language/idefics_client.py +6 -2
  233. helm/clients/vision_language/paligemma_client.py +2 -2
  234. helm/clients/vision_language/qwen2_vlm_client.py +66 -53
  235. helm/clients/vision_language/qwen_vlm_client.py +7 -5
  236. helm/clients/vllm_client.py +43 -7
  237. helm/clients/vllm_granite_thinking_client.py +56 -0
  238. helm/clients/writer_client.py +102 -0
  239. helm/common/context.py +80 -0
  240. helm/common/credentials_utils.py +5 -5
  241. helm/common/critique_request.py +0 -1
  242. helm/common/general.py +9 -2
  243. helm/common/hierarchical_logger.py +104 -12
  244. helm/common/local_context.py +140 -0
  245. helm/common/object_spec.py +23 -8
  246. helm/common/remote_context.py +61 -0
  247. helm/common/request.py +8 -0
  248. helm/common/test_logging.py +94 -0
  249. helm/config/model_deployments.yaml +995 -45
  250. helm/config/model_metadata.yaml +780 -59
  251. helm/config/tokenizer_configs.yaml +224 -3
  252. helm/proxy/cli.py +4 -2
  253. helm/proxy/critique/mechanical_turk_utils.py +1 -1
  254. helm/proxy/retry.py +5 -0
  255. helm/proxy/services/server_service.py +21 -85
  256. helm/tokenizers/grok_tokenizer.py +55 -0
  257. helm/tokenizers/huggingface_tokenizer.py +1 -1
  258. helm/tokenizers/test_grok_tokenizer.py +33 -0
  259. helm/benchmark/metrics/numeracy_metrics.py +0 -72
  260. helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
  261. helm/benchmark/scenarios/numeracy_scenario.py +0 -793
  262. helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +0 -46
  263. helm/benchmark/static_build/assets/index-262903c1.js +0 -10
  264. helm/benchmark/static_build/assets/index-42060d71.css +0 -1
  265. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/entry_points.txt +0 -0
  266. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/licenses/LICENSE +0 -0
  267. {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/top_level.txt +0 -0
  268. /helm/benchmark/static_build/assets/{medhelm-overview-3ddfcd65.png → medhelm-v1-overview-3ddfcd65.png} +0 -0
@@ -0,0 +1,288 @@
1
+ """Cultural alignment evaluation scenario based on Vietnam World Values Survey responses."""
2
+
3
+ import os
4
+ import json
5
+ import random
6
+ from typing import List
7
+ from datasets import load_dataset
8
+ from huggingface_hub import snapshot_download
9
+
10
+ from helm.common.hierarchical_logger import hlog, hwarn
11
+ from helm.benchmark.scenarios.scenario import (
12
+ Scenario,
13
+ Instance,
14
+ Reference,
15
+ TEST_SPLIT,
16
+ CORRECT_TAG,
17
+ Input,
18
+ Output,
19
+ )
20
+
21
+ SUPPORTED_LANGUAGES = ["en", "vi"]
22
+
23
+
24
+ class CulturalValueUnderstandingWVSScenario(Scenario):
25
+ """Cultural values understanding evaluation based on Vietnam World Values Survey responses."""
26
+
27
+ name = "cultural_value_understanding_wvs"
28
+ description = "Evaluates model understanding of cultural values from WVS Wave 7"
29
+ tags = ["cultural_value_understanding"]
30
+
31
+ def __init__(
32
+ self,
33
+ language: str,
34
+ personas_filename: str = "vn_personas_with_filtered_responses.json",
35
+ questions_filename: str = "wvs_questions_dump.en.json",
36
+ repo_id: str = "stair-lab/cultural-alignment-dataset",
37
+ num_personas: int = 10,
38
+ num_question_variants: int = 2,
39
+ include_few_shot_examples: bool = False,
40
+ ):
41
+ super().__init__()
42
+ self.language = language
43
+ self.personas_filename = personas_filename
44
+ self.questions_filename = questions_filename
45
+ self.repo_id = repo_id
46
+ self.num_personas = 10 if isinstance(num_personas, str) else num_personas
47
+ self.num_question_variants = min(num_question_variants, 4)
48
+ self.include_few_shot_examples = include_few_shot_examples
49
+
50
+ if self.language not in SUPPORTED_LANGUAGES:
51
+ raise ValueError(f"Unsupported language: {self.language}. Supported languages are: {SUPPORTED_LANGUAGES}")
52
+
53
+ def get_few_shot_examples(self) -> str:
54
+ """Return a few-shot learning examples formatted as a string."""
55
+ examples = """Here are some examples:
56
+
57
+ Example 1:
58
+ You are a 35 year old male from Vietnam. You have higher education and consider yourself middle class. You are married.
59
+
60
+ How important is family in your life?
61
+
62
+ 1. Very important
63
+ 2. Rather important
64
+ 3. Not very important
65
+ 4. Not at all important
66
+
67
+ Please select the option number that best represents your view.
68
+ 1
69
+
70
+ Example 2:
71
+ You are a 28 year old female from Vietnam. You have middle education and consider yourself working class. You are single.
72
+
73
+ Would you say that most people can be trusted or that you need to be very careful in dealing with people?
74
+
75
+ 1. Most people can be trusted
76
+ 2. Need to be very careful
77
+
78
+ Please select the option number that best represents your view.
79
+ 2
80
+
81
+ Example 3:
82
+ You are a 50 year old male from Vietnam. You have lower education and consider yourself lower class. You are married.
83
+
84
+ Do you think that homosexuality is justifiable?
85
+
86
+ 1. Never justifiable
87
+ 2. Rarely justifiable
88
+ 3. Sometimes justifiable
89
+ 4. Always justifiable
90
+
91
+ Please select the option number that best represents your view.
92
+ 1
93
+
94
+ Now answer the following question:
95
+ """ # noqa: E501
96
+ return examples
97
+
98
+ def get_instances(self, output_path: str) -> List[Instance]:
99
+ """Generate test instances from Vietnam personas and WVS questions."""
100
+ instances: List[Instance] = []
101
+
102
+ # Download files from Hugging Face Hub
103
+ repo_local_path = snapshot_download(
104
+ repo_id=self.repo_id, repo_type="dataset", revision="fe54b6f5d75cfca5377707cd7199e39f517e3a1f"
105
+ )
106
+
107
+ # Load the downloaded files
108
+ with open(os.path.join(repo_local_path, self.personas_filename), "r", encoding="utf-8") as f:
109
+ personas = json.load(f)
110
+
111
+ with open(os.path.join(repo_local_path, self.questions_filename), "r", encoding="utf-8") as f:
112
+ questions = json.load(f)
113
+
114
+ # Get few-shot examples
115
+ few_shot_examples = self.get_few_shot_examples() if self.include_few_shot_examples else ""
116
+
117
+ # Sample personas
118
+ sampled_personas = random.sample(personas, min(self.num_personas, len(personas)))
119
+
120
+ # Create instances for each persona and question
121
+ for persona in sampled_personas:
122
+ # Get demographic info for persona description
123
+ persona_desc = (
124
+ f"You are a {persona.get('age', 'adult')} year old {persona.get('sex', 'person')} from Vietnam. "
125
+ )
126
+ persona_desc += f"You have {persona.get('education', 'some')} education and consider yourself {persona.get('social_class', 'middle class')}. " # noqa: E501
127
+ persona_desc += f"You are {persona.get('marital_status', 'single')}."
128
+
129
+ # Process each question this persona answered
130
+ for qid, human_response in persona.get("responses", {}).items():
131
+ # Skip if no human response or if it's 0 (which might be a "Don't know" response)
132
+ if human_response is None:
133
+ continue
134
+
135
+ # Convert human_response to int (if possible)
136
+ try:
137
+ human_response_int = int(human_response)
138
+ except (ValueError, TypeError):
139
+ # Skip if human_response can't be converted to int
140
+ continue
141
+
142
+ # Get question info
143
+ question_data = questions.get(qid, {})
144
+ if not question_data:
145
+ continue
146
+
147
+ # Get options directly from question data
148
+ q_options = question_data.get("options", [])
149
+ if not q_options:
150
+ continue
151
+
152
+ # Skip if human_response is out of range
153
+ if human_response_int < 0 or human_response_int > len(q_options):
154
+ continue
155
+
156
+ # Special handling for "Don't know" or zero responses
157
+ if human_response_int == 0:
158
+ # Some questions might encode "Don't know" as 0
159
+ # Skip for now, or you could add special handling
160
+ continue
161
+
162
+ # Use the predefined question variations
163
+ question_variants = question_data.get("questions", [])
164
+ if not question_variants:
165
+ question_variants = [f"Question {qid}: {question_data.get('description', '')}"]
166
+
167
+ # Use the specified number of variants
168
+ variants_to_use = min(self.num_question_variants, len(question_variants))
169
+ selected_variants = question_variants[:variants_to_use]
170
+
171
+ # Create instances for each selected question variant
172
+ for q_text in selected_variants:
173
+ # Format the prompt with or without few-shot examples
174
+ if self.include_few_shot_examples:
175
+ prompt = f"{few_shot_examples}{persona_desc}\n\n{q_text}\n\n"
176
+ else:
177
+ prompt = f"{persona_desc}\n\n{q_text}\n\n"
178
+
179
+ # Add options from question data - with numbers, not letters
180
+ for i, opt in enumerate(q_options, 1):
181
+ prompt += f"{i}. {opt}\n"
182
+
183
+ prompt += "\nPlease select the option number that best represents your view. Return only the option number. Do not return anything else." # noqa: E501
184
+
185
+ # Create a reference with just the human response number
186
+ # We don't create multiple references, just use the actual human response
187
+ reference = Reference(Output(text=str(human_response_int)), tags=[CORRECT_TAG])
188
+
189
+ # Create the instance
190
+ instance = Instance(
191
+ input=Input(text=prompt),
192
+ references=[reference],
193
+ split=TEST_SPLIT,
194
+ )
195
+ instances.append(instance)
196
+
197
+ if not instances:
198
+ hwarn("No valid instances were created. Check the input data and parameters.")
199
+ else:
200
+ hlog(f"Created {len(instances)} instances for cultural value understanding scenario.")
201
+
202
+ return instances
203
+
204
+
205
+ class SocialNormApplicationNormADScenario(Scenario):
206
+ """Social norms application evaluation based on Vietnam World Values Survey responses."""
207
+
208
+ name = "social_norm_application_normad"
209
+ description = "Evaluates model understanding of social norms from WVS Wave 7"
210
+ tags = ["social_norm_application"]
211
+
212
+ def __init__(
213
+ self,
214
+ language: str,
215
+ ):
216
+ super().__init__()
217
+ self.language = language
218
+
219
+ if self.language not in SUPPORTED_LANGUAGES:
220
+ raise ValueError(f"Unsupported language: {self.language}. Supported languages are: {SUPPORTED_LANGUAGES}")
221
+
222
+ def get_instances(self, output_path: str) -> List[Instance]:
223
+ dataset = load_dataset(
224
+ "akhilayerukola/NormAd", split="train", revision="311049f5610e932792672e4e13e7df62b6290336"
225
+ )
226
+
227
+ instances: List[Instance] = []
228
+
229
+ for item in dataset:
230
+ input_text = (
231
+ f"Assumming that below norms and story are from {item['Country'].capitalize()}:\n"
232
+ f"{item['Background']}\n\n{item['Story']}\n"
233
+ "Return one of the following options: 'yes', 'no', or 'neutral'. Do not return anything else.\n\n"
234
+ )
235
+
236
+ instance = Instance(
237
+ input=Input(text=input_text),
238
+ references=[Reference(Output(text=item["Gold Label"]), tags=[CORRECT_TAG])],
239
+ split=TEST_SPLIT,
240
+ )
241
+ instances.append(instance)
242
+
243
+ return instances
244
+
245
+
246
+ class SocialNormExplanationNormADScenario(Scenario):
247
+ """Social norms explanation evaluation based on Vietnam World Values Survey responses."""
248
+
249
+ name = "social_norm_explanation_normad"
250
+ description = "Evaluates model understanding of social norms from WVS Wave 7"
251
+ tags = ["social_norm_explanation"]
252
+
253
+ def __init__(
254
+ self,
255
+ language: str,
256
+ ):
257
+ super().__init__()
258
+ self.language = language
259
+
260
+ if self.language not in SUPPORTED_LANGUAGES:
261
+ raise ValueError(f"Unsupported language: {self.language}. Supported languages are: {SUPPORTED_LANGUAGES}")
262
+
263
+ def get_instances(self, output_path: str) -> List[Instance]:
264
+ dataset = load_dataset(
265
+ "akhilayerukola/NormAd", split="train", revision="311049f5610e932792672e4e13e7df62b6290336"
266
+ )
267
+
268
+ instances: List[Instance] = []
269
+
270
+ for item in dataset:
271
+ if not item["Explanation"]:
272
+ continue
273
+
274
+ input_text = (
275
+ f"Assumming that below norms and story are from {item['Country'].capitalize()}:\n"
276
+ f"{item['Background']}\n\n{item['Story']}\n"
277
+ f"The answer is {item['Gold Label']}. "
278
+ "Briefly explain the reasoning behind this answer in one or two sentences.\n\n"
279
+ )
280
+
281
+ instance = Instance(
282
+ input=Input(text=input_text),
283
+ references=[Reference(Output(text=item["Explanation"]), tags=[CORRECT_TAG])],
284
+ split=TEST_SPLIT,
285
+ )
286
+ instances.append(instance)
287
+
288
+ return instances
@@ -18,13 +18,14 @@ from helm.benchmark.scenarios.scenario import (
18
18
 
19
19
 
20
20
  def remove_boxed(string: str) -> Optional[str]:
21
- """Source: https://github.com/hendrycks/math
21
+ r"""Source: https://github.com/hendrycks/math
22
22
 
23
- Extract the text within a \\boxed{...} environment.
23
+ Extract the text within a \boxed{...} environment.
24
24
 
25
25
  Example:
26
- >>> remove_boxed(\\boxed{\\frac{2}{3}})
27
- \\frac{2}{3}
26
+ >>> from helm.benchmark.scenarios.math_scenario import * # NOQA
27
+ >>> remove_boxed(r'\boxed{\frac{2}{3}}')
28
+ '\\frac{2}{3}'
28
29
  """
29
30
  left = "\\boxed{"
30
31
  try:
@@ -68,17 +69,17 @@ def last_boxed_only_string(string: str) -> Optional[str]:
68
69
 
69
70
 
70
71
  def _fix_fracs(string: str) -> str:
71
- """Source: https://github.com/hendrycks/math
72
+ r"""Source: https://github.com/hendrycks/math
72
73
 
73
74
  Reformat fractions.
74
75
 
75
76
  Examples:
76
- >>> _fix_fracs("\\frac1b")
77
- \frac{1}{b}
78
- >>> _fix_fracs("\\frac12")
79
- \frac{1}{2}
80
- >>> _fix_fracs("\\frac1{72}")
81
- \frac{1}{72}
77
+ >>> _fix_fracs(r"\frac1b")
78
+ '\\frac{1}{b}'
79
+ >>> _fix_fracs(r"\frac12")
80
+ '\\frac{1}{2}'
81
+ >>> _fix_fracs(r"\frac1{72}")
82
+ '\\frac{1}{72}'
82
83
  """
83
84
  substrs = string.split("\\frac")
84
85
  new_str = substrs[0]
@@ -112,13 +113,13 @@ def _fix_fracs(string: str) -> str:
112
113
 
113
114
 
114
115
  def _fix_a_slash_b(string: str) -> str:
115
- """Source: https://github.com/hendrycks/math
116
+ r"""Source: https://github.com/hendrycks/math
116
117
 
117
118
  Reformat fractions formatted as a/b to \\frac{a}{b}.
118
119
 
119
120
  Example:
120
- >>> _fix_a_slash_b("2/3")
121
- \frac{2}{3}
121
+ >>> _fix_a_slash_b(r"2/3")
122
+ '\\frac{2}{3}'
122
123
  """
123
124
  if len(string.split("/")) != 2:
124
125
  return string
@@ -149,13 +150,13 @@ def _remove_right_units(string: str) -> str:
149
150
 
150
151
 
151
152
  def _fix_sqrt(string: str) -> str:
152
- """Source: https://github.com/hendrycks/math
153
+ r"""Source: https://github.com/hendrycks/math
153
154
 
154
155
  Reformat square roots.
155
156
 
156
157
  Example:
157
- >>> _fix_sqrt("\\sqrt3")
158
- \sqrt{3}
158
+ >>> _fix_sqrt("\\sqrt3")
159
+ '\\sqrt{3}'
159
160
  """
160
161
  if "\\sqrt" not in string:
161
162
  return string
@@ -210,7 +211,7 @@ def _strip_string(string: str) -> str:
210
211
 
211
212
  # remove percentage
212
213
  string = string.replace("\\%", "")
213
- string = string.replace("\%", "")
214
+ string = string.replace(r"\%", "")
214
215
 
215
216
  # " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
216
217
  string = string.replace(" .", " 0.")
@@ -391,13 +392,13 @@ class MATHScenario(Scenario):
391
392
  for split, split_name in zip([TRAIN_SPLIT, TEST_SPLIT], ["train", "test"]):
392
393
  if split == TRAIN_SPLIT and self.use_official_examples:
393
394
  train_instances = [
394
- ("What is $\left(\\frac{7}{8}\\right)^3 \cdot \left(\\frac{7}{8}\\right)^{-3}$?", "1"),
395
+ ("What is $\\left(\\frac{7}{8}\\right)^3 \\cdot \\left(\\frac{7}{8}\\right)^{-3}$?", "1"),
395
396
  (
396
397
  "In how many ways can 4 books be selected from a shelf of 6 books"
397
398
  + " if the order in which the books are selected does not matter?",
398
399
  "15",
399
400
  ),
400
- ("Find the distance between the points $(2,1,-4)$ and $(5,8,-3).$", "\sqrt{59}"),
401
+ ("Find the distance between the points $(2,1,-4)$ and $(5,8,-3).$", "\\sqrt{59}"),
401
402
  (
402
403
  "The faces of an octahedral die are labeled with digits $1$ through $8$."
403
404
  + " What is the probability, expressed as a common fraction,"
@@ -90,7 +90,12 @@ class MedDialogScenario(Scenario):
90
90
  """
91
91
 
92
92
  name = "med_dialog"
93
- description = "A collection of doctor-patient conversations with corresponding summaries."
93
+ description = (
94
+ "MedDialog is a benchmark of real-world doctor-patient conversations focused on health-related"
95
+ "concerns and advice. Each dialogue is paired with a one-sentence summary"
96
+ "that reflects the core patient question or exchange. The benchmark evaluates a model's"
97
+ "ability to condense medical dialogue into concise, informative summaries."
98
+ )
94
99
  tags = ["dialogue", "biomedical"]
95
100
 
96
101
  def __init__(self, subset: str):
@@ -60,12 +60,18 @@ class MedalignScenario(Scenario):
60
60
  """
61
61
 
62
62
  name = "medalign"
63
- description = "A dataset that asks models to answer questions/follow instructions over longitudinal EHR."
63
+ description = (
64
+ "MedAlign is a benchmark that evaluates a model's ability to interpret and follow"
65
+ "instructions grounded in longitudinal electronic health records (EHR). Each instance"
66
+ "includes an event-stream style patient record and a natural language question or task,"
67
+ "requiring clinically informed reading comprehension and reasoning."
68
+ )
64
69
  tags = ["knowledge", "reasoning", "biomedical"]
65
70
 
66
- def __init__(self, max_length: int):
71
+ def __init__(self, max_length: int, data_path: str):
67
72
  super().__init__()
68
73
  self.max_length = max_length
74
+ self.data_path = data_path
69
75
 
70
76
  def process_tsv(self, data) -> List[Instance]:
71
77
  instances: List[Instance] = []
@@ -84,5 +90,5 @@ class MedalignScenario(Scenario):
84
90
  return instances
85
91
 
86
92
  def get_instances(self, output_path: str) -> List[Instance]:
87
- dataset = return_dataset_dataframe(self.max_length)
93
+ dataset = return_dataset_dataframe(self.max_length, self.data_path)
88
94
  return self.process_tsv(dataset)
@@ -2,23 +2,15 @@
2
2
  # type: ignore
3
3
  # fmt: off
4
4
 
5
- import ast
6
- import datetime
7
5
  import transformers
8
- import langchain
9
- import langchain.prompts
10
- import lxml.etree
11
6
  import os
12
7
  import pandas as pd
13
- import re
14
8
  import tiktoken
15
9
 
16
- from langchain_community.retrievers import BM25Retriever
17
10
  from tqdm import tqdm
18
- from typing import Any, Dict, Optional, Union, Callable
19
- from langchain.schema import Document
20
- import langchain_community
11
+ from typing import Any, Dict, Optional, Callable
21
12
 
13
+ from helm.common.general import check_file_exists
22
14
 
23
15
 
24
16
  def get_instructions(path_to_instructions: str) -> Dict[int, Dict[str, Any]]:
@@ -166,102 +158,13 @@ def get_tokenizer(tokenizer_name: str) -> Callable:
166
158
  return transformers.AutoTokenizer.from_pretrained(tokenizer_name, legacy=False)
167
159
 
168
160
 
169
- def retrieve_most_relevant_visits(ehr_visit_strs, query, target_length, tokenizer):
170
- """
171
- Retrieve and filter relevant EHR visits based on a query and target length.
172
-
173
- This function retrieves electronic health record (EHR) visit strings, sorts them
174
- by relevance using the BM25Retriever, and constructs a list of final documents
175
- that fit within a specified character length. The final list ensures that the
176
- most important visit isn't cut off and is sorted chronologically.
177
-
178
- Parameters:
179
- ehr_visit_strs (list of str): List of EHR visit strings.
180
- query (str): Query string to retrieve relevant visits.
181
- target_length (int): Maximum total token count for the final list of documents.
182
- tokenizer (Callable): Tokenizer that converts text to tokens (used for tracking context length)
183
-
184
- Returns:
185
- list[str]: List of EHR visit strings sorted chronologically and constrained by the target length.
186
- """
187
- ehr_visits=re.split(r'(?=</encounter>\n)',ehr_visit_strs)
188
- langchain_docs = [
189
- langchain.schema.Document(page_content=doc) for doc in ehr_visits #broken since ehr_visit_strs is one string of all visits
190
- ]
191
- # `k` is the number of documents to retrieve
192
- # We retrieve everything and just use the BM25Retriever to sort the documents
193
- retriever = langchain_community.retrievers.BM25Retriever.from_documents(
194
- langchain_docs, k=len(langchain_docs)
195
- )
196
-
197
- # Invoking the retriever means the most relevant documents are sorted first
198
- sorted_docs = retriever.invoke(query)
199
-
200
- # Define the regex pattern to find the start time
201
- # pattern = r'start="([\d/]+ [\d:]+)"'
202
- pattern = r'start="([\d/]+ [\d:]+ ?[APM]{0,2})"'
203
-
204
- docs = []
205
- dts = []
206
-
207
- # Find the startime of the document
208
- for doc in sorted_docs:
209
- doc_content = doc.page_content
210
- start_dt_match = re.search(pattern, doc_content)
211
- if start_dt_match:
212
- start_dt = start_dt_match.group(1)
213
- parsed = False
214
- # Try different date formats
215
- for fmt in (
216
- "%m/%d/%y %I:%M %p",
217
- "%m/%d/%Y %I:%M %p",
218
- "%m/%d/%y %H:%M",
219
- "%m/%d/%Y %H:%M",
220
- ):
221
- try:
222
- dts.append(datetime.datetime.strptime(start_dt, fmt))
223
- parsed = True
224
- break
225
- except ValueError:
226
- continue
227
- if not parsed:
228
- print(f"Error parsing date: {start_dt}")
229
- continue
230
- else:
231
- print(f"Start time not found., {doc_content}")
232
- dts.append(datetime.datetime.min)
233
- docs.append(doc_content)
234
-
235
- final_docs = []
236
- current_length = 0
237
-
238
- # Add documents until we exceed the allocated context length
239
- for i in range(len(docs)):
240
- doc_content = docs[i]
241
- doc_length = len(tokenizer.encode(doc_content))
242
- final_docs.append((dts[i], doc_content))
243
- current_length += doc_length
244
- if current_length > target_length:
245
- break
246
-
247
- # Sort final_docs chronologically
248
- final_docs.sort(key=lambda x: x[0])
249
-
250
- # Extract only the document content for the final output
251
- final_docs_content = [doc_content for _, doc_content in final_docs]
252
-
253
- return final_docs_content
254
-
255
-
256
-
257
161
  def pack_and_trim_prompts(
258
162
  instructions: Dict[int, Dict[str, str]],
259
163
  ehrs: Dict[int, str],
260
- prompt_template: langchain.prompts.PromptTemplate,
164
+ prompt_string: str,
261
165
  context_length: int,
262
166
  generation_length: int,
263
167
  tokenizer: Any,
264
- use_RAG: bool = True,
265
168
  verbose: bool = False,
266
169
  include_ehr: bool = True,
267
170
  ) -> Dict[int, str]:
@@ -275,26 +178,15 @@ def pack_and_trim_prompts(
275
178
  patient_id = int(instructions[instruction_id]["patient_id"])
276
179
  relevant_ehr = ehrs[patient_id]
277
180
 
278
- # Calculate how many tokens of EHR we can include in the prompt
279
181
  num_tokens_instruction = len(tokenizer.encode(instruction))
280
- num_tokens_prompt_template = len(tokenizer.encode(prompt_template.template))
182
+ num_tokens_prompt_template = len(tokenizer.encode(prompt_string))
281
183
  if include_ehr:
282
184
  target_ehr_length = context_length - generation_length - num_tokens_prompt_template - num_tokens_instruction
283
185
  else:
284
186
  target_ehr_length = 0
285
187
  if target_ehr_length <= 0:
286
- prompt_with_truncated_ehr = prompt_template.format(question=instruction, ehr="")
188
+ prompt_with_truncated_ehr = prompt_string.format(question=instruction, ehr="")
287
189
  else:
288
- if use_RAG:
289
- # Return a list of the most relevant visit strings
290
- most_relevant_visits = retrieve_most_relevant_visits(
291
- ehr_visit_strs=relevant_ehr,
292
- query=instruction,
293
- target_length=target_ehr_length,
294
- tokenizer=tokenizer,
295
- )
296
- relevant_ehr = "\n".join(most_relevant_visits)
297
-
298
190
  # Do a first pass with a fast tokenizer
299
191
  fast_tokenizer = tiktoken.get_encoding("cl100k_base")
300
192
  fast_encoded = fast_tokenizer.encode(relevant_ehr)
@@ -306,13 +198,17 @@ def pack_and_trim_prompts(
306
198
  encoded_ehr = tokenizer.encode(fast_truncated_ehr)
307
199
  truncated_encoded_ehr = encoded_ehr[-target_ehr_length:]
308
200
  truncated_ehr = tokenizer.decode(truncated_encoded_ehr)
309
- prompt_with_truncated_ehr = prompt_template.format(question=instruction, ehr=truncated_ehr)
201
+ prompt_with_truncated_ehr = prompt_string.format(question=instruction, ehr=truncated_ehr)
202
+ else:
203
+ # If the fast encoding is still too long, just use the full EHR up to allowed length
204
+ truncated_ehr = fast_tokenizer.decode(fast_encoded[-target_ehr_length:])
205
+ prompt_with_truncated_ehr = prompt_string.format(question=instruction, ehr=truncated_ehr)
310
206
 
311
- prompts_map[instruction_id] = prompt_with_truncated_ehr
207
+ prompts_map[instruction_id] = prompt_with_truncated_ehr
312
208
 
313
- if verbose:
314
- print(prompt_with_truncated_ehr)
315
- print("~" * 20)
209
+ if verbose:
210
+ print(prompt_with_truncated_ehr)
211
+ print("~" * 20)
316
212
  return prompts_map
317
213
 
318
214
 
@@ -321,7 +217,6 @@ def preprocess_prompts(
321
217
  generation_length,
322
218
  path_to_instructions,
323
219
  path_to_ehrs,
324
- use_RAG,
325
220
  include_ehr,
326
221
  tokenizer,
327
222
  codes_only=False,
@@ -346,16 +241,18 @@ def preprocess_prompts(
346
241
 
347
242
  # CONSTRUCT & TRUNCATE PROMPTS #
348
243
  print("Constructing prompts using instructions and EHRs...")
349
- prompt_string="Instruction: Answer the following question based on the EHR:\n\nEHR: {ehr}\n\nQuestion: {question}\n\nAnswer:"
350
- prompt_template = langchain.prompts.PromptTemplate.from_template(prompt_string)
244
+ prompt_string = (
245
+ "Instruction: Answer the following question based on the EHR:\n\n"
246
+ "EHR: {ehr}\n\nQuestion: {question}\n\nAnswer:"
247
+ )
248
+
351
249
  filled_prompts = pack_and_trim_prompts(
352
250
  instructions=instructions,
353
251
  ehrs=ehrs,
354
- prompt_template=prompt_template,
252
+ prompt_string=prompt_string,
355
253
  context_length=target_context_length,
356
254
  generation_length=generation_length,
357
255
  tokenizer=tokenizer,
358
- use_RAG=use_RAG,
359
256
  verbose=False,
360
257
  include_ehr=include_ehr,
361
258
  )
@@ -399,20 +296,21 @@ def add_reference_responses(prompts_df, path_to_reference_responses) -> pd.DataF
399
296
  Returns:
400
297
  pd.DataFrame: DataFrame containing the processed data.
401
298
  """
402
- gold_df = pd.read_csv(path_to_reference_responses)
299
+ gold_df = pd.read_csv(path_to_reference_responses, sep='\t')
403
300
  gold_df = gold_df.query("annotator_num == 'Annotator_1'")
404
301
  gold_df = gold_df[["instruction_id", "clinician_response"]]
405
302
  merged_df = gold_df.merge(prompts_df, on="instruction_id", how="inner")
406
303
  return merged_df
407
304
 
408
305
 
409
- def return_dataset_dataframe(max_length: int) -> pd.DataFrame:
306
+ def return_dataset_dataframe(max_length: int, data_path: str) -> pd.DataFrame:
410
307
  target_context_length = max_length
411
308
  generation_length = 256
412
- path_to_instructions = "/share/pi/nigam/datasets/medalign_release_fixes/clinician-reviewed-model-responses.tsv"
413
- path_to_ehrs = "/share/pi/nigam/datasets/medalign_release_fixes/medalign_ehr_xml"
414
- path_to_reference_responses = "/share/pi/nigam/scottyf/clinician-instruction-responses.csv"
415
- use_RAG = False
309
+ path_to_instructions = os.path.join(data_path, "clinician-reviewed-model-responses.tsv")
310
+ check_file_exists(path_to_instructions, msg=f"[MedAlignScenario] Required instructions file not found: '{path_to_instructions}'")
311
+ path_to_ehrs = os.path.join(data_path, "medalign_ehr_xml")
312
+ path_to_reference_responses = os.path.join(data_path, "clinician-instruction-responses.tsv")
313
+ check_file_exists(path_to_reference_responses, msg=f"[MedAlignScenario] Required clinician responses file not found: '{path_to_reference_responses}'")
416
314
  include_ehr = True
417
315
  tokenizer = "tiktoken"
418
316
 
@@ -421,7 +319,6 @@ def return_dataset_dataframe(max_length: int) -> pd.DataFrame:
421
319
  generation_length=generation_length,
422
320
  path_to_instructions=path_to_instructions,
423
321
  path_to_ehrs=path_to_ehrs,
424
- use_RAG=use_RAG,
425
322
  include_ehr=include_ehr,
426
323
  tokenizer=tokenizer,
427
324
  )