crfm-helm 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/METADATA +74 -53
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/RECORD +262 -182
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapters/in_context_learning_adapter.py +3 -3
- helm/benchmark/adaptation/adapters/test_adapter.py +4 -4
- helm/benchmark/annotation/air_bench_annotator.py +2 -2
- helm/benchmark/annotation/bigcodebench_annotator.py +3 -3
- helm/benchmark/annotation/bird_sql_annotator.py +2 -2
- helm/benchmark/annotation/chw_care_plan_annotator.py +7 -12
- helm/benchmark/annotation/ehr_sql_annotator.py +2 -2
- helm/benchmark/annotation/helpdesk_call_summarization_annotator.py +7 -7
- helm/benchmark/annotation/live_qa_annotator.py +1 -1
- helm/benchmark/annotation/mimic_bhc_annotator.py +100 -0
- helm/benchmark/annotation/model_as_judge.py +12 -16
- helm/benchmark/annotation/omni_math_annotator.py +13 -14
- helm/benchmark/annotation/wildbench_annotator.py +9 -9
- helm/benchmark/executor.py +11 -12
- helm/benchmark/metrics/aci_bench_metrics.py +9 -29
- helm/benchmark/metrics/bias_word_lists.py +1 -1
- helm/benchmark/metrics/chw_care_plan_metrics.py +10 -30
- helm/benchmark/metrics/classification_metrics.py +3 -3
- helm/benchmark/metrics/cleva_harms_metrics.py +2 -2
- helm/benchmark/metrics/codeinsights_code_efficiency_metrics.py +186 -0
- helm/benchmark/metrics/codeinsights_code_evaluation_metrics.py +477 -0
- helm/benchmark/metrics/codeinsights_correct_code_metrics.py +366 -0
- helm/benchmark/metrics/codeinsights_edge_case_metrics.py +92 -0
- helm/benchmark/metrics/codeinsights_metric_specs.py +51 -0
- helm/benchmark/metrics/comet_metric.py +1 -1
- helm/benchmark/metrics/conv_fin_qa_calc_metrics.py +2 -2
- helm/benchmark/metrics/copyright_metrics.py +1 -1
- helm/benchmark/metrics/decodingtrust_stereotype_bias_metrics.py +1 -1
- helm/benchmark/metrics/dischargeme_metrics.py +9 -29
- helm/benchmark/metrics/efficiency_metrics.py +3 -3
- helm/benchmark/metrics/evaluate_reference_metrics.py +1 -1
- helm/benchmark/metrics/gpt4_audio_refusal_metrics.py +145 -0
- helm/benchmark/metrics/ifeval_metrics.py +2 -2
- helm/benchmark/metrics/image_generation/clip_score_metrics.py +13 -2
- helm/benchmark/metrics/image_generation/fractal_dimension/fractal_dimension_util.py +1 -1
- helm/benchmark/metrics/kpi_edgar_metrics.py +121 -0
- helm/benchmark/metrics/llm_jury_metrics.py +46 -0
- helm/benchmark/metrics/lmkt_metric_specs.py +12 -0
- helm/benchmark/metrics/lmkt_metrics.py +47 -0
- helm/benchmark/metrics/med_dialog_metrics.py +9 -29
- helm/benchmark/metrics/medalign_metrics.py +9 -29
- helm/benchmark/metrics/medi_qa_metrics.py +9 -29
- helm/benchmark/metrics/medication_qa_metrics.py +10 -30
- helm/benchmark/metrics/melt_bias_metric.py +234 -0
- helm/benchmark/metrics/melt_bias_word_lists.py +1367 -0
- helm/benchmark/metrics/melt_metric_specs.py +43 -0
- helm/benchmark/metrics/melt_toxicity_metric.py +107 -0
- helm/benchmark/metrics/mental_health_metrics.py +9 -29
- helm/benchmark/metrics/metric_service.py +11 -11
- helm/benchmark/metrics/mimic_bhc_metrics.py +14 -0
- helm/benchmark/metrics/mimic_rrs_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_procedures_metrics.py +9 -29
- helm/benchmark/metrics/mtsamples_replicate_metrics.py +9 -29
- helm/benchmark/metrics/openai_mrcr_metrics.py +52 -0
- helm/benchmark/metrics/ruler_qa_metrics.py +34 -0
- helm/benchmark/metrics/starr_patient_instructions_metrics.py +9 -29
- helm/benchmark/metrics/summac/model_summac.py +2 -3
- helm/benchmark/metrics/summarization_metrics.py +2 -1
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/toxicity_metrics.py +2 -2
- helm/benchmark/metrics/unitxt_metrics.py +3 -4
- helm/benchmark/metrics/vision_language/emd_utils.py +4 -2
- helm/benchmark/metrics/vision_language/image_utils.py +2 -2
- helm/benchmark/model_deployment_registry.py +16 -26
- helm/benchmark/presentation/contamination.py +3 -3
- helm/benchmark/presentation/create_plots.py +43 -13
- helm/benchmark/presentation/run_display.py +13 -0
- helm/benchmark/presentation/schema.py +7 -1
- helm/benchmark/presentation/summarize.py +84 -61
- helm/benchmark/presentation/test_create_plots.py +4 -1
- helm/benchmark/reeval_run.py +3 -4
- helm/benchmark/reeval_runner.py +3 -3
- helm/benchmark/run.py +84 -73
- helm/benchmark/run_expander.py +12 -1
- helm/benchmark/run_spec_factory.py +7 -6
- helm/benchmark/run_specs/arabic_run_specs.py +73 -0
- helm/benchmark/run_specs/audio_run_specs.py +52 -8
- helm/benchmark/run_specs/bluex_run_specs.py +40 -0
- helm/benchmark/run_specs/classic_run_specs.py +0 -53
- helm/benchmark/run_specs/codeinsights_run_specs.py +192 -0
- helm/benchmark/run_specs/enterprise_run_specs.py +20 -0
- helm/benchmark/run_specs/experimental_run_specs.py +31 -1
- helm/benchmark/run_specs/healthqa_br_run_specs.py +40 -0
- helm/benchmark/run_specs/heim_run_specs.py +3 -1
- helm/benchmark/run_specs/lmkt_run_specs.py +144 -0
- helm/benchmark/run_specs/long_context_run_specs.py +114 -15
- helm/benchmark/run_specs/medhelm_run_specs.py +146 -41
- helm/benchmark/run_specs/melt_run_specs.py +783 -0
- helm/benchmark/run_specs/multilingual_run_specs.py +50 -0
- helm/benchmark/run_specs/speech_disorder_audio_run_specs.py +163 -0
- helm/benchmark/run_specs/vlm_run_specs.py +28 -0
- helm/benchmark/runner.py +5 -5
- helm/benchmark/scenarios/aci_bench_scenario.py +7 -1
- helm/benchmark/scenarios/alghafa_scenario.py +126 -0
- helm/benchmark/scenarios/arabic_mmlu_scenario.py +78 -0
- helm/benchmark/scenarios/aratrust_scenario.py +76 -0
- helm/benchmark/scenarios/audio_language/air_bench_chat_scenario.py +3 -1
- helm/benchmark/scenarios/audio_language/air_bench_foundation_scenario.py +5 -5
- helm/benchmark/scenarios/audio_language/casual_conversations2_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/corebench_scenario.py +77 -0
- helm/benchmark/scenarios/audio_language/mustard_scenario.py +1 -1
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_classification_scenario.py +104 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_asr_transcription_scenario.py +99 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_classification_scenario.py +118 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_breakdown_scenario.py +86 -0
- helm/benchmark/scenarios/audio_language/ultra_suite_disorder_symptoms_scenario.py +117 -0
- helm/benchmark/scenarios/audio_language/vocal_sound_scenario.py +15 -1
- helm/benchmark/scenarios/audio_language/voxceleb2_scenario.py +1 -2
- helm/benchmark/scenarios/autobencher_capabilities_scenario.py +2 -2
- helm/benchmark/scenarios/bluex_scenario.py +66 -0
- helm/benchmark/scenarios/chw_care_plan_scenario.py +14 -13
- helm/benchmark/scenarios/clear_scenario.py +11 -7
- helm/benchmark/scenarios/cleva_scenario.py +1 -1
- helm/benchmark/scenarios/codeinsights_code_efficiency_scenario.py +197 -0
- helm/benchmark/scenarios/codeinsights_correct_code_scenario.py +78 -0
- helm/benchmark/scenarios/codeinsights_edge_case_scenario.py +192 -0
- helm/benchmark/scenarios/codeinsights_student_coding_scenario.py +162 -0
- helm/benchmark/scenarios/codeinsights_student_mistake_scenario.py +188 -0
- helm/benchmark/scenarios/dischargeme_scenario.py +36 -21
- helm/benchmark/scenarios/ehr_sql_scenario.py +7 -1
- helm/benchmark/scenarios/ehrshot_scenario.py +28 -55
- helm/benchmark/scenarios/exams_multilingual_scenario.py +115 -0
- helm/benchmark/scenarios/grammar.py +2 -2
- helm/benchmark/scenarios/headqa_scenario.py +6 -1
- helm/benchmark/scenarios/healthqa_br_scenario.py +80 -0
- helm/benchmark/scenarios/infinite_bench_en_mc_scenario.py +90 -0
- helm/benchmark/scenarios/infinite_bench_en_qa_scenario.py +85 -0
- helm/benchmark/scenarios/{infinite_bench_sum_scenario.py → infinite_bench_en_sum_scenario.py} +10 -13
- helm/benchmark/scenarios/kpi_edgar_scenario.py +151 -0
- helm/benchmark/scenarios/lmkt_scenarios.py +288 -0
- helm/benchmark/scenarios/math_scenario.py +21 -20
- helm/benchmark/scenarios/med_dialog_scenario.py +6 -1
- helm/benchmark/scenarios/medalign_scenario.py +9 -3
- helm/benchmark/scenarios/medalign_scenario_helper.py +27 -130
- helm/benchmark/scenarios/medbullets_scenario.py +7 -2
- helm/benchmark/scenarios/medcalc_bench_scenario.py +4 -2
- helm/benchmark/scenarios/medec_scenario.py +6 -1
- helm/benchmark/scenarios/medhallu_scenario.py +7 -1
- helm/benchmark/scenarios/medi_qa_scenario.py +10 -4
- helm/benchmark/scenarios/medication_qa_scenario.py +7 -1
- helm/benchmark/scenarios/melt_ir_scenario.py +171 -0
- helm/benchmark/scenarios/melt_knowledge_scenario.py +246 -0
- helm/benchmark/scenarios/melt_lm_scenarios.py +252 -0
- helm/benchmark/scenarios/melt_scenarios.py +793 -0
- helm/benchmark/scenarios/melt_srn_scenario.py +342 -0
- helm/benchmark/scenarios/melt_synthetic_reasoning_scenario.py +222 -0
- helm/benchmark/scenarios/melt_translation_scenario.py +152 -0
- helm/benchmark/scenarios/mental_health_scenario.py +16 -5
- helm/benchmark/scenarios/mimic_bhc_scenario.py +13 -8
- helm/benchmark/scenarios/mimic_rrs_scenario.py +17 -8
- helm/benchmark/scenarios/mimiciv_billing_code_scenario.py +14 -8
- helm/benchmark/scenarios/mmlu_pro_scenario.py +1 -1
- helm/benchmark/scenarios/mmmlu_scenario.py +85 -0
- helm/benchmark/scenarios/mtsamples_procedures_scenario.py +5 -2
- helm/benchmark/scenarios/mtsamples_replicate_scenario.py +3 -2
- helm/benchmark/scenarios/n2c2_ct_matching_scenario.py +11 -5
- helm/benchmark/scenarios/openai_mrcr_scenario.py +79 -0
- helm/benchmark/scenarios/pubmed_qa_scenario.py +6 -1
- helm/benchmark/scenarios/race_based_med_scenario.py +18 -8
- helm/benchmark/scenarios/ruler_qa_scenario_helper.py +2 -2
- helm/benchmark/scenarios/ruler_qa_scenarios.py +2 -2
- helm/benchmark/scenarios/seahelm_scenario.py +2 -2
- helm/benchmark/scenarios/shc_bmt_scenario.py +12 -6
- helm/benchmark/scenarios/shc_cdi_scenario.py +11 -6
- helm/benchmark/scenarios/shc_conf_scenario.py +12 -6
- helm/benchmark/scenarios/shc_ent_scenario.py +11 -6
- helm/benchmark/scenarios/shc_gip_scenario.py +13 -5
- helm/benchmark/scenarios/shc_privacy_scenario.py +78 -0
- helm/benchmark/scenarios/shc_proxy_scenario.py +76 -0
- helm/benchmark/scenarios/shc_ptbm_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sei_scenario.py +12 -7
- helm/benchmark/scenarios/shc_sequoia_scenario.py +13 -5
- helm/benchmark/scenarios/starr_patient_instructions_scenario.py +15 -8
- helm/benchmark/scenarios/test_alghafa_scenario.py +29 -0
- helm/benchmark/scenarios/test_aratrust_scenario.py +21 -0
- helm/benchmark/scenarios/test_bluex_scenario.py +59 -0
- helm/benchmark/scenarios/test_exams_multilingual_scenario.py +29 -0
- helm/benchmark/scenarios/test_healtha_br_scenario.py +57 -0
- helm/benchmark/scenarios/test_infinite_bench_en_qa_scenario.py +18 -0
- helm/benchmark/scenarios/test_infinite_bench_en_sum_scenario.py +31 -0
- helm/benchmark/scenarios/truthful_qa_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/msr_vtt_scenario.py +75 -0
- helm/benchmark/server.py +2 -1
- helm/benchmark/slurm_jobs.py +1 -2
- helm/benchmark/slurm_runner.py +8 -1
- helm/benchmark/static/schema_arabic.yaml +228 -0
- helm/benchmark/static/schema_audio.yaml +60 -49
- helm/benchmark/static/schema_classic.yaml +0 -17
- helm/benchmark/static/schema_enterprise.yaml +21 -0
- helm/benchmark/static/schema_long_context.yaml +81 -20
- helm/benchmark/static/schema_medhelm.yaml +272 -213
- helm/benchmark/static/schema_melt.yaml +1257 -0
- helm/benchmark/static/schema_slphelm.yaml +162 -0
- helm/benchmark/static/schema_vhelm.yaml +26 -26
- helm/benchmark/static/schema_video.yaml +219 -0
- helm/benchmark/static_build/assets/index-b9779128.css +1 -0
- helm/benchmark/static_build/assets/index-e439d5e1.js +10 -0
- helm/benchmark/static_build/assets/medhelm-overview-eac29843.png +0 -0
- helm/benchmark/static_build/assets/{tremor-9cefc3c5.js → tremor-38a10867.js} +1 -1
- helm/benchmark/static_build/index.html +4 -4
- helm/benchmark/window_services/encoder_decoder_window_service.py +3 -3
- helm/benchmark/window_services/image_generation/clip_window_service.py +1 -3
- helm/benchmark/window_services/test_utils.py +3 -4
- helm/benchmark/window_services/tokenizer_service.py +7 -8
- helm/clients/anthropic_client.py +69 -29
- helm/clients/audio_language/diva_llama_client.py +4 -2
- helm/clients/audio_language/qwen2_5_omni_client.py +209 -0
- helm/clients/audio_language/qwen2_audiolm_client.py +8 -6
- helm/clients/audio_language/qwen_audiolm_client.py +4 -2
- helm/clients/audio_language/test.py +62 -0
- helm/clients/bedrock_client.py +3 -1
- helm/clients/client.py +7 -7
- helm/clients/grok_client.py +36 -0
- helm/clients/huggingface_client.py +42 -3
- helm/clients/huggingface_pipeline_client.py +138 -0
- helm/clients/image_generation/dalle_mini/model/configuration.py +1 -1
- helm/clients/image_generation/dalle_mini/model/modeling.py +1 -1
- helm/clients/image_generation/dalle_mini/model/processor.py +1 -1
- helm/clients/image_generation/dalle_mini/model/tokenizer.py +1 -1
- helm/clients/openai_client.py +102 -55
- helm/clients/openai_responses_client.py +176 -0
- helm/clients/palmyra_client.py +2 -5
- helm/clients/reka_client.py +2 -2
- helm/clients/test_huggingface_client.py +3 -3
- helm/clients/together_client.py +31 -6
- helm/clients/vertexai_client.py +17 -9
- helm/clients/vision_language/huggingface_vision2seq_client.py +6 -4
- helm/clients/vision_language/huggingface_vlm_client.py +2 -2
- helm/clients/vision_language/idefics_client.py +6 -2
- helm/clients/vision_language/paligemma_client.py +2 -2
- helm/clients/vision_language/qwen2_vlm_client.py +66 -53
- helm/clients/vision_language/qwen_vlm_client.py +7 -5
- helm/clients/vllm_client.py +43 -7
- helm/clients/vllm_granite_thinking_client.py +56 -0
- helm/clients/writer_client.py +102 -0
- helm/common/context.py +80 -0
- helm/common/credentials_utils.py +5 -5
- helm/common/critique_request.py +0 -1
- helm/common/general.py +9 -2
- helm/common/hierarchical_logger.py +104 -12
- helm/common/local_context.py +140 -0
- helm/common/object_spec.py +23 -8
- helm/common/remote_context.py +61 -0
- helm/common/request.py +8 -0
- helm/common/test_logging.py +94 -0
- helm/config/model_deployments.yaml +995 -45
- helm/config/model_metadata.yaml +780 -59
- helm/config/tokenizer_configs.yaml +224 -3
- helm/proxy/cli.py +4 -2
- helm/proxy/critique/mechanical_turk_utils.py +1 -1
- helm/proxy/retry.py +5 -0
- helm/proxy/services/server_service.py +21 -85
- helm/tokenizers/grok_tokenizer.py +55 -0
- helm/tokenizers/huggingface_tokenizer.py +1 -1
- helm/tokenizers/test_grok_tokenizer.py +33 -0
- helm/benchmark/metrics/numeracy_metrics.py +0 -72
- helm/benchmark/metrics/test_numeracy_metrics.py +0 -95
- helm/benchmark/scenarios/numeracy_scenario.py +0 -793
- helm/benchmark/scenarios/test_infinite_bench_sum_scenario.py +0 -46
- helm/benchmark/static_build/assets/index-262903c1.js +0 -10
- helm/benchmark/static_build/assets/index-42060d71.css +0 -1
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/licenses/LICENSE +0 -0
- {crfm_helm-0.5.5.dist-info → crfm_helm-0.5.7.dist-info}/top_level.txt +0 -0
- /helm/benchmark/static_build/assets/{medhelm-overview-3ddfcd65.png → medhelm-v1-overview-3ddfcd65.png} +0 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
"""Cultural alignment evaluation scenario based on Vietnam World Values Survey responses."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import json
|
|
5
|
+
import random
|
|
6
|
+
from typing import List
|
|
7
|
+
from datasets import load_dataset
|
|
8
|
+
from huggingface_hub import snapshot_download
|
|
9
|
+
|
|
10
|
+
from helm.common.hierarchical_logger import hlog, hwarn
|
|
11
|
+
from helm.benchmark.scenarios.scenario import (
|
|
12
|
+
Scenario,
|
|
13
|
+
Instance,
|
|
14
|
+
Reference,
|
|
15
|
+
TEST_SPLIT,
|
|
16
|
+
CORRECT_TAG,
|
|
17
|
+
Input,
|
|
18
|
+
Output,
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
SUPPORTED_LANGUAGES = ["en", "vi"]
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class CulturalValueUnderstandingWVSScenario(Scenario):
|
|
25
|
+
"""Cultural values understanding evaluation based on Vietnam World Values Survey responses."""
|
|
26
|
+
|
|
27
|
+
name = "cultural_value_understanding_wvs"
|
|
28
|
+
description = "Evaluates model understanding of cultural values from WVS Wave 7"
|
|
29
|
+
tags = ["cultural_value_understanding"]
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
language: str,
|
|
34
|
+
personas_filename: str = "vn_personas_with_filtered_responses.json",
|
|
35
|
+
questions_filename: str = "wvs_questions_dump.en.json",
|
|
36
|
+
repo_id: str = "stair-lab/cultural-alignment-dataset",
|
|
37
|
+
num_personas: int = 10,
|
|
38
|
+
num_question_variants: int = 2,
|
|
39
|
+
include_few_shot_examples: bool = False,
|
|
40
|
+
):
|
|
41
|
+
super().__init__()
|
|
42
|
+
self.language = language
|
|
43
|
+
self.personas_filename = personas_filename
|
|
44
|
+
self.questions_filename = questions_filename
|
|
45
|
+
self.repo_id = repo_id
|
|
46
|
+
self.num_personas = 10 if isinstance(num_personas, str) else num_personas
|
|
47
|
+
self.num_question_variants = min(num_question_variants, 4)
|
|
48
|
+
self.include_few_shot_examples = include_few_shot_examples
|
|
49
|
+
|
|
50
|
+
if self.language not in SUPPORTED_LANGUAGES:
|
|
51
|
+
raise ValueError(f"Unsupported language: {self.language}. Supported languages are: {SUPPORTED_LANGUAGES}")
|
|
52
|
+
|
|
53
|
+
def get_few_shot_examples(self) -> str:
|
|
54
|
+
"""Return a few-shot learning examples formatted as a string."""
|
|
55
|
+
examples = """Here are some examples:
|
|
56
|
+
|
|
57
|
+
Example 1:
|
|
58
|
+
You are a 35 year old male from Vietnam. You have higher education and consider yourself middle class. You are married.
|
|
59
|
+
|
|
60
|
+
How important is family in your life?
|
|
61
|
+
|
|
62
|
+
1. Very important
|
|
63
|
+
2. Rather important
|
|
64
|
+
3. Not very important
|
|
65
|
+
4. Not at all important
|
|
66
|
+
|
|
67
|
+
Please select the option number that best represents your view.
|
|
68
|
+
1
|
|
69
|
+
|
|
70
|
+
Example 2:
|
|
71
|
+
You are a 28 year old female from Vietnam. You have middle education and consider yourself working class. You are single.
|
|
72
|
+
|
|
73
|
+
Would you say that most people can be trusted or that you need to be very careful in dealing with people?
|
|
74
|
+
|
|
75
|
+
1. Most people can be trusted
|
|
76
|
+
2. Need to be very careful
|
|
77
|
+
|
|
78
|
+
Please select the option number that best represents your view.
|
|
79
|
+
2
|
|
80
|
+
|
|
81
|
+
Example 3:
|
|
82
|
+
You are a 50 year old male from Vietnam. You have lower education and consider yourself lower class. You are married.
|
|
83
|
+
|
|
84
|
+
Do you think that homosexuality is justifiable?
|
|
85
|
+
|
|
86
|
+
1. Never justifiable
|
|
87
|
+
2. Rarely justifiable
|
|
88
|
+
3. Sometimes justifiable
|
|
89
|
+
4. Always justifiable
|
|
90
|
+
|
|
91
|
+
Please select the option number that best represents your view.
|
|
92
|
+
1
|
|
93
|
+
|
|
94
|
+
Now answer the following question:
|
|
95
|
+
""" # noqa: E501
|
|
96
|
+
return examples
|
|
97
|
+
|
|
98
|
+
def get_instances(self, output_path: str) -> List[Instance]:
|
|
99
|
+
"""Generate test instances from Vietnam personas and WVS questions."""
|
|
100
|
+
instances: List[Instance] = []
|
|
101
|
+
|
|
102
|
+
# Download files from Hugging Face Hub
|
|
103
|
+
repo_local_path = snapshot_download(
|
|
104
|
+
repo_id=self.repo_id, repo_type="dataset", revision="fe54b6f5d75cfca5377707cd7199e39f517e3a1f"
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Load the downloaded files
|
|
108
|
+
with open(os.path.join(repo_local_path, self.personas_filename), "r", encoding="utf-8") as f:
|
|
109
|
+
personas = json.load(f)
|
|
110
|
+
|
|
111
|
+
with open(os.path.join(repo_local_path, self.questions_filename), "r", encoding="utf-8") as f:
|
|
112
|
+
questions = json.load(f)
|
|
113
|
+
|
|
114
|
+
# Get few-shot examples
|
|
115
|
+
few_shot_examples = self.get_few_shot_examples() if self.include_few_shot_examples else ""
|
|
116
|
+
|
|
117
|
+
# Sample personas
|
|
118
|
+
sampled_personas = random.sample(personas, min(self.num_personas, len(personas)))
|
|
119
|
+
|
|
120
|
+
# Create instances for each persona and question
|
|
121
|
+
for persona in sampled_personas:
|
|
122
|
+
# Get demographic info for persona description
|
|
123
|
+
persona_desc = (
|
|
124
|
+
f"You are a {persona.get('age', 'adult')} year old {persona.get('sex', 'person')} from Vietnam. "
|
|
125
|
+
)
|
|
126
|
+
persona_desc += f"You have {persona.get('education', 'some')} education and consider yourself {persona.get('social_class', 'middle class')}. " # noqa: E501
|
|
127
|
+
persona_desc += f"You are {persona.get('marital_status', 'single')}."
|
|
128
|
+
|
|
129
|
+
# Process each question this persona answered
|
|
130
|
+
for qid, human_response in persona.get("responses", {}).items():
|
|
131
|
+
# Skip if no human response or if it's 0 (which might be a "Don't know" response)
|
|
132
|
+
if human_response is None:
|
|
133
|
+
continue
|
|
134
|
+
|
|
135
|
+
# Convert human_response to int (if possible)
|
|
136
|
+
try:
|
|
137
|
+
human_response_int = int(human_response)
|
|
138
|
+
except (ValueError, TypeError):
|
|
139
|
+
# Skip if human_response can't be converted to int
|
|
140
|
+
continue
|
|
141
|
+
|
|
142
|
+
# Get question info
|
|
143
|
+
question_data = questions.get(qid, {})
|
|
144
|
+
if not question_data:
|
|
145
|
+
continue
|
|
146
|
+
|
|
147
|
+
# Get options directly from question data
|
|
148
|
+
q_options = question_data.get("options", [])
|
|
149
|
+
if not q_options:
|
|
150
|
+
continue
|
|
151
|
+
|
|
152
|
+
# Skip if human_response is out of range
|
|
153
|
+
if human_response_int < 0 or human_response_int > len(q_options):
|
|
154
|
+
continue
|
|
155
|
+
|
|
156
|
+
# Special handling for "Don't know" or zero responses
|
|
157
|
+
if human_response_int == 0:
|
|
158
|
+
# Some questions might encode "Don't know" as 0
|
|
159
|
+
# Skip for now, or you could add special handling
|
|
160
|
+
continue
|
|
161
|
+
|
|
162
|
+
# Use the predefined question variations
|
|
163
|
+
question_variants = question_data.get("questions", [])
|
|
164
|
+
if not question_variants:
|
|
165
|
+
question_variants = [f"Question {qid}: {question_data.get('description', '')}"]
|
|
166
|
+
|
|
167
|
+
# Use the specified number of variants
|
|
168
|
+
variants_to_use = min(self.num_question_variants, len(question_variants))
|
|
169
|
+
selected_variants = question_variants[:variants_to_use]
|
|
170
|
+
|
|
171
|
+
# Create instances for each selected question variant
|
|
172
|
+
for q_text in selected_variants:
|
|
173
|
+
# Format the prompt with or without few-shot examples
|
|
174
|
+
if self.include_few_shot_examples:
|
|
175
|
+
prompt = f"{few_shot_examples}{persona_desc}\n\n{q_text}\n\n"
|
|
176
|
+
else:
|
|
177
|
+
prompt = f"{persona_desc}\n\n{q_text}\n\n"
|
|
178
|
+
|
|
179
|
+
# Add options from question data - with numbers, not letters
|
|
180
|
+
for i, opt in enumerate(q_options, 1):
|
|
181
|
+
prompt += f"{i}. {opt}\n"
|
|
182
|
+
|
|
183
|
+
prompt += "\nPlease select the option number that best represents your view. Return only the option number. Do not return anything else." # noqa: E501
|
|
184
|
+
|
|
185
|
+
# Create a reference with just the human response number
|
|
186
|
+
# We don't create multiple references, just use the actual human response
|
|
187
|
+
reference = Reference(Output(text=str(human_response_int)), tags=[CORRECT_TAG])
|
|
188
|
+
|
|
189
|
+
# Create the instance
|
|
190
|
+
instance = Instance(
|
|
191
|
+
input=Input(text=prompt),
|
|
192
|
+
references=[reference],
|
|
193
|
+
split=TEST_SPLIT,
|
|
194
|
+
)
|
|
195
|
+
instances.append(instance)
|
|
196
|
+
|
|
197
|
+
if not instances:
|
|
198
|
+
hwarn("No valid instances were created. Check the input data and parameters.")
|
|
199
|
+
else:
|
|
200
|
+
hlog(f"Created {len(instances)} instances for cultural value understanding scenario.")
|
|
201
|
+
|
|
202
|
+
return instances
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class SocialNormApplicationNormADScenario(Scenario):
|
|
206
|
+
"""Social norms application evaluation based on Vietnam World Values Survey responses."""
|
|
207
|
+
|
|
208
|
+
name = "social_norm_application_normad"
|
|
209
|
+
description = "Evaluates model understanding of social norms from WVS Wave 7"
|
|
210
|
+
tags = ["social_norm_application"]
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
language: str,
|
|
215
|
+
):
|
|
216
|
+
super().__init__()
|
|
217
|
+
self.language = language
|
|
218
|
+
|
|
219
|
+
if self.language not in SUPPORTED_LANGUAGES:
|
|
220
|
+
raise ValueError(f"Unsupported language: {self.language}. Supported languages are: {SUPPORTED_LANGUAGES}")
|
|
221
|
+
|
|
222
|
+
def get_instances(self, output_path: str) -> List[Instance]:
|
|
223
|
+
dataset = load_dataset(
|
|
224
|
+
"akhilayerukola/NormAd", split="train", revision="311049f5610e932792672e4e13e7df62b6290336"
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
instances: List[Instance] = []
|
|
228
|
+
|
|
229
|
+
for item in dataset:
|
|
230
|
+
input_text = (
|
|
231
|
+
f"Assumming that below norms and story are from {item['Country'].capitalize()}:\n"
|
|
232
|
+
f"{item['Background']}\n\n{item['Story']}\n"
|
|
233
|
+
"Return one of the following options: 'yes', 'no', or 'neutral'. Do not return anything else.\n\n"
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
instance = Instance(
|
|
237
|
+
input=Input(text=input_text),
|
|
238
|
+
references=[Reference(Output(text=item["Gold Label"]), tags=[CORRECT_TAG])],
|
|
239
|
+
split=TEST_SPLIT,
|
|
240
|
+
)
|
|
241
|
+
instances.append(instance)
|
|
242
|
+
|
|
243
|
+
return instances
|
|
244
|
+
|
|
245
|
+
|
|
246
|
+
class SocialNormExplanationNormADScenario(Scenario):
|
|
247
|
+
"""Social norms explanation evaluation based on Vietnam World Values Survey responses."""
|
|
248
|
+
|
|
249
|
+
name = "social_norm_explanation_normad"
|
|
250
|
+
description = "Evaluates model understanding of social norms from WVS Wave 7"
|
|
251
|
+
tags = ["social_norm_explanation"]
|
|
252
|
+
|
|
253
|
+
def __init__(
|
|
254
|
+
self,
|
|
255
|
+
language: str,
|
|
256
|
+
):
|
|
257
|
+
super().__init__()
|
|
258
|
+
self.language = language
|
|
259
|
+
|
|
260
|
+
if self.language not in SUPPORTED_LANGUAGES:
|
|
261
|
+
raise ValueError(f"Unsupported language: {self.language}. Supported languages are: {SUPPORTED_LANGUAGES}")
|
|
262
|
+
|
|
263
|
+
def get_instances(self, output_path: str) -> List[Instance]:
|
|
264
|
+
dataset = load_dataset(
|
|
265
|
+
"akhilayerukola/NormAd", split="train", revision="311049f5610e932792672e4e13e7df62b6290336"
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
instances: List[Instance] = []
|
|
269
|
+
|
|
270
|
+
for item in dataset:
|
|
271
|
+
if not item["Explanation"]:
|
|
272
|
+
continue
|
|
273
|
+
|
|
274
|
+
input_text = (
|
|
275
|
+
f"Assumming that below norms and story are from {item['Country'].capitalize()}:\n"
|
|
276
|
+
f"{item['Background']}\n\n{item['Story']}\n"
|
|
277
|
+
f"The answer is {item['Gold Label']}. "
|
|
278
|
+
"Briefly explain the reasoning behind this answer in one or two sentences.\n\n"
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
instance = Instance(
|
|
282
|
+
input=Input(text=input_text),
|
|
283
|
+
references=[Reference(Output(text=item["Explanation"]), tags=[CORRECT_TAG])],
|
|
284
|
+
split=TEST_SPLIT,
|
|
285
|
+
)
|
|
286
|
+
instances.append(instance)
|
|
287
|
+
|
|
288
|
+
return instances
|
|
@@ -18,13 +18,14 @@ from helm.benchmark.scenarios.scenario import (
|
|
|
18
18
|
|
|
19
19
|
|
|
20
20
|
def remove_boxed(string: str) -> Optional[str]:
|
|
21
|
-
"""Source: https://github.com/hendrycks/math
|
|
21
|
+
r"""Source: https://github.com/hendrycks/math
|
|
22
22
|
|
|
23
|
-
Extract the text within a
|
|
23
|
+
Extract the text within a \boxed{...} environment.
|
|
24
24
|
|
|
25
25
|
Example:
|
|
26
|
-
|
|
27
|
-
|
|
26
|
+
>>> from helm.benchmark.scenarios.math_scenario import * # NOQA
|
|
27
|
+
>>> remove_boxed(r'\boxed{\frac{2}{3}}')
|
|
28
|
+
'\\frac{2}{3}'
|
|
28
29
|
"""
|
|
29
30
|
left = "\\boxed{"
|
|
30
31
|
try:
|
|
@@ -68,17 +69,17 @@ def last_boxed_only_string(string: str) -> Optional[str]:
|
|
|
68
69
|
|
|
69
70
|
|
|
70
71
|
def _fix_fracs(string: str) -> str:
|
|
71
|
-
"""Source: https://github.com/hendrycks/math
|
|
72
|
+
r"""Source: https://github.com/hendrycks/math
|
|
72
73
|
|
|
73
74
|
Reformat fractions.
|
|
74
75
|
|
|
75
76
|
Examples:
|
|
76
|
-
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
77
|
+
>>> _fix_fracs(r"\frac1b")
|
|
78
|
+
'\\frac{1}{b}'
|
|
79
|
+
>>> _fix_fracs(r"\frac12")
|
|
80
|
+
'\\frac{1}{2}'
|
|
81
|
+
>>> _fix_fracs(r"\frac1{72}")
|
|
82
|
+
'\\frac{1}{72}'
|
|
82
83
|
"""
|
|
83
84
|
substrs = string.split("\\frac")
|
|
84
85
|
new_str = substrs[0]
|
|
@@ -112,13 +113,13 @@ def _fix_fracs(string: str) -> str:
|
|
|
112
113
|
|
|
113
114
|
|
|
114
115
|
def _fix_a_slash_b(string: str) -> str:
|
|
115
|
-
"""Source: https://github.com/hendrycks/math
|
|
116
|
+
r"""Source: https://github.com/hendrycks/math
|
|
116
117
|
|
|
117
118
|
Reformat fractions formatted as a/b to \\frac{a}{b}.
|
|
118
119
|
|
|
119
120
|
Example:
|
|
120
|
-
|
|
121
|
-
|
|
121
|
+
>>> _fix_a_slash_b(r"2/3")
|
|
122
|
+
'\\frac{2}{3}'
|
|
122
123
|
"""
|
|
123
124
|
if len(string.split("/")) != 2:
|
|
124
125
|
return string
|
|
@@ -149,13 +150,13 @@ def _remove_right_units(string: str) -> str:
|
|
|
149
150
|
|
|
150
151
|
|
|
151
152
|
def _fix_sqrt(string: str) -> str:
|
|
152
|
-
"""Source: https://github.com/hendrycks/math
|
|
153
|
+
r"""Source: https://github.com/hendrycks/math
|
|
153
154
|
|
|
154
155
|
Reformat square roots.
|
|
155
156
|
|
|
156
157
|
Example:
|
|
157
|
-
|
|
158
|
-
|
|
158
|
+
>>> _fix_sqrt("\\sqrt3")
|
|
159
|
+
'\\sqrt{3}'
|
|
159
160
|
"""
|
|
160
161
|
if "\\sqrt" not in string:
|
|
161
162
|
return string
|
|
@@ -210,7 +211,7 @@ def _strip_string(string: str) -> str:
|
|
|
210
211
|
|
|
211
212
|
# remove percentage
|
|
212
213
|
string = string.replace("\\%", "")
|
|
213
|
-
string = string.replace("\%", "")
|
|
214
|
+
string = string.replace(r"\%", "")
|
|
214
215
|
|
|
215
216
|
# " 0." equivalent to " ." and "{0." equivalent to "{." Alternatively, add "0" if "." is the start of the string
|
|
216
217
|
string = string.replace(" .", " 0.")
|
|
@@ -391,13 +392,13 @@ class MATHScenario(Scenario):
|
|
|
391
392
|
for split, split_name in zip([TRAIN_SPLIT, TEST_SPLIT], ["train", "test"]):
|
|
392
393
|
if split == TRAIN_SPLIT and self.use_official_examples:
|
|
393
394
|
train_instances = [
|
|
394
|
-
("What is
|
|
395
|
+
("What is $\\left(\\frac{7}{8}\\right)^3 \\cdot \\left(\\frac{7}{8}\\right)^{-3}$?", "1"),
|
|
395
396
|
(
|
|
396
397
|
"In how many ways can 4 books be selected from a shelf of 6 books"
|
|
397
398
|
+ " if the order in which the books are selected does not matter?",
|
|
398
399
|
"15",
|
|
399
400
|
),
|
|
400
|
-
("Find the distance between the points $(2,1,-4)$ and $(5,8,-3).$", "
|
|
401
|
+
("Find the distance between the points $(2,1,-4)$ and $(5,8,-3).$", "\\sqrt{59}"),
|
|
401
402
|
(
|
|
402
403
|
"The faces of an octahedral die are labeled with digits $1$ through $8$."
|
|
403
404
|
+ " What is the probability, expressed as a common fraction,"
|
|
@@ -90,7 +90,12 @@ class MedDialogScenario(Scenario):
|
|
|
90
90
|
"""
|
|
91
91
|
|
|
92
92
|
name = "med_dialog"
|
|
93
|
-
description =
|
|
93
|
+
description = (
|
|
94
|
+
"MedDialog is a benchmark of real-world doctor-patient conversations focused on health-related"
|
|
95
|
+
"concerns and advice. Each dialogue is paired with a one-sentence summary"
|
|
96
|
+
"that reflects the core patient question or exchange. The benchmark evaluates a model's"
|
|
97
|
+
"ability to condense medical dialogue into concise, informative summaries."
|
|
98
|
+
)
|
|
94
99
|
tags = ["dialogue", "biomedical"]
|
|
95
100
|
|
|
96
101
|
def __init__(self, subset: str):
|
|
@@ -60,12 +60,18 @@ class MedalignScenario(Scenario):
|
|
|
60
60
|
"""
|
|
61
61
|
|
|
62
62
|
name = "medalign"
|
|
63
|
-
description =
|
|
63
|
+
description = (
|
|
64
|
+
"MedAlign is a benchmark that evaluates a model's ability to interpret and follow"
|
|
65
|
+
"instructions grounded in longitudinal electronic health records (EHR). Each instance"
|
|
66
|
+
"includes an event-stream style patient record and a natural language question or task,"
|
|
67
|
+
"requiring clinically informed reading comprehension and reasoning."
|
|
68
|
+
)
|
|
64
69
|
tags = ["knowledge", "reasoning", "biomedical"]
|
|
65
70
|
|
|
66
|
-
def __init__(self, max_length: int):
|
|
71
|
+
def __init__(self, max_length: int, data_path: str):
|
|
67
72
|
super().__init__()
|
|
68
73
|
self.max_length = max_length
|
|
74
|
+
self.data_path = data_path
|
|
69
75
|
|
|
70
76
|
def process_tsv(self, data) -> List[Instance]:
|
|
71
77
|
instances: List[Instance] = []
|
|
@@ -84,5 +90,5 @@ class MedalignScenario(Scenario):
|
|
|
84
90
|
return instances
|
|
85
91
|
|
|
86
92
|
def get_instances(self, output_path: str) -> List[Instance]:
|
|
87
|
-
dataset = return_dataset_dataframe(self.max_length)
|
|
93
|
+
dataset = return_dataset_dataframe(self.max_length, self.data_path)
|
|
88
94
|
return self.process_tsv(dataset)
|
|
@@ -2,23 +2,15 @@
|
|
|
2
2
|
# type: ignore
|
|
3
3
|
# fmt: off
|
|
4
4
|
|
|
5
|
-
import ast
|
|
6
|
-
import datetime
|
|
7
5
|
import transformers
|
|
8
|
-
import langchain
|
|
9
|
-
import langchain.prompts
|
|
10
|
-
import lxml.etree
|
|
11
6
|
import os
|
|
12
7
|
import pandas as pd
|
|
13
|
-
import re
|
|
14
8
|
import tiktoken
|
|
15
9
|
|
|
16
|
-
from langchain_community.retrievers import BM25Retriever
|
|
17
10
|
from tqdm import tqdm
|
|
18
|
-
from typing import Any, Dict, Optional,
|
|
19
|
-
from langchain.schema import Document
|
|
20
|
-
import langchain_community
|
|
11
|
+
from typing import Any, Dict, Optional, Callable
|
|
21
12
|
|
|
13
|
+
from helm.common.general import check_file_exists
|
|
22
14
|
|
|
23
15
|
|
|
24
16
|
def get_instructions(path_to_instructions: str) -> Dict[int, Dict[str, Any]]:
|
|
@@ -166,102 +158,13 @@ def get_tokenizer(tokenizer_name: str) -> Callable:
|
|
|
166
158
|
return transformers.AutoTokenizer.from_pretrained(tokenizer_name, legacy=False)
|
|
167
159
|
|
|
168
160
|
|
|
169
|
-
def retrieve_most_relevant_visits(ehr_visit_strs, query, target_length, tokenizer):
|
|
170
|
-
"""
|
|
171
|
-
Retrieve and filter relevant EHR visits based on a query and target length.
|
|
172
|
-
|
|
173
|
-
This function retrieves electronic health record (EHR) visit strings, sorts them
|
|
174
|
-
by relevance using the BM25Retriever, and constructs a list of final documents
|
|
175
|
-
that fit within a specified character length. The final list ensures that the
|
|
176
|
-
most important visit isn't cut off and is sorted chronologically.
|
|
177
|
-
|
|
178
|
-
Parameters:
|
|
179
|
-
ehr_visit_strs (list of str): List of EHR visit strings.
|
|
180
|
-
query (str): Query string to retrieve relevant visits.
|
|
181
|
-
target_length (int): Maximum total token count for the final list of documents.
|
|
182
|
-
tokenizer (Callable): Tokenizer that converts text to tokens (used for tracking context length)
|
|
183
|
-
|
|
184
|
-
Returns:
|
|
185
|
-
list[str]: List of EHR visit strings sorted chronologically and constrained by the target length.
|
|
186
|
-
"""
|
|
187
|
-
ehr_visits=re.split(r'(?=</encounter>\n)',ehr_visit_strs)
|
|
188
|
-
langchain_docs = [
|
|
189
|
-
langchain.schema.Document(page_content=doc) for doc in ehr_visits #broken since ehr_visit_strs is one string of all visits
|
|
190
|
-
]
|
|
191
|
-
# `k` is the number of documents to retrieve
|
|
192
|
-
# We retrieve everything and just use the BM25Retriever to sort the documents
|
|
193
|
-
retriever = langchain_community.retrievers.BM25Retriever.from_documents(
|
|
194
|
-
langchain_docs, k=len(langchain_docs)
|
|
195
|
-
)
|
|
196
|
-
|
|
197
|
-
# Invoking the retriever means the most relevant documents are sorted first
|
|
198
|
-
sorted_docs = retriever.invoke(query)
|
|
199
|
-
|
|
200
|
-
# Define the regex pattern to find the start time
|
|
201
|
-
# pattern = r'start="([\d/]+ [\d:]+)"'
|
|
202
|
-
pattern = r'start="([\d/]+ [\d:]+ ?[APM]{0,2})"'
|
|
203
|
-
|
|
204
|
-
docs = []
|
|
205
|
-
dts = []
|
|
206
|
-
|
|
207
|
-
# Find the startime of the document
|
|
208
|
-
for doc in sorted_docs:
|
|
209
|
-
doc_content = doc.page_content
|
|
210
|
-
start_dt_match = re.search(pattern, doc_content)
|
|
211
|
-
if start_dt_match:
|
|
212
|
-
start_dt = start_dt_match.group(1)
|
|
213
|
-
parsed = False
|
|
214
|
-
# Try different date formats
|
|
215
|
-
for fmt in (
|
|
216
|
-
"%m/%d/%y %I:%M %p",
|
|
217
|
-
"%m/%d/%Y %I:%M %p",
|
|
218
|
-
"%m/%d/%y %H:%M",
|
|
219
|
-
"%m/%d/%Y %H:%M",
|
|
220
|
-
):
|
|
221
|
-
try:
|
|
222
|
-
dts.append(datetime.datetime.strptime(start_dt, fmt))
|
|
223
|
-
parsed = True
|
|
224
|
-
break
|
|
225
|
-
except ValueError:
|
|
226
|
-
continue
|
|
227
|
-
if not parsed:
|
|
228
|
-
print(f"Error parsing date: {start_dt}")
|
|
229
|
-
continue
|
|
230
|
-
else:
|
|
231
|
-
print(f"Start time not found., {doc_content}")
|
|
232
|
-
dts.append(datetime.datetime.min)
|
|
233
|
-
docs.append(doc_content)
|
|
234
|
-
|
|
235
|
-
final_docs = []
|
|
236
|
-
current_length = 0
|
|
237
|
-
|
|
238
|
-
# Add documents until we exceed the allocated context length
|
|
239
|
-
for i in range(len(docs)):
|
|
240
|
-
doc_content = docs[i]
|
|
241
|
-
doc_length = len(tokenizer.encode(doc_content))
|
|
242
|
-
final_docs.append((dts[i], doc_content))
|
|
243
|
-
current_length += doc_length
|
|
244
|
-
if current_length > target_length:
|
|
245
|
-
break
|
|
246
|
-
|
|
247
|
-
# Sort final_docs chronologically
|
|
248
|
-
final_docs.sort(key=lambda x: x[0])
|
|
249
|
-
|
|
250
|
-
# Extract only the document content for the final output
|
|
251
|
-
final_docs_content = [doc_content for _, doc_content in final_docs]
|
|
252
|
-
|
|
253
|
-
return final_docs_content
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
161
|
def pack_and_trim_prompts(
|
|
258
162
|
instructions: Dict[int, Dict[str, str]],
|
|
259
163
|
ehrs: Dict[int, str],
|
|
260
|
-
|
|
164
|
+
prompt_string: str,
|
|
261
165
|
context_length: int,
|
|
262
166
|
generation_length: int,
|
|
263
167
|
tokenizer: Any,
|
|
264
|
-
use_RAG: bool = True,
|
|
265
168
|
verbose: bool = False,
|
|
266
169
|
include_ehr: bool = True,
|
|
267
170
|
) -> Dict[int, str]:
|
|
@@ -275,26 +178,15 @@ def pack_and_trim_prompts(
|
|
|
275
178
|
patient_id = int(instructions[instruction_id]["patient_id"])
|
|
276
179
|
relevant_ehr = ehrs[patient_id]
|
|
277
180
|
|
|
278
|
-
# Calculate how many tokens of EHR we can include in the prompt
|
|
279
181
|
num_tokens_instruction = len(tokenizer.encode(instruction))
|
|
280
|
-
num_tokens_prompt_template = len(tokenizer.encode(
|
|
182
|
+
num_tokens_prompt_template = len(tokenizer.encode(prompt_string))
|
|
281
183
|
if include_ehr:
|
|
282
184
|
target_ehr_length = context_length - generation_length - num_tokens_prompt_template - num_tokens_instruction
|
|
283
185
|
else:
|
|
284
186
|
target_ehr_length = 0
|
|
285
187
|
if target_ehr_length <= 0:
|
|
286
|
-
prompt_with_truncated_ehr =
|
|
188
|
+
prompt_with_truncated_ehr = prompt_string.format(question=instruction, ehr="")
|
|
287
189
|
else:
|
|
288
|
-
if use_RAG:
|
|
289
|
-
# Return a list of the most relevant visit strings
|
|
290
|
-
most_relevant_visits = retrieve_most_relevant_visits(
|
|
291
|
-
ehr_visit_strs=relevant_ehr,
|
|
292
|
-
query=instruction,
|
|
293
|
-
target_length=target_ehr_length,
|
|
294
|
-
tokenizer=tokenizer,
|
|
295
|
-
)
|
|
296
|
-
relevant_ehr = "\n".join(most_relevant_visits)
|
|
297
|
-
|
|
298
190
|
# Do a first pass with a fast tokenizer
|
|
299
191
|
fast_tokenizer = tiktoken.get_encoding("cl100k_base")
|
|
300
192
|
fast_encoded = fast_tokenizer.encode(relevant_ehr)
|
|
@@ -306,13 +198,17 @@ def pack_and_trim_prompts(
|
|
|
306
198
|
encoded_ehr = tokenizer.encode(fast_truncated_ehr)
|
|
307
199
|
truncated_encoded_ehr = encoded_ehr[-target_ehr_length:]
|
|
308
200
|
truncated_ehr = tokenizer.decode(truncated_encoded_ehr)
|
|
309
|
-
prompt_with_truncated_ehr =
|
|
201
|
+
prompt_with_truncated_ehr = prompt_string.format(question=instruction, ehr=truncated_ehr)
|
|
202
|
+
else:
|
|
203
|
+
# If the fast encoding is still too long, just use the full EHR up to allowed length
|
|
204
|
+
truncated_ehr = fast_tokenizer.decode(fast_encoded[-target_ehr_length:])
|
|
205
|
+
prompt_with_truncated_ehr = prompt_string.format(question=instruction, ehr=truncated_ehr)
|
|
310
206
|
|
|
311
|
-
|
|
207
|
+
prompts_map[instruction_id] = prompt_with_truncated_ehr
|
|
312
208
|
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
209
|
+
if verbose:
|
|
210
|
+
print(prompt_with_truncated_ehr)
|
|
211
|
+
print("~" * 20)
|
|
316
212
|
return prompts_map
|
|
317
213
|
|
|
318
214
|
|
|
@@ -321,7 +217,6 @@ def preprocess_prompts(
|
|
|
321
217
|
generation_length,
|
|
322
218
|
path_to_instructions,
|
|
323
219
|
path_to_ehrs,
|
|
324
|
-
use_RAG,
|
|
325
220
|
include_ehr,
|
|
326
221
|
tokenizer,
|
|
327
222
|
codes_only=False,
|
|
@@ -346,16 +241,18 @@ def preprocess_prompts(
|
|
|
346
241
|
|
|
347
242
|
# CONSTRUCT & TRUNCATE PROMPTS #
|
|
348
243
|
print("Constructing prompts using instructions and EHRs...")
|
|
349
|
-
prompt_string=
|
|
350
|
-
|
|
244
|
+
prompt_string = (
|
|
245
|
+
"Instruction: Answer the following question based on the EHR:\n\n"
|
|
246
|
+
"EHR: {ehr}\n\nQuestion: {question}\n\nAnswer:"
|
|
247
|
+
)
|
|
248
|
+
|
|
351
249
|
filled_prompts = pack_and_trim_prompts(
|
|
352
250
|
instructions=instructions,
|
|
353
251
|
ehrs=ehrs,
|
|
354
|
-
|
|
252
|
+
prompt_string=prompt_string,
|
|
355
253
|
context_length=target_context_length,
|
|
356
254
|
generation_length=generation_length,
|
|
357
255
|
tokenizer=tokenizer,
|
|
358
|
-
use_RAG=use_RAG,
|
|
359
256
|
verbose=False,
|
|
360
257
|
include_ehr=include_ehr,
|
|
361
258
|
)
|
|
@@ -399,20 +296,21 @@ def add_reference_responses(prompts_df, path_to_reference_responses) -> pd.DataF
|
|
|
399
296
|
Returns:
|
|
400
297
|
pd.DataFrame: DataFrame containing the processed data.
|
|
401
298
|
"""
|
|
402
|
-
gold_df = pd.read_csv(path_to_reference_responses)
|
|
299
|
+
gold_df = pd.read_csv(path_to_reference_responses, sep='\t')
|
|
403
300
|
gold_df = gold_df.query("annotator_num == 'Annotator_1'")
|
|
404
301
|
gold_df = gold_df[["instruction_id", "clinician_response"]]
|
|
405
302
|
merged_df = gold_df.merge(prompts_df, on="instruction_id", how="inner")
|
|
406
303
|
return merged_df
|
|
407
304
|
|
|
408
305
|
|
|
409
|
-
def return_dataset_dataframe(max_length: int) -> pd.DataFrame:
|
|
306
|
+
def return_dataset_dataframe(max_length: int, data_path: str) -> pd.DataFrame:
|
|
410
307
|
target_context_length = max_length
|
|
411
308
|
generation_length = 256
|
|
412
|
-
path_to_instructions = "
|
|
413
|
-
|
|
414
|
-
|
|
415
|
-
|
|
309
|
+
path_to_instructions = os.path.join(data_path, "clinician-reviewed-model-responses.tsv")
|
|
310
|
+
check_file_exists(path_to_instructions, msg=f"[MedAlignScenario] Required instructions file not found: '{path_to_instructions}'")
|
|
311
|
+
path_to_ehrs = os.path.join(data_path, "medalign_ehr_xml")
|
|
312
|
+
path_to_reference_responses = os.path.join(data_path, "clinician-instruction-responses.tsv")
|
|
313
|
+
check_file_exists(path_to_reference_responses, msg=f"[MedAlignScenario] Required clinician responses file not found: '{path_to_reference_responses}'")
|
|
416
314
|
include_ehr = True
|
|
417
315
|
tokenizer = "tiktoken"
|
|
418
316
|
|
|
@@ -421,7 +319,6 @@ def return_dataset_dataframe(max_length: int) -> pd.DataFrame:
|
|
|
421
319
|
generation_length=generation_length,
|
|
422
320
|
path_to_instructions=path_to_instructions,
|
|
423
321
|
path_to_ehrs=path_to_ehrs,
|
|
424
|
-
use_RAG=use_RAG,
|
|
425
322
|
include_ehr=include_ehr,
|
|
426
323
|
tokenizer=tokenizer,
|
|
427
324
|
)
|