crfm-helm 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (98) hide show
  1. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/METADATA +13 -3
  2. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/RECORD +96 -63
  3. helm/benchmark/adaptation/adapter_spec.py +32 -31
  4. helm/benchmark/annotation/air_bench_annotator.py +64 -0
  5. helm/benchmark/annotation/annotator_factory.py +6 -0
  6. helm/benchmark/annotation/live_qa_annotator.py +84 -0
  7. helm/benchmark/annotation/medication_qa_annotator.py +81 -0
  8. helm/benchmark/augmentations/translate_perturbation.py +1 -0
  9. helm/benchmark/huggingface_registration.py +16 -6
  10. helm/benchmark/metrics/air_bench_metrics.py +56 -0
  11. helm/benchmark/metrics/fin_qa_metrics.py +60 -0
  12. helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
  13. helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
  14. helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
  15. helm/benchmark/metrics/live_qa_metrics.py +23 -0
  16. helm/benchmark/metrics/medication_qa_metrics.py +23 -0
  17. helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
  18. helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
  19. helm/benchmark/metrics/unitxt_metrics.py +20 -10
  20. helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
  21. helm/benchmark/metrics/vision_language/image_metrics.py +29 -71
  22. helm/benchmark/presentation/schema.py +54 -4
  23. helm/benchmark/presentation/test_schema.py +11 -0
  24. helm/benchmark/run.py +16 -2
  25. helm/benchmark/run_expander.py +77 -0
  26. helm/benchmark/run_spec_factory.py +4 -0
  27. helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
  28. helm/benchmark/run_specs/classic_run_specs.py +15 -11
  29. helm/benchmark/run_specs/decodingtrust_run_specs.py +3 -1
  30. helm/benchmark/run_specs/experimental_run_specs.py +33 -0
  31. helm/benchmark/run_specs/finance_run_specs.py +33 -0
  32. helm/benchmark/run_specs/vlm_run_specs.py +168 -45
  33. helm/benchmark/scenarios/air_bench_scenario.py +50 -0
  34. helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
  35. helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
  36. helm/benchmark/scenarios/fin_qa_scenario.py +117 -0
  37. helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
  38. helm/benchmark/scenarios/vision_language/bingo_scenario.py +3 -3
  39. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +13 -2
  40. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +1 -5
  41. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +0 -4
  42. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +4 -2
  43. helm/benchmark/scenarios/vision_language/pairs_scenario.py +6 -5
  44. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +3 -3
  45. helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +95 -0
  46. helm/benchmark/static/schema_air_bench.yaml +3149 -0
  47. helm/benchmark/static/schema_classic.yaml +3 -59
  48. helm/benchmark/static/schema_finance.yaml +143 -0
  49. helm/benchmark/static/schema_image2structure.yaml +254 -111
  50. helm/benchmark/static/schema_instruction_following.yaml +3 -52
  51. helm/benchmark/static/schema_lite.yaml +3 -61
  52. helm/benchmark/static/schema_medical.yaml +255 -0
  53. helm/benchmark/static/schema_mmlu.yaml +3 -61
  54. helm/benchmark/static/schema_tables.yaml +200 -0
  55. helm/benchmark/static/schema_thai.yaml +223 -0
  56. helm/benchmark/static/schema_unitxt.yaml +3 -61
  57. helm/benchmark/static/{schema_vlm.yaml → schema_vhelm.yaml} +294 -293
  58. helm/benchmark/static/schema_vhelm_lite.yaml +4 -59
  59. helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
  60. helm/benchmark/static_build/assets/index-30dbceba.js +10 -0
  61. helm/benchmark/static_build/assets/index-66b02d40.css +1 -0
  62. helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
  63. helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
  64. helm/benchmark/static_build/index.html +2 -2
  65. helm/clients/anthropic_client.py +43 -9
  66. helm/clients/auto_client.py +11 -0
  67. helm/clients/client.py +24 -7
  68. helm/clients/cohere_client.py +98 -3
  69. helm/clients/huggingface_client.py +71 -12
  70. helm/clients/openai_client.py +9 -2
  71. helm/clients/reka_client.py +189 -0
  72. helm/clients/test_client.py +3 -3
  73. helm/clients/test_huggingface_client.py +19 -3
  74. helm/clients/test_together_client.py +72 -2
  75. helm/clients/together_client.py +129 -23
  76. helm/clients/vertexai_client.py +62 -18
  77. helm/clients/vision_language/huggingface_vlm_client.py +1 -0
  78. helm/clients/vision_language/paligemma_client.py +146 -0
  79. helm/clients/vision_language/palmyra_vision_client.py +84 -0
  80. helm/clients/yi_client.py +31 -0
  81. helm/common/critique_request.py +10 -1
  82. helm/common/images_utils.py +19 -0
  83. helm/config/model_deployments.yaml +412 -18
  84. helm/config/model_metadata.yaml +447 -25
  85. helm/config/tokenizer_configs.yaml +93 -1
  86. helm/proxy/critique/model_critique_client.py +32 -4
  87. helm/proxy/services/server_service.py +1 -1
  88. helm/tokenizers/auto_tokenizer.py +1 -1
  89. helm/tokenizers/cohere_tokenizer.py +44 -2
  90. helm/tokenizers/huggingface_tokenizer.py +36 -13
  91. helm/tokenizers/test_cohere_tokenizer.py +39 -0
  92. helm/tokenizers/test_huggingface_tokenizer.py +5 -1
  93. helm/benchmark/static_build/assets/index-737eef9e.js +0 -10
  94. helm/benchmark/static_build/assets/index-878a1094.css +0 -1
  95. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/LICENSE +0 -0
  96. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/WHEEL +0 -0
  97. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/entry_points.txt +0 -0
  98. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,64 @@
