scorebook 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,145 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Literal, Optional, Union, cast
3
+
4
+ from scorebook.exceptions import DataMismatchError, ParameterValidationError
5
+ from scorebook.score.score_helpers import (
6
+ calculate_metric_scores_async,
7
+ format_results,
8
+ resolve_metrics,
9
+ validate_items,
10
+ )
11
+ from scorebook.trismik.upload_results import upload_result_async
12
+ from scorebook.types import Metrics
13
+ from scorebook.utils import resolve_show_progress, resolve_upload_results, scoring_progress_context
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ async def score_async(
19
+ items: List[Dict[str, Any]],
20
+ metrics: Metrics,
21
+ output_column: str = "output",
22
+ label_column: str = "label",
23
+ input_column: str = "input",
24
+ hyperparameters: Optional[Dict[str, Any]] = None,
25
+ dataset_name: Optional[str] = None,
26
+ model_name: Optional[str] = None,
27
+ metadata: Optional[Dict[str, Any]] = None,
28
+ experiment_id: Optional[str] = None,
29
+ project_id: Optional[str] = None,
30
+ upload_results: Union[Literal["auto"], bool] = "auto",
31
+ show_progress: Optional[bool] = None,
32
+ ) -> Dict[str, List[Dict[str, Any]]]:
33
+ """Score pre-computed model outputs against labels using specified metrics.
34
+
35
+ Args:
36
+ items: List of dictionaries containing model outputs and labels. Each item should
37
+ have keys matching the output_column and label_column parameters.
38
+ metrics: Metric(s) to compute. Can be a single Metric class, instance, string name,
39
+ or a list of any combination of these.
40
+ output_column: Key in items dictionaries containing model outputs. Defaults to "output".
41
+ label_column: Key in items dictionaries containing ground truth labels. Defaults to "label".
42
+ input_column: Key in items dictionaries containing inputs for reference.
43
+ Defaults to "input".
44
+ hyperparameters: Optional dictionary of hyperparameters used during inference.
45
+ Defaults to None.
46
+ dataset_name: Optional name of the dataset being evaluated. Defaults to None.
47
+ model_name: Optional name of the model being evaluated. Defaults to None.
48
+ metadata: Optional dictionary of additional metadata to store with results.
49
+ Defaults to None.
50
+ experiment_id: Optional experiment identifier for grouping related runs.
51
+ Required if upload_results is True. Defaults to None.
52
+ project_id: Optional Trismik project ID for uploading results.
53
+ Required if upload_results is True. Defaults to None.
54
+ upload_results: Whether to upload results to Trismik. Can be True, False, or "auto"
55
+ (uploads if experiment_id and project_id are provided). Defaults to "auto".
56
+ show_progress: Whether to display a progress bar during scoring. If None, uses
57
+ SHOW_PROGRESS_BARS from settings (defaults to True). Defaults to None.
58
+
59
+ Returns:
60
+ Dictionary containing scoring results with keys:
61
+ - "aggregate_results": List with one dict containing aggregate metric scores
62
+ - "item_results": List of dicts with per-item scores and data
63
+ """
64
+
65
+ # Resolve and validate parameters
66
+ upload_results = cast(bool, resolve_upload_results(upload_results))
67
+ show_progress_bars = resolve_show_progress(show_progress)
68
+
69
+ # Validate upload requirements
70
+ if upload_results and (experiment_id is None or project_id is None):
71
+ raise ParameterValidationError(
72
+ "experiment_id and project_id are required to upload a run",
73
+ )
74
+
75
+ # Validate items parameter
76
+ validate_items(items, output_column, label_column)
77
+
78
+ # Validate hyperparameters is a dict (not list)
79
+ if hyperparameters is not None and not isinstance(hyperparameters, dict):
80
+ raise ParameterValidationError("hyperparameters must be a dict")
81
+
82
+ # Resolve metrics to a list of Metrics
83
+ metric_instances = resolve_metrics(metrics)
84
+
85
+ # Extract outputs and labels from items
86
+ inputs = [item.get(input_column) for item in items]
87
+ outputs = [item.get(output_column) for item in items]
88
+ labels = [item.get(label_column) for item in items]
89
+
90
+ # Validate outputs and labels have same length
91
+ if len(outputs) != len(labels):
92
+ raise DataMismatchError(len(outputs), len(labels), dataset_name)
93
+
94
+ # Compute scores for each metric with progress display
95
+ with scoring_progress_context(
96
+ total_metrics=len(metric_instances),
97
+ enabled=show_progress_bars,
98
+ ) as progress_bar:
99
+ metric_scores = await calculate_metric_scores_async(
100
+ metrics=metric_instances,
101
+ outputs=outputs,
102
+ labels=labels,
103
+ dataset_name=dataset_name,
104
+ progress_bar=progress_bar,
105
+ )
106
+
107
+ # Build results
108
+ results: Dict[str, List[Dict[str, Any]]] = format_results(
109
+ inputs=inputs,
110
+ outputs=outputs,
111
+ labels=labels,
112
+ metric_scores=metric_scores,
113
+ hyperparameters=hyperparameters,
114
+ dataset_name=dataset_name,
115
+ )
116
+
117
+ # Upload if requested
118
+ if upload_results and experiment_id and project_id:
119
+ try:
120
+ run_id = await upload_result_async(
121
+ run_result=results,
122
+ experiment_id=experiment_id,
123
+ project_id=project_id,
124
+ dataset_name=dataset_name,
125
+ hyperparameters=hyperparameters,
126
+ metadata=metadata,
127
+ model_name=model_name,
128
+ )
129
+ logger.info(f"Score results uploaded successfully with run_id: {run_id}")
130
+
131
+ # Add run_id to aggregate results
132
+ if results.get("aggregate_results"):
133
+ results["aggregate_results"][0]["run_id"] = run_id
134
+
135
+ # Add run_id to each item result
136
+ if results.get("item_results"):
137
+ for item in results["item_results"]:
138
+ item["run_id"] = run_id
139
+
140
+ except Exception as e:
141
+ logger.warning(f"Failed to upload score results: {e}")
142
+ # Don't raise - continue execution even if upload fails
143
+
144
+ logger.info("Async scoring complete")
145
+ return results
File without changes
@@ -0,0 +1,145 @@
1
+ import logging
2
+ from typing import Any, Dict, List, Literal, Optional, Union, cast
3
+
4
+ from scorebook.exceptions import DataMismatchError, ParameterValidationError
5
+ from scorebook.score.score_helpers import (
6
+ calculate_metric_scores,
7
+ format_results,
8
+ resolve_metrics,
9
+ validate_items,
10
+ )
11
+ from scorebook.trismik.upload_results import upload_result
12
+ from scorebook.types import Metrics
13
+ from scorebook.utils import resolve_show_progress, resolve_upload_results, scoring_progress_context
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ def score(
19
+ items: List[Dict[str, Any]],
20
+ metrics: Metrics,
21
+ output_column: str = "output",
22
+ label_column: str = "label",
23
+ input_column: str = "input",
24
+ hyperparameters: Optional[Dict[str, Any]] = None,
25
+ dataset_name: Optional[str] = None,
26
+ model_name: Optional[str] = None,
27
+ metadata: Optional[Dict[str, Any]] = None,
28
+ experiment_id: Optional[str] = None,
29
+ project_id: Optional[str] = None,
30
+ upload_results: Union[Literal["auto"], bool] = "auto",
31
+ show_progress: Optional[bool] = None,
32
+ ) -> Dict[str, List[Dict[str, Any]]]:
33
+ """Score pre-computed model outputs against labels using specified metrics.
34
+
35
+ Args:
36
+ items: List of dictionaries containing model outputs and labels. Each item should
37
+ have keys matching the output_column and label_column parameters.
38
+ metrics: Metric(s) to compute. Can be a single Metric class, instance, string name,
39
+ or a list of any combination of these.
40
+ output_column: Key in items dictionaries containing model outputs. Defaults to "output".
41
+ label_column: Key in items dictionaries containing ground truth labels. Defaults to "label".
42
+ input_column: Key in items dictionaries containing inputs for reference.
43
+ Defaults to "input".
44
+ hyperparameters: Optional dictionary of hyperparameters used during inference.
45
+ Defaults to None.
46
+ dataset_name: Optional name of the dataset being evaluated. Defaults to None.
47
+ model_name: Optional name of the model being evaluated. Defaults to None.
48
+ metadata: Optional dictionary of additional metadata to store with results.
49
+ Defaults to None.
50
+ experiment_id: Optional experiment identifier for grouping related runs.
51
+ Required if upload_results is True. Defaults to None.
52
+ project_id: Optional Trismik project ID for uploading results.
53
+ Required if upload_results is True. Defaults to None.
54
+ upload_results: Whether to upload results to Trismik. Can be True, False, or "auto"
55
+ (uploads if experiment_id and project_id are provided). Defaults to "auto".
56
+ show_progress: Whether to display a progress bar during scoring. If None, uses
57
+ SHOW_PROGRESS_BARS from settings (defaults to True). Defaults to None.
58
+
59
+ Returns:
60
+ Dictionary containing scoring results with keys:
61
+ - "aggregate_results": List with one dict containing aggregate metric scores
62
+ - "item_results": List of dicts with per-item scores and data
63
+ """
64
+
65
+ # Resolve and validate parameters
66
+ upload_results = cast(bool, resolve_upload_results(upload_results))
67
+ show_progress_bars = resolve_show_progress(show_progress)
68
+
69
+ # Validate upload requirements
70
+ if upload_results and (experiment_id is None or project_id is None):
71
+ raise ParameterValidationError(
72
+ "experiment_id and project_id are required to upload a run",
73
+ )
74
+
75
+ # Validate items parameter
76
+ validate_items(items, output_column, label_column)
77
+
78
+ # Validate hyperparameters is a dict (not list)
79
+ if hyperparameters is not None and not isinstance(hyperparameters, dict):
80
+ raise ParameterValidationError("hyperparameters must be a dict")
81
+
82
+ # Resolve metrics to a list of Metrics
83
+ metric_instances = resolve_metrics(metrics)
84
+
85
+ # Extract outputs and labels from items
86
+ inputs = [item.get(input_column) for item in items]
87
+ outputs = [item.get(output_column) for item in items]
88
+ labels = [item.get(label_column) for item in items]
89
+
90
+ # Validate outputs and labels have same length
91
+ if len(outputs) != len(labels):
92
+ raise DataMismatchError(len(outputs), len(labels), dataset_name)
93
+
94
+ # Compute scores for each metric with progress display
95
+ with scoring_progress_context(
96
+ total_metrics=len(metric_instances),
97
+ enabled=show_progress_bars,
98
+ ) as progress_bar:
99
+ metric_scores = calculate_metric_scores(
100
+ metrics=metric_instances,
101
+ outputs=outputs,
102
+ labels=labels,
103
+ dataset_name=dataset_name,
104
+ progress_bar=progress_bar,
105
+ )
106
+
107
+ # Build results
108
+ results: Dict[str, List[Dict[str, Any]]] = format_results(
109
+ inputs=inputs,
110
+ outputs=outputs,
111
+ labels=labels,
112
+ metric_scores=metric_scores,
113
+ hyperparameters=hyperparameters,
114
+ dataset_name=dataset_name,
115
+ )
116
+
117
+ # Upload if requested
118
+ if upload_results and experiment_id and project_id:
119
+ try:
120
+ run_id = upload_result(
121
+ run_result=results,
122
+ experiment_id=experiment_id,
123
+ project_id=project_id,
124
+ dataset_name=dataset_name,
125
+ hyperparameters=hyperparameters,
126
+ metadata=metadata,
127
+ model_name=model_name,
128
+ )
129
+ logger.info(f"Score results uploaded successfully with run_id: {run_id}")
130
+
131
+ # Add run_id to aggregate results
132
+ if results.get("aggregate_results"):
133
+ results["aggregate_results"][0]["run_id"] = run_id
134
+
135
+ # Add run_id to each item result
136
+ if results.get("item_results"):
137
+ for item in results["item_results"]:
138
+ item["run_id"] = run_id
139
+
140
+ except Exception as e:
141
+ logger.warning(f"Failed to upload score results: {e}")
142
+ # Don't raise - continue execution even if upload fails
143
+
144
+ logger.info("Scoring complete")
145
+ return results
@@ -0,0 +1,207 @@
1
+ """Helper functions shared between score() and score_async()."""
2
+
3
+ import logging
4
+ from typing import Any, Dict, List, Mapping, Optional, Type, Union
5
+
6
+ from scorebook.exceptions import DataMismatchError, ParameterValidationError
7
+ from scorebook.metrics.metric_base import MetricBase
8
+ from scorebook.metrics.metric_registry import MetricRegistry
9
+ from scorebook.types import MetricScore
10
+ from scorebook.utils import is_awaitable
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ def validate_items(items: List[Dict[str, Any]], output_column: str, label_column: str) -> None:
16
+ """Validate the items parameter."""
17
+ if not isinstance(items, list):
18
+ raise ParameterValidationError("items must be a list")
19
+
20
+ if len(items) == 0:
21
+ raise ParameterValidationError("items list cannot be empty")
22
+
23
+ required = {output_column, label_column}
24
+ for idx, item in enumerate(items):
25
+ if not isinstance(item, Mapping):
26
+ raise ParameterValidationError(f"Item at index {idx} is not a dict")
27
+
28
+ missing = required - item.keys()
29
+ if missing:
30
+ for key in sorted(missing):
31
+ raise ParameterValidationError(f"Item at index {idx} missing required '{key}' key")
32
+
33
+
34
+ def resolve_metrics(
35
+ metrics: Union[
36
+ str, MetricBase, Type[MetricBase], List[Union[str, MetricBase, Type[MetricBase]]]
37
+ ]
38
+ ) -> List[MetricBase]:
39
+ """Resolve metrics parameter to list of MetricBase instances."""
40
+ # Ensure metrics is a list
41
+ if not isinstance(metrics, list):
42
+ metrics = [metrics]
43
+
44
+ # Resolve each metric
45
+ metric_instances = []
46
+ for metric in metrics:
47
+ if isinstance(metric, str) or (isinstance(metric, type) and issubclass(metric, MetricBase)):
48
+ # Use MetricRegistry to resolve string names or classes
49
+ metric_instance = MetricRegistry.get(metric)
50
+ metric_instances.append(metric_instance)
51
+ elif isinstance(metric, MetricBase):
52
+ # Already an instance
53
+ metric_instances.append(metric)
54
+ else:
55
+ raise ParameterValidationError(
56
+ f"Invalid metric type: {type(metric)}. "
57
+ "Metrics must be string names, MetricBase classes, or MetricBase instances"
58
+ )
59
+
60
+ return metric_instances
61
+
62
+
63
+ async def calculate_metric_scores_async(
64
+ metrics: List[MetricBase],
65
+ outputs: List[Any],
66
+ labels: List[Any],
67
+ dataset_name: Optional[str],
68
+ progress_bar: Optional[Any] = None,
69
+ ) -> List[MetricScore]:
70
+ """Calculate metric scores asynchronously (supports both sync and async metrics).
71
+
72
+ Args:
73
+ metrics: List of metric instances to compute scores for.
74
+ outputs: List of model outputs.
75
+ labels: List of ground truth labels.
76
+ dataset_name: Name of the dataset being scored.
77
+ progress_bar: Optional progress bar to update during computation.
78
+
79
+ Returns:
80
+ List of MetricScore objects containing aggregate and item-level scores.
81
+
82
+ Raises:
83
+ DataMismatchError: If outputs and labels have different lengths.
84
+ """
85
+ if len(outputs) != len(labels):
86
+ raise DataMismatchError(len(outputs), len(labels), dataset_name)
87
+
88
+ metric_scores: List[MetricScore] = []
89
+ for metric in metrics:
90
+
91
+ if progress_bar is not None:
92
+ progress_bar.set_current_metric(metric.name)
93
+
94
+ if is_awaitable(metric.score):
95
+ aggregate_scores, item_scores = await metric.score(outputs, labels)
96
+ else:
97
+ aggregate_scores, item_scores = metric.score(outputs, labels)
98
+
99
+ metric_scores.append(MetricScore(metric.name, aggregate_scores, item_scores))
100
+
101
+ if progress_bar is not None:
102
+ progress_bar.update(1)
103
+
104
+ return metric_scores
105
+
106
+
107
+ def calculate_metric_scores(
108
+ metrics: List[MetricBase],
109
+ outputs: List[Any],
110
+ labels: List[Any],
111
+ dataset_name: Optional[str],
112
+ progress_bar: Optional[Any] = None,
113
+ ) -> List[MetricScore]:
114
+ """Calculate metric scores synchronously (sync metrics only).
115
+
116
+ Args:
117
+ metrics: List of metric instances to compute scores for.
118
+ outputs: List of model outputs.
119
+ labels: List of ground truth labels.
120
+ dataset_name: Name of the dataset being scored.
121
+ progress_bar: Optional progress bar to update during computation.
122
+
123
+ Returns:
124
+ List of MetricScore objects containing aggregate and item-level scores.
125
+
126
+ Raises:
127
+ DataMismatchError: If outputs and labels have different lengths.
128
+ ParameterValidationError: If any metric has an async score method.
129
+ """
130
+ if len(outputs) != len(labels):
131
+ raise DataMismatchError(len(outputs), len(labels), dataset_name)
132
+
133
+ metric_scores: List[MetricScore] = []
134
+ for metric in metrics:
135
+
136
+ if progress_bar is not None:
137
+ progress_bar.set_current_metric(metric.name)
138
+
139
+ if is_awaitable(metric.score):
140
+ raise ParameterValidationError(
141
+ f"Metric '{metric.name}' has an async score() method. "
142
+ "Use score_async() instead of score() for async metrics."
143
+ )
144
+
145
+ aggregate_scores, item_scores = metric.score(outputs, labels)
146
+ metric_scores.append(MetricScore(metric.name, aggregate_scores, item_scores))
147
+
148
+ if progress_bar is not None:
149
+ progress_bar.update(1)
150
+
151
+ return metric_scores
152
+
153
+
154
+ def format_results(
155
+ inputs: Optional[List[Any]],
156
+ outputs: List[Any],
157
+ labels: List[Any],
158
+ metric_scores: List[MetricScore],
159
+ hyperparameters: Optional[Dict[str, Any]] = None,
160
+ dataset_name: Optional[str] = None,
161
+ ) -> Dict[str, List[Dict[str, Any]]]:
162
+ """Format results dict with both aggregates and items."""
163
+ # Use defaults if not provided
164
+ hyperparameters = hyperparameters or {}
165
+ dataset_name = dataset_name or "scored_items"
166
+
167
+ # Build aggregate results
168
+ aggregate_result = {
169
+ "dataset": dataset_name,
170
+ **hyperparameters,
171
+ }
172
+
173
+ # Add aggregate scores from metrics
174
+ for metric_score in metric_scores:
175
+ for key, value in metric_score.aggregate_scores.items():
176
+ score_key = (
177
+ key if key == metric_score.metric_name else f"{metric_score.metric_name}_{key}"
178
+ )
179
+ aggregate_result[score_key] = value
180
+
181
+ # Build item results
182
+ item_results = []
183
+ for idx in range(len(outputs)):
184
+ item_result: Dict[str, Any] = {
185
+ "id": idx,
186
+ "dataset": dataset_name,
187
+ "output": outputs[idx],
188
+ "label": labels[idx],
189
+ **hyperparameters,
190
+ }
191
+
192
+ # Add input if present
193
+ if inputs is not None and inputs[idx] is not None:
194
+ item_result["input"] = inputs[idx]
195
+
196
+ # Add item-level metric scores
197
+ for metric_score in metric_scores:
198
+ if idx < len(metric_score.item_scores):
199
+ item_result[metric_score.metric_name] = metric_score.item_scores[idx]
200
+
201
+ item_results.append(item_result)
202
+
203
+ # Always return both aggregates and items
204
+ return {
205
+ "aggregate_results": [aggregate_result],
206
+ "item_results": item_results,
207
+ }