scorebook 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scorebook/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ """
2
+ Scorebook package.
3
+
4
+ A Python project for scorebook functionality.
5
+ """
6
+
7
+ import importlib.metadata
8
+
9
+ # get version from pyproject.toml
10
+ __version__ = importlib.metadata.version(__package__ or __name__)
11
+
12
+ from scorebook.evaluator import evaluate
13
+ from scorebook.types.eval_dataset import EvalDataset
14
+
15
+ __all__ = ["EvalDataset", "evaluate"]
scorebook/evaluator.py ADDED
@@ -0,0 +1,228 @@
1
+ """
2
+ Model evaluation functionality for the Scorebook framework.
3
+
4
+ This module provides the core evaluation logic to assess model predictions
5
+ against ground truth labels using configurable metrics. It supports:
6
+
7
+ - Batch evaluation of models across multiple datasets
8
+ - Flexible metric computation and aggregation
9
+ - Optional parameter sweeping and experiment tracking
10
+ - Customizable inference functions
11
+
12
+ The main entry point is the `evaluate()` function which handles running
13
+ models on datasets and computing metric scores.
14
+ """
15
+
16
+ import asyncio
17
+ from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
18
+
19
+ from scorebook.types.eval_dataset import EvalDataset
20
+ from scorebook.types.eval_result import EvalResult
21
+ from scorebook.utils import evaluation_progress, expand_dict, is_awaitable
22
+
23
+
24
+ async def _evaluate_async(
25
+ inference_callable: Callable,
26
+ eval_datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
27
+ hyperparameters: Optional[Dict[str, Any]] = None,
28
+ experiment_id: Optional[str] = None,
29
+ item_limit: Optional[int] = None,
30
+ return_type: str = "dict",
31
+ score_type: str = "aggregate",
32
+ ) -> Union[Dict, List]:
33
+ """Run inference across datasets/hyperparams, compute metrics, and format results."""
34
+ _validate_score_type(score_type)
35
+
36
+ normalized_datasets = _normalize_datasets(eval_datasets)
37
+ hyperparam_grid = _expand_hyperparams(hyperparameters)
38
+
39
+ eval_results: List[EvalResult] = []
40
+
41
+ with evaluation_progress(normalized_datasets, len(hyperparam_grid)) as progress_bars:
42
+ # Loop through datasets, then hyperparameters for clear progress tracking
43
+ for dataset_idx, eval_dataset in enumerate(normalized_datasets):
44
+ with progress_bars.hyperparam_progress_context():
45
+ # Run inference for each hyperparameter configuration on this dataset
46
+ for hp_idx, hyperparam_config in enumerate(hyperparam_grid):
47
+ items = _clip_items(eval_dataset.items, item_limit)
48
+ labels = _labels_for(items, eval_dataset.label)
49
+
50
+ # 1) Run inference
51
+ outputs = await _run_inference_callable(
52
+ inference_callable, items, hyperparam_config
53
+ )
54
+
55
+ # 2) Score metrics
56
+ metric_scores = _score_metrics(eval_dataset, outputs, labels)
57
+
58
+ # 3) Wrap into EvalResult
59
+ eval_results.append(
60
+ EvalResult(eval_dataset, outputs, metric_scores, hyperparam_config)
61
+ )
62
+
63
+ # Update inner progress bar
64
+ progress_bars.update_hyperparam_progress()
65
+
66
+ # Update the outer progress bar
67
+ progress_bars.update_dataset_progress()
68
+
69
+ # TODO: experiment_id handling (left as passthrough to preserve behavior)
70
+ if experiment_id:
71
+ pass
72
+
73
+ # 4) Format as requested
74
+ return _format_results(eval_results, return_type, score_type)
75
+
76
+
77
+ def evaluate(
78
+ inference_callable: Callable,
79
+ eval_datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
80
+ hyperparameters: Optional[Dict[str, Any]] = None,
81
+ experiment_id: Optional[str] = None,
82
+ item_limit: Optional[int] = None,
83
+ return_type: str = "dict",
84
+ score_type: str = "aggregate",
85
+ ) -> Union[Dict, List]:
86
+ """
87
+ Evaluate model predictions using specified metrics on given datasets.
88
+
89
+ This function runs the provided inference callable on one or more evaluation datasets,
90
+ computes metric scores, and returns the evaluation results. It supports batch processing,
91
+ parameter sweeping, and different result formatting options.
92
+
93
+ Args:
94
+ inference_callable: A callable function or object that takes (items, hyperparameters)
95
+ and returns predictions. Can be a regular function, async function,
96
+ or callable instance (like a class with __call__ method).
97
+ eval_datasets: One or more evaluation datasets to run evaluation on. Can be:
98
+ - A single EvalDataset instance
99
+ - A list of EvalDataset instances
100
+ - A string identifier (for future dataset registry support)
101
+ - A list of string identifiers
102
+ hyperparameters: Optional dictionary containing hyperparameter sweep configuration.
103
+ experiment_id: Optional string identifier for tracking multiple evaluation runs.
104
+ item_limit: Optional integer limiting the number of items to evaluate per dataset.
105
+ return_type: Format of the return value. Currently only "dict" is supported.
106
+ score_type: Type of score aggregation to return. Options:
107
+ - "aggregate": Return aggregated metrics
108
+ - "item": Return per-item scores
109
+ - "all": Return both aggregate and per-item scores
110
+
111
+ Returns:
112
+ Dictionary mapping dataset names to their evaluation results. For each dataset,
113
+ returns a dictionary containing:
114
+ - items: List of EvalResult objects with predictions and ground truth
115
+ - metrics: Dictionary mapping metric names to their computed scores
116
+
117
+ Example:
118
+
119
+ python
120
+ dataset = EvalDataset.from_huggingface("dataset_name", label="answer", metrics=[Precision])
121
+ def inference_fn(items):
122
+ # Model inference logic here - process all items at once
123
+ return [prediction for item in items]
124
+
125
+ results = evaluate(inference_fn, dataset, item_limit=100)
126
+ """
127
+ return asyncio.run(
128
+ _evaluate_async(
129
+ inference_callable=inference_callable,
130
+ eval_datasets=eval_datasets,
131
+ hyperparameters=hyperparameters,
132
+ experiment_id=experiment_id,
133
+ item_limit=item_limit,
134
+ return_type=return_type,
135
+ score_type=score_type,
136
+ )
137
+ )
138
+
139
+
140
+ # ===== Helper Functions =====
141
+
142
+
143
+ def _normalize_datasets(
144
+ datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]]
145
+ ) -> List[EvalDataset]:
146
+ if not isinstance(datasets, list):
147
+ datasets = [datasets]
148
+ # TODO: handle other types (string registry, etc.)
149
+ return [d for d in datasets if isinstance(d, EvalDataset)]
150
+
151
+
152
+ def _validate_score_type(score_type: str) -> None:
153
+ if score_type not in {"aggregate", "item", "all"}:
154
+ raise ValueError("score_type must be 'aggregate', 'item', or 'all'")
155
+
156
+
157
+ def _expand_hyperparams(hyperparameters: Optional[Dict[str, Any]]) -> Any:
158
+ return expand_dict(hyperparameters or {})
159
+
160
+
161
+ def _clip_items(items: List[Dict[str, Any]], item_limit: Optional[int]) -> List[Dict[str, Any]]:
162
+ return items[:item_limit] if item_limit else items
163
+
164
+
165
+ def _labels_for(items: List[Dict[str, Any]], label_key: str) -> List[Any]:
166
+ return [item.get(label_key) for item in items]
167
+
168
+
169
+ async def _run_inference_callable(
170
+ inference_callable: Callable,
171
+ items: List[Dict[str, Any]],
172
+ hyperparams: Dict[str, Any],
173
+ ) -> Any:
174
+ if is_awaitable(inference_callable):
175
+ return await inference_callable(items, **hyperparams)
176
+ else:
177
+ return inference_callable(items, **hyperparams)
178
+
179
+
180
+ # Yields (eval_dataset, items, labels, hyperparams) for every dataset x hyperparam combo.
181
+ def _iter_dataset_jobs(
182
+ datasets: List[EvalDataset],
183
+ hyperparam_grid: List[Dict[str, Any]],
184
+ item_limit: Optional[int],
185
+ ) -> Iterable[Tuple[EvalDataset, List[Dict[str, Any]], List[Any], Dict[str, Any]]]:
186
+ for eval_dataset in datasets:
187
+ for hp in hyperparam_grid:
188
+ items = _clip_items(eval_dataset.items, item_limit)
189
+ labels = _labels_for(items, eval_dataset.label)
190
+ yield eval_dataset, items, labels, hp
191
+
192
+
193
+ def _score_metrics(
194
+ eval_dataset: EvalDataset, outputs: List[Any], labels: List[Any]
195
+ ) -> Dict[str, Dict[str, Any]]:
196
+ metric_scores: Dict[str, Dict[str, Any]] = {}
197
+ for metric in eval_dataset.metrics:
198
+ aggregate_scores, item_scores = metric.score(outputs, labels)
199
+ metric_scores[metric.name] = {
200
+ "aggregate_scores": aggregate_scores,
201
+ "item_scores": item_scores,
202
+ }
203
+ return metric_scores
204
+
205
+
206
+ def _format_results(
207
+ eval_results: List[EvalResult], return_type: str, score_type: str
208
+ ) -> Union[Dict, List]:
209
+
210
+ if return_type != "dict":
211
+ return {er.eval_dataset.name: er for er in eval_results}
212
+
213
+ if score_type == "all":
214
+ combined: Dict[str, List[Dict[str, Any]]] = {"aggregate": [], "per_sample": []}
215
+ for er in eval_results:
216
+ d = er.to_dict()
217
+ combined["aggregate"].extend(d["aggregate"])
218
+ combined["per_sample"].extend(d["per_sample"])
219
+ return combined
220
+
221
+ if score_type == "aggregate":
222
+ return [er.aggregate_scores for er in eval_results]
223
+
224
+ if score_type == "item":
225
+ return [item for er in eval_results for item in er.item_scores]
226
+
227
+ # Should be unreachable due to validation
228
+ return {}
@@ -0,0 +1,11 @@
1
+ """
2
+ Inference module for model execution and predictions.
3
+
4
+ This module provides functionality for running inference with various models
5
+ and processing their responses. It includes utilities for both single and
6
+ batch inference operations.
7
+ """
8
+
9
+ from scorebook.inference.openai import batch, responses
10
+
11
+ __all__ = ["responses", "batch"]
@@ -0,0 +1,185 @@
1
+ """
2
+ OpenAI inference implementation for Scorebook.
3
+
4
+ This module provides utilities for running inference using OpenAI's models,
5
+ supporting both single response and batch inference operations. It handles
6
+ API communication, request formatting, and response processing.
7
+ """
8
+
9
+ import asyncio
10
+ import json
11
+ import tempfile
12
+ from typing import Any, List
13
+
14
+ from openai import OpenAI
15
+ from tqdm.asyncio import tqdm
16
+
17
+
18
+ async def responses(
19
+ items: List[Any], model: str = "gpt-4.1-nano", client: Any = None, **hyperparameters: Any
20
+ ) -> List[Any]:
21
+ """Process multiple inference requests using OpenAI's API.
22
+
23
+ This asynchronous function handles multiple inference requests,
24
+ manages the API communication, and processes the responses.
25
+
26
+ Args:
27
+ items: List of preprocessed items to process.
28
+ model: OpenAI model to use.
29
+ client: Optional OpenAI client instance.
30
+ hyperparameters: Dictionary of hyperparameters for inference.
31
+
32
+ Returns:
33
+ List of raw model responses.
34
+
35
+ Raises:
36
+ NotImplementedError: Currently not implemented.
37
+ """
38
+ if client is None:
39
+ client = OpenAI()
40
+
41
+ results = []
42
+ for item in items:
43
+ response = client.responses.create(model=model, input=item)
44
+ results.append(response)
45
+
46
+ return results
47
+
48
+
49
+ async def batch(
50
+ items: List[Any],
51
+ model: str = "gpt-4.1-nano",
52
+ client: Any = None,
53
+ **hyperparameters: Any,
54
+ ) -> List[Any]:
55
+ """Process multiple inference requests in batch using OpenAI's API.
56
+
57
+ This asynchronous function handles batch processing of inference requests,
58
+ optimizing for throughput while respecting API rate limits.
59
+
60
+ Args:
61
+ items: List of preprocessed items to process.
62
+ model: OpenAI model to use.
63
+ client: Optional OpenAI client instance.
64
+ hyperparameters: Dictionary of hyperparameters for inference.
65
+
66
+ Returns:
67
+ A list of raw model responses.
68
+
69
+ Raises:
70
+ NotImplementedError: Currently not implemented.
71
+ """
72
+ if client is None:
73
+ client = OpenAI()
74
+
75
+ file_id = _upload_batch(items, client)
76
+ batch_id = _start_batch(file_id, client)
77
+
78
+ # Initialize progress bar
79
+ pbar = tqdm(total=len(items), desc="Batch processing", unit="requests")
80
+
81
+ awaiting_batch = True
82
+ while awaiting_batch:
83
+ batch_object = await _get_batch(batch_id, client)
84
+ batch_status = batch_object.status
85
+
86
+ if hasattr(batch_object, "request_counts") and batch_object.request_counts:
87
+ completed = batch_object.request_counts.completed
88
+ total = batch_object.request_counts.total
89
+ pbar.n = completed
90
+ pbar.set_postfix(status=batch_status, completed=f"{completed}/{total}")
91
+ else:
92
+ pbar.set_postfix(status=batch_status)
93
+
94
+ pbar.refresh()
95
+
96
+ if batch_status == "completed":
97
+ awaiting_batch = False
98
+ pbar.n = pbar.total
99
+ pbar.set_postfix(status="completed")
100
+ elif batch_status == "failed":
101
+ raise Exception("Batch processing failed")
102
+ else:
103
+ await asyncio.sleep(60)
104
+
105
+ pbar.close()
106
+
107
+ # Get the final batch object to access output_file_id
108
+ final_batch_object = await _get_batch(batch_id, client)
109
+ output_file_id = final_batch_object.output_file_id
110
+
111
+ batch_result = await _get_results_file(output_file_id, client)
112
+ return batch_result
113
+
114
+
115
+ def _upload_batch(items: List[Any], client: Any) -> str:
116
+ """Create a .jsonl file from preprocessed items and upload to OpenAI for batch processing.
117
+
118
+ Args:
119
+ items: A list of preprocessed items, each representing a single dataset eval item.
120
+
121
+ Returns:
122
+ The file ID returned by OpenAI after uploading.
123
+ """
124
+ print("Uploading batch...")
125
+ # Instantiate OpenAI client
126
+ if client is None:
127
+ client = OpenAI()
128
+
129
+ # Create temp .jsonl file
130
+ with tempfile.NamedTemporaryFile(mode="w+", suffix=".jsonl", delete=False) as f:
131
+ for i, item in enumerate(items):
132
+ # Construct each batch line
133
+ payload = {
134
+ "custom_id": f"request-{i}",
135
+ "method": "POST",
136
+ "url": "/v1/chat/completions",
137
+ "body": item,
138
+ }
139
+ f.write(json.dumps(payload) + "\n")
140
+ file_path = f.name
141
+
142
+ # Upload file to OpenAI
143
+ with open(file_path, "rb") as upload_file:
144
+ response = client.files.create(file=upload_file, purpose="batch")
145
+
146
+ return str(response.id)
147
+
148
+
149
+ def _start_batch(file_id: str, client: Any) -> str:
150
+ batch_response = client.batches.create(
151
+ input_file_id=file_id,
152
+ endpoint="/v1/chat/completions",
153
+ completion_window="24h",
154
+ )
155
+ return str(batch_response.id)
156
+
157
+
158
+ async def _get_batch(batch_id: str, client: Any) -> Any:
159
+ batch_object = client.batches.retrieve(batch_id)
160
+ return batch_object
161
+
162
+
163
+ async def _get_results_file(output_file_id: str, client: Any) -> List[str]:
164
+ """Download and parse the batch results file from OpenAI."""
165
+ response = client.files.content(output_file_id)
166
+
167
+ # Parse the JSONL content
168
+ content = response.content.decode("utf-8")
169
+ results = []
170
+
171
+ for line in content.strip().split("\n"):
172
+ if line.strip():
173
+ result_obj = json.loads(line)
174
+ # Extract the response from the batch result structure
175
+ if "response" in result_obj and "body" in result_obj["response"]:
176
+ response_body = result_obj["response"]["body"]
177
+ if "choices" in response_body and len(response_body["choices"]) > 0:
178
+ message_content = response_body["choices"][0]["message"]["content"]
179
+ results.append(message_content)
180
+ else:
181
+ results.append("")
182
+ else:
183
+ results.append("")
184
+
185
+ return results
@@ -0,0 +1,186 @@
1
+ """
2
+ Portkey inference implementation for Scorebook.
3
+
4
+ This module provides utilities for running inference using Portkey's API,
5
+ supporting both single response and batch inference operations. It handles
6
+ API communication, request formatting, and response processing.
7
+ """
8
+
9
+ import asyncio
10
+ import json
11
+ import os
12
+ import tempfile
13
+ from typing import Any, List, Optional
14
+
15
+ from portkey_ai import AsyncPortkey
16
+ from tqdm.asyncio import tqdm
17
+
18
+
19
+ async def responses(
20
+ items: List[Any], model: str, client: Optional[AsyncPortkey] = None, **hyperparameters: Any
21
+ ) -> List[Any]:
22
+ """Process multiple inference requests using Portkey's API.
23
+
24
+ This asynchronous function handles multiple inference requests,
25
+ manages the API communication, and processes the responses.
26
+
27
+ Args:
28
+ items: List of preprocessed items to process.
29
+ model: Model to use via Portkey.
30
+ client: Optional Portkey client instance.
31
+ hyperparameters: Dictionary of hyperparameters for inference.
32
+
33
+ Returns:
34
+ List of raw model responses.
35
+ """
36
+
37
+ if client is None:
38
+ client = AsyncPortkey(api_key=os.getenv("PORTKEY_API_KEY"))
39
+
40
+ results = []
41
+ for item in items:
42
+ response = await client.chat.completions.create(
43
+ model=model,
44
+ messages=item if isinstance(item, list) else [{"role": "user", "content": str(item)}],
45
+ )
46
+ results.append(response)
47
+
48
+ return results
49
+
50
+
51
+ async def batch(
52
+ items: List[Any],
53
+ model: str,
54
+ client: Optional[AsyncPortkey] = None,
55
+ **hyperparameters: Any,
56
+ ) -> List[Any]:
57
+ """Process multiple inference requests in batch using Portkey's API.
58
+
59
+ This asynchronous function handles batch processing of inference requests,
60
+ optimizing for throughput while respecting API rate limits.
61
+
62
+ Args:
63
+ items: List of preprocessed items to process.
64
+ model: Model to use via Portkey.
65
+ client: Optional Portkey client instance.
66
+ hyperparameters: Dictionary of hyperparameters for inference.
67
+
68
+ Returns:
69
+ A list of raw model responses.
70
+ """
71
+
72
+ provider, model = model.split("/")
73
+
74
+ if client is None:
75
+ client = AsyncPortkey(provider=provider, api_key=os.getenv("PORTKEY_API_KEY"))
76
+
77
+ file_id = await _upload_batch(items, client, model, **hyperparameters)
78
+ batch_id = await _start_batch(file_id, client)
79
+
80
+ # Initialize progress bar
81
+ pbar = tqdm(total=len(items), desc="Batch processing", unit="requests")
82
+
83
+ awaiting_batch = True
84
+
85
+ while awaiting_batch:
86
+ batch_object = await _get_batch(batch_id, client)
87
+ batch_status = batch_object.status
88
+
89
+ if hasattr(batch_object, "request_counts") and batch_object.request_counts:
90
+ completed = batch_object.request_counts.completed
91
+ total = batch_object.request_counts.total
92
+ pbar.n = completed
93
+ pbar.set_postfix(status=batch_status, completed=f"{completed}/{total}")
94
+ else:
95
+ pbar.set_postfix(status=batch_status)
96
+
97
+ pbar.refresh()
98
+
99
+ if batch_status == "completed":
100
+ awaiting_batch = False
101
+ pbar.n = pbar.total
102
+ pbar.set_postfix(status="completed")
103
+ elif batch_status == "failed":
104
+ raise Exception("Batch processing failed")
105
+ else:
106
+ await asyncio.sleep(60)
107
+
108
+ pbar.close()
109
+
110
+ # Use the final batch object to access output_file_id
111
+ output_file_id = batch_object.output_file_id
112
+
113
+ batch_result = await _get_results_file(output_file_id, client)
114
+ return batch_result
115
+
116
+
117
+ async def _upload_batch(
118
+ items: List[Any], client: AsyncPortkey, model: str, **hyperparameters: Any
119
+ ) -> str:
120
+ """Create a .jsonl file from preprocessed items and upload to Portkey for batch processing.
121
+
122
+ Args:
123
+ items: A list of preprocessed items, each representing a single dataset eval item.
124
+ client: Portkey client instance.
125
+ model: Model to use for batch processing.
126
+ hyperparameters: Additional parameters for the batch requests.
127
+
128
+ Returns:
129
+ The file ID returned by Portkey after uploading.
130
+ """
131
+ print("Uploading batch...")
132
+
133
+ # Create temp .jsonl file
134
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
135
+ for i, item in enumerate(items):
136
+ # Construct each batch line
137
+ payload = {
138
+ "custom_id": f"request-{i}",
139
+ "method": "POST",
140
+ "url": "/v1/chat/completions",
141
+ "body": {
142
+ "model": model,
143
+ "messages": (
144
+ item if isinstance(item, list) else [{"role": "user", "content": str(item)}]
145
+ ),
146
+ **hyperparameters,
147
+ },
148
+ }
149
+ f.write(json.dumps(payload) + "\n")
150
+ file_path = f.name
151
+
152
+ # Upload file to Portkey
153
+ with open(file_path, "rb") as upload_file:
154
+ response = await client.files.create(file=upload_file, purpose="batch")
155
+
156
+ return str(response.id)
157
+
158
+
159
+ async def _start_batch(file_id: str, client: Any) -> str:
160
+ batch_response = await client.batches.create(
161
+ input_file_id=file_id,
162
+ endpoint="/v1/chat/completions",
163
+ completion_window="24h",
164
+ )
165
+ return str(batch_response.id)
166
+
167
+
168
+ async def _get_batch(batch_id: str, client: Any) -> Any:
169
+ batch_object = await client.batches.retrieve(batch_id)
170
+ return batch_object
171
+
172
+
173
+ async def _get_results_file(output_file_id: str, client: Any) -> List[str]:
174
+ """Download and parse the batch results file from Portkey."""
175
+ response = await client.files.content(output_file_id)
176
+
177
+ # Parse the JSONL content
178
+ content = response.content.decode("utf-8")
179
+ results = []
180
+
181
+ for line in content.strip().split("\n"):
182
+ result_obj = json.loads(line)
183
+ message_content = result_obj["response"]["body"]["choices"][0]["message"]["content"]
184
+ results.append(message_content)
185
+
186
+ return results
@@ -0,0 +1,18 @@
1
+ """
2
+ Metrics for evaluating model predictions.
3
+
4
+ This module provides a collection of evaluation metrics for comparing model outputs
5
+ against ground truth labels. Available metrics include standard classification and
6
+ generation metrics like accuracy, precision, recall, F1-score, etc.
7
+
8
+ Metrics can be accessed by name through the `get_metrics()` function or used
9
+ directly by instantiating specific metric classes. All metrics implement a
10
+ common interface for scoring predictions against references.
11
+ """
12
+
13
+ from scorebook.metrics.accuracy import Accuracy
14
+ from scorebook.metrics.metric_base import MetricBase
15
+ from scorebook.metrics.metric_registry import MetricRegistry
16
+ from scorebook.metrics.precision import Precision
17
+
18
+ __all__ = ["MetricBase", "Precision", "Accuracy", "MetricRegistry"]
@@ -0,0 +1,42 @@
1
+ """Accuracy metric implementation for Scorebook."""
2
+
3
+ from typing import Any, Dict, List, Tuple
4
+
5
+ from scorebook.metrics.metric_base import MetricBase
6
+ from scorebook.metrics.metric_registry import MetricRegistry
7
+
8
+
9
+ @MetricRegistry.register()
10
+ class Accuracy(MetricBase):
11
+ """Accuracy metric for evaluating model predictions of any type.
12
+
13
+ Accuracy = correct predictions / total predictions
14
+ """
15
+
16
+ @staticmethod
17
+ def score(outputs: List[Any], labels: List[Any]) -> Tuple[Dict[str, Any], List[Any]]:
18
+ """Calculate accuracy score between predictions and references.
19
+
20
+ Args:
21
+ outputs: A list of inference outputs.
22
+ labels: A list of ground truth labels.
23
+
24
+ Returns:
25
+ The aggregate accuracy score for all items (correct predictions / total predictions).
26
+ The item scores for each output-label pair (true/false).
27
+ """
28
+ if len(outputs) != len(labels):
29
+ raise ValueError("Number of outputs must match number of labels")
30
+
31
+ if not outputs: # Handle empty lists
32
+ return {"accuracy": 0.0}, []
33
+
34
+ # Calculate item scores
35
+ item_scores = [output == label for output, label in zip(outputs, labels)]
36
+
37
+ # Calculate aggregate score
38
+ correct_predictions = sum(item_scores)
39
+ total_predictions = len(outputs)
40
+ aggregate_scores = {"accuracy": correct_predictions / total_predictions}
41
+
42
+ return aggregate_scores, item_scores