1
+ import datasets
2
+ import os
3
+ import re
4
+ from typing import Any
5
+
6
+ from helm.common.general import ensure_directory_exists
7
+ from helm.benchmark.adaptation.request_state import RequestState
8
+ from helm.benchmark.annotation.annotator import Annotator
9
+ from helm.clients.auto_client import AutoClient
10
+ from helm.common.request import Request
11
+
12
+
13
+ class AIRBench2024Annotator(Annotator):
14
+ """The AIRBench 2024 autograder."""
15
+
16
+ name = "air_bench_2024"
17
+
18
+ def __init__(self, auto_client: AutoClient, file_storage_path: str):
19
+ self._auto_client = auto_client
20
+ cache_dir = os.path.join(file_storage_path, "data")
21
+ ensure_directory_exists(cache_dir)
22
+ dataset = datasets.load_dataset(
23
+ "stanford-crfm/air-bench-2024", "judge_prompts", split="test", cache_dir=cache_dir
24
+ )
25
+ self._category_id_to_judge_prompt = {row["cate-idx"]: row["judge_prompt"] for row in dataset}
26
+ # Regex pattern is lenient to allow for typos e.g. extra whitespace
27
+ self._pattern = re.compile("##\s*short_reasoning\s*:(.*)##\s*the_score\s*:(.*)", re.DOTALL)
28
+
29
+ def annotate(self, request_state: RequestState) -> Any:
30
+ assert request_state.result
31
+ assert len(request_state.result.completions) == 1
32
+ model_input_text = request_state.request.prompt
33
+ model_output_text = request_state.result.completions[0].text
34
+ if not model_output_text.strip():
35
+ return {"prompt_text": "", "reasoning": "BLOCKED_REQUEST_OR_EMPTY_RESPONSE", "score": 0.0}
36
+ category_id = request_state.instance.references[0].output.text
37
+ prompt_template = self._category_id_to_judge_prompt[category_id]
38
+ # Strip to deal with incorrectly formatted input CSV.
39
+ # TODO: Stop stripping after CSV is fixed.
40
+ annotator_prompt = prompt_template.replace("{{QUESTION}}", model_input_text).replace(
41
+ "{{ANSWER}}", model_output_text
42
+ )
43
+ annotator_request = Request(
44
+ model="openai/gpt-4o-2024-05-13",
45
+ model_deployment="openai/gpt-4o-2024-05-13",
46
+ prompt=annotator_prompt,
47
+ temperature=0.0,
48
+ max_tokens=64,
49
+ )
50
+ annotator_response = self._auto_client.make_request(annotator_request)
51
+ if not annotator_response.success:
52
+ raise Exception(f"Annotation request failed: {annotator_response.error}")
53
+ assert len(annotator_response.completions) == 1
54
+ annotator_response_text = annotator_response.completions[0].text
55
+ annotator_response_parts = self._pattern.search(annotator_response_text)
56
+ if not annotator_response_parts:
57
+ raise Exception(f"Malformed annotator response: {annotator_response_text}")
58
+ reasoning = annotator_response_parts[1].strip()
59
+ try:
60
+ score = float(annotator_response_parts[2].strip())
61
+ except ValueError as e:
62
+ raise Exception(f"Malformed annotator response: {annotator_response_text}") from e
63
+
64
+ return {"prompt_text": annotator_prompt, "reasoning": reasoning, "score": score}
@@ -1,6 +1,7 @@
1
1
  import os
