crfm-helm 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of crfm-helm might be problematic. Click here for more details.
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/METADATA +81 -112
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/RECORD +165 -155
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/WHEEL +1 -1
- helm/benchmark/adaptation/adapters/multiple_choice_joint_adapter.py +12 -5
- helm/benchmark/adaptation/adapters/test_generation_adapter.py +12 -12
- helm/benchmark/adaptation/adapters/test_language_modeling_adapter.py +8 -8
- helm/benchmark/adaptation/adapters/test_multiple_choice_joint_adapter.py +77 -9
- helm/benchmark/adaptation/common_adapter_specs.py +2 -0
- helm/benchmark/annotation/anthropic_red_team_annotator.py +57 -0
- helm/benchmark/annotation/call_center_annotator.py +258 -0
- helm/benchmark/annotation/financebench_annotator.py +79 -0
- helm/benchmark/annotation/harm_bench_annotator.py +55 -0
- helm/benchmark/annotation/{image2structure → image2struct}/latex_compiler_annotator.py +2 -2
- helm/benchmark/annotation/{image2structure → image2struct}/lilypond_compiler_annotator.py +5 -3
- helm/benchmark/annotation/{image2structure → image2struct}/webpage_compiler_annotator.py +5 -5
- helm/benchmark/annotation/live_qa_annotator.py +37 -45
- helm/benchmark/annotation/medication_qa_annotator.py +36 -44
- helm/benchmark/annotation/model_as_judge.py +96 -0
- helm/benchmark/annotation/simple_safety_tests_annotator.py +50 -0
- helm/benchmark/annotation/xstest_annotator.py +100 -0
- helm/benchmark/metrics/annotation_metrics.py +108 -0
- helm/benchmark/metrics/bhasa_metrics.py +188 -0
- helm/benchmark/metrics/bhasa_metrics_specs.py +10 -0
- helm/benchmark/metrics/code_metrics_helper.py +11 -1
- helm/benchmark/metrics/safety_metrics.py +79 -0
- helm/benchmark/metrics/summac/model_summac.py +3 -3
- helm/benchmark/metrics/tokens/test_ai21_token_cost_estimator.py +2 -2
- helm/benchmark/metrics/tokens/test_openai_token_cost_estimator.py +4 -4
- helm/benchmark/metrics/unitxt_metrics.py +17 -3
- helm/benchmark/metrics/vision_language/image_metrics.py +7 -3
- helm/benchmark/metrics/vision_language/image_utils.py +1 -1
- helm/benchmark/model_metadata_registry.py +3 -3
- helm/benchmark/presentation/create_plots.py +1 -1
- helm/benchmark/presentation/schema.py +3 -0
- helm/benchmark/presentation/summarize.py +106 -256
- helm/benchmark/presentation/test_run_entry.py +1 -0
- helm/benchmark/presentation/test_summarize.py +145 -3
- helm/benchmark/run.py +15 -0
- helm/benchmark/run_expander.py +83 -30
- helm/benchmark/run_specs/bhasa_run_specs.py +652 -0
- helm/benchmark/run_specs/call_center_run_specs.py +152 -0
- helm/benchmark/run_specs/decodingtrust_run_specs.py +8 -8
- helm/benchmark/run_specs/experimental_run_specs.py +52 -0
- helm/benchmark/run_specs/finance_run_specs.py +82 -1
- helm/benchmark/run_specs/safety_run_specs.py +154 -0
- helm/benchmark/run_specs/vlm_run_specs.py +100 -24
- helm/benchmark/scenarios/anthropic_red_team_scenario.py +71 -0
- helm/benchmark/scenarios/banking77_scenario.py +51 -0
- helm/benchmark/scenarios/bhasa_scenario.py +1942 -0
- helm/benchmark/scenarios/call_center_scenario.py +84 -0
- helm/benchmark/scenarios/decodingtrust_stereotype_bias_scenario.py +2 -1
- helm/benchmark/scenarios/ewok_scenario.py +116 -0
- helm/benchmark/scenarios/fin_qa_scenario.py +2 -0
- helm/benchmark/scenarios/financebench_scenario.py +53 -0
- helm/benchmark/scenarios/harm_bench_scenario.py +59 -0
- helm/benchmark/scenarios/raft_scenario.py +1 -1
- helm/benchmark/scenarios/scenario.py +1 -1
- helm/benchmark/scenarios/simple_safety_tests_scenario.py +33 -0
- helm/benchmark/scenarios/test_commonsense_scenario.py +21 -0
- helm/benchmark/scenarios/test_ewok_scenario.py +25 -0
- helm/benchmark/scenarios/test_financebench_scenario.py +26 -0
- helm/benchmark/scenarios/test_gsm_scenario.py +31 -0
- helm/benchmark/scenarios/test_legalbench_scenario.py +30 -0
- helm/benchmark/scenarios/test_math_scenario.py +2 -8
- helm/benchmark/scenarios/test_med_qa_scenario.py +30 -0
- helm/benchmark/scenarios/test_mmlu_scenario.py +33 -0
- helm/benchmark/scenarios/test_narrativeqa_scenario.py +73 -0
- helm/benchmark/scenarios/thai_exam_scenario.py +4 -4
- helm/benchmark/scenarios/vision_language/a_okvqa_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/bingo_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/crossmodal_3600_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/exams_v_scenario.py +104 -0
- helm/benchmark/scenarios/vision_language/fair_face_scenario.py +136 -0
- helm/benchmark/scenarios/vision_language/flickr30k_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/gqa_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/hateful_memes_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/chart2csv_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/latex_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/musicsheet_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/utils_latex.py +31 -39
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/driver.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/utils.py +1 -1
- helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage_scenario.py +41 -12
- helm/benchmark/scenarios/vision_language/math_vista_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/mementos_scenario.py +3 -3
- helm/benchmark/scenarios/vision_language/mm_safety_bench_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/mme_scenario.py +21 -18
- helm/benchmark/scenarios/vision_language/mmmu_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/pairs_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/pope_scenario.py +2 -1
- helm/benchmark/scenarios/vision_language/real_world_qa_scenario.py +57 -0
- helm/benchmark/scenarios/vision_language/seed_bench_scenario.py +7 -5
- helm/benchmark/scenarios/vision_language/unicorn_scenario.py +2 -2
- helm/benchmark/scenarios/vision_language/vibe_eval_scenario.py +6 -3
- helm/benchmark/scenarios/vision_language/viz_wiz_scenario.py +1 -1
- helm/benchmark/scenarios/vision_language/vqa_scenario.py +3 -1
- helm/benchmark/scenarios/xstest_scenario.py +35 -0
- helm/benchmark/server.py +1 -6
- helm/benchmark/static/schema_air_bench.yaml +750 -750
- helm/benchmark/static/schema_bhasa.yaml +709 -0
- helm/benchmark/static/schema_call_center.yaml +232 -0
- helm/benchmark/static/schema_cleva.yaml +768 -0
- helm/benchmark/static/schema_decodingtrust.yaml +444 -0
- helm/benchmark/static/schema_ewok.yaml +367 -0
- helm/benchmark/static/schema_finance.yaml +55 -9
- helm/benchmark/static/{schema_image2structure.yaml → schema_image2struct.yaml} +231 -90
- helm/benchmark/static/schema_legal.yaml +566 -0
- helm/benchmark/static/schema_safety.yaml +266 -0
- helm/benchmark/static/schema_tables.yaml +149 -8
- helm/benchmark/static/schema_thai.yaml +21 -0
- helm/benchmark/static/schema_vhelm.yaml +137 -101
- helm/benchmark/static_build/assets/accenture-6f97eeda.png +0 -0
- helm/benchmark/static_build/assets/aisingapore-6dfc9acf.png +0 -0
- helm/benchmark/static_build/assets/cresta-9e22b983.png +0 -0
- helm/benchmark/static_build/assets/cuhk-8c5631e9.png +0 -0
- helm/benchmark/static_build/assets/index-05c76bb1.css +1 -0
- helm/benchmark/static_build/assets/index-3ee38b3d.js +10 -0
- helm/benchmark/static_build/assets/scb10x-204bd786.png +0 -0
- helm/benchmark/static_build/assets/vhelm-aspects-1437d673.png +0 -0
- helm/benchmark/static_build/assets/vhelm-framework-a1ca3f3f.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-8afb7616.png +0 -0
- helm/benchmark/static_build/assets/wellsfargo-a86a6c4a.png +0 -0
- helm/benchmark/static_build/index.html +2 -2
- helm/benchmark/window_services/test_openai_window_service.py +8 -8
- helm/benchmark/window_services/tokenizer_service.py +0 -5
- helm/clients/ai21_client.py +71 -1
- helm/clients/anthropic_client.py +7 -19
- helm/clients/huggingface_client.py +38 -37
- helm/clients/nvidia_nim_client.py +35 -0
- helm/clients/openai_client.py +18 -4
- helm/clients/palmyra_client.py +24 -0
- helm/clients/perspective_api_client.py +11 -6
- helm/clients/test_client.py +4 -6
- helm/clients/together_client.py +22 -0
- helm/clients/vision_language/open_flamingo_client.py +1 -2
- helm/clients/vision_language/palmyra_vision_client.py +28 -13
- helm/common/cache.py +8 -30
- helm/common/images_utils.py +6 -0
- helm/common/key_value_store.py +9 -9
- helm/common/mongo_key_value_store.py +5 -4
- helm/common/request.py +16 -0
- helm/common/test_cache.py +1 -48
- helm/common/tokenization_request.py +0 -9
- helm/config/model_deployments.yaml +444 -329
- helm/config/model_metadata.yaml +513 -111
- helm/config/tokenizer_configs.yaml +140 -11
- helm/proxy/example_queries.py +14 -21
- helm/proxy/server.py +0 -9
- helm/proxy/services/remote_service.py +0 -6
- helm/proxy/services/server_service.py +6 -20
- helm/proxy/services/service.py +0 -6
- helm/proxy/token_counters/test_auto_token_counter.py +2 -2
- helm/tokenizers/ai21_tokenizer.py +51 -59
- helm/tokenizers/cohere_tokenizer.py +0 -75
- helm/tokenizers/huggingface_tokenizer.py +0 -1
- helm/tokenizers/test_ai21_tokenizer.py +48 -0
- helm/benchmark/data_overlap/data_overlap_spec.py +0 -86
- helm/benchmark/data_overlap/export_scenario_text.py +0 -119
- helm/benchmark/data_overlap/light_scenario.py +0 -60
- helm/benchmark/scenarios/vision_language/image2structure/webpage/__init__.py +0 -0
- helm/benchmark/static/benchmarking.css +0 -156
- helm/benchmark/static/benchmarking.js +0 -1705
- helm/benchmark/static/config.js +0 -3
- helm/benchmark/static/general.js +0 -122
- helm/benchmark/static/images/crfm-logo.png +0 -0
- helm/benchmark/static/images/helm-logo-simple.png +0 -0
- helm/benchmark/static/images/helm-logo.png +0 -0
- helm/benchmark/static/images/language-model-helm.png +0 -0
- helm/benchmark/static/images/organizations/ai21.png +0 -0
- helm/benchmark/static/images/organizations/anthropic.png +0 -0
- helm/benchmark/static/images/organizations/bigscience.png +0 -0
- helm/benchmark/static/images/organizations/cohere.png +0 -0
- helm/benchmark/static/images/organizations/eleutherai.png +0 -0
- helm/benchmark/static/images/organizations/google.png +0 -0
- helm/benchmark/static/images/organizations/meta.png +0 -0
- helm/benchmark/static/images/organizations/microsoft.png +0 -0
- helm/benchmark/static/images/organizations/nvidia.png +0 -0
- helm/benchmark/static/images/organizations/openai.png +0 -0
- helm/benchmark/static/images/organizations/together.png +0 -0
- helm/benchmark/static/images/organizations/tsinghua-keg.png +0 -0
- helm/benchmark/static/images/organizations/yandex.png +0 -0
- helm/benchmark/static/images/scenarios-by-metrics.png +0 -0
- helm/benchmark/static/images/taxonomy-scenarios.png +0 -0
- helm/benchmark/static/index.html +0 -68
- helm/benchmark/static/info-icon.png +0 -0
- helm/benchmark/static/json-urls.js +0 -69
- helm/benchmark/static/plot-captions.js +0 -27
- helm/benchmark/static/utils.js +0 -285
- helm/benchmark/static_build/assets/index-30dbceba.js +0 -10
- helm/benchmark/static_build/assets/index-66b02d40.css +0 -1
- helm/benchmark/static_build/assets/vhelm-framework-cde7618a.png +0 -0
- helm/benchmark/static_build/assets/vhelm-model-6d812526.png +0 -0
- helm/benchmark/window_services/ai21_window_service.py +0 -247
- helm/benchmark/window_services/cohere_window_service.py +0 -101
- helm/benchmark/window_services/test_ai21_window_service.py +0 -163
- helm/benchmark/window_services/test_cohere_window_service.py +0 -75
- helm/benchmark/window_services/test_cohere_window_service_utils.py +0 -8328
- helm/benchmark/window_services/test_ice_window_service.py +0 -327
- helm/tokenizers/ice_tokenizer.py +0 -30
- helm/tokenizers/test_ice_tokenizer.py +0 -57
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/LICENSE +0 -0
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/entry_points.txt +0 -0
- {crfm_helm-0.5.2.dist-info → crfm_helm-0.5.4.dist-info}/top_level.txt +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/__init__.py +0 -0
- /helm/benchmark/annotation/{image2structure → image2struct}/image_compiler_annotator.py +0 -0
- /helm/benchmark/{data_overlap → scenarios/vision_language/image2struct}/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure/image2structure_scenario.py → image2struct/image2struct_scenario.py} +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct/webpage}/__init__.py +0 -0
- /helm/benchmark/scenarios/vision_language/{image2structure → image2struct}/webpage/jekyll_server.py +0 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
4
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
5
|
+
from helm.benchmark.metrics.metric import Metric
|
|
6
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
7
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
8
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class AnnotationLabelMetric(Metric):
|
|
12
|
+
"""Binary metric for labels produced by annotators.
|
|
13
|
+
|
|
14
|
+
Expects the annotation with the given annotator name and key to be a string label.
|
|
15
|
+
|
|
16
|
+
For each possible label in the list of possible labels, produces a
|
|
17
|
+
corresponding stat with a value of 1 or 0 indicating if the actual label
|
|
18
|
+
in the annoation."""
|
|
19
|
+
|
|
20
|
+
def __init__(self, annotator_name: str, key: str, labels: List[str]):
|
|
21
|
+
super().__init__()
|
|
22
|
+
self.annotator_name = annotator_name
|
|
23
|
+
self.key = key
|
|
24
|
+
self.labels = labels
|
|
25
|
+
|
|
26
|
+
def evaluate_generation(
|
|
27
|
+
self,
|
|
28
|
+
adapter_spec: AdapterSpec,
|
|
29
|
+
request_state: RequestState,
|
|
30
|
+
metric_service: MetricService,
|
|
31
|
+
eval_cache_path: str,
|
|
32
|
+
) -> List[Stat]:
|
|
33
|
+
assert request_state.annotations
|
|
34
|
+
annotation_label = request_state.annotations[self.annotator_name][self.key]
|
|
35
|
+
if annotation_label not in self.labels:
|
|
36
|
+
raise ValueError(
|
|
37
|
+
f"Unrecognized annotation label '{annotation_label}' "
|
|
38
|
+
f"(known labels: {self.labels}) "
|
|
39
|
+
f"in annotation {request_state.annotations[self.annotator_name]} "
|
|
40
|
+
f"for instance id {request_state.instance.id}"
|
|
41
|
+
)
|
|
42
|
+
stats: List[Stat] = []
|
|
43
|
+
for label in self.labels:
|
|
44
|
+
stats.append(
|
|
45
|
+
Stat(MetricName(f"annotation_{self.annotator_name}_{self.key}_{label}")).add(
|
|
46
|
+
1 if label == annotation_label else 0
|
|
47
|
+
)
|
|
48
|
+
)
|
|
49
|
+
return stats
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
class AnnotationNumericMetric(Metric):
|
|
53
|
+
"""Numeric metric for numbers produced by annotators.
|
|
54
|
+
|
|
55
|
+
Expects the annotation with the given annotator name and key to be a number."""
|
|
56
|
+
|
|
57
|
+
def __init__(self, annotator_name: str, key: str):
|
|
58
|
+
super().__init__()
|
|
59
|
+
self.annotator_name = annotator_name
|
|
60
|
+
self.key = key
|
|
61
|
+
|
|
62
|
+
def evaluate_generation(
|
|
63
|
+
self,
|
|
64
|
+
adapter_spec: AdapterSpec,
|
|
65
|
+
request_state: RequestState,
|
|
66
|
+
metric_service: MetricService,
|
|
67
|
+
eval_cache_path: str,
|
|
68
|
+
) -> List[Stat]:
|
|
69
|
+
assert request_state.annotations
|
|
70
|
+
score = request_state.annotations[self.annotator_name][self.key]
|
|
71
|
+
return [Stat(MetricName(f"annotation_{self.annotator_name}_{self.key}")).add(score)]
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
class AnnotationLikertScaleMetric(Metric):
|
|
75
|
+
"""Numeric metric for labels produced by annotators.
|
|
76
|
+
|
|
77
|
+
Expects the annotation with the given annotator name and key to be a string label.
|
|
78
|
+
|
|
79
|
+
For each possible label in the list of possible labels, produces a
|
|
80
|
+
corresponding stat with a value of 1 or 0 indicating if the actual label
|
|
81
|
+
in the annoation."""
|
|
82
|
+
|
|
83
|
+
def __init__(self, annotator_name: str, key: str, min_score: int, max_score: int):
|
|
84
|
+
super().__init__()
|
|
85
|
+
self.annotator_name = annotator_name
|
|
86
|
+
self.key = key
|
|
87
|
+
self.min_score = min_score
|
|
88
|
+
self.max_score = max_score
|
|
89
|
+
|
|
90
|
+
def evaluate_generation(
|
|
91
|
+
self,
|
|
92
|
+
adapter_spec: AdapterSpec,
|
|
93
|
+
request_state: RequestState,
|
|
94
|
+
metric_service: MetricService,
|
|
95
|
+
eval_cache_path: str,
|
|
96
|
+
) -> List[Stat]:
|
|
97
|
+
assert request_state.annotations
|
|
98
|
+
likert_score = request_state.annotations[self.annotator_name][self.key]
|
|
99
|
+
if likert_score < self.min_score or likert_score > self.max_score:
|
|
100
|
+
raise ValueError(
|
|
101
|
+
f"Likert score {likert_score} "
|
|
102
|
+
f"out of bounds {self.min_score} to {self.max_score} "
|
|
103
|
+
f"under key {self.key} and annotator {self.annotator_name} "
|
|
104
|
+
f"in annotation {request_state.annotations[self.annotator_name]} "
|
|
105
|
+
f"for instance id {request_state.instance.id}"
|
|
106
|
+
)
|
|
107
|
+
normalized_score = (likert_score - self.min_score) / (self.max_score - self.min_score)
|
|
108
|
+
return [Stat(MetricName(f"annotation_{self.annotator_name}_{self.key}")).add(normalized_score)]
|
|
@@ -0,0 +1,188 @@
|
|
|
1
|
+
import re
|
|
2
|
+
import string
|
|
3
|
+
from typing import Callable, Dict, List
|
|
4
|
+
from collections import Counter
|
|
5
|
+
|
|
6
|
+
from pythainlp.tokenize import word_tokenize
|
|
7
|
+
from sacrebleu.metrics import CHRF
|
|
8
|
+
|
|
9
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
10
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
11
|
+
from helm.benchmark.metrics.metric import Metric
|
|
12
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
13
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
14
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class BhasaMachineTranslationMetric(Metric):
|
|
18
|
+
"""Machine Translation Metrics
|
|
19
|
+
|
|
20
|
+
This class computes the following standard machine translation metrics
|
|
21
|
+
|
|
22
|
+
1. chr_f_plus_plus (ChrF++)
|
|
23
|
+
|
|
24
|
+
@inproceedings{popovic-2015-chrf,
|
|
25
|
+
title = "chr{F}: character n-gram {F}-score for automatic {MT} evaluation",
|
|
26
|
+
author = "Popovi{\'c}, Maja",
|
|
27
|
+
editor = "Bojar, Ond{\v{r}}ej and
|
|
28
|
+
Chatterjee, Rajan and
|
|
29
|
+
Federmann, Christian and
|
|
30
|
+
Haddow, Barry and
|
|
31
|
+
Hokamp, Chris and
|
|
32
|
+
Huck, Matthias and
|
|
33
|
+
Logacheva, Varvara and
|
|
34
|
+
Pecina, Pavel",
|
|
35
|
+
booktitle = "Proceedings of the Tenth Workshop on Statistical Machine Translation",
|
|
36
|
+
month = sep,
|
|
37
|
+
year = "2015",
|
|
38
|
+
address = "Lisbon, Portugal",
|
|
39
|
+
publisher = "Association for Computational Linguistics",
|
|
40
|
+
url = "https://aclanthology.org/W15-3049",
|
|
41
|
+
doi = "10.18653/v1/W15-3049",
|
|
42
|
+
pages = "392--395",
|
|
43
|
+
github = "https://github.com/mjpost/sacrebleu",
|
|
44
|
+
}
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self):
|
|
48
|
+
self.chrf_scorer = CHRF(word_order=2)
|
|
49
|
+
|
|
50
|
+
def chr_f_plus_plus(self, refs: List[str], pred: str) -> Dict[str, float]:
|
|
51
|
+
metrics: Dict[str, float] = {}
|
|
52
|
+
metrics["chr_f_plus_plus"] = self.chrf_scorer.sentence_score(pred, refs).score
|
|
53
|
+
return metrics
|
|
54
|
+
|
|
55
|
+
def evaluate_generation(
|
|
56
|
+
self,
|
|
57
|
+
adapter_spec: AdapterSpec,
|
|
58
|
+
request_state: RequestState,
|
|
59
|
+
metric_service: MetricService,
|
|
60
|
+
eval_cache_path: str,
|
|
61
|
+
) -> List[Stat]:
|
|
62
|
+
refs: List[str] = [ref.output.text for ref in request_state.instance.references]
|
|
63
|
+
|
|
64
|
+
assert request_state.result is not None
|
|
65
|
+
pred: str = request_state.result.completions[0].text.strip()
|
|
66
|
+
|
|
67
|
+
result: List[Stat] = []
|
|
68
|
+
|
|
69
|
+
# Compute ChrF++ metrics
|
|
70
|
+
result.extend(
|
|
71
|
+
[Stat(MetricName(name)).add(float(val)) for name, val in self.chr_f_plus_plus(refs, pred).items()]
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
return result
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
class BhasaQAMetric(Metric):
|
|
78
|
+
"""Bhasa QA Metrics
|
|
79
|
+
|
|
80
|
+
This class computes the following standard SQuAD v1.1 metrics
|
|
81
|
+
|
|
82
|
+
1. squad_exact_match_score (SQuAD exact match score)
|
|
83
|
+
2. squad_f1_score (SQuAD macro-averaged F1 score)
|
|
84
|
+
|
|
85
|
+
@inproceedings{rajpurkar-etal-2016-squad,
|
|
86
|
+
title = "{SQ}u{AD}: 100,000+ Questions for Machine Comprehension of Text",
|
|
87
|
+
author = "Rajpurkar, Pranav and
|
|
88
|
+
Zhang, Jian and
|
|
89
|
+
Lopyrev, Konstantin and
|
|
90
|
+
Liang, Percy",
|
|
91
|
+
editor = "Su, Jian and
|
|
92
|
+
Duh, Kevin and
|
|
93
|
+
Carreras, Xavier",
|
|
94
|
+
booktitle = "Proceedings of the 2016 Conference on Empirical Methods in Natural Language Processing",
|
|
95
|
+
month = nov,
|
|
96
|
+
year = "2016",
|
|
97
|
+
address = "Austin, Texas",
|
|
98
|
+
publisher = "Association for Computational Linguistics",
|
|
99
|
+
url = "https://aclanthology.org/D16-1264",
|
|
100
|
+
doi = "10.18653/v1/D16-1264",
|
|
101
|
+
pages = "2383--2392",
|
|
102
|
+
}
|
|
103
|
+
"""
|
|
104
|
+
|
|
105
|
+
def __init__(self, language: str = "en"):
|
|
106
|
+
self.language: str = language
|
|
107
|
+
self.metrics: Dict[str, Callable[[str, str], float]] = {
|
|
108
|
+
"squad_exact_match_score": self.squad_exact_match_score,
|
|
109
|
+
"squad_f1_score": self.squad_f1_score,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
def normalize_answer(self, text: str) -> List[str]:
|
|
113
|
+
"""
|
|
114
|
+
For Thai, this will split the text using PyThaiNLP's tokenizer.
|
|
115
|
+
For all other languages, this will:
|
|
116
|
+
- Lower text
|
|
117
|
+
- Remove punctuation
|
|
118
|
+
- Remove extra whitespace
|
|
119
|
+
|
|
120
|
+
If the language is English, it will
|
|
121
|
+
- Remove articles "a", "an", and "the"
|
|
122
|
+
|
|
123
|
+
Modifies code from [SQuAD v1.1](https://github.com/allenai/bi-att-flow/blob/master/squad/evaluate-v1.1.py).
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def remove_articles(text: str) -> str:
|
|
127
|
+
return re.sub(r"\b(a|an|the)\b", " ", text)
|
|
128
|
+
|
|
129
|
+
# This function is implemented to match SQuAD v1.1 behavior
|
|
130
|
+
def white_space_fix(text: str) -> str:
|
|
131
|
+
return " ".join(text.split())
|
|
132
|
+
|
|
133
|
+
def remove_punc(text: str) -> str:
|
|
134
|
+
exclude = set(string.punctuation)
|
|
135
|
+
return "".join(ch for ch in text if ch not in exclude)
|
|
136
|
+
|
|
137
|
+
def lower(text: str) -> str:
|
|
138
|
+
return text.lower()
|
|
139
|
+
|
|
140
|
+
normalized_text = remove_punc(lower(text))
|
|
141
|
+
if self.language == "th":
|
|
142
|
+
return word_tokenize(normalized_text, engine="newmm")
|
|
143
|
+
elif self.language == "en":
|
|
144
|
+
return white_space_fix(remove_articles(normalized_text)).split()
|
|
145
|
+
else:
|
|
146
|
+
return white_space_fix(normalized_text).split()
|
|
147
|
+
|
|
148
|
+
def squad_f1_score(self, gold: str, pred: str) -> float:
|
|
149
|
+
prediction_tokens = self.normalize_answer(pred)
|
|
150
|
+
ground_truth_tokens = self.normalize_answer(gold)
|
|
151
|
+
common = Counter(prediction_tokens) & Counter(ground_truth_tokens)
|
|
152
|
+
num_same = sum(common.values())
|
|
153
|
+
if num_same == 0:
|
|
154
|
+
return 0
|
|
155
|
+
precision = 1.0 * num_same / len(prediction_tokens)
|
|
156
|
+
recall = 1.0 * num_same / len(ground_truth_tokens)
|
|
157
|
+
f1 = (2 * precision * recall) / (precision + recall)
|
|
158
|
+
return f1
|
|
159
|
+
|
|
160
|
+
def squad_exact_match_score(self, gold: str, pred: str) -> float:
|
|
161
|
+
return self.normalize_answer(pred) == self.normalize_answer(gold)
|
|
162
|
+
|
|
163
|
+
def evaluate_generation(
|
|
164
|
+
self,
|
|
165
|
+
adapter_spec: AdapterSpec,
|
|
166
|
+
request_state: RequestState,
|
|
167
|
+
metric_service: MetricService,
|
|
168
|
+
eval_cache_path: str,
|
|
169
|
+
) -> List[Stat]:
|
|
170
|
+
|
|
171
|
+
stats: List[Stat] = []
|
|
172
|
+
if len(request_state.instance.references) > 0:
|
|
173
|
+
golds = [reference for reference in request_state.instance.references if reference.is_correct]
|
|
174
|
+
assert len(golds) > 0
|
|
175
|
+
|
|
176
|
+
assert request_state.result is not None
|
|
177
|
+
sorted_completions = sorted(request_state.result.completions, key=lambda x: -x.logprob)
|
|
178
|
+
preds = [completion.text.strip() for completion in sorted_completions]
|
|
179
|
+
|
|
180
|
+
for name, metric in self.metrics.items():
|
|
181
|
+
score_1 = max(metric(gold.output.text.strip(), preds[0]) for gold in golds)
|
|
182
|
+
metrics = [Stat(MetricName(name)).add(score_1)]
|
|
183
|
+
if adapter_spec.num_outputs != 1:
|
|
184
|
+
score_k = max(metric(gold.output.text.strip(), pred) for gold in golds for pred in preds)
|
|
185
|
+
metrics.append(Stat(MetricName(f"{name}@{adapter_spec.num_outputs}")).add(score_k))
|
|
186
|
+
stats.extend(metrics)
|
|
187
|
+
|
|
188
|
+
return stats
|
|
@@ -0,0 +1,10 @@
|
|
|
1
|
+
from typing import Any, Dict, List
|
|
2
|
+
from helm.benchmark.metrics.metric import MetricSpec
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def get_bhasa_machine_translation_metric_specs() -> List[MetricSpec]:
|
|
6
|
+
return [MetricSpec(class_name="helm.benchmark.metrics.bhasa_metrics.BhasaMachineTranslationMetric")]
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def get_bhasa_qa_metric_specs(args: Dict[str, Any]) -> List[MetricSpec]:
|
|
10
|
+
return [MetricSpec(class_name="helm.benchmark.metrics.bhasa_metrics.BhasaQAMetric", args=args)]
|
|
@@ -27,14 +27,24 @@ import signal
|
|
|
27
27
|
import sys
|
|
28
28
|
import tempfile
|
|
29
29
|
from typing import List, Union, Dict, Optional
|
|
30
|
+
from types import ModuleType
|
|
30
31
|
from unittest.mock import patch, mock_open
|
|
31
32
|
|
|
32
33
|
import numpy as np
|
|
33
|
-
from pyext import RuntimeModule
|
|
34
34
|
|
|
35
35
|
from helm.common.hierarchical_logger import hlog
|
|
36
36
|
|
|
37
37
|
|
|
38
|
+
class RuntimeModule(ModuleType):
|
|
39
|
+
"""crfm-helm's replacement for pyext.RuntimeModule, since pyext is not supported by Python >=3.11"""
|
|
40
|
+
|
|
41
|
+
@staticmethod
|
|
42
|
+
def from_string(module_name: str, module_doc: str, module_contents: str) -> "RuntimeModule":
|
|
43
|
+
module = RuntimeModule(module_name, module_doc)
|
|
44
|
+
exec(module_contents, module.__dict__)
|
|
45
|
+
return module
|
|
46
|
+
|
|
47
|
+
|
|
38
48
|
# === APPS evaluation below ===
|
|
39
49
|
class CodeType(Enum):
|
|
40
50
|
call_based = 0
|
|
@@ -0,0 +1,79 @@
|
|
|
1
|
+
import numbers
|
|
2
|
+
from typing import Any, Dict, List, cast
|
|
3
|
+
|
|
4
|
+
from helm.benchmark.adaptation.adapter_spec import AdapterSpec
|
|
5
|
+
from helm.benchmark.adaptation.request_state import RequestState
|
|
6
|
+
from helm.benchmark.metrics.basic_metrics import compute_request_state_metrics
|
|
7
|
+
from helm.benchmark.metrics.efficiency_metrics import EfficiencyMetric
|
|
8
|
+
from helm.benchmark.metrics.metric import Metric
|
|
9
|
+
from helm.benchmark.metrics.metric_name import MetricName
|
|
10
|
+
from helm.benchmark.metrics.metric_service import MetricService
|
|
11
|
+
from helm.benchmark.metrics.statistic import Stat
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class SafetyBasicGenerationMetric(Metric):
|
|
15
|
+
"""Replacement for BasicGenerationMetric for HELM Safety.
|
|
16
|
+
We call compute_request_state_metrics here because we can't use `BasicGenerationMetric`
|
|
17
|
+
because we abuse "references" to store metadata rather than true metadata."""
|
|
18
|
+
|
|
19
|
+
def __init__(self):
|
|
20
|
+
super().__init__()
|
|
21
|
+
self.efficiency_metric = EfficiencyMetric()
|
|
22
|
+
|
|
23
|
+
def evaluate_generation(
|
|
24
|
+
self,
|
|
25
|
+
adapter_spec: AdapterSpec,
|
|
26
|
+
request_state: RequestState,
|
|
27
|
+
metric_service: MetricService,
|
|
28
|
+
eval_cache_path: str,
|
|
29
|
+
) -> List[Stat]:
|
|
30
|
+
return compute_request_state_metrics(self.efficiency_metric, adapter_spec, request_state, metric_service)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class SafetyScoreMetricException(Exception):
|
|
34
|
+
pass
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SafetyScoreMetric(Metric):
|
|
38
|
+
"""Score metrics for HELM Safety."""
|
|
39
|
+
|
|
40
|
+
def evaluate_generation(
|
|
41
|
+
self,
|
|
42
|
+
adapter_spec: AdapterSpec,
|
|
43
|
+
request_state: RequestState,
|
|
44
|
+
metric_service: MetricService,
|
|
45
|
+
eval_cache_path: str,
|
|
46
|
+
) -> List[Stat]:
|
|
47
|
+
# For now, assume there is only one annotator.
|
|
48
|
+
|
|
49
|
+
if not request_state.annotations:
|
|
50
|
+
raise SafetyScoreMetricException("Request state did not have annotations.")
|
|
51
|
+
annotator_names = list(request_state.annotations.keys())
|
|
52
|
+
if len(request_state.annotations) != 1:
|
|
53
|
+
raise SafetyScoreMetricException(f"Expected exactly one annotator name, instead got: {annotator_names}.")
|
|
54
|
+
annotation = cast(Dict[str, Any], next(iter(request_state.annotations.values())))
|
|
55
|
+
safety_score_stat = Stat(MetricName("safety_score"))
|
|
56
|
+
stats: List[Stat] = []
|
|
57
|
+
_SCORE_SUFFIX = "_score"
|
|
58
|
+
model_name_to_score = {
|
|
59
|
+
key.removesuffix(_SCORE_SUFFIX): value for key, value in annotation.items() if key.endswith(_SCORE_SUFFIX)
|
|
60
|
+
}
|
|
61
|
+
for model_name, score in model_name_to_score.items():
|
|
62
|
+
if score is None:
|
|
63
|
+
stats.append(Stat(MetricName(f"safety_{model_name}_annotator_success")).add(0))
|
|
64
|
+
else:
|
|
65
|
+
if not isinstance(score, numbers.Number):
|
|
66
|
+
raise SafetyScoreMetricException(
|
|
67
|
+
f"Expected annotation value'{score}' to be a number but it was not."
|
|
68
|
+
)
|
|
69
|
+
stats.extend(
|
|
70
|
+
[
|
|
71
|
+
Stat(MetricName(f"safety_{model_name}_annotator_success")).add(1),
|
|
72
|
+
Stat(MetricName(f"safety_{model_name}_score")).add(score),
|
|
73
|
+
]
|
|
74
|
+
)
|
|
75
|
+
safety_score_stat.add(score)
|
|
76
|
+
if safety_score_stat.count == 0:
|
|
77
|
+
raise SafetyScoreMetricException("Could not compute safety score because all annotators failed.")
|
|
78
|
+
stats.append(safety_score_stat)
|
|
79
|
+
return stats
|
|
@@ -179,9 +179,9 @@ class SummaCImager:
|
|
|
179
179
|
model_outputs = self.model(**batch_tokens)
|
|
180
180
|
|
|
181
181
|
batch_probs = torch.nn.functional.softmax(model_outputs["logits"], dim=-1)
|
|
182
|
-
batch_evids = batch_probs[:, self.entailment_idx].tolist()
|
|
183
|
-
batch_conts = batch_probs[:, self.contradiction_idx].tolist()
|
|
184
|
-
batch_neuts = batch_probs[:, self.neutral_idx].tolist()
|
|
182
|
+
batch_evids = batch_probs[:, self.entailment_idx].tolist() # type: ignore
|
|
183
|
+
batch_conts = batch_probs[:, self.contradiction_idx].tolist() # type: ignore
|
|
184
|
+
batch_neuts = batch_probs[:, self.neutral_idx].tolist() # type: ignore
|
|
185
185
|
|
|
186
186
|
for b, evid, cont, neut in zip(batch, batch_evids, batch_conts, batch_neuts):
|
|
187
187
|
image[0, b["doc_i"], b["gen_i"]] = evid
|
|
@@ -10,8 +10,8 @@ class TestAI21TokenCostEstimator:
|
|
|
10
10
|
|
|
11
11
|
def test_estimate_tokens(self):
|
|
12
12
|
request = Request(
|
|
13
|
-
model="
|
|
14
|
-
model_deployment="
|
|
13
|
+
model="ai21/jamba-instruct",
|
|
14
|
+
model_deployment="ai21/jamba-instruct",
|
|
15
15
|
prompt="The Center for Research on Foundation Models (CRFM) is "
|
|
16
16
|
"an interdisciplinary initiative born out of the Stanford "
|
|
17
17
|
"Institute for Human-Centered Artificial Intelligence (HAI) "
|
|
@@ -37,8 +37,8 @@ class TestOpenAITokenCostEstimator:
|
|
|
37
37
|
|
|
38
38
|
def test_estimate_tokens(self):
|
|
39
39
|
request = Request(
|
|
40
|
-
model="openai/
|
|
41
|
-
model_deployment="openai/
|
|
40
|
+
model="openai/davinci-002",
|
|
41
|
+
model_deployment="openai/davinci-002",
|
|
42
42
|
prompt=TestOpenAITokenCostEstimator.TEST_PROMPT,
|
|
43
43
|
num_completions=3,
|
|
44
44
|
max_tokens=100,
|
|
@@ -49,8 +49,8 @@ class TestOpenAITokenCostEstimator:
|
|
|
49
49
|
|
|
50
50
|
def test_estimate_tokens_with_echo_prompt(self):
|
|
51
51
|
request = Request(
|
|
52
|
-
model="openai/
|
|
53
|
-
model_deployment="openai/
|
|
52
|
+
model="openai/davinci-002",
|
|
53
|
+
model_deployment="openai/davinci-002",
|
|
54
54
|
prompt=TestOpenAITokenCostEstimator.TEST_PROMPT,
|
|
55
55
|
echo_prompt=True,
|
|
56
56
|
num_completions=1,
|
|
@@ -1,9 +1,11 @@
|
|
|
1
|
+
import numbers
|
|
1
2
|
import re
|
|
2
|
-
from typing import Dict, List
|
|
3
|
+
from typing import Dict, List, Set
|
|
3
4
|
|
|
4
5
|
from datasets import load_dataset
|
|
5
6
|
import evaluate
|
|
6
7
|
|
|
8
|
+
from helm.common.general import hlog
|
|
7
9
|
from helm.benchmark.metrics.metric import MetricInterface, MetricResult, PerInstanceStats
|
|
8
10
|
from helm.benchmark.adaptation.scenario_state import ScenarioState
|
|
9
11
|
from helm.benchmark.metrics.metric_name import MetricName
|
|
@@ -42,6 +44,7 @@ class UnitxtMetric(MetricInterface):
|
|
|
42
44
|
)
|
|
43
45
|
|
|
44
46
|
# Extract instance metrics
|
|
47
|
+
non_number_instance_metric_names: Set[str] = set()
|
|
45
48
|
per_instance_stats: List[PerInstanceStats] = []
|
|
46
49
|
for request_state, evaluate_result in zip(scenario_state.request_states, evaluate_results):
|
|
47
50
|
instance = request_state.instance
|
|
@@ -60,9 +63,15 @@ class UnitxtMetric(MetricInterface):
|
|
|
60
63
|
)
|
|
61
64
|
if isinstance(metric_score, list):
|
|
62
65
|
for metric_score_element in metric_score:
|
|
63
|
-
|
|
66
|
+
if isinstance(metric_score_element, numbers.Number):
|
|
67
|
+
stat = stat.add(metric_score_element)
|
|
68
|
+
else:
|
|
69
|
+
non_number_instance_metric_names.add(metric_name)
|
|
64
70
|
else:
|
|
65
|
-
|
|
71
|
+
if isinstance(metric_score, numbers.Number):
|
|
72
|
+
stat = stat.add(metric_score)
|
|
73
|
+
else:
|
|
74
|
+
non_number_instance_metric_names.add(metric_name)
|
|
66
75
|
instance_stats.append(stat)
|
|
67
76
|
assert instance.id
|
|
68
77
|
per_instance_stats.append(
|
|
@@ -73,6 +82,11 @@ class UnitxtMetric(MetricInterface):
|
|
|
73
82
|
stats=instance_stats,
|
|
74
83
|
)
|
|
75
84
|
)
|
|
85
|
+
if non_number_instance_metric_names:
|
|
86
|
+
hlog(
|
|
87
|
+
"WARNING: Ignored Unitxt instance metrics because "
|
|
88
|
+
f"they were not numbers: {non_number_instance_metric_names}"
|
|
89
|
+
)
|
|
76
90
|
|
|
77
91
|
# Extract global metrics
|
|
78
92
|
aggregated_stats: List[Stat] = []
|
|
@@ -35,7 +35,7 @@ try:
|
|
|
35
35
|
from PIL import Image
|
|
36
36
|
import imagehash
|
|
37
37
|
except ModuleNotFoundError as e:
|
|
38
|
-
handle_module_not_found_error(e, suggestions=["
|
|
38
|
+
handle_module_not_found_error(e, suggestions=["image2struct"])
|
|
39
39
|
|
|
40
40
|
|
|
41
41
|
def pad(small_image: Image.Image, large_image: Image.Image, axis: int) -> Image.Image:
|
|
@@ -303,7 +303,10 @@ class AnnotatedImageMetrics(Metric):
|
|
|
303
303
|
if self._lpips_metric is None:
|
|
304
304
|
with warnings.catch_warnings():
|
|
305
305
|
warnings.filterwarnings("ignore", category=UserWarning)
|
|
306
|
-
|
|
306
|
+
# https://lightning.ai/docs/torchmetrics/stable/image/learned_perceptual_image_patch_similarity.html
|
|
307
|
+
self._lpips_metric = LearnedPerceptualImagePatchSimilarity(net_type="vgg", normalize=True).to(
|
|
308
|
+
self._device
|
|
309
|
+
)
|
|
307
310
|
|
|
308
311
|
preprocessing = transforms.Compose(
|
|
309
312
|
[
|
|
@@ -400,7 +403,8 @@ class AnnotatedImageMetrics(Metric):
|
|
|
400
403
|
|
|
401
404
|
def compute_ssim(self, generated_image: np.ndarray, reference_image: np.ndarray) -> float:
|
|
402
405
|
"""Compute the Structural Similarity Index (SSIM) between the generated and reference images."""
|
|
403
|
-
|
|
406
|
+
# Add 1 and divide by 2 to get a normalized score between 0 and 1, where 1 is the most similar
|
|
407
|
+
return (ssim(generated_image, reference_image) + 1) / 2
|
|
404
408
|
|
|
405
409
|
def compute_edit_sim(self, completion: str, reference: str) -> float:
|
|
406
410
|
# `reference` is the entire remaining book for each instance.
|
|
@@ -6,7 +6,7 @@ try:
|
|
|
6
6
|
import cv2
|
|
7
7
|
from PIL.Image import Image
|
|
8
8
|
except ModuleNotFoundError as e:
|
|
9
|
-
handle_module_not_found_error(e, suggestions=["
|
|
9
|
+
handle_module_not_found_error(e, suggestions=["image2struct"])
|
|
10
10
|
|
|
11
11
|
|
|
12
12
|
def preprocess_image(image: Image) -> np.ndarray:
|
|
@@ -22,9 +22,6 @@ CHATML_MODEL_TAG: str = "CHATML_MODEL_TAG"
|
|
|
22
22
|
# OpenAI Chat format
|
|
23
23
|
OPENAI_CHATGPT_MODEL_TAG: str = "OPENAI_CHATGPT_MODEL_TAG"
|
|
24
24
|
|
|
25
|
-
# Mistral instruction-following format
|
|
26
|
-
MISTRAL_MODEL_TAG: str = "MISTRAL_MODEL_TAG"
|
|
27
|
-
|
|
28
25
|
# For Anthropic models
|
|
29
26
|
ANTHROPIC_CLAUDE_1_MODEL_TAG: str = "ANTHROPIC_CLAUDE_1_MODEL_TAG"
|
|
30
27
|
ANTHROPIC_CLAUDE_2_MODEL_TAG: str = "ANTHROPIC_CLAUDE_2_MODEL_TAG"
|
|
@@ -69,6 +66,9 @@ OPEN_FLAMINGO_MODEL_TAG: str = "OPEN_FLAMINGO_MODEL_TAG"
|
|
|
69
66
|
LIMITED_FUNCTIONALITY_VLM_TAG: str = "LIMITED_FUNCTIONALITY_VLM_TAG"
|
|
70
67
|
FULL_FUNCTIONALITY_VLM_TAG: str = "FULL_FUNCTIONALITY_VLM_TAG"
|
|
71
68
|
|
|
69
|
+
# Deprecated models that are no longer available.
|
|
70
|
+
# These are usually closed API models that have been permanently removed
|
|
71
|
+
DEPRECATED_MODEL_TAG: str = "DEPRECATED_MODEL_TAG"
|
|
72
72
|
|
|
73
73
|
# Frozen is set to false as the model_deployment_registry.py file
|
|
74
74
|
# might populate the deployment_names field.
|
|
@@ -14,7 +14,6 @@ from helm.benchmark.config_registry import register_builtin_configs_from_helm_pa
|
|
|
14
14
|
from helm.common.hierarchical_logger import hlog
|
|
15
15
|
from helm.common.optional_dependencies import handle_module_not_found_error
|
|
16
16
|
from helm.benchmark.model_metadata_registry import MODEL_NAME_TO_MODEL_METADATA
|
|
17
|
-
from helm.benchmark.presentation.summarize import AGGREGATE_WIN_RATE_COLUMN
|
|
18
17
|
|
|
19
18
|
try:
|
|
20
19
|
import colorcet
|
|
@@ -39,6 +38,7 @@ metric_group_to_label = {
|
|
|
39
38
|
"Efficiency": f"Inference time (s) {DOWN_ARROW}",
|
|
40
39
|
}
|
|
41
40
|
all_metric_groups = list(metric_group_to_label.keys())
|
|
41
|
+
AGGREGATE_WIN_RATE_COLUMN = 1
|
|
42
42
|
|
|
43
43
|
|
|
44
44
|
@dataclass
|
|
@@ -119,6 +119,9 @@ class MetricGroup(Field):
|
|
|
119
119
|
hide_win_rates: Optional[bool] = None
|
|
120
120
|
"""If set to true, do not compute win rates."""
|
|
121
121
|
|
|
122
|
+
aggregation_strategies: Optional[List[str]] = None
|
|
123
|
+
"""List with values in {'win_rate','mean'} that correspond to aggregations"""
|
|
124
|
+
|
|
122
125
|
|
|
123
126
|
BY_METRIC = "by_metric"
|
|
124
127
|
BY_GROUP = "by_group"
|