crfm-helm 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of crfm-helm might be problematic. Click here for more details.

Files changed (125) hide show
  1. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/METADATA +19 -5
  2. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/RECORD +121 -76
  3. helm/benchmark/adaptation/adapter_spec.py +32 -31
  4. helm/benchmark/adaptation/adapters/multimodal/in_context_learning_multimodal_adapter.py +1 -0
  5. helm/benchmark/adaptation/adapters/multimodal/multimodal_prompt.py +7 -0
  6. helm/benchmark/adaptation/adapters/multimodal/test_multimodal_prompt.py +2 -0
  7. helm/benchmark/annotation/air_bench_annotator.py +64 -0
  8. helm/benchmark/annotation/annotator_factory.py +6 -0
  9. helm/benchmark/annotation/image2structure/lilypond_compiler_annotator.py +1 -1
  10. helm/benchmark/annotation/live_qa_annotator.py +84 -0
  11. helm/benchmark/annotation/medication_qa_annotator.py +81 -0
  12. helm/benchmark/augmentations/perturbation.py +17 -1
  13. helm/benchmark/augmentations/test_perturbation.py +30 -0
  14. helm/benchmark/augmentations/translate_perturbation.py +1 -0
  15. helm/benchmark/huggingface_registration.py +16 -6
  16. helm/benchmark/metrics/air_bench_metrics.py +56 -0
  17. helm/benchmark/metrics/efficiency_metrics.py +9 -2
  18. helm/benchmark/metrics/evaluate_reference_metrics.py +16 -0
  19. helm/benchmark/metrics/fin_qa_metrics.py +60 -0
  20. helm/benchmark/metrics/fin_qa_metrics_helper.py +398 -0
  21. helm/benchmark/metrics/gpt4v_originality_critique_metrics.py +126 -0
  22. helm/benchmark/metrics/instruction_following_critique_metrics.py +1 -0
  23. helm/benchmark/metrics/live_qa_metrics.py +23 -0
  24. helm/benchmark/metrics/medication_qa_metrics.py +23 -0
  25. helm/benchmark/metrics/prometheus_vision_critique_metrics.py +185 -0
  26. helm/benchmark/metrics/reka_vibe_critique_metrics.py +158 -0
  27. helm/benchmark/metrics/unitxt_metrics.py +20 -10
  28. helm/benchmark/metrics/vision_language/emd_utils.py +4 -0
  29. helm/benchmark/metrics/vision_language/image_metrics.py +104 -21
  30. helm/benchmark/model_metadata_registry.py +5 -1
  31. helm/benchmark/presentation/schema.py +54 -4
  32. helm/benchmark/presentation/test_schema.py +11 -0
  33. helm/benchmark/run.py +16 -2
  34. helm/benchmark/run_expander.py +112 -63
  35. helm/benchmark/run_spec_factory.py +15 -10
  36. helm/benchmark/run_specs/air_bench_run_specs.py +40 -0
  37. helm/benchmark/run_specs/classic_run_specs.py +15 -11
  38. helm/benchmark/run_specs/decodingtrust_run_specs.py +3 -1
  39. helm/benchmark/run_specs/experimental_run_specs.py +33 -0
  40. helm/benchmark/run_specs/finance_run_specs.py +33 -0
  41. helm/benchmark/run_specs/vlm_run_specs.py +444 -65
  42. helm/benchmark/scenarios/air_bench_scenario.py +50 -0
  43. helm/benchmark/scenarios/ci_mcqa_scenario.py +80 -0
  44. helm/benchmark/scenarios/entity_data_imputation_scenario.py +8 -2
  45. helm/benchmark/scenarios/fin_qa_scenario.py +117 -0
  46. helm/benchmark/scenarios/legalbench_scenario.py +6 -2
  47. helm/benchmark/scenarios/math_scenario.py +1 -1
  48. helm/benchmark/scenarios/test_air_bench_scenario.py +27 -0
  49. helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +83 -0
  50. helm/benchmark/scenarios/vision_language/bingo_scenario.py +3 -3
  51. helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +134 -0
  52. helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +74 -0
  53. helm/benchmark/scenarios/vision_language/gqa_scenario.py +91 -0
  54. helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +4 -2
  55. helm/benchmark/scenarios/vision_language/image2structure/image2structure_scenario.py +13 -2
  56. helm/benchmark/scenarios/vision_language/image2structure/latex_scenario.py +1 -5
  57. helm/benchmark/scenarios/vision_language/image2structure/musicsheet_scenario.py +1 -5
  58. helm/benchmark/scenarios/vision_language/image2structure/webpage_scenario.py +5 -3
  59. helm/benchmark/scenarios/vision_language/math_vista_scenario.py +117 -0
  60. helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +103 -0
  61. helm/benchmark/scenarios/vision_language/mscoco_captioning_scenario.py +92 -0
  62. helm/benchmark/scenarios/vision_language/mscoco_categorization_scenario.py +117 -0
  63. helm/benchmark/scenarios/vision_language/originality_scenario.py +35 -0
  64. helm/benchmark/scenarios/vision_language/pairs_scenario.py +247 -0
  65. helm/benchmark/scenarios/vision_language/unicorn_scenario.py +3 -3
  66. helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +95 -0
  67. helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +2 -2
  68. helm/benchmark/scenarios/vision_language/vqa_scenario.py +4 -2
  69. helm/benchmark/static/schema_air_bench.yaml +3149 -0
  70. helm/benchmark/static/schema_classic.yaml +3 -59
  71. helm/benchmark/static/schema_finance.yaml +143 -0
  72. helm/benchmark/static/schema_image2structure.yaml +447 -0
  73. helm/benchmark/static/schema_instruction_following.yaml +3 -52
  74. helm/benchmark/static/schema_lite.yaml +3 -61
  75. helm/benchmark/static/schema_medical.yaml +255 -0
  76. helm/benchmark/static/schema_mmlu.yaml +3 -61
  77. helm/benchmark/static/schema_tables.yaml +200 -0
  78. helm/benchmark/static/schema_thai.yaml +223 -0
  79. helm/benchmark/static/schema_unitxt.yaml +3 -61
  80. helm/benchmark/static/schema_vhelm.yaml +824 -0
  81. helm/benchmark/static/schema_vhelm_lite.yaml +109 -0
  82. helm/benchmark/static_build/assets/air-overview-d2e6c49f.png +0 -0
  83. helm/benchmark/static_build/assets/index-30dbceba.js +10 -0
  84. helm/benchmark/static_build/assets/index-66b02d40.css +1 -0
  85. helm/benchmark/static_build/assets/overview-74aea3d8.png +0 -0
  86. helm/benchmark/static_build/assets/process-flow-bd2eba96.png +0 -0
  87. helm/benchmark/static_build/index.html +2 -2
  88. helm/clients/anthropic_client.py +78 -14
  89. helm/clients/auto_client.py +11 -0
  90. helm/clients/client.py +24 -7
  91. helm/clients/cohere_client.py +98 -3
  92. helm/clients/huggingface_client.py +71 -12
  93. helm/clients/openai_client.py +11 -5
  94. helm/clients/reka_client.py +189 -0
  95. helm/clients/test_client.py +3 -3
  96. helm/clients/test_huggingface_client.py +19 -3
  97. helm/clients/test_together_client.py +72 -2
  98. helm/clients/together_client.py +199 -2
  99. helm/clients/vertexai_client.py +117 -64
  100. helm/clients/vision_language/huggingface_vision2seq_client.py +145 -0
  101. helm/clients/vision_language/huggingface_vlm_client.py +12 -4
  102. helm/clients/vision_language/idefics_client.py +2 -2
  103. helm/clients/vision_language/paligemma_client.py +146 -0
  104. helm/clients/vision_language/palmyra_vision_client.py +84 -0
  105. helm/clients/yi_client.py +31 -0
  106. helm/common/critique_request.py +10 -1
  107. helm/common/images_utils.py +29 -3
  108. helm/config/model_deployments.yaml +504 -12
  109. helm/config/model_metadata.yaml +579 -52
  110. helm/config/tokenizer_configs.yaml +100 -1
  111. helm/proxy/critique/model_critique_client.py +32 -4
  112. helm/proxy/services/server_service.py +1 -1
  113. helm/tokenizers/auto_tokenizer.py +1 -1
  114. helm/tokenizers/cohere_tokenizer.py +44 -2
  115. helm/tokenizers/huggingface_tokenizer.py +36 -13
  116. helm/tokenizers/test_cohere_tokenizer.py +39 -0
  117. helm/tokenizers/test_huggingface_tokenizer.py +5 -1
  118. helm/benchmark/static/schema_vlm.yaml +0 -576
  119. helm/benchmark/static_build/assets/index-5088afcb.css +0 -1
  120. helm/benchmark/static_build/assets/index-d839df55.js +0 -9
  121. helm/benchmark/test_model_deployment_definition.py +0 -90
  122. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/LICENSE +0 -0
  123. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/WHEEL +0 -0
  124. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/entry_points.txt +0 -0
  125. {crfm_helm-0.5.0.dist-info → crfm_helm-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,185 @@
1
+ from typing import Dict, List, Optional
2
+ import re
3
+
4
+ from tqdm import tqdm
5
+
6
+ from helm.benchmark.adaptation.request_state import RequestState
7
+ from helm.benchmark.adaptation.scenario_state import ScenarioState
8
+ from helm.benchmark.adaptation.adapter_spec import AdapterSpec
9
+ from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats, add_context
10
+ from helm.benchmark.metrics.metric_name import MetricContext, MetricName
11
+ from helm.benchmark.metrics.metric_service import MetricService
12
+ from helm.benchmark.metrics.statistic import Stat, merge_stat
13
+ from helm.common.critique_request import CritiqueTaskTemplate, CritiqueQuestionTemplate, CritiqueRequest, QuestionType
14
+ from helm.common.hierarchical_logger import hlog
15
+ from helm.common.request import RequestResult, GeneratedOutput
16
+ from helm.common.media_object import MultimediaObject, IMAGE_TYPE, MediaObject, TEXT_TYPE
17
+
18
+
19
+ class PrometheusVisionCritiqueMetric(MetricInterface):
20
+ """
21
+ We compute the same metrics from the Prometheus-Vision: Vision-Language Model as a Judge for
22
+ Fine-Grained Evaluation paper:
23
+ https://arxiv.org/pdf/2401.06591.pdf
24
+
25
+ In this paper, the output of a Vision-Language Model named Prometheus-Vision is used to evaluate
26
+ the quality of the output of other Vision-Language Models to be evaluated.
27
+ """
28
+
29
+ # We can add more evaluation aspects here
30
+ METRIC_NAME: str = "prometheus_vision"
31
+ METRIC_PROMPT: str = """A chat between a curious human and an artificial intelligence assistant. \
32
+ The assistant gives helpful, detailed, and polite answers to the human's questions.\nUSER:<image>\
33
+ ###Task Description:
34
+ An instruction (might include an Input inside it), a response to evaluate, a reference answer that gets a score of 5, \
35
+ image and a score rubric representing an evaluation criterion is given.
36
+ 1. Write a detailed feedback that assesses the quality of the response strictly based on the given score rubric, not \
37
+ evaluating in general.
38
+ 2. After writing a feedback, write a score that is an integer between 1 and 5. You should refer to the score rubric.
39
+ 3. The output format should look as follows: Feedback: (write a feedback for criteria) [RESULT] (an integer number \
40
+ between 1 and 5)
41
+ 4. Please do not generate any other opening, closing, and explanations.
42
+
43
+ ###The instruction to evaluate:
44
+ {{orig_instruction}}
45
+
46
+ ###Response to evaluate:
47
+ {{orig_response}}
48
+
49
+ ###Reference Answer (Score 5):
50
+ {{orig_reference_answer}}
51
+
52
+ ###Score Rubrics:
53
+ [{{orig_criteria}}]
54
+ Score 1: {{orig_score1_description}}
55
+ Score 2: {{orig_score2_description}}
56
+ Score 3: {{orig_score3_description}}
57
+ Score 4: {{orig_score4_description}}
58
+ Score 5: {{orig_score5_description}}
59
+
60
+ ###Feedback:
61
+ ASSISTANT:
62
+ """
63
+
64
+ def __init__(self, num_respondents: int, max_tokens: int):
65
+ self._num_respondents = num_respondents
66
+ self._max_tokens = max_tokens
67
+
68
+ def __repr__(self) -> str:
69
+ return "PrometheusVisionCritiqueMetric()"
70
+
71
+ def _extract_score_from_prometheus_vision_output(self, evaluator_response: str):
72
+ evaluator_response = evaluator_response.split("ASSISTANT:")[1]
73
+ re_match = re.search(r"\s*([1-5])", evaluator_response)
74
+ if re_match is None:
75
+ hlog(f"Error parsing answer: {evaluator_response}. Skipping question (and so the respondent entirely)")
76
+ return None
77
+ return int(re_match.group(1))
78
+
79
+ def evaluate(
80
+ self,
81
+ scenario_state: ScenarioState,
82
+ metric_service: MetricService,
83
+ eval_cache_path: str,
84
+ parallelism: int,
85
+ ) -> MetricResult:
86
+ request_states: List[RequestState] = scenario_state.request_states
87
+
88
+ all_stats: Dict[MetricName, Stat] = {}
89
+ per_instance_stats: List[PerInstanceStats] = []
90
+ for request_state in tqdm(request_states):
91
+ context = MetricContext.from_instance(request_state.instance)
92
+ stats_without_context = self.evaluate_generation(
93
+ scenario_state.adapter_spec,
94
+ request_state,
95
+ metric_service,
96
+ eval_cache_path,
97
+ )
98
+ stats = [add_context(stat_without_context, context) for stat_without_context in stats_without_context]
99
+ for stat in stats:
100
+ merge_stat(all_stats, stat)
101
+ assert request_state.instance.id is not None
102
+ per_instance_stats.append(
103
+ PerInstanceStats(
104
+ instance_id=request_state.instance.id,
105
+ perturbation=request_state.instance.perturbation,
106
+ train_trial_index=request_state.train_trial_index,
107
+ stats=stats,
108
+ )
109
+ )
110
+ return MetricResult(aggregated_stats=list(all_stats.values()), per_instance_stats=per_instance_stats)
111
+
112
+ def evaluate_generation(
113
+ self,
114
+ adapter_spec: AdapterSpec,
115
+ request_state: RequestState,
116
+ metric_service: MetricService,
117
+ eval_cache_path: str,
118
+ ) -> List[Stat]:
119
+ input_content = request_state.request
120
+ # Predicted outputs and their prometheus vision scores
121
+ assert request_state.result is not None
122
+ request_result: RequestResult = request_state.result
123
+ # Get input image and generated response for evaluation
124
+ assert input_content.multimodal_prompt is not None
125
+ completions: List[GeneratedOutput] = request_result.completions
126
+ generated_text: str = completions[0].text
127
+ input_media: MultimediaObject = input_content.multimodal_prompt
128
+ ref_text: str = request_state.instance.references[0].output.text
129
+ image_objects: List[MediaObject] = [
130
+ item for item in input_media.media_objects if item.is_type(IMAGE_TYPE) and item.location
131
+ ]
132
+ input_text: Optional[str] = [item for item in input_media.media_objects if item.is_type(TEXT_TYPE)][0].text
133
+
134
+ template = CritiqueTaskTemplate(
135
+ name="vhelm_prometheus_vision",
136
+ instructions=self.METRIC_PROMPT,
137
+ num_respondents=self._num_respondents,
138
+ max_tokens=self._max_tokens,
139
+ questions=[
140
+ CritiqueQuestionTemplate(
141
+ name=self.METRIC_NAME,
142
+ question_type=QuestionType.FREE_RESPONSE,
143
+ text="",
144
+ options=[],
145
+ media_object=image_objects[0], # we only take the first image as input
146
+ )
147
+ ],
148
+ )
149
+ request = CritiqueRequest(
150
+ template=template,
151
+ fields={
152
+ "orig_instruction": input_text if input_text is not None else "",
153
+ "orig_response": generated_text,
154
+ "orig_reference_answer": ref_text,
155
+ "orig_criteria": "similarity between the reponse and the reference.",
156
+ "orig_score1_description": "The model's responses do not follow the instructions provided.",
157
+ "orig_score2_description": "The resulting response follows the instructions, but the answer \
158
+ is completely wrong relative to the reference answer.",
159
+ "orig_score3_description": "The resulting response follows the instructions, but the answer is \
160
+ partially wrong relative to the reference answer.",
161
+ "orig_score4_description": "The resulting response follows the instructions, the overall answer \
162
+ is relatively perfect with only a very few errors.",
163
+ "orig_score5_description": "The overall answer is completely correct compared to the reference \
164
+ answer, and conforms to the instructions provided.",
165
+ },
166
+ )
167
+ # send to critique request
168
+ result = metric_service.make_critique_request(request)
169
+ if not result or not result.responses:
170
+ # Skip computing metrics if there aren't any responses yet
171
+ hlog("Waiting for responses to be generated.")
172
+ return []
173
+
174
+ stats: Dict[str, Stat] = {}
175
+ for question in template.questions:
176
+ stats[question.name] = Stat(MetricName(question.name))
177
+
178
+ for response in result.responses:
179
+ for answer_name, answer in response.answers.items():
180
+ assert isinstance(answer, str)
181
+ answer_value: float
182
+ answer_value = self._extract_score_from_prometheus_vision_output(answer)
183
+ stats[answer_name].add(answer_value)
184
+
185
+ return list(stats.values())
@@ -0,0 +1,158 @@
1
+ from typing import Dict, List, Optional
2
+ import re
3
+
4
+ from helm.benchmark.adaptation.request_state import RequestState
5
+ from helm.benchmark.adaptation.scenario_state import ScenarioState
6
+ from helm.benchmark.adaptation.adapter_spec import AdapterSpec
7
+ from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats, add_context
8
+ from helm.benchmark.metrics.metric_name import MetricContext, MetricName
9
+ from helm.benchmark.metrics.metric_service import MetricService
10
+ from helm.benchmark.metrics.statistic import Stat, merge_stat
11
+ from helm.common.critique_request import CritiqueTaskTemplate, CritiqueQuestionTemplate, CritiqueRequest, QuestionType
12
+ from helm.common.hierarchical_logger import hlog
13
+ from helm.common.request import RequestResult, GeneratedOutput
14
+ from helm.common.media_object import MultimediaObject, IMAGE_TYPE, TEXT_TYPE, MediaObject
15
+
16
+
17
+ class RekaVibeCritiqueMetric(MetricInterface):
18
+ """
19
+ Critique evaluation for evaluating the correctness of generated response given the image and
20
+ reference by Reka-vibe-eval.
21
+ """
22
+
23
+ # We can add more evaluation aspects here
24
+ VIBE_EVAL_NAME: str = "reka_vibe"
25
+ REKA_VIBE_PROMPT_WITH_IMAGE: str = """\
26
+ [Question]
27
+ {{prompt}}
28
+
29
+ [Assistant Response]
30
+ {{generation}}
31
+
32
+ [Ground Truth Response]
33
+ {{reference}}
34
+
35
+ [System]
36
+ Rate whether the assistant response correctly matches the ground truth, in regards to the image above.
37
+ The rating should be 1-5, where 1 is incorrect and 5 is correct.
38
+ Your response should be in the format:
39
+ Short Explanation: (explanation in only one sentence)
40
+ Rating: (int)"""
41
+
42
+ def __init__(self, num_respondents: int, max_tokens: int):
43
+ self._num_respondents = num_respondents
44
+ self._max_tokens = max_tokens
45
+
46
+ def __repr__(self) -> str:
47
+ return "RekaVibeCritiqueMetric()"
48
+
49
+ def _extract_score_from_reka_output(self, evaluator_response: str):
50
+ """
51
+ Extract the score from the evaluator response. Refer to the official Vibe-Eval implementation:
52
+ https://github.com/reka-ai/reka-vibe-eval/blob/3852d4712da172a7b85dddeffc4f9c3482a6f4c9/evaluate.py#L159-#L164
53
+ """
54
+ re_match = re.search(r"Rating:\s*([1-5])", evaluator_response)
55
+ if re_match is None:
56
+ hlog(f"Error parsing answer: {evaluator_response}. Skipping question (and so the respondent entirely)")
57
+ return None
58
+ return int(re_match.group(1))
59
+
60
+ def evaluate(
61
+ self,
62
+ scenario_state: ScenarioState,
63
+ metric_service: MetricService,
64
+ eval_cache_path: str,
65
+ parallelism: int,
66
+ ) -> MetricResult:
67
+ request_states: List[RequestState] = scenario_state.request_states
68
+
69
+ all_stats: Dict[MetricName, Stat] = {}
70
+ per_instance_stats: List[PerInstanceStats] = []
71
+ for request_state in request_states:
72
+ context = MetricContext.from_instance(request_state.instance)
73
+ stats_without_context = self.evaluate_generation(
74
+ scenario_state.adapter_spec,
75
+ request_state,
76
+ metric_service,
77
+ eval_cache_path,
78
+ )
79
+ stats = [add_context(stat_without_context, context) for stat_without_context in stats_without_context]
80
+ for stat in stats:
81
+ merge_stat(all_stats, stat)
82
+ assert request_state.instance.id is not None
83
+ per_instance_stats.append(
84
+ PerInstanceStats(
85
+ instance_id=request_state.instance.id,
86
+ perturbation=request_state.instance.perturbation,
87
+ train_trial_index=request_state.train_trial_index,
88
+ stats=stats,
89
+ )
90
+ )
91
+ return MetricResult(aggregated_stats=list(all_stats.values()), per_instance_stats=per_instance_stats)
92
+
93
+ def evaluate_generation(
94
+ self,
95
+ adapter_spec: AdapterSpec,
96
+ request_state: RequestState,
97
+ metric_service: MetricService,
98
+ eval_cache_path: str,
99
+ ) -> List[Stat]:
100
+ input_content = request_state.request
101
+ # Predicted outputs and their originality scores
102
+ assert request_state.result is not None
103
+ request_result: RequestResult = request_state.result
104
+ # Get input image and generated response for the originality evaluation
105
+ assert input_content.multimodal_prompt is not None
106
+ completions: List[GeneratedOutput] = request_result.completions
107
+ generated_text: str = completions[0].text
108
+ input_media: MultimediaObject = input_content.multimodal_prompt
109
+ ref_text: str = request_state.instance.references[0].output.text
110
+
111
+ image_objects: List[MediaObject] = [
112
+ item for item in input_media.media_objects if item.is_type(IMAGE_TYPE) and item.location
113
+ ]
114
+ input_text: Optional[str] = [item for item in input_media.media_objects if item.is_type(TEXT_TYPE)][0].text
115
+
116
+ template = CritiqueTaskTemplate(
117
+ name="vhelm_vibe_eval",
118
+ instructions=self.REKA_VIBE_PROMPT_WITH_IMAGE,
119
+ num_respondents=self._num_respondents,
120
+ max_tokens=self._max_tokens,
121
+ questions=[
122
+ CritiqueQuestionTemplate(
123
+ name=self.VIBE_EVAL_NAME,
124
+ question_type=QuestionType.FREE_RESPONSE,
125
+ text="",
126
+ options=[],
127
+ media_object=image_objects[0], # we only take the first image as input
128
+ )
129
+ ],
130
+ )
131
+
132
+ request = CritiqueRequest(
133
+ template=template,
134
+ fields={
135
+ "prompt": input_text if input_text is not None else "",
136
+ "generation": generated_text,
137
+ "reference": ref_text,
138
+ },
139
+ )
140
+
141
+ # send to critique request
142
+ result = metric_service.make_critique_request(request)
143
+ if not result or not result.responses:
144
+ # Skip computing metrics if there aren't any responses yet
145
+ hlog("Waiting for responses to be generated.")
146
+ return []
147
+ stats: Dict[str, Stat] = {}
148
+ for question in template.questions:
149
+ stats[question.name] = Stat(MetricName(question.name))
150
+
151
+ for response in result.responses:
152
+ for answer_name, answer in response.answers.items():
153
+ assert isinstance(answer, str)
154
+ answer_value: float
155
+ answer_value = self._extract_score_from_reka_output(answer)
156
+ stats[answer_name].add(answer_value)
157
+
158
+ return list(stats.values())
@@ -50,16 +50,20 @@ class UnitxtMetric(MetricInterface):
50
50
  for metric_name, metric_score in instance_results.items():
51
51
  if metric_name == "score" or metric_name == "score_name":
52
52
  continue
53
- instance_stats.append(
54
- Stat(
55
- MetricName(
56
- name=metric_name,
57
- split=instance.split,
58
- sub_split=instance.sub_split,
59
- perturbation=instance.perturbation,
60
- )
61
- ).add(metric_score)
53
+ stat = Stat(
54
+ MetricName(
55
+ name=metric_name,
56
+ split=instance.split,
57
+ sub_split=instance.sub_split,
58
+ perturbation=instance.perturbation,
59
+ )
62
60
  )
61
+ if isinstance(metric_score, list):
62
+ for metric_score_element in metric_score:
63
+ stat = stat.add(metric_score_element)
64
+ else:
65
+ stat = stat.add(metric_score)
66
+ instance_stats.append(stat)
63
67
  assert instance.id
64
68
  per_instance_stats.append(
65
69
  PerInstanceStats(
@@ -77,5 +81,11 @@ class UnitxtMetric(MetricInterface):
77
81
  for metric_name, metric_score in global_results.items():
78
82
  if metric_name == "score" or metric_name == "score_name":
79
83
  continue
80
- aggregated_stats.append(Stat(MetricName(name=metric_name)).add(metric_score))
84
+ stat = Stat(MetricName(name=metric_name))
85
+ if isinstance(metric_score, list):
86
+ for metric_score_element in metric_score:
87
+ stat = stat.add(metric_score_element)
88
+ else:
89
+ stat = stat.add(metric_score)
90
+ aggregated_stats.append(stat)
81
91
  return MetricResult(aggregated_stats=aggregated_stats, per_instance_stats=per_instance_stats)
@@ -280,6 +280,10 @@ def compute_emd_recursive(
280
280
  assert max_num_patches > 0
281
281
  assert 0 < weight_most_frequent_color <= 1
282
282
 
283
+ # Convert the images to RGB first. Some images have 4 channels (RGBA)
284
+ img1_PIL = img1_PIL.convert("RGB")
285
+ img2_PIL = img2_PIL.convert("RGB")
286
+
283
287
  # Resize the images so that there are not too many patches
284
288
  # Try to maintain the aspect ratio and resize to a multiple of the patch size
285
289
  num_patches = math.ceil(img1_PIL.size[0] / patch_size[0]) * math.ceil(img1_PIL.size[1] / patch_size[1])
@@ -28,7 +28,7 @@ from helm.benchmark.metrics.vision_language.image_utils import (
28
28
  pixel_similarity,
29
29
  sift_similarity,
30
30
  )
31
- from helm.benchmark.metrics.vision_language.emd_utils import compute_emd_recursive
31
+ from helm.benchmark.metrics.vision_language.emd_utils import compute_emd_recursive, get_most_frequent_color, to_gray
32
32
 
33
33
  try:
34
34
  from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
@@ -78,7 +78,8 @@ class AnnotatedImageMetrics(Metric):
78
78
 
79
79
  # Metric names
80
80
  COMPILE_METRIC: str = "compilation_success"
81
- EARTH_MOVER_SIMILARITY: str = "earth_mover_similarity"
81
+ EARTH_MOVER_SIMILARITY = "earth_mover_similarity"
82
+ BLOCK_EMD: str = "block_emd"
82
83
  PIXEL_SIMILARITY: str = "pixel_similarity"
83
84
  SIFT_SIMILARITY: str = "sift_similarity"
84
85
  LPIPS_SIMILARITY: str = "lpips_similarity"
@@ -106,7 +107,10 @@ class AnnotatedImageMetrics(Metric):
106
107
  metrics: List[AnnotatedMetric] = [
107
108
  AnnotatedMetric(self.PIXEL_SIMILARITY, pixel_similarity, "image_np_gray"),
108
109
  AnnotatedMetric(self.SIFT_SIMILARITY, sift_similarity, "image_np"),
109
- AnnotatedMetric(self.EARTH_MOVER_SIMILARITY, self.compute_emd_similarity_recursive, "image_PIL"),
110
+ AnnotatedMetric(self.BLOCK_EMD, self.compute_block_emd_raw, "image_PIL"), # Raw block-EMD
111
+ AnnotatedMetric(
112
+ self.EARTH_MOVER_SIMILARITY, self.ems, "image_PIL"
113
+ ), # Normalized block-EMD against black/white
110
114
  AnnotatedMetric(self.LPIPS_SIMILARITY, self.lpips_similarity, "image_PIL"),
111
115
  AnnotatedMetric(self.FID_SIMILARITY, self.fid_similarity, "image_PIL"),
112
116
  AnnotatedMetric(self.SSIM_SIMILARITY, self.compute_ssim, "image_np_gray"),
@@ -384,9 +388,15 @@ class AnnotatedImageMetrics(Metric):
384
388
  features1 = self._get_inception_features(img1_tensor)
385
389
  features2 = self._get_inception_features(img2_tensor)
386
390
 
387
- fid_score = self._calculate_fid(features1, features2)
388
- normalize_fid: float = np.exp(-fid_score * self.NORMALIZE_FID_FACTOR)
389
- return normalize_fid
391
+ # TODO: Justify the value of the constant here or remove this code to only keep the cosine similarity.
392
+ # fid_score = self._calculate_fid(features1, features2)
393
+ # normalize_fid: float = np.exp(-fid_score * self.NORMALIZE_FID_FACTOR)
394
+ # return normalize_fid
395
+
396
+ # Use the cosine similarity between the features as a proxy for FID
397
+ # Return a score between 0 and 1, where 1 is the most similar
398
+ score = 0.5 * (1 + np.dot(features1[0], features2[0]) / (np.linalg.norm(features1) * np.linalg.norm(features2)))
399
+ return score
390
400
 
391
401
  def compute_ssim(self, generated_image: np.ndarray, reference_image: np.ndarray) -> float:
392
402
  """Compute the Structural Similarity Index (SSIM) between the generated and reference images."""
@@ -407,7 +417,7 @@ class AnnotatedImageMetrics(Metric):
407
417
  result = _edit_similarity(completion_tokens, truncated_reference_tokens)
408
418
  return result
409
419
 
410
- def compute_emd_similarity_recursive(
420
+ def ems(
411
421
  self,
412
422
  pred_image: Image.Image,
413
423
  ref_image: Image.Image,
@@ -417,18 +427,31 @@ class AnnotatedImageMetrics(Metric):
417
427
  weight_most_frequent_color: float = 0.001,
418
428
  use_tqdm: bool = False,
419
429
  ):
420
- emd_value = compute_emd_recursive(
421
- pred_image,
422
- ref_image,
423
- threshold_most_frequent_color,
424
- patch_size,
425
- max_num_patches,
426
- weight_most_frequent_color,
427
- use_tqdm,
428
- )
430
+ """Same as compute_emd_similarity_recursive EXCEPT that
431
+ the normalization is against an image of the median color.
432
+ """
429
433
 
430
- def do_it():
431
- constant_image = Image.new("RGB", ref_image.size, (255, 255, 255)) # default color is white
434
+ def compute_numerator():
435
+ return self.compute_block_emd_raw_wrapper(
436
+ pred_image,
437
+ ref_image,
438
+ threshold_most_frequent_color,
439
+ patch_size,
440
+ max_num_patches,
441
+ weight_most_frequent_color,
442
+ use_tqdm,
443
+ )
444
+
445
+ def compute_denominator():
446
+ ref_img_np = np.array(ref_image)
447
+ (rgb_most_frequent_color, _) = get_most_frequent_color(ref_img_np)
448
+ grayscale_most_frequent_color = to_gray(rgb_most_frequent_color)[0]
449
+
450
+ # Most frequent color as base
451
+ if grayscale_most_frequent_color < 127:
452
+ constant_image = Image.new("RGB", ref_image.size, (255, 255, 255)) # Make it white
453
+ else:
454
+ constant_image = Image.new("RGB", ref_image.size, (0, 0, 0)) # Make it black
432
455
  value = compute_emd_recursive(
433
456
  constant_image,
434
457
  ref_image,
@@ -442,9 +465,69 @@ class AnnotatedImageMetrics(Metric):
442
465
 
443
466
  hash_dict = {
444
467
  "reference_image": str(AnnotatedImageMetrics.HASH_FUNC(ref_image, hash_size=self.HASH_LENGTH)),
468
+ "generated_image": str(AnnotatedImageMetrics.HASH_FUNC(pred_image, hash_size=self.HASH_LENGTH)),
469
+ }
470
+ cache_key_numerator = {"metric_name": f"intermediate_{self.BLOCK_EMD}", **hash_dict}
471
+ cache_key_denominator = {"metric_name": "intermediate_ems_extreme_denominator", **hash_dict}
472
+
473
+ assert self._cache is not None
474
+ emd_raw, _ = self._cache.get(cache_key_numerator, compute_numerator)
475
+ emd_base, _ = self._cache.get(cache_key_denominator, compute_denominator)
476
+
477
+ return 1.0 - emd_raw["value"] / emd_base["value"]
478
+
479
+ def compute_block_emd_raw(
480
+ self,
481
+ pred_image: Image.Image,
482
+ ref_image: Image.Image,
483
+ threshold_most_frequent_color: float = 0.5,
484
+ patch_size: Tuple[int, int] = (8, 8),
485
+ max_num_patches: int = 100,
486
+ weight_most_frequent_color: float = 0.001,
487
+ use_tqdm: bool = False,
488
+ ):
489
+ def compute():
490
+ return self.compute_block_emd_raw_wrapper(
491
+ pred_image,
492
+ ref_image,
493
+ threshold_most_frequent_color,
494
+ patch_size,
495
+ max_num_patches,
496
+ weight_most_frequent_color,
497
+ use_tqdm,
498
+ )
499
+
500
+ hash_dict = {
501
+ "reference_image": str(AnnotatedImageMetrics.HASH_FUNC(ref_image, hash_size=self.HASH_LENGTH)),
502
+ "generated_image": str(AnnotatedImageMetrics.HASH_FUNC(pred_image, hash_size=self.HASH_LENGTH)),
445
503
  }
446
- cache_key = {"metric_name": f"intermediate_{self.EARTH_MOVER_SIMILARITY}", **hash_dict}
504
+ cache_key = {"metric_name": f"intermediate_{self.BLOCK_EMD}", **hash_dict}
447
505
  assert self._cache is not None
448
- response_metric, _ = self._cache.get(cache_key, do_it)
506
+ emd_raw, _ = self._cache.get(cache_key, compute)
449
507
 
450
- return 1.0 - emd_value / response_metric["value"]
508
+ return emd_raw["value"]
509
+
510
+ def compute_block_emd_raw_wrapper(
511
+ self,
512
+ pred_image: Image.Image,
513
+ ref_image: Image.Image,
514
+ threshold_most_frequent_color: float = 0.5,
515
+ patch_size: Tuple[int, int] = (8, 8),
516
+ max_num_patches: int = 100,
517
+ weight_most_frequent_color: float = 0.001,
518
+ use_tqdm: bool = False,
519
+ ):
520
+ """Computes the block Earth Moving Distance (EMD). This attempts to
521
+ speed up EMD for images with huge areas by considering
522
+ movement/transformation of blocks of pixels.
523
+ """
524
+ emd_value = compute_emd_recursive(
525
+ pred_image,
526
+ ref_image,
527
+ threshold_most_frequent_color,
528
+ patch_size,
529
+ max_num_patches,
530
+ weight_most_frequent_color,
531
+ use_tqdm,
532
+ )
533
+ return {"value": emd_value}
@@ -32,6 +32,7 @@ ANTHROPIC_CLAUDE_3_MODEL_TAG: str = "ANTHROPIC_CLAUDE_3_MODEL_TAG"
32
32
 
33
33
  GOOGLE_PALM_2_MODEL_TAG: str = "GOOGLE_PALM_2_MODEL_TAG"
34
34
  GOOGLE_GEMINI_MODEL_TAG: str = "GOOGLE_GEMINI_MODEL_TAG"
35
+ GOOGLE_GEMINI_PRO_VISION_V1_TAG: str = "GOOGLE_GEMINI_PRO_VISION_V1_TAG"
35
36
  GOOGLE_GEMMA_INSTRUCT_MODEL_TAG: str = "GOOGLE_GEMMA_INSTRUCT_MODEL_TAG"
36
37
 
37
38
  # Models which emit garbage tokens when temperature=0.
@@ -159,7 +160,10 @@ def register_model_metadata(model_metadata: ModelMetadata) -> None:
159
160
  def get_model_metadata(model_name: str) -> ModelMetadata:
160
161
  """Return the `ModelMetadata` for the model name."""
161
162
  if model_name not in MODEL_NAME_TO_MODEL_METADATA:
162
- raise ValueError(f"No model with name: {model_name}")
163
+ raise ValueError(
164
+ f"No model metadata for model name: {model_name} - "
165
+ "did you remember to add this model to model_metadata.yaml?"
166
+ )
163
167
 
164
168
  return MODEL_NAME_TO_MODEL_METADATA[model_name]
165
169