2
2
  from typing import Any, Dict, Mapping, Optional
3
3
 
4
+ from helm.clients.auto_client import AutoClient
4
5
  from helm.common.credentials_utils import provide_api_key
5
6
  from helm.common.cache_backend_config import CacheBackendConfig, CacheConfig
6
7
  from helm.common.hierarchical_logger import hlog
@@ -46,6 +47,11 @@ class AnnotatorFactory:
46
47
  provider_bindings={
47
48
  "api_key": lambda: provide_api_key(self.credentials, annotator_name),
48
49
  "file_storage_path": lambda: self._get_file_storage_path(annotator_name),
50
+ "auto_client": lambda: AutoClient(
51
+ credentials=self.credentials,
52
+ file_storage_path=self.file_storage_path,
53
+ cache_backend_config=self.cache_backend_config,
54
+ ),
49
55
  },
50
56
  )
51
57
  annotator = create_object(annotator_spec)
@@ -0,0 +1,84 @@
1
+ import os
2
+ import re
3
+ from typing import Any
4
+
5
+ from helm.common.general import ensure_directory_exists
6
+ from helm.benchmark.adaptation.request_state import RequestState
7
+ from helm.benchmark.annotation.annotator import Annotator
8
+ from helm.clients.auto_client import AutoClient
9
+ from helm.common.request import Request
10
+
11
+
12
+ class LiveQAAnnotator(Annotator):
13
+ """The LiveQA autograder."""
14
+
15
+ name = "live_qa"
16
+
17
+ def __init__(self, auto_client: AutoClient, file_storage_path: str):
18
+ self._auto_client = auto_client
19
+ cache_dir = os.path.join(file_storage_path, "data")
20
+ ensure_directory_exists(cache_dir)
21
+ # Regex pattern is lenient to allow for typos e.g. extra whitespace
22
+ self._pattern = re.compile("##\s*short_reasoning\s*:(.*)##\s*the_score\s*:(.*)", re.DOTALL)
23
+
24
+ def annotate(self, request_state: RequestState) -> Any:
25
+ assert request_state.result
26
+ assert len(request_state.result.completions) == 1
27
+ model_input_text = request_state.request.prompt
28
+ model_output_text = request_state.result.completions[0].text
29
+ if not model_output_text.strip():
30
+ return {"prompt_text": "", "reasoning": "BLOCKED_REQUEST_OR_EMPTY_RESPONSE", "score": 0.0}
31
+ correct_responses = ""
32
+ for i, reference in enumerate(request_state.instance.references):
33
+ correct_responses += "Correct Response #" + str(i) + ": " + reference.output.text + "\n"
34
+ prompt_template = """You are a medical judge will be given a user request and a model response to that request. Your task is to evaluate whether the model response properly answers the request based on a set of correct answers.
35
+ The user's request will be provided in these tags:
36
+ <user_request>
37
+ {{QUESTION}}
38
+ </user_request>
39
+ The model's response will be provided in these tags:
40
+ <model_response>
41
+ {{ANSWER}}
42
+ </model_response>
43
+ Some potential correct responses will be provided in these tags:
44
+ <correct_responses>
45
+ {{CORRECT_RESPONSES}}
46
+ </correct_responses>
47
+ Carefully analyze the <model_response> and determine if the model is correct with respect to one or more of the <correct_responses> on a range from 0 to 1.
48
+ # 1, if the model response fully matches the correct response content and intent for at least one correct response
49
+ # 0.7, if the model response contains correct information but is incomplete relative to the correct response(s)
50
+ # 0.3, if the model response contains some correct information but is incomplete and/or contains incorrect information relative to the correct response(s)
51
+ # 0, if the model response is completely unrelated or incorrect relative to the correct response(s)
52
+ Output your one-sentence concise reasoning in the ##short_reasoning tag and your score in the ##the_score tag. Your reasoning should be less than 20 tokens. The score should be a single number with no other output.
53
+ Format:
54
+ ##short_reasoning: <your concise and brief reasoning>
55
+ ##the_score: <your score>""" # noqa: E501
56
+ annotator_prompt = (
57
+ prompt_template.strip('"')
58
+ .strip()
59
+ .replace("{{QUESTION}}", model_input_text)
60
+ .replace("{{ANSWER}}", model_output_text)
61
+ .replace("{{CORRECT_RESPONSES}}", correct_responses)
62
+ )
63
+ annotator_request = Request(
64
+ model="openai/gpt-4-turbo-2024-04-09",
65
+ model_deployment="openai/gpt-4-turbo-2024-04-09",
66
+ prompt=annotator_prompt,
67
+ temperature=0.0,
68
+ max_tokens=64,
69
+ )
70
+ annotator_response = self._auto_client.make_request(annotator_request)
71
+ if not annotator_response.success:
72
+ raise Exception(f"Annotation request failed: {annotator_response.error}")
73
+ assert len(annotator_response.completions) == 1
74
+ annotator_response_text = annotator_response.completions[0].text
75
+ annotator_response_parts = self._pattern.search(annotator_response_text)
76
+ if not annotator_response_parts:
77
+ raise Exception(f"Malformed annotator response: {annotator_response_text}")
78
+ reasoning = annotator_response_parts[1].strip()
79
+ try:
80
+ score = float(annotator_response_parts[2].strip())
81
+ except ValueError as e:
82
+ raise Exception(f"Malformed annotator response: {annotator_response_text}") from e
83
+
84
+ return {"prompt_text": annotator_prompt, "reasoning": reasoning, "score": score}
@@ -0,0 +1,81 @@
1
+ import os
2
+ import re
3
+ from typing import Any
4
+
5
+ from helm.common.general import ensure_directory_exists
6
+ from helm.benchmark.adaptation.request_state import RequestState
7
+ from helm.benchmark.annotation.annotator import Annotator
8
+ from helm.clients.auto_client import AutoClient
9
+ from helm.common.request import Request
10
+
11
+
12
+ class MedicationQAAnnotator(Annotator):
13
+ """The LiveQA autograder."""
14
+
15
+ name = "medication_qa"
16
+
17
+ def __init__(self, auto_client: AutoClient, file_storage_path: str):
18
+ self._auto_client = auto_client
19
+ cache_dir = os.path.join(file_storage_path, "data")
20
+ ensure_directory_exists(cache_dir)
21
+ # Regex pattern is lenient to allow for typos e.g. extra whitespace
22
+ self._pattern = re.compile("##\s*short_reasoning\s*:(.*)##\s*the_score\s*:(.*)", re.DOTALL)
23
+
24
+ def annotate(self, request_state: RequestState) -> Any:
25
+ assert request_state.result
26
+ assert len(request_state.result.completions) == 1
27
+ model_input_text = request_state.request.prompt
28
+ model_output_text = request_state.result.completions[0].text
29
+ if not model_output_text.strip():
30
+ return {"prompt_text": "", "reasoning": "BLOCKED_REQUEST_OR_EMPTY_RESPONSE", "score": 0.0}
31
+ correct_response = request_state.instance.references[0].output.text
32
+ prompt_template = """You are a medical judge will be given a user request and a model response to that request. Your task is to evaluate whether the model response properly answers the request based on a set of correct answers.
33
+ The user's request will be provided in these tags:
34
+ <user_request>
35
+ {{QUESTION}}
36
+ </user_request>
37
+ The model's response will be provided in these tags:
38
+ <model_response>
39
+ {{ANSWER}}
40
+ </model_response>
41
+ Some potential correct responses will be provided in these tags:
42
+ <correct_response>
43
+ {{CORRECT_RESPONSE}}
44
+ </correct_response>
45
+ Carefully analyze the <model_response> and determine if the model is correct with respect to one or more of the <correct_responses> on a range from 0 to 1.
46
+ # 1, if the model response fully matches the correct response content and intent for at least one correct response
47
+ # 0.5, if the model response contains some correct information but is incomplete and/or contains incorrect information relative to the correct response(s)
48
+ # 0, if the model response is completely unrelated or incorrect relative to the correct response(s)
49
+ Output your one-sentence concise reasoning in the ##short_reasoning tag and your score in the ##the_score tag. Your reasoning should be less than 20 tokens. The score should be a single number with no other output.
50
+ Format:
51
+ ##short_reasoning: <your concise and brief reasoning>
52
+ ##the_score: <your score>""" # noqa: E501
53
+ annotator_prompt = (
54
+ prompt_template.strip('"')
55
+ .strip()
56
+ .replace("{{QUESTION}}", model_input_text)
57
+ .replace("{{ANSWER}}", model_output_text)
58
+ .replace("{{CORRECT_RESPONSE}}", correct_response)
59
+ )
60
+ annotator_request = Request(
61
+ model="openai/gpt-4-turbo-2024-04-09",
62
+ model_deployment="openai/gpt-4-turbo-2024-04-09",
63
+ prompt=annotator_prompt,
64
+ temperature=0.0,
65
+ max_tokens=64,
66
+ )
67
+ annotator_response = self._auto_client.make_request(annotator_request)
68
+ if not annotator_response.success:
69
+ raise Exception(f"Annotation request failed: {annotator_response.error}")
70
+ assert len(annotator_response.completions) == 1
71
+ annotator_response_text = annotator_response.completions[0].text
72
+ annotator_response_parts = self._pattern.search(annotator_response_text)
73
+ if not annotator_response_parts:
74
+ raise Exception(f"Malformed annotator response: {annotator_response_text}")
75
+ reasoning = annotator_response_parts[1].strip()
76
+ try:
77
+ score = float(annotator_response_parts[2].strip())
78
+ except ValueError as e:
79
+ raise Exception(f"Malformed annotator response: {annotator_response_text}") from e
80
+
81
+ return {"prompt_text": annotator_prompt, "reasoning": reasoning, "score": score}
@@ -17,6 +17,7 @@ class TranslatePerturbation(TextPerturbation):
17
17
  language_code: str = "zh-CN"
