scorebook 0.0.11__py3-none-any.whl → 0.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scorebook/__init__.py +8 -1
- scorebook/evaluate/_async/evaluate_async.py +100 -125
- scorebook/evaluate/_sync/evaluate.py +100 -126
- scorebook/evaluate/evaluate_helpers.py +24 -24
- scorebook/exceptions.py +6 -2
- scorebook/score/__init__.py +6 -0
- scorebook/score/_async/__init__.py +0 -0
- scorebook/score/_async/score_async.py +145 -0
- scorebook/score/_sync/__init__.py +0 -0
- scorebook/score/_sync/score.py +145 -0
- scorebook/score/score_helpers.py +207 -0
- scorebook/trismik/upload_results.py +254 -0
- scorebook/types.py +33 -54
- scorebook/utils/__init__.py +8 -1
- scorebook/utils/common_helpers.py +41 -0
- scorebook/utils/progress_bars.py +67 -0
- {scorebook-0.0.11.dist-info → scorebook-0.0.12.dist-info}/METADATA +2 -2
- {scorebook-0.0.11.dist-info → scorebook-0.0.12.dist-info}/RECORD +21 -13
- {scorebook-0.0.11.dist-info → scorebook-0.0.12.dist-info}/WHEEL +0 -0
- {scorebook-0.0.11.dist-info → scorebook-0.0.12.dist-info}/entry_points.txt +0 -0
- {scorebook-0.0.11.dist-info → scorebook-0.0.12.dist-info}/licenses/LICENSE +0 -0
|
@@ -2,9 +2,8 @@
|
|
|
2
2
|
|
|
3
3
|
import asyncio
|
|
4
4
|
import dataclasses
|
|
5
|
-
import inspect
|
|
6
5
|
import logging
|
|
7
|
-
from typing import Any, Callable, Dict, Iterable, List,
|
|
6
|
+
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Union
|
|
8
7
|
|
|
9
8
|
from trismik._async.client import TrismikAsyncClient
|
|
10
9
|
from trismik._sync.client import TrismikClient
|
|
@@ -25,30 +24,34 @@ from scorebook.utils import expand_dict, is_awaitable
|
|
|
25
24
|
logger = logging.getLogger(__name__)
|
|
26
25
|
|
|
27
26
|
|
|
28
|
-
|
|
29
|
-
|
|
27
|
+
# TODO: Remove this when backend supports boolean item metrics
|
|
28
|
+
NORMALIZE_METRICS_FOR_UPLOAD = True
|
|
30
29
|
|
|
31
|
-
if upload_results == "auto":
|
|
32
|
-
upload_results = get_token() is not None
|
|
33
|
-
logger.debug("Auto upload results resolved to: %s", upload_results)
|
|
34
30
|
|
|
35
|
-
|
|
31
|
+
def normalize_metric_value(value: Any) -> Any:
|
|
32
|
+
"""Normalize metric values for API upload compatibility.
|
|
36
33
|
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
34
|
+
TEMPORARY WORKAROUND: The Trismik API currently rejects boolean metric values.
|
|
35
|
+
This function converts boolean values to floats (True -> 1.0, False -> 0.0)
|
|
36
|
+
to ensure upload compatibility.
|
|
40
37
|
|
|
41
38
|
Args:
|
|
42
|
-
|
|
39
|
+
value: The metric value to normalize
|
|
43
40
|
|
|
44
41
|
Returns:
|
|
45
|
-
|
|
42
|
+
Float if value is bool, otherwise unchanged
|
|
43
|
+
|
|
44
|
+
TODO: Remove this function when backend supports boolean metrics natively.
|
|
45
|
+
To revert: Set NORMALIZE_METRICS_FOR_UPLOAD = False
|
|
46
46
|
"""
|
|
47
|
-
if
|
|
48
|
-
|
|
47
|
+
if not NORMALIZE_METRICS_FOR_UPLOAD:
|
|
48
|
+
return value
|
|
49
|
+
|
|
50
|
+
# Convert booleans to floats for API compatibility
|
|
51
|
+
if isinstance(value, bool):
|
|
52
|
+
return float(value) # True -> 1.0, False -> 0.0
|
|
49
53
|
|
|
50
|
-
|
|
51
|
-
return show_progress
|
|
54
|
+
return value
|
|
52
55
|
|
|
53
56
|
|
|
54
57
|
def validate_parameters(params: Dict[str, Any], caller: Callable[..., Any]) -> None:
|
|
@@ -109,7 +112,7 @@ def prepare_datasets(
|
|
|
109
112
|
|
|
110
113
|
# Prepare adaptive datasets
|
|
111
114
|
elif isinstance(dataset, str) and dataset.endswith(":adaptive"):
|
|
112
|
-
datasets_out.append(AdaptiveEvalDataset(dataset
|
|
115
|
+
datasets_out.append(AdaptiveEvalDataset(dataset))
|
|
113
116
|
|
|
114
117
|
# TODO: dataset name string registry
|
|
115
118
|
elif isinstance(dataset, str):
|
|
@@ -220,9 +223,9 @@ def build_adaptive_eval_run_spec(
|
|
|
220
223
|
metadata: Optional[Dict[str, Any]] = None,
|
|
221
224
|
) -> AdaptiveEvalRunSpec:
|
|
222
225
|
"""Build AdaptiveEvalRunSpec objects for a dataset/hyperparameter combination."""
|
|
223
|
-
dataset
|
|
226
|
+
# Keep the full dataset name including ":adaptive" suffix for backend API
|
|
224
227
|
adaptive_eval_run_spec = AdaptiveEvalRunSpec(
|
|
225
|
-
|
|
228
|
+
adaptive_dataset,
|
|
226
229
|
dataset_index,
|
|
227
230
|
hyperparameter_config,
|
|
228
231
|
hyperparameter_config_index,
|
|
@@ -345,10 +348,7 @@ def make_trismik_inference(
|
|
|
345
348
|
"""
|
|
346
349
|
|
|
347
350
|
# Check if the inference function is async
|
|
348
|
-
is_async =
|
|
349
|
-
hasattr(inference_function, "__call__")
|
|
350
|
-
and inspect.iscoroutinefunction(inference_function.__call__)
|
|
351
|
-
)
|
|
351
|
+
is_async = is_awaitable(inference_function)
|
|
352
352
|
|
|
353
353
|
def sync_trismik_inference_function(eval_items: Any, **kwargs: Any) -> Any:
|
|
354
354
|
# Single TrismikMultipleChoiceTextItem dataclass
|
scorebook/exceptions.py
CHANGED
|
@@ -84,10 +84,14 @@ class MetricComputationError(EvaluationError):
|
|
|
84
84
|
)
|
|
85
85
|
|
|
86
86
|
|
|
87
|
-
class
|
|
87
|
+
class ScoreError(ScoreBookError):
|
|
88
|
+
"""Raised when there are errors during scoring."""
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
class DataMismatchError(ScoreError):
|
|
88
92
|
"""Raised when there's a mismatch between outputs and expected labels."""
|
|
89
93
|
|
|
90
|
-
def __init__(self, outputs_count: int, labels_count: int, dataset_name: str):
|
|
94
|
+
def __init__(self, outputs_count: int, labels_count: int, dataset_name: str = "Dataset"):
|
|
91
95
|
"""Initialize data mismatch error."""
|
|
92
96
|
self.outputs_count = outputs_count
|
|
93
97
|
self.labels_count = labels_count
|
|
File without changes
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
|
3
|
+
|
|
4
|
+
from scorebook.exceptions import DataMismatchError, ParameterValidationError
|
|
5
|
+
from scorebook.score.score_helpers import (
|
|
6
|
+
calculate_metric_scores_async,
|
|
7
|
+
format_results,
|
|
8
|
+
resolve_metrics,
|
|
9
|
+
validate_items,
|
|
10
|
+
)
|
|
11
|
+
from scorebook.trismik.upload_results import upload_result_async
|
|
12
|
+
from scorebook.types import Metrics
|
|
13
|
+
from scorebook.utils import resolve_show_progress, resolve_upload_results, scoring_progress_context
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
async def score_async(
|
|
19
|
+
items: List[Dict[str, Any]],
|
|
20
|
+
metrics: Metrics,
|
|
21
|
+
output_column: str = "output",
|
|
22
|
+
label_column: str = "label",
|
|
23
|
+
input_column: str = "input",
|
|
24
|
+
hyperparameters: Optional[Dict[str, Any]] = None,
|
|
25
|
+
dataset_name: Optional[str] = None,
|
|
26
|
+
model_name: Optional[str] = None,
|
|
27
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
28
|
+
experiment_id: Optional[str] = None,
|
|
29
|
+
project_id: Optional[str] = None,
|
|
30
|
+
upload_results: Union[Literal["auto"], bool] = "auto",
|
|
31
|
+
show_progress: Optional[bool] = None,
|
|
32
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
33
|
+
"""Score pre-computed model outputs against labels using specified metrics.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
items: List of dictionaries containing model outputs and labels. Each item should
|
|
37
|
+
have keys matching the output_column and label_column parameters.
|
|
38
|
+
metrics: Metric(s) to compute. Can be a single Metric class, instance, string name,
|
|
39
|
+
or a list of any combination of these.
|
|
40
|
+
output_column: Key in items dictionaries containing model outputs. Defaults to "output".
|
|
41
|
+
label_column: Key in items dictionaries containing ground truth labels. Defaults to "label".
|
|
42
|
+
input_column: Key in items dictionaries containing inputs for reference.
|
|
43
|
+
Defaults to "input".
|
|
44
|
+
hyperparameters: Optional dictionary of hyperparameters used during inference.
|
|
45
|
+
Defaults to None.
|
|
46
|
+
dataset_name: Optional name of the dataset being evaluated. Defaults to None.
|
|
47
|
+
model_name: Optional name of the model being evaluated. Defaults to None.
|
|
48
|
+
metadata: Optional dictionary of additional metadata to store with results.
|
|
49
|
+
Defaults to None.
|
|
50
|
+
experiment_id: Optional experiment identifier for grouping related runs.
|
|
51
|
+
Required if upload_results is True. Defaults to None.
|
|
52
|
+
project_id: Optional Trismik project ID for uploading results.
|
|
53
|
+
Required if upload_results is True. Defaults to None.
|
|
54
|
+
upload_results: Whether to upload results to Trismik. Can be True, False, or "auto"
|
|
55
|
+
(uploads if experiment_id and project_id are provided). Defaults to "auto".
|
|
56
|
+
show_progress: Whether to display a progress bar during scoring. If None, uses
|
|
57
|
+
SHOW_PROGRESS_BARS from settings (defaults to True). Defaults to None.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Dictionary containing scoring results with keys:
|
|
61
|
+
- "aggregate_results": List with one dict containing aggregate metric scores
|
|
62
|
+
- "item_results": List of dicts with per-item scores and data
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
# Resolve and validate parameters
|
|
66
|
+
upload_results = cast(bool, resolve_upload_results(upload_results))
|
|
67
|
+
show_progress_bars = resolve_show_progress(show_progress)
|
|
68
|
+
|
|
69
|
+
# Validate upload requirements
|
|
70
|
+
if upload_results and (experiment_id is None or project_id is None):
|
|
71
|
+
raise ParameterValidationError(
|
|
72
|
+
"experiment_id and project_id are required to upload a run",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Validate items parameter
|
|
76
|
+
validate_items(items, output_column, label_column)
|
|
77
|
+
|
|
78
|
+
# Validate hyperparameters is a dict (not list)
|
|
79
|
+
if hyperparameters is not None and not isinstance(hyperparameters, dict):
|
|
80
|
+
raise ParameterValidationError("hyperparameters must be a dict")
|
|
81
|
+
|
|
82
|
+
# Resolve metrics to a list of Metrics
|
|
83
|
+
metric_instances = resolve_metrics(metrics)
|
|
84
|
+
|
|
85
|
+
# Extract outputs and labels from items
|
|
86
|
+
inputs = [item.get(input_column) for item in items]
|
|
87
|
+
outputs = [item.get(output_column) for item in items]
|
|
88
|
+
labels = [item.get(label_column) for item in items]
|
|
89
|
+
|
|
90
|
+
# Validate outputs and labels have same length
|
|
91
|
+
if len(outputs) != len(labels):
|
|
92
|
+
raise DataMismatchError(len(outputs), len(labels), dataset_name)
|
|
93
|
+
|
|
94
|
+
# Compute scores for each metric with progress display
|
|
95
|
+
with scoring_progress_context(
|
|
96
|
+
total_metrics=len(metric_instances),
|
|
97
|
+
enabled=show_progress_bars,
|
|
98
|
+
) as progress_bar:
|
|
99
|
+
metric_scores = await calculate_metric_scores_async(
|
|
100
|
+
metrics=metric_instances,
|
|
101
|
+
outputs=outputs,
|
|
102
|
+
labels=labels,
|
|
103
|
+
dataset_name=dataset_name,
|
|
104
|
+
progress_bar=progress_bar,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Build results
|
|
108
|
+
results: Dict[str, List[Dict[str, Any]]] = format_results(
|
|
109
|
+
inputs=inputs,
|
|
110
|
+
outputs=outputs,
|
|
111
|
+
labels=labels,
|
|
112
|
+
metric_scores=metric_scores,
|
|
113
|
+
hyperparameters=hyperparameters,
|
|
114
|
+
dataset_name=dataset_name,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Upload if requested
|
|
118
|
+
if upload_results and experiment_id and project_id:
|
|
119
|
+
try:
|
|
120
|
+
run_id = await upload_result_async(
|
|
121
|
+
run_result=results,
|
|
122
|
+
experiment_id=experiment_id,
|
|
123
|
+
project_id=project_id,
|
|
124
|
+
dataset_name=dataset_name,
|
|
125
|
+
hyperparameters=hyperparameters,
|
|
126
|
+
metadata=metadata,
|
|
127
|
+
model_name=model_name,
|
|
128
|
+
)
|
|
129
|
+
logger.info(f"Score results uploaded successfully with run_id: {run_id}")
|
|
130
|
+
|
|
131
|
+
# Add run_id to aggregate results
|
|
132
|
+
if results.get("aggregate_results"):
|
|
133
|
+
results["aggregate_results"][0]["run_id"] = run_id
|
|
134
|
+
|
|
135
|
+
# Add run_id to each item result
|
|
136
|
+
if results.get("item_results"):
|
|
137
|
+
for item in results["item_results"]:
|
|
138
|
+
item["run_id"] = run_id
|
|
139
|
+
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.warning(f"Failed to upload score results: {e}")
|
|
142
|
+
# Don't raise - continue execution even if upload fails
|
|
143
|
+
|
|
144
|
+
logger.info("Async scoring complete")
|
|
145
|
+
return results
|
|
File without changes
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
import logging
|
|
2
|
+
from typing import Any, Dict, List, Literal, Optional, Union, cast
|
|
3
|
+
|
|
4
|
+
from scorebook.exceptions import DataMismatchError, ParameterValidationError
|
|
5
|
+
from scorebook.score.score_helpers import (
|
|
6
|
+
calculate_metric_scores,
|
|
7
|
+
format_results,
|
|
8
|
+
resolve_metrics,
|
|
9
|
+
validate_items,
|
|
10
|
+
)
|
|
11
|
+
from scorebook.trismik.upload_results import upload_result
|
|
12
|
+
from scorebook.types import Metrics
|
|
13
|
+
from scorebook.utils import resolve_show_progress, resolve_upload_results, scoring_progress_context
|
|
14
|
+
|
|
15
|
+
logger = logging.getLogger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
def score(
|
|
19
|
+
items: List[Dict[str, Any]],
|
|
20
|
+
metrics: Metrics,
|
|
21
|
+
output_column: str = "output",
|
|
22
|
+
label_column: str = "label",
|
|
23
|
+
input_column: str = "input",
|
|
24
|
+
hyperparameters: Optional[Dict[str, Any]] = None,
|
|
25
|
+
dataset_name: Optional[str] = None,
|
|
26
|
+
model_name: Optional[str] = None,
|
|
27
|
+
metadata: Optional[Dict[str, Any]] = None,
|
|
28
|
+
experiment_id: Optional[str] = None,
|
|
29
|
+
project_id: Optional[str] = None,
|
|
30
|
+
upload_results: Union[Literal["auto"], bool] = "auto",
|
|
31
|
+
show_progress: Optional[bool] = None,
|
|
32
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
33
|
+
"""Score pre-computed model outputs against labels using specified metrics.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
items: List of dictionaries containing model outputs and labels. Each item should
|
|
37
|
+
have keys matching the output_column and label_column parameters.
|
|
38
|
+
metrics: Metric(s) to compute. Can be a single Metric class, instance, string name,
|
|
39
|
+
or a list of any combination of these.
|
|
40
|
+
output_column: Key in items dictionaries containing model outputs. Defaults to "output".
|
|
41
|
+
label_column: Key in items dictionaries containing ground truth labels. Defaults to "label".
|
|
42
|
+
input_column: Key in items dictionaries containing inputs for reference.
|
|
43
|
+
Defaults to "input".
|
|
44
|
+
hyperparameters: Optional dictionary of hyperparameters used during inference.
|
|
45
|
+
Defaults to None.
|
|
46
|
+
dataset_name: Optional name of the dataset being evaluated. Defaults to None.
|
|
47
|
+
model_name: Optional name of the model being evaluated. Defaults to None.
|
|
48
|
+
metadata: Optional dictionary of additional metadata to store with results.
|
|
49
|
+
Defaults to None.
|
|
50
|
+
experiment_id: Optional experiment identifier for grouping related runs.
|
|
51
|
+
Required if upload_results is True. Defaults to None.
|
|
52
|
+
project_id: Optional Trismik project ID for uploading results.
|
|
53
|
+
Required if upload_results is True. Defaults to None.
|
|
54
|
+
upload_results: Whether to upload results to Trismik. Can be True, False, or "auto"
|
|
55
|
+
(uploads if experiment_id and project_id are provided). Defaults to "auto".
|
|
56
|
+
show_progress: Whether to display a progress bar during scoring. If None, uses
|
|
57
|
+
SHOW_PROGRESS_BARS from settings (defaults to True). Defaults to None.
|
|
58
|
+
|
|
59
|
+
Returns:
|
|
60
|
+
Dictionary containing scoring results with keys:
|
|
61
|
+
- "aggregate_results": List with one dict containing aggregate metric scores
|
|
62
|
+
- "item_results": List of dicts with per-item scores and data
|
|
63
|
+
"""
|
|
64
|
+
|
|
65
|
+
# Resolve and validate parameters
|
|
66
|
+
upload_results = cast(bool, resolve_upload_results(upload_results))
|
|
67
|
+
show_progress_bars = resolve_show_progress(show_progress)
|
|
68
|
+
|
|
69
|
+
# Validate upload requirements
|
|
70
|
+
if upload_results and (experiment_id is None or project_id is None):
|
|
71
|
+
raise ParameterValidationError(
|
|
72
|
+
"experiment_id and project_id are required to upload a run",
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Validate items parameter
|
|
76
|
+
validate_items(items, output_column, label_column)
|
|
77
|
+
|
|
78
|
+
# Validate hyperparameters is a dict (not list)
|
|
79
|
+
if hyperparameters is not None and not isinstance(hyperparameters, dict):
|
|
80
|
+
raise ParameterValidationError("hyperparameters must be a dict")
|
|
81
|
+
|
|
82
|
+
# Resolve metrics to a list of Metrics
|
|
83
|
+
metric_instances = resolve_metrics(metrics)
|
|
84
|
+
|
|
85
|
+
# Extract outputs and labels from items
|
|
86
|
+
inputs = [item.get(input_column) for item in items]
|
|
87
|
+
outputs = [item.get(output_column) for item in items]
|
|
88
|
+
labels = [item.get(label_column) for item in items]
|
|
89
|
+
|
|
90
|
+
# Validate outputs and labels have same length
|
|
91
|
+
if len(outputs) != len(labels):
|
|
92
|
+
raise DataMismatchError(len(outputs), len(labels), dataset_name)
|
|
93
|
+
|
|
94
|
+
# Compute scores for each metric with progress display
|
|
95
|
+
with scoring_progress_context(
|
|
96
|
+
total_metrics=len(metric_instances),
|
|
97
|
+
enabled=show_progress_bars,
|
|
98
|
+
) as progress_bar:
|
|
99
|
+
metric_scores = calculate_metric_scores(
|
|
100
|
+
metrics=metric_instances,
|
|
101
|
+
outputs=outputs,
|
|
102
|
+
labels=labels,
|
|
103
|
+
dataset_name=dataset_name,
|
|
104
|
+
progress_bar=progress_bar,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
# Build results
|
|
108
|
+
results: Dict[str, List[Dict[str, Any]]] = format_results(
|
|
109
|
+
inputs=inputs,
|
|
110
|
+
outputs=outputs,
|
|
111
|
+
labels=labels,
|
|
112
|
+
metric_scores=metric_scores,
|
|
113
|
+
hyperparameters=hyperparameters,
|
|
114
|
+
dataset_name=dataset_name,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
# Upload if requested
|
|
118
|
+
if upload_results and experiment_id and project_id:
|
|
119
|
+
try:
|
|
120
|
+
run_id = upload_result(
|
|
121
|
+
run_result=results,
|
|
122
|
+
experiment_id=experiment_id,
|
|
123
|
+
project_id=project_id,
|
|
124
|
+
dataset_name=dataset_name,
|
|
125
|
+
hyperparameters=hyperparameters,
|
|
126
|
+
metadata=metadata,
|
|
127
|
+
model_name=model_name,
|
|
128
|
+
)
|
|
129
|
+
logger.info(f"Score results uploaded successfully with run_id: {run_id}")
|
|
130
|
+
|
|
131
|
+
# Add run_id to aggregate results
|
|
132
|
+
if results.get("aggregate_results"):
|
|
133
|
+
results["aggregate_results"][0]["run_id"] = run_id
|
|
134
|
+
|
|
135
|
+
# Add run_id to each item result
|
|
136
|
+
if results.get("item_results"):
|
|
137
|
+
for item in results["item_results"]:
|
|
138
|
+
item["run_id"] = run_id
|
|
139
|
+
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.warning(f"Failed to upload score results: {e}")
|
|
142
|
+
# Don't raise - continue execution even if upload fails
|
|
143
|
+
|
|
144
|
+
logger.info("Scoring complete")
|
|
145
|
+
return results
|
|
@@ -0,0 +1,207 @@
|
|
|
1
|
+
"""Helper functions shared between score() and score_async()."""
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
from typing import Any, Dict, List, Mapping, Optional, Type, Union
|
|
5
|
+
|
|
6
|
+
from scorebook.exceptions import DataMismatchError, ParameterValidationError
|
|
7
|
+
from scorebook.metrics.metric_base import MetricBase
|
|
8
|
+
from scorebook.metrics.metric_registry import MetricRegistry
|
|
9
|
+
from scorebook.types import MetricScore
|
|
10
|
+
from scorebook.utils import is_awaitable
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def validate_items(items: List[Dict[str, Any]], output_column: str, label_column: str) -> None:
|
|
16
|
+
"""Validate the items parameter."""
|
|
17
|
+
if not isinstance(items, list):
|
|
18
|
+
raise ParameterValidationError("items must be a list")
|
|
19
|
+
|
|
20
|
+
if len(items) == 0:
|
|
21
|
+
raise ParameterValidationError("items list cannot be empty")
|
|
22
|
+
|
|
23
|
+
required = {output_column, label_column}
|
|
24
|
+
for idx, item in enumerate(items):
|
|
25
|
+
if not isinstance(item, Mapping):
|
|
26
|
+
raise ParameterValidationError(f"Item at index {idx} is not a dict")
|
|
27
|
+
|
|
28
|
+
missing = required - item.keys()
|
|
29
|
+
if missing:
|
|
30
|
+
for key in sorted(missing):
|
|
31
|
+
raise ParameterValidationError(f"Item at index {idx} missing required '{key}' key")
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def resolve_metrics(
|
|
35
|
+
metrics: Union[
|
|
36
|
+
str, MetricBase, Type[MetricBase], List[Union[str, MetricBase, Type[MetricBase]]]
|
|
37
|
+
]
|
|
38
|
+
) -> List[MetricBase]:
|
|
39
|
+
"""Resolve metrics parameter to list of MetricBase instances."""
|
|
40
|
+
# Ensure metrics is a list
|
|
41
|
+
if not isinstance(metrics, list):
|
|
42
|
+
metrics = [metrics]
|
|
43
|
+
|
|
44
|
+
# Resolve each metric
|
|
45
|
+
metric_instances = []
|
|
46
|
+
for metric in metrics:
|
|
47
|
+
if isinstance(metric, str) or (isinstance(metric, type) and issubclass(metric, MetricBase)):
|
|
48
|
+
# Use MetricRegistry to resolve string names or classes
|
|
49
|
+
metric_instance = MetricRegistry.get(metric)
|
|
50
|
+
metric_instances.append(metric_instance)
|
|
51
|
+
elif isinstance(metric, MetricBase):
|
|
52
|
+
# Already an instance
|
|
53
|
+
metric_instances.append(metric)
|
|
54
|
+
else:
|
|
55
|
+
raise ParameterValidationError(
|
|
56
|
+
f"Invalid metric type: {type(metric)}. "
|
|
57
|
+
"Metrics must be string names, MetricBase classes, or MetricBase instances"
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
return metric_instances
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
async def calculate_metric_scores_async(
|
|
64
|
+
metrics: List[MetricBase],
|
|
65
|
+
outputs: List[Any],
|
|
66
|
+
labels: List[Any],
|
|
67
|
+
dataset_name: Optional[str],
|
|
68
|
+
progress_bar: Optional[Any] = None,
|
|
69
|
+
) -> List[MetricScore]:
|
|
70
|
+
"""Calculate metric scores asynchronously (supports both sync and async metrics).
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
metrics: List of metric instances to compute scores for.
|
|
74
|
+
outputs: List of model outputs.
|
|
75
|
+
labels: List of ground truth labels.
|
|
76
|
+
dataset_name: Name of the dataset being scored.
|
|
77
|
+
progress_bar: Optional progress bar to update during computation.
|
|
78
|
+
|
|
79
|
+
Returns:
|
|
80
|
+
List of MetricScore objects containing aggregate and item-level scores.
|
|
81
|
+
|
|
82
|
+
Raises:
|
|
83
|
+
DataMismatchError: If outputs and labels have different lengths.
|
|
84
|
+
"""
|
|
85
|
+
if len(outputs) != len(labels):
|
|
86
|
+
raise DataMismatchError(len(outputs), len(labels), dataset_name)
|
|
87
|
+
|
|
88
|
+
metric_scores: List[MetricScore] = []
|
|
89
|
+
for metric in metrics:
|
|
90
|
+
|
|
91
|
+
if progress_bar is not None:
|
|
92
|
+
progress_bar.set_current_metric(metric.name)
|
|
93
|
+
|
|
94
|
+
if is_awaitable(metric.score):
|
|
95
|
+
aggregate_scores, item_scores = await metric.score(outputs, labels)
|
|
96
|
+
else:
|
|
97
|
+
aggregate_scores, item_scores = metric.score(outputs, labels)
|
|
98
|
+
|
|
99
|
+
metric_scores.append(MetricScore(metric.name, aggregate_scores, item_scores))
|
|
100
|
+
|
|
101
|
+
if progress_bar is not None:
|
|
102
|
+
progress_bar.update(1)
|
|
103
|
+
|
|
104
|
+
return metric_scores
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
def calculate_metric_scores(
|
|
108
|
+
metrics: List[MetricBase],
|
|
109
|
+
outputs: List[Any],
|
|
110
|
+
labels: List[Any],
|
|
111
|
+
dataset_name: Optional[str],
|
|
112
|
+
progress_bar: Optional[Any] = None,
|
|
113
|
+
) -> List[MetricScore]:
|
|
114
|
+
"""Calculate metric scores synchronously (sync metrics only).
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
metrics: List of metric instances to compute scores for.
|
|
118
|
+
outputs: List of model outputs.
|
|
119
|
+
labels: List of ground truth labels.
|
|
120
|
+
dataset_name: Name of the dataset being scored.
|
|
121
|
+
progress_bar: Optional progress bar to update during computation.
|
|
122
|
+
|
|
123
|
+
Returns:
|
|
124
|
+
List of MetricScore objects containing aggregate and item-level scores.
|
|
125
|
+
|
|
126
|
+
Raises:
|
|
127
|
+
DataMismatchError: If outputs and labels have different lengths.
|
|
128
|
+
ParameterValidationError: If any metric has an async score method.
|
|
129
|
+
"""
|
|
130
|
+
if len(outputs) != len(labels):
|
|
131
|
+
raise DataMismatchError(len(outputs), len(labels), dataset_name)
|
|
132
|
+
|
|
133
|
+
metric_scores: List[MetricScore] = []
|
|
134
|
+
for metric in metrics:
|
|
135
|
+
|
|
136
|
+
if progress_bar is not None:
|
|
137
|
+
progress_bar.set_current_metric(metric.name)
|
|
138
|
+
|
|
139
|
+
if is_awaitable(metric.score):
|
|
140
|
+
raise ParameterValidationError(
|
|
141
|
+
f"Metric '{metric.name}' has an async score() method. "
|
|
142
|
+
"Use score_async() instead of score() for async metrics."
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
aggregate_scores, item_scores = metric.score(outputs, labels)
|
|
146
|
+
metric_scores.append(MetricScore(metric.name, aggregate_scores, item_scores))
|
|
147
|
+
|
|
148
|
+
if progress_bar is not None:
|
|
149
|
+
progress_bar.update(1)
|
|
150
|
+
|
|
151
|
+
return metric_scores
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def format_results(
|
|
155
|
+
inputs: Optional[List[Any]],
|
|
156
|
+
outputs: List[Any],
|
|
157
|
+
labels: List[Any],
|
|
158
|
+
metric_scores: List[MetricScore],
|
|
159
|
+
hyperparameters: Optional[Dict[str, Any]] = None,
|
|
160
|
+
dataset_name: Optional[str] = None,
|
|
161
|
+
) -> Dict[str, List[Dict[str, Any]]]:
|
|
162
|
+
"""Format results dict with both aggregates and items."""
|
|
163
|
+
# Use defaults if not provided
|
|
164
|
+
hyperparameters = hyperparameters or {}
|
|
165
|
+
dataset_name = dataset_name or "scored_items"
|
|
166
|
+
|
|
167
|
+
# Build aggregate results
|
|
168
|
+
aggregate_result = {
|
|
169
|
+
"dataset": dataset_name,
|
|
170
|
+
**hyperparameters,
|
|
171
|
+
}
|
|
172
|
+
|
|
173
|
+
# Add aggregate scores from metrics
|
|
174
|
+
for metric_score in metric_scores:
|
|
175
|
+
for key, value in metric_score.aggregate_scores.items():
|
|
176
|
+
score_key = (
|
|
177
|
+
key if key == metric_score.metric_name else f"{metric_score.metric_name}_{key}"
|
|
178
|
+
)
|
|
179
|
+
aggregate_result[score_key] = value
|
|
180
|
+
|
|
181
|
+
# Build item results
|
|
182
|
+
item_results = []
|
|
183
|
+
for idx in range(len(outputs)):
|
|
184
|
+
item_result: Dict[str, Any] = {
|
|
185
|
+
"id": idx,
|
|
186
|
+
"dataset": dataset_name,
|
|
187
|
+
"output": outputs[idx],
|
|
188
|
+
"label": labels[idx],
|
|
189
|
+
**hyperparameters,
|
|
190
|
+
}
|
|
191
|
+
|
|
192
|
+
# Add input if present
|
|
193
|
+
if inputs is not None and inputs[idx] is not None:
|
|
194
|
+
item_result["input"] = inputs[idx]
|
|
195
|
+
|
|
196
|
+
# Add item-level metric scores
|
|
197
|
+
for metric_score in metric_scores:
|
|
198
|
+
if idx < len(metric_score.item_scores):
|
|
199
|
+
item_result[metric_score.metric_name] = metric_score.item_scores[idx]
|
|
200
|
+
|
|
201
|
+
item_results.append(item_result)
|
|
202
|
+
|
|
203
|
+
# Always return both aggregates and items
|
|
204
|
+
return {
|
|
205
|
+
"aggregate_results": [aggregate_result],
|
|
206
|
+
"item_results": item_results,
|
|
207
|
+
}
|