crfm-helm 0.5.1__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (98) hide show
  1. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/METADATA +13 -3
  2. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/RECORD +96 -63
  3. helm/benchmark/adaptation/adapter_spec.py +32 -31
  4. helm/benchmark/annotation/air_bench_annotator.py +64 -0
  5. helm/benchmark/annotation/annotator_factory.py +6 -0
  6. helm/benchmark/annotation/live_qa_annotator.py +84 -0
  7. helm/benchmark/annotation/medication_qa_annotator.py +81 -0
  8. helm/benchmark/augmentations/translate_perturbation.py +1 -0
  9. helm/benchmark/huggingface_registration.py +16 -6
  10. helm/benchmark/metrics/air_bench_metrics.py +56 -0
  11. helm/benchmark/metrics/fin_qa_metrics.py +60 -0
  12. helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
  13. helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
  14. helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
  15. helm/benchmark/metrics/live_qa_metrics.py +23 -0
  16. helm/benchmark/metrics/medication_qa_metrics.py +23 -0
  17. helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
  18. helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
  19. helm/benchmark/metrics/unitxt_metrics.py +20 -10
  20. helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
  21. helm/benchmark/metrics/vision_language/image_metrics.py +29 -71
  22. helm/benchmark/presentation/schema.py +54 -4
  23. helm/benchmark/presentation/test_schema.py +11 -0
  24. helm/benchmark/run.py +16 -2
  25. helm/benchmark/run_expander.py +77 -0
  26. helm/benchmark/run_spec_factory.py +4 -0
  27. helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
  28. helm/benchmark/run_specs/classic_run_specs.py +15 -11
  29. helm/benchmark/run_specs/decodingtrust_run_specs.py +3 -1
  30. helm/benchmark/run_specs/experimental_run_specs.py +33 -0
  31. helm/benchmark/run_specs/finance_run_specs.py +33 -0
  32. helm/benchmark/run_specs/vlm_run_specs.py +168 -45
  33. helm/benchmark/scenarios/air_bench_scenario.py +50 -0
  34. helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
  35. helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
  36. helm/benchmark/scenarios/fin_qa_scenario.py +117 -0
  37. helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
  38. helm/benchmark/scenarios/vision_language/bingo_scenario.py +3 -3
  39. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +13 -2
  40. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +1 -5
  41. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +0 -4
  42. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +4 -2
  43. helm/benchmark/scenarios/vision_language/pairs_scenario.py +6 -5
  44. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +3 -3
  45. helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +95 -0
  46. helm/benchmark/static/schema_air_bench.yaml +3149 -0
  47. helm/benchmark/static/schema_classic.yaml +3 -59
  48. helm/benchmark/static/schema_finance.yaml +143 -0
  49. helm/benchmark/static/schema_image2structure.yaml +254 -111
  50. helm/benchmark/static/schema_instruction_following.yaml +3 -52
  51. helm/benchmark/static/schema_lite.yaml +3 -61
  52. helm/benchmark/static/schema_medical.yaml +255 -0
  53. helm/benchmark/static/schema_mmlu.yaml +3 -61
  54. helm/benchmark/static/schema_tables.yaml +200 -0
  55. helm/benchmark/static/schema_thai.yaml +223 -0
  56. helm/benchmark/static/schema_unitxt.yaml +3 -61
  57. helm/benchmark/static/{schema_vlm.yaml → schema_vhelm.yaml} +294 -293
  58. helm/benchmark/static/schema_vhelm_lite.yaml +4 -59
  59. helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
  60. helm/benchmark/static_build/assets/index-30dbceba.js +10 -0
  61. helm/benchmark/static_build/assets/index-66b02d40.css +1 -0
  62. helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
  63. helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
  64. helm/benchmark/static_build/index.html +2 -2
  65. helm/clients/anthropic_client.py +43 -9
  66. helm/clients/auto_client.py +11 -0
  67. helm/clients/client.py +24 -7
  68. helm/clients/cohere_client.py +98 -3
  69. helm/clients/huggingface_client.py +71 -12
  70. helm/clients/openai_client.py +9 -2
  71. helm/clients/reka_client.py +189 -0
  72. helm/clients/test_client.py +3 -3
  73. helm/clients/test_huggingface_client.py +19 -3
  74. helm/clients/test_together_client.py +72 -2
  75. helm/clients/together_client.py +129 -23
  76. helm/clients/vertexai_client.py +62 -18
  77. helm/clients/vision_language/huggingface_vlm_client.py +1 -0
  78. helm/clients/vision_language/paligemma_client.py +146 -0
  79. helm/clients/vision_language/palmyra_vision_client.py +84 -0
  80. helm/clients/yi_client.py +31 -0
  81. helm/common/critique_request.py +10 -1
  82. helm/common/images_utils.py +19 -0
  83. helm/config/model_deployments.yaml +412 -18
  84. helm/config/model_metadata.yaml +447 -25
  85. helm/config/tokenizer_configs.yaml +93 -1
  86. helm/proxy/critique/model_critique_client.py +32 -4
  87. helm/proxy/services/server_service.py +1 -1
  88. helm/tokenizers/auto_tokenizer.py +1 -1
  89. helm/tokenizers/cohere_tokenizer.py +44 -2
  90. helm/tokenizers/huggingface_tokenizer.py +36 -13
  91. helm/tokenizers/test_cohere_tokenizer.py +39 -0
  92. helm/tokenizers/test_huggingface_tokenizer.py +5 -1
  93. helm/benchmark/static_build/assets/index-737eef9e.js +0 -10
  94. helm/benchmark/static_build/assets/index-878a1094.css +0 -1
  95. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/LICENSE +0 -0
  96. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/WHEEL +0 -0
  97. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/entry_points.txt +0 -0
  98. {crfm_helm-0.5.1.dist-info → crfm_helm-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,185 @@
1
+ from typing import Dict, List, Optional
2
+ import re
3
+
4
+ from tqdm import tqdm
5
+
6
+ from helm.benchmark.adaptation.request_state import RequestState
7
+ from helm.benchmark.adaptation.scenario_state import ScenarioState
8
+ from helm.benchmark.adaptation.adapter_spec import AdapterSpec
9
+ from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats, add_context
10
+ from helm.benchmark.metrics.metric_name import MetricContext, MetricName
11
+ from helm.benchmark.metrics.metric_service import MetricService
12
+ from helm.benchmark.metrics.statistic import Stat, merge_stat
13
+ from helm.common.critique_request import CritiqueTaskTemplate, CritiqueQuestionTemplate, CritiqueRequest, QuestionType
14
+ from helm.common.hierarchical_logger import hlog
15
+ from helm.common.request import RequestResult, GeneratedOutput
16
+ from helm.common.media_object import MultimediaObject, IMAGE_TYPE, MediaObject, TEXT_TYPE
17
+
18
+
19
+ class PrometheusVisionCritiqueMetric(MetricInterface):
20
+ """
21
+ We compute the same metrics from the Prometheus-Vision: Vision-Language Model as a Judge for
22
+ Fine-Grained Evaluation paper:
23
+ https://arxiv.org/pdf/2401.06591.pdf
24
+
25
+ In this paper, the output of a Vision-Language Model named Prometheus-Vision is used to evaluate
26
+ the quality of the output of other Vision-Language Models to be evaluated.
27
+ """
28
+
29
+ # We can add more evaluation aspects here
30
+ METRIC_NAME: str = "prometheus_vision"
31
+ METRIC_PROMPT: str = """A chat between a curious human and an artificial intelligence assistant. \
32
+ The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:<image>\
33
+ ###Task Description:
34
+ An instruction (might include an Input inside it), a response to evaluate, a reference answer that gets a score of 5, \
35
+ image and a score rubric representing an evaluation criterion is given.
36
+ 1. Write a detailed feedback that assesses the quality of the response strictly based on the given score rubric, not \
37
+ evaluating in general.
38
+ 2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.
39
+ 3. The output format should look as follows: Feedback: (write a feedback for criteria) [RESULT] (an integer number \
40
+ between 1 and 5)
41
+ 4. Please do not generate any other opening, closing, and explanations.
42
+
43
+ ###The instruction to evaluate:
44
+ {{orig_instruction}}
45
+
46
+ ###Response to evaluate:
47
+ {{orig_response}}
48
+
49
+ ###Reference Answer (Score 5):
50
+ {{orig_reference_answer}}
51
+
52
+ ###Score Rubrics:
53
+ [{{orig_criteria}}]
54
+ Score 1: {{orig_score1_description}}
55
+ Score 2: {{orig_score2_description}}
56
+ Score 3: {{orig_score3_description}}
57
+ Score 4: {{orig_score4_description}}
58
+ Score 5: {{orig_score5_description}}
59
+
60
+ ###Feedback:
61
+ ASSISTANT:
62
+ """
63
+
64
+ def __init__(self, num_respondents: int, max_tokens: int):
65
+ self._num_respondents = num_respondents
66
+ self._max_tokens = max_tokens
67
+
68
+ def __repr__(self) -> str:
69
+ return "PrometheusVisionCritiqueMetric()"
70
+
71
+ def _extract_score_from_prometheus_vision_output(self, evaluator_response: str):
72
+ evaluator_response = evaluator_response.split("ASSISTANT:")[1]
73
+ re_match = re.search(r"\s*([1-5])", evaluator_response)
74
+ if re_match is None:
75
+ hlog(f"Error parsing answer: {evaluator_response}. Skipping question (and so the respondent entirely)")
76
+ return None
77
+ return int(re_match.group(1))
78
+
79
+ def evaluate(
80
+ self,
81
+ scenario_state: ScenarioState,
82
+ metric_service: MetricService,
83
+ eval_cache_path: str,
84
+ parallelism: int,
85
+ ) -> MetricResult:
86
+ request_states: List[RequestState] = scenario_state.request_states
87
+
88
+ all_stats: Dict[MetricName, Stat] = {}
89
+ per_instance_stats: List[PerInstanceStats] = []
90
+ for request_state in tqdm(request_states):
91
+ context = MetricContext.from_instance(request_state.instance)
92
+ stats_without_context = self.evaluate_generation(
93
+ scenario_state.adapter_spec,
94
+ request_state,
95
+ metric_service,
96
+ eval_cache_path,
97
+ )
98
+ stats = [add_context(stat_without_context, context) for stat_without_context in stats_without_context]
99
+ for stat in stats:
100
+ merge_stat(all_stats, stat)
101
+ assert request_state.instance.id is not None
102
+ per_instance_stats.append(
103
+ PerInstanceStats(
104
+ instance_id=request_state.instance.id,
105
+ perturbation=request_state.instance.perturbation,
106
+ train_trial_index=request_state.train_trial_index,
107
+ stats=stats,
108
+ )
109
+ )
110
+ return MetricResult(aggregated_stats=list(all_stats.values()), per_instance_stats=per_instance_stats)
111
+
112
+ def evaluate_generation(
113
+ self,
114
+ adapter_spec: AdapterSpec,
115
+ request_state: RequestState,
116
+ metric_service: MetricService,
117
+ eval_cache_path: str,
118
+ ) -> List[Stat]:
119
+ input_content = request_state.request
120
+ # Predicted outputs and their prometheus vision scores
121
+ assert request_state.result is not None
122
+ request_result: RequestResult = request_state.result
123
+ # Get input image and generated response for evaluation
124
+ assert input_content.multimodal_prompt is not None
125
+ completions: List[GeneratedOutput] = request_result.completions
126
+ generated_text: str = completions[0].text
127
+ input_media: MultimediaObject = input_content.multimodal_prompt
128
+ ref_text: str = request_state.instance.references[0].output.text
129
+ image_objects: List[MediaObject] = [
130
+ item for item in input_media.media_objects if item.is_type(IMAGE_TYPE) and item.location
131
+ ]
132
+ input_text: Optional[str] = [item for item in input_media.media_objects if item.is_type(TEXT_TYPE)][0].text
133
+
134
+ template = CritiqueTaskTemplate(
135
+ name="vhelm_prometheus_vision",
136
+ instructions=self.METRIC_PROMPT,
137
+ num_respondents=self._num_respondents,
138
+ max_tokens=self._max_tokens,
139
+ questions=[
140
+ CritiqueQuestionTemplate(
141
+ name=self.METRIC_NAME,
142
+ question_type=QuestionType.FREE_RESPONSE,
143
+ text="",
144
+ options=[],
145
+ media_object=image_objects[0], # we only take the first image as input
146
+ )
147
+ ],
148
+ )
149
+ request = CritiqueRequest(
150
+ template=template,
151
+ fields={
152
+ "orig_instruction": input_text if input_text is not None else "",
153
+ "orig_response": generated_text,
154
+ "orig_reference_answer": ref_text,
155
+ "orig_criteria": "similarity between the reponse and the reference.",
156
+ "orig_score1_description": "The model's responses do not follow the instructions provided.",
157
+ "orig_score2_description": "The resulting response follows the instructions, but the answer \
158
+ is completely wrong relative to the reference answer.",
159
+ "orig_score3_description": "The resulting response follows the instructions, but the answer is \
160
+ partially wrong relative to the reference answer.",
161
+ "orig_score4_description": "The resulting response follows the instructions, the overall answer \
162
+ is relatively perfect with only a very few errors.",
163
+ "orig_score5_description": "The overall answer is completely correct compared to the reference \
164
+ answer, and conforms to the instructions provided.",
165
+ },
166
+ )
167
+ # send to critique request
168
+ result = metric_service.make_critique_request(request)
169
+ if not result or not result.responses:
170
+ # Skip computing metrics if there aren't any responses yet
171
+ hlog("Waiting for responses to be generated.")
172
+ return []
173
+
174
+ stats: Dict[str, Stat] = {}
175
+ for question in template.questions:
176
+ stats[question.name] = Stat(MetricName(question.name))
177
+
178
+ for response in result.responses:
179
+ for answer_name, answer in response.answers.items():
180
+ assert isinstance(answer, str)
181
+ answer_value: float
182
+ answer_value = self._extract_score_from_prometheus_vision_output(answer)
183
+ stats[answer_name].add(answer_value)
184
+
185
+ return list(stats.values())
@@ -0,0 +1,158 @@
1
+ from typing import Dict, List, Optional
2
+ import re
3
+
4
+ from helm.benchmark.adaptation.request_state import RequestState
5
+ from helm.benchmark.adaptation.scenario_state import ScenarioState
6
+ from helm.benchmark.adaptation.adapter_spec import AdapterSpec
7
+ from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats, add_context
8
+ from helm.benchmark.metrics.metric_name import MetricContext, MetricName
9
+ from helm.benchmark.metrics.metric_service import MetricService
10
+ from helm.benchmark.metrics.statistic import Stat, merge_stat
11
+ from helm.common.critique_request import CritiqueTaskTemplate, CritiqueQuestionTemplate, CritiqueRequest, QuestionType
12
+ from helm.common.hierarchical_logger import hlog
13
+ from helm.common.request import RequestResult, GeneratedOutput
14
+ from helm.common.media_object import MultimediaObject, IMAGE_TYPE, TEXT_TYPE, MediaObject
15
+
16
+
17
+ class RekaVibeCritiqueMetric(MetricInterface):
18
+ """
19
+ Critique evaluation for evaluating the correctness of generated response given the image and
20
+ reference by Reka-vibe-eval.
21
+ """
22
+
23
+ # We can add more evaluation aspects here
24
+ VIBE_EVAL_NAME: str = "reka_vibe"
25
+ REKA_VIBE_PROMPT_WITH_IMAGE: str = """\
26
+ [Question]
27
+ {{prompt}}
28
+
29
+ [Assistant Response]
30
+ {{generation}}
31
+
32
+ [Ground Truth Response]
33
+ {{reference}}
34
+
35
+ [System]
36
+ Rate whether the assistant response correctly matches the ground truth, in regards to the image above.
37
+ The rating should be 1-5, where 1 is incorrect and 5 is correct.
38
+ Your response should be in the format:
39
+ Short Explanation: (explanation in only one sentence)
40
+ Rating: (int)"""
41
+
42
+ def __init__(self, num_respondents: int, max_tokens: int):
43
+ self._num_respondents = num_respondents
44
+ self._max_tokens = max_tokens
45
+
46
+ def __repr__(self) -> str:
47
+ return "RekaVibeCritiqueMetric()"
48
+
49
+ def _extract_score_from_reka_output(self, evaluator_response: str):
50
+ """
51
+ Extract the score from the evaluator response. Refer to the official Vibe-Eval implementation:
52
+ https://github.com/reka-ai/reka-vibe-eval/blob/3852d4712da172a7b85dddeffc4f9c3482a6f4c9/evaluate.py#L159-#L164
53
+ """
54
+ re_match = re.search(r"Rating:\s*([1-5])", evaluator_response)
55
+ if re_match is None:
56
+ hlog(f"Error parsing answer: {evaluator_response}. Skipping question (and so the respondent entirely)")
57
+ return None
58
+ return int(re_match.group(1))
59
+
60
+ def evaluate(
61
+ self,
62
+ scenario_state: ScenarioState,
63
+ metric_service: MetricService,
64
+ eval_cache_path: str,
65
+ parallelism: int,
66
+ ) -> MetricResult:
67
+ request_states: List[RequestState] = scenario_state.request_states
68
+
69
+ all_stats: Dict[MetricName, Stat] = {}
70
+ per_instance_stats: List[PerInstanceStats] = []
71
+ for request_state in request_states:
72
+ context = MetricContext.from_instance(request_state.instance)
73
+ stats_without_context = self.evaluate_generation(
74
+ scenario_state.adapter_spec,
75
+ request_state,
76
+ metric_service,
77
+ eval_cache_path,
78
+ )
79
+ stats = [add_context(stat_without_context, context) for stat_without_context in stats_without_context]
80
+ for stat in stats:
81
+ merge_stat(all_stats, stat)
82
+ assert request_state.instance.id is not None
83
+ per_instance_stats.append(
84
+ PerInstanceStats(
85
+ instance_id=request_state.instance.id,
86
+ perturbation=request_state.instance.perturbation,
87
+ train_trial_index=request_state.train_trial_index,
88
+ stats=stats,
89
+ )
90
+ )
91
+ return MetricResult(aggregated_stats=list(all_stats.values()), per_instance_stats=per_instance_stats)
92
+
93
+ def evaluate_generation(
94
+ self,
95
+ adapter_spec: AdapterSpec,
96
+ request_state: RequestState,
97
+ metric_service: MetricService,
98
+ eval_cache_path: str,
99
+ ) -> List[Stat]:
100
+ input_content = request_state.request
101
+ # Predicted outputs and their originality scores
102
+ assert request_state.result is not None
103
+ request_result: RequestResult = request_state.result
104
+ # Get input image and generated response for the originality evaluation
105
+ assert input_content.multimodal_prompt is not None
106
+ completions: List[GeneratedOutput] = request_result.completions
107
+ generated_text: str = completions[0].text
108
+ input_media: MultimediaObject = input_content.multimodal_prompt
109
+ ref_text: str = request_state.instance.references[0].output.text
110
+
111
+ image_objects: List[MediaObject] = [
112
+ item for item in input_media.media_objects if item.is_type(IMAGE_TYPE) and item.location
113
+ ]
114
+ input_text: Optional[str] = [item for item in input_media.media_objects if item.is_type(TEXT_TYPE)][0].text
115
+
116
+ template = CritiqueTaskTemplate(
117
+ name="vhelm_vibe_eval",
118
+ instructions=self.REKA_VIBE_PROMPT_WITH_IMAGE,
119
+ num_respondents=self._num_respondents,
120
+ max_tokens=self._max_tokens,
121
+ questions=[
122
+ CritiqueQuestionTemplate(
123
+ name=self.VIBE_EVAL_NAME,
124
+ question_type=QuestionType.FREE_RESPONSE,
125
+ text="",
126
+ options=[],
127
+ media_object=image_objects[0], # we only take the first image as input
128
+ )
129
+ ],
130
+ )
131
+
132
+ request = CritiqueRequest(
133
+ template=template,
134
+ fields={
135
+ "prompt": input_text if input_text is not None else "",
136
+ "generation": generated_text,
137
+ "reference": ref_text,
138
+ },
139
+ )
140
+
141
+ # send to critique request
142
+ result = metric_service.make_critique_request(request)
143
+ if not result or not result.responses:
144
+ # Skip computing metrics if there aren't any responses yet
145
+ hlog("Waiting for responses to be generated.")
146
+ return []
147
+ stats: Dict[str, Stat] = {}
148
+ for question in template.questions:
149
+ stats[question.name] = Stat(MetricName(question.name))
150
+
151
+ for response in result.responses:
152
+ for answer_name, answer in response.answers.items():
153
+ assert isinstance(answer, str)
154
+ answer_value: float
155
+ answer_value = self._extract_score_from_reka_output(answer)
156
+ stats[answer_name].add(answer_value)
157
+
158
+ return list(stats.values())
@@ -50,16 +50,20 @@ class UnitxtMetric(MetricInterface):
50
50
  for metric_name, metric_score in instance_results.items():
51
51
  if metric_name == "score" or metric_name == "score_name":
52
52
  continue
53
- instance_stats.append(
54
- Stat(
55
- MetricName(
56
- name=metric_name,
57
- split=instance.split,
58
- sub_split=instance.sub_split,
59
- perturbation=instance.perturbation,
60
- )
61
- ).add(metric_score)
53
+ stat = Stat(
54
+ MetricName(
55
+ name=metric_name,
56
+ split=instance.split,
57
+ sub_split=instance.sub_split,
58
+ perturbation=instance.perturbation,
59
+ )
62
60
  )
61
+ if isinstance(metric_score, list):
62
+ for metric_score_element in metric_score:
63
+ stat = stat.add(metric_score_element)
64
+ else:
65
+ stat = stat.add(metric_score)
66
+ instance_stats.append(stat)
63
67
  assert instance.id
64
68
  per_instance_stats.append(
65
69
  PerInstanceStats(
@@ -77,5 +81,11 @@ class UnitxtMetric(MetricInterface):
77
81
  for metric_name, metric_score in global_results.items():
78
82
  if metric_name == "score" or metric_name == "score_name":
79
83
  continue
80
- aggregated_stats.append(Stat(MetricName(name=metric_name)).add(metric_score))
84
+ stat = Stat(MetricName(name=metric_name))
85
+ if isinstance(metric_score, list):
86
+ for metric_score_element in metric_score:
87
+ stat = stat.add(metric_score_element)
88
+ else:
89
+ stat = stat.add(metric_score)
90
+ aggregated_stats.append(stat)
81
91
  return MetricResult(aggregated_stats=aggregated_stats, per_instance_stats=per_instance_stats)
@@ -280,6 +280,10 @@ def compute_emd_recursive(
280
280
  assert max_num_patches > 0
281
281
  assert 0 < weight_most_frequent_color <= 1
282
282
 
283
+ # Convert the images to RGB first. Some images have 4 channels (RGBA)
284
+ img1_PIL = img1_PIL.convert("RGB")
285
+ img2_PIL = img2_PIL.convert("RGB")
286
+
283
287
  # Resize the images so that there are not too many patches
284
288
  # Try to maintain the aspect ratio and resize to a multiple of the patch size
285
289
  num_patches = math.ceil(img1_PIL.size[0] / patch_size[0]) * math.ceil(img1_PIL.size[1] / patch_size[1])
@@ -28,7 +28,7 @@ from helm.benchmark.metrics.vision_language.image_utils import (
28
28
  pixel_similarity,
29
29
  sift_similarity,
30
30
  )
31
- from helm.benchmark.metrics.vision_language.emd_utils import compute_emd_recursive, get_most_frequent_color
31
+ from helm.benchmark.metrics.vision_language.emd_utils import compute_emd_recursive, get_most_frequent_color, to_gray
32
32
 
33
33
  try:
34
34
  from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
@@ -78,9 +78,8 @@ class AnnotatedImageMetrics(Metric):
78
78
 
79
79
  # Metric names
80
80
  COMPILE_METRIC: str = "compilation_success"
81
- BLOCK_EARTH_MOVER_SIMILARITY_NORM1: str = "block_emd_similarity_white"
82
- BLOCK_EARTH_MOVER_SIMILARITY_NORM2: str = "block_emd_similarity_median_color"
83
- BLOCK_EARTH_MOVER_SIMILARITY: str = "block_emd_similarity"
81
+ EARTH_MOVER_SIMILARITY = "earth_mover_similarity"
82
+ BLOCK_EMD: str = "block_emd"
84
83
  PIXEL_SIMILARITY: str = "pixel_similarity"
85
84
  SIFT_SIMILARITY: str = "sift_similarity"
86
85
  LPIPS_SIMILARITY: str = "lpips_similarity"
@@ -108,12 +107,10 @@ class AnnotatedImageMetrics(Metric):
108
107
  metrics: List[AnnotatedMetric] = [
109
108
  AnnotatedMetric(self.PIXEL_SIMILARITY, pixel_similarity, "image_np_gray"),
110
109
  AnnotatedMetric(self.SIFT_SIMILARITY, sift_similarity, "image_np"),
111
- # Raw block EMD
112
- AnnotatedMetric(self.BLOCK_EARTH_MOVER_SIMILARITY, self.compute_block_emd_raw, "image_PIL"),
113
- # Normalized block EMD against white
114
- AnnotatedMetric(self.BLOCK_EARTH_MOVER_SIMILARITY_NORM1, self.compute_block_emd_white, "image_PIL"),
115
- # Normalized block EMD against median
116
- AnnotatedMetric(self.BLOCK_EARTH_MOVER_SIMILARITY_NORM2, self.compute_block_emd_median, "image_PIL"),
110
+ AnnotatedMetric(self.BLOCK_EMD, self.compute_block_emd_raw, "image_PIL"), # Raw block-EMD
111
+ AnnotatedMetric(
112
+ self.EARTH_MOVER_SIMILARITY, self.ems, "image_PIL"
113
+ ), # Normalized block-EMD against black/white
117
114
  AnnotatedMetric(self.LPIPS_SIMILARITY, self.lpips_similarity, "image_PIL"),
118
115
  AnnotatedMetric(self.FID_SIMILARITY, self.fid_similarity, "image_PIL"),
119
116
  AnnotatedMetric(self.SSIM_SIMILARITY, self.compute_ssim, "image_np_gray"),
@@ -391,9 +388,15 @@ class AnnotatedImageMetrics(Metric):
391
388
  features1 = self._get_inception_features(img1_tensor)
392
389
  features2 = self._get_inception_features(img2_tensor)
393
390
 
394
- fid_score = self._calculate_fid(features1, features2)
395
- normalize_fid: float = np.exp(-fid_score * self.NORMALIZE_FID_FACTOR)
396
- return normalize_fid
391
+ # TODO: Justify the value of the constant here or remove this code to only keep the cosine similarity.
392
+ # fid_score = self._calculate_fid(features1, features2)
393
+ # normalize_fid: float = np.exp(-fid_score * self.NORMALIZE_FID_FACTOR)
394
+ # return normalize_fid
395
+
396
+ # Use the cosine similarity between the features as a proxy for FID
397
+ # Return a score between 0 and 1, where 1 is the most similar
398
+ score = 0.5 * (1 + np.dot(features1[0], features2[0]) / (np.linalg.norm(features1) * np.linalg.norm(features2)))
399
+ return score
397
400
 
398
401
  def compute_ssim(self, generated_image: np.ndarray, reference_image: np.ndarray) -> float:
399
402
  """Compute the Structural Similarity Index (SSIM) between the generated and reference images."""
@@ -414,58 +417,7 @@ class AnnotatedImageMetrics(Metric):
414
417
  result = _edit_similarity(completion_tokens, truncated_reference_tokens)
415
418
  return result
416
419
 
417
- def compute_block_emd_white(
418
- self,
419
- pred_image: Image.Image,
420
- ref_image: Image.Image,
421
- threshold_most_frequent_color: float = 0.5,
422
- patch_size: Tuple[int, int] = (8, 8),
423
- max_num_patches: int = 100,
424
- weight_most_frequent_color: float = 0.001,
425
- use_tqdm: bool = False,
426
- ):
427
- """Computes the block Earth Moving Distance (EMD). This attempts to
428
- speed up EMD for images with huge areas by considering movement/transformatio
429
- of blocks of pixels. The score is normalized against EMD against white images
430
- """
431
-
432
- def compute_numerator():
433
- return self.compute_block_emd_raw_wrapper(
434
- pred_image,
435
- ref_image,
436
- threshold_most_frequent_color,
437
- patch_size,
438
- max_num_patches,
439
- weight_most_frequent_color,
440
- use_tqdm,
441
- )
442
-
443
- def compute_denominator():
444
- constant_image = Image.new("RGB", ref_image.size, (255, 255, 255)) # default color is white
445
- value = compute_emd_recursive(
446
- constant_image,
447
- ref_image,
448
- threshold_most_frequent_color,
449
- patch_size,
450
- max_num_patches,
451
- weight_most_frequent_color,
452
- use_tqdm,
453
- )
454
- return {"value": value}
455
-
456
- hash_dict = {
457
- "reference_image": str(AnnotatedImageMetrics.HASH_FUNC(ref_image, hash_size=self.HASH_LENGTH)),
458
- }
459
- cache_key_numerator = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY}", **hash_dict}
460
- cache_key_denominator = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY_NORM1}", **hash_dict}
461
-
462
- assert self._cache is not None
463
- emd_raw, _ = self._cache.get(cache_key_numerator, compute_numerator)
464
- emd_base, _ = self._cache.get(cache_key_denominator, compute_denominator)
465
-
466
- return 1.0 - emd_raw["value"] / emd_base["value"]
467
-
468
- def compute_block_emd_median(
420
+ def ems(
469
421
  self,
470
422
  pred_image: Image.Image,
471
423
  ref_image: Image.Image,
@@ -493,9 +445,13 @@ class AnnotatedImageMetrics(Metric):
493
445
  def compute_denominator():
494
446
  ref_img_np = np.array(ref_image)
495
447
  (rgb_most_frequent_color, _) = get_most_frequent_color(ref_img_np)
448
+ grayscale_most_frequent_color = to_gray(rgb_most_frequent_color)[0]
496
449
 
497
450
  # Most frequent color as base
498
- constant_image = Image.new("RGB", ref_image.size, tuple(rgb_most_frequent_color)) # type: ignore
451
+ if grayscale_most_frequent_color < 127:
452
+ constant_image = Image.new("RGB", ref_image.size, (255, 255, 255)) # Make it white
453
+ else:
454
+ constant_image = Image.new("RGB", ref_image.size, (0, 0, 0)) # Make it black
499
455
  value = compute_emd_recursive(
500
456
  constant_image,
501
457
  ref_image,
@@ -509,9 +465,10 @@ class AnnotatedImageMetrics(Metric):
509
465
 
510
466
  hash_dict = {
511
467
  "reference_image": str(AnnotatedImageMetrics.HASH_FUNC(ref_image, hash_size=self.HASH_LENGTH)),
468
+ "generated_image": str(AnnotatedImageMetrics.HASH_FUNC(pred_image, hash_size=self.HASH_LENGTH)),
512
469
  }
513
- cache_key_numerator = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY}", **hash_dict}
514
- cache_key_denominator = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY_NORM2}", **hash_dict}
470
+ cache_key_numerator = {"metric_name": f"intermediate_{self.BLOCK_EMD}", **hash_dict}
471
+ cache_key_denominator = {"metric_name": "intermediate_ems_extreme_denominator", **hash_dict}
515
472
 
516
473
  assert self._cache is not None
517
474
  emd_raw, _ = self._cache.get(cache_key_numerator, compute_numerator)
@@ -542,8 +499,9 @@ class AnnotatedImageMetrics(Metric):
542
499
 
543
500
  hash_dict = {
544
501
  "reference_image": str(AnnotatedImageMetrics.HASH_FUNC(ref_image, hash_size=self.HASH_LENGTH)),
502
+ "generated_image": str(AnnotatedImageMetrics.HASH_FUNC(pred_image, hash_size=self.HASH_LENGTH)),
545
503
  }
546
- cache_key = {"metric_name": f"intermediate_{self.BLOCK_EARTH_MOVER_SIMILARITY}", **hash_dict}
504
+ cache_key = {"metric_name": f"intermediate_{self.BLOCK_EMD}", **hash_dict}
547
505
  assert self._cache is not None
548
506
  emd_raw, _ = self._cache.get(cache_key, compute)
549
507
 
@@ -560,8 +518,8 @@ class AnnotatedImageMetrics(Metric):
560
518
  use_tqdm: bool = False,
561
519
  ):
562
520
  """Computes the block Earth Moving Distance (EMD). This attempts to
563
- speed up EMD for images with huge areas by considering movement/transformatio
564
- of blocks of pixels. The score is normalized against EMD against white images
521
+ speed up EMD for images with huge areas by considering
522
+ movement/transformation of blocks of pixels.
565
523
  """
566
524
  emd_value = compute_emd_recursive(
567
525
  pred_image,
@@ -1,6 +1,9 @@
1
+ import ast
2
+ import dataclasses
1
3
  from dataclasses import dataclass, field
2
4
  from typing import List, Optional, Dict
3
5
  import dacite
6
+ from inspect import cleandoc
4
7
  import mako.template
5
8
  import yaml
6
9
  import importlib_resources as resources
@@ -17,6 +20,11 @@ SCHEMA_YAML_PACKAGE: str = "helm.benchmark.static"
17
20
  SCHEMA_CLASSIC_YAML_FILENAME: str = "schema_classic.yaml"
18
21
 
19
22
 
23
+ _ADAPTER_SPEC_PACKAGE = "helm.benchmark.adaptation"
24
+ _ADAPTER_SPEC_FILENAME = "adapter_spec.py"
25
+ _ADAPTER_SPEC_CLASS_NAME = "AdapterSpec"
26
+
27
+
20
28
  @dataclass(frozen=True)
21
29
  class Field:
22
30
  """
@@ -198,9 +206,6 @@ class RunGroup(Field):
198
206
  class Schema:
199
207
  """Specifies information about what to display on the frontend."""
200
208
 
201
- # Adapter fields (e.g., temperature)
202
- adapter: List[Field]
203
-
204
209
  # Information about each field
205
210
  metrics: List[Field]
206
211
 
@@ -213,6 +218,11 @@ class Schema:
213
218
  # Group the scenarios
214
219
  run_groups: List[RunGroup]
215
220
 
221
+ # Adapter fields (e.g., temperature)
222
+ # Automatically populated from the docstrings in the AdapterSpec class definition.
223
+ # Should not be specified in the user's YAML file.
224
+ adapter: Optional[List[Field]] = None
225
+
216
226
  def __post_init__(self):
217
227
  self.name_to_metric = {metric.name: metric for metric in self.metrics}
218
228
  self.name_to_perturbation = {perturbation.name: perturbation for perturbation in self.perturbations}
@@ -220,6 +230,43 @@ class Schema:
220
230
  self.name_to_run_group = {run_group.name: run_group for run_group in self.run_groups}
221
231
 
222
232
 
233
+ def get_adapter_fields() -> List[Field]:
234
+ """Generate the adapter fields from the docstrings in the AdapterSpec class definition."""
235
+ # Unfortunately there is no standard library support for getting docstrings of class fields,
236
+ # so we have to do the parsing outselves. Fortunately, the parsing is quite straightforward.
237
+ adapter_spec_path = resources.files(_ADAPTER_SPEC_PACKAGE).joinpath(_ADAPTER_SPEC_FILENAME)
238
+ with open(adapter_spec_path, "r") as f:
239
+ contents = f.read()
240
+ module_node = ast.parse(contents)
241
+ adapter_spec_node = [
242
+ node
243
+ for node in ast.iter_child_nodes(module_node)
244
+ if isinstance(node, ast.ClassDef) and node.name == _ADAPTER_SPEC_CLASS_NAME
245
+ ][0]
246
+ metadata_fields: List[Field] = []
247
+ field_name: str = ""
248
+ for node in ast.iter_child_nodes(adapter_spec_node):
249
+ if isinstance(node, ast.AnnAssign) and isinstance(node.target, ast.Name):
250
+ # This node is a field definition.
251
+ # Save the name of the field for later.
252
+ field_name = node.target.id
253
+ else:
254
+ # If this is a docstring that immediately follows a field definition,
255
+ # output an adapter field with the name set to the field definition and
256
+ # the description set to the docstring.
257
+ if (
258
+ field_name
259
+ and isinstance(node, ast.Expr)
260
+ and isinstance(node.value, ast.Constant)
261
+ and isinstance(node.value.value, str)
262
+ ):
263
+ description = cleandoc(node.value.value).replace("\n", " ")
264
+ metadata_fields.append(Field(name=field_name, description=description))
265
+ field_name = ""
266
+
267
+ return metadata_fields
268
+
269
+
223
270
  def get_default_schema_path() -> str:
224
271
  return resources.files(SCHEMA_YAML_PACKAGE).joinpath(SCHEMA_CLASSIC_YAML_FILENAME)
225
272
 
@@ -229,4 +276,7 @@ def read_schema(schema_path: str) -> Schema:
229
276
  hlog(f"Reading schema file {schema_path}...")
230
277
  with open(schema_path, "r") as f:
231
278
  raw = yaml.safe_load(f)
232
- return dacite.from_dict(Schema, raw)
279
+ schema = dacite.from_dict(Schema, raw)
280
+ if schema.adapter:
281
+ hlog(f"WARNING: The `adapter` field is deprecated and should be removed from schema file {schema_path}")
282
+ return dataclasses.replace(schema, adapter=get_adapter_fields())
@@ -0,0 +1,11 @@
1
+ from helm.benchmark.presentation.schema import get_adapter_fields
2
+
3
+
4
+ def test_get_adapter_fields() -> None:
5
+ adapter_fields = get_adapter_fields()
6
+ assert adapter_fields
7
+ assert adapter_fields[0].name == "method"
8
+ assert (
9
+ adapter_fields[0].description
10
+ == "The high-level strategy for converting instances into a prompt for the language model."
11
+ )