18
18
 
19
19
  name: str = "translate"
20
+ should_perturb_references: bool = True
20
21
 
21
22
  def __init__(self, language_code: str):
22
23
  self.language_code: str = language_code
@@ -1,5 +1,5 @@
1
1
  import os
2
- from typing import Optional
2
+ from typing import Optional, Dict, Union
3
3
 
4
4
  from helm.benchmark.model_deployment_registry import (
5
5
  ClientSpec,
@@ -17,14 +17,22 @@ from helm.tokenizers.huggingface_tokenizer import HuggingFaceTokenizer
17
17
 
18
18
 
19
19
  def register_huggingface_model(
20
- helm_model_name: str, pretrained_model_name_or_path: str, revision: Optional[str] = None
20
+ helm_model_name: str,
21
+ pretrained_model_name_or_path: str,
22
+ revision: Optional[str] = None,
23
+ openvino: Optional[bool] = False,
21
24
  ) -> None:
22
- object_spec_args = {"pretrained_model_name_or_path": pretrained_model_name_or_path}
25
+ object_spec_args: Dict[str, Union[str, bool]] = {"pretrained_model_name_or_path": pretrained_model_name_or_path}
23
26
  if revision:
24
27
  object_spec_args["revision"] = revision
28
+ if openvino:
29
+ object_spec_args["openvino"] = openvino
25
30
 
26
31
  # Auto-infer model properties from the tokenizer.
27
- with HuggingFaceTokenizer.create_tokenizer(**object_spec_args) as tokenizer:
32
+ create_tokenizer_args: Dict[str, str] = {"pretrained_model_name_or_path": pretrained_model_name_or_path}
33
+ if revision:
34
+ create_tokenizer_args["revision"] = revision
35
+ with HuggingFaceTokenizer.create_tokenizer(**create_tokenizer_args) as tokenizer:
28
36
  max_sequence_length = tokenizer.model_max_length
29
37
  end_of_text_token = tokenizer.eos_token or ""
30
38
  prefix_token = tokenizer.bos_token or ""
@@ -71,7 +79,7 @@ def register_huggingface_model(
71
79
  register_tokenizer_config(tokenizer_config)
72
80
 
73
81
 
74
- def register_huggingface_hub_model_from_flag_value(raw_model_string: str) -> None:
82
+ def register_huggingface_hub_model_from_flag_value(raw_model_string: str, openvino=False) -> None:
75
83
  raw_model_string_parts = raw_model_string.split("@")
76
84
  pretrained_model_name_or_path: str
77
85
  revision: Optional[str]
@@ -88,10 +96,11 @@ def register_huggingface_hub_model_from_flag_value(raw_model_string: str) -> Non
88
96
  helm_model_name=raw_model_string,
89
97
  pretrained_model_name_or_path=pretrained_model_name_or_path,
90
98
  revision=revision,
99
+ openvino=openvino,
91
100
  )
92
101
 
93
102
 
94
- def register_huggingface_local_model_from_flag_value(path: str) -> None:
103
+ def register_huggingface_local_model_from_flag_value(path: str, openvino=False) -> None:
95
104
  if not path:
96
105
  raise ValueError("Path to Hugging Face model must be non-empty")
97
106
  path_parts = os.path.split(path)
@@ -99,4 +108,5 @@ def register_huggingface_local_model_from_flag_value(path: str) -> None:
99
108
  register_huggingface_model(
100
109
  helm_model_name=helm_model_name,
101
110
  pretrained_model_name_or_path=path,
111
+ openvino=openvino,
102
112
  )
@@ -0,0 +1,56 @@
1
+ from typing import List
2
+
3
+ from helm.benchmark.adaptation.adapter_spec import AdapterSpec
4
+ from helm.benchmark.adaptation.request_state import RequestState
5
+ from helm.benchmark.metrics.basic_metrics import compute_request_state_metrics
6
+ from helm.benchmark.metrics.efficiency_metrics import EfficiencyMetric
7
+ from helm.benchmark.metrics.metric import Metric
8
+ from helm.benchmark.metrics.metric_name import MetricName
9
+ from helm.benchmark.metrics.metric_service import MetricService
10
+ from helm.benchmark.metrics.statistic import Stat
11
+
12
+
13
+ class AIRBench2024BasicGenerationMetric(Metric):
14
+ """Replacement for BasicGenerationMetric for AIRBench 2024.
15
+
16
+ We call compute_request_state_metrics here because we can't use `BasicGenerationMetric`
17
+ because we abuse "references" to store metadata rather than true metadata."""
18
+
19
+ def __init__(self):
20
+ super().__init__()
21
+ self.efficiency_metric = EfficiencyMetric()
22
+
23
+ def evaluate_generation(
24
+ self,
25
+ adapter_spec: AdapterSpec,
26
+ request_state: RequestState,
27
+ metric_service: MetricService,
28
+ eval_cache_path: str,
29
+ ) -> List[Stat]:
30
+ return compute_request_state_metrics(self.efficiency_metric, adapter_spec, request_state, metric_service)
31
+
32
+
33
+ class AIRBench2024ScoreMetric(Metric):
34
+ """Score metrics for AIRBench 2024."""
35
+
36
+ def evaluate_generation(
37
+ self,
38
+ adapter_spec: AdapterSpec,
39
+ request_state: RequestState,
40
+ metric_service: MetricService,
41
+ eval_cache_path: str,
42
+ ) -> List[Stat]:
43
+ assert len(request_state.instance.references) > 1
44
+ category_text = request_state.instance.references[0].output.text
45
+ category_parts = category_text.split(".")
46
+ assert len(category_parts) == 3
47
+ assert request_state.annotations
48
+ score = request_state.annotations["air_bench_2024"]["score"]
49
+ return [
50
+ Stat(MetricName("air_score")).add(score),
51
+ Stat(MetricName(f"air_category_{category_parts[0]}_score")).add(score),
52
+ Stat(MetricName(f"air_category_{category_parts[0]}_{category_parts[1]}_score")).add(score),
53
+ Stat(MetricName(f"air_category_{category_parts[0]}_{category_parts[1]}_{category_parts[2]}_score")).add(
54
+ score
55
+ ),
56
+ ]
@@ -0,0 +1,60 @@
1
+ import math
2
+ import json
3
+ from typing import List, Union
4
+
5
+ from helm.benchmark.adaptation.adapter_spec import AdapterSpec
6
+ from helm.benchmark.adaptation.request_state import RequestState
7
+ from helm.benchmark.metrics.metric import Metric
8
+ from helm.benchmark.metrics.metric_name import MetricName
9
+ from helm.benchmark.metrics.metric_service import MetricService
10
+ from helm.benchmark.metrics.statistic import Stat
11
+ from helm.benchmark.metrics.fin_qa_metrics_helper import ( # type: ignore
12
+ equal_program,
13
+ eval_program,
14
+ program_tokenization,
15
+ )
16
+
17
+
18
+ def _get_program_accuracy(reference_program: List[str], generated_program: List[str]) -> float:
19
+ return 1.0 if equal_program(reference_program, generated_program) else 0.0
20
+
21
+
22
+ def _get_execution_accuracy(reference_execution: str, generated_program: List[str], table: List[List[str]]) -> float:
23
+ invalid_flag: int
24
+ generated_result: Union[str, float]
25
+ invalid_flag, generated_result = eval_program(generated_program, table)
26
+ if invalid_flag:
27
+ return 0.0
28
+ if reference_execution == "yes" or reference_execution == "no":
29
+ return 1.0 if reference_execution == generated_result else 0
30
+ else:
31
+ if not isinstance(generated_result, float):
32
+ return 0.0
33
+ return 1.0 if math.isclose(float(reference_execution), generated_result) else 0
34
+
35
+
36
+ class FinQAMetric(Metric):
37
+ def evaluate_generation(
38
+ self,
39
+ adapter_spec: AdapterSpec,
40
+ request_state: RequestState,
41
+ metric_service: MetricService,
42
+ eval_cache_path: str,
43
+ ) -> List[Stat]:
44
+ assert len(request_state.instance.references) == 3
45
+ reference_text = request_state.instance.references[0].output.text
46
+ reference_program = program_tokenization(reference_text)
47
+ reference_execution = request_state.instance.references[1].output.text
48
+ table: List[List[str]] = json.loads(request_state.instance.references[2].output.text)
49
+
50
+ assert request_state.result
51
+ assert len(request_state.result.completions) == 1
52
+ generated_text = request_state.result.completions[0].text.strip()
53
+ generated_program = program_tokenization(generated_text)
54
+
55
+ return [
56
+ Stat(MetricName("program_accuracy")).add(_get_program_accuracy(reference_program, generated_program)),
57
+ Stat(MetricName("execution_accuracy")).add(
58
+ _get_execution_accuracy(reference_execution, generated_program, table)
59
+ ),
60
+ ]