scorebook 0.0.2__tar.gz → 0.0.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {scorebook-0.0.2 → scorebook-0.0.3}/PKG-INFO +1 -1
  2. {scorebook-0.0.2 → scorebook-0.0.3}/pyproject.toml +1 -1
  3. scorebook-0.0.3/src/scorebook/evaluator.py +379 -0
  4. scorebook-0.0.3/src/scorebook/exceptions.py +54 -0
  5. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/inference/openai.py +75 -37
  6. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/types/__init__.py +2 -1
  7. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/types/eval_dataset.py +6 -0
  8. scorebook-0.0.3/src/scorebook/types/eval_run_spec.py +28 -0
  9. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/types/inference_pipeline.py +2 -2
  10. scorebook-0.0.3/src/scorebook/utils/logging_utils.py +1 -0
  11. scorebook-0.0.3/src/scorebook/utils/progress_bars.py +146 -0
  12. scorebook-0.0.2/src/scorebook/evaluator.py +0 -271
  13. scorebook-0.0.2/src/scorebook/utils/progress_bars.py +0 -89
  14. {scorebook-0.0.2 → scorebook-0.0.3}/LICENSE +0 -0
  15. {scorebook-0.0.2 → scorebook-0.0.3}/README.md +0 -0
  16. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/__init__.py +0 -0
  17. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/inference/__init__.py +0 -0
  18. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/inference/bedrock.py +0 -0
  19. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/inference/portkey.py +0 -0
  20. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/inference/vertex.py +0 -0
  21. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/metrics/__init__.py +0 -0
  22. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/metrics/accuracy.py +0 -0
  23. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/metrics/metric_base.py +0 -0
  24. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/metrics/metric_registry.py +0 -0
  25. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/metrics/precision.py +0 -0
  26. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/types/eval_result.py +0 -0
  27. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/utils/__init__.py +0 -0
  28. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/utils/async_utils.py +0 -0
  29. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/utils/build_prompt.py +0 -0
  30. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/utils/io_helpers.py +0 -0
  31. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/utils/jinja_helpers.py +0 -0
  32. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/utils/mappers.py +0 -0
  33. {scorebook-0.0.2 → scorebook-0.0.3}/src/scorebook/utils/transform_helpers.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: scorebook
3
- Version: 0.0.2
3
+ Version: 0.0.3
4
4
  Summary: A Python project for LLM evaluation.
5
5
  Author: Euan Campbell
6
6
  Author-email: euan@trismik.com
@@ -14,7 +14,7 @@ dependencies = [
14
14
  ]
15
15
 
16
16
  [tool.poetry]
17
- version = "0.0.2" # base version
17
+ version = "0.0.3" # base version
18
18
  packages = [{ include = "scorebook", from = "src" }]
19
19
 
20
20
 
@@ -0,0 +1,379 @@
1
+ """
2
+ Model evaluation functionality for the Scorebook framework.
3
+
4
+ This module provides the core evaluation logic to assess model predictions
5
+ against ground truth labels using configurable metrics. It supports:
6
+
7
+ - Batch evaluation of models across multiple datasets
8
+ - Flexible metric computation and aggregation
9
+ - Optional parameter sweeping and experiment tracking
10
+ - Customizable inference functions
11
+
12
+ The main entry point is the `evaluate()` function which handles running
13
+ models on datasets and computing metric scores.
14
+ """
15
+
16
+ import asyncio
17
+ import logging
18
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
19
+
20
+ from scorebook.exceptions import (
21
+ DataMismatchError,
22
+ MetricComputationError,
23
+ ParallelExecutionError,
24
+ ParameterValidationError,
25
+ )
26
+ from scorebook.types import EvalDataset, EvalResult, EvalRunSpec
27
+ from scorebook.utils import evaluation_progress, expand_dict, is_awaitable
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def evaluate(
33
+ inference_callable: Callable,
34
+ eval_datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
35
+ hyperparameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
36
+ experiment_id: Optional[str] = None,
37
+ project_id: Optional[str] = None,
38
+ parallel: bool = False,
39
+ return_dict: bool = True,
40
+ return_aggregates: bool = True,
41
+ return_items: bool = False,
42
+ return_output: bool = False,
43
+ sample_size: Optional[int] = None,
44
+ ) -> Union[Dict, List]:
45
+ """
46
+ Evaluate model predictions using specified metrics on given datasets.
47
+
48
+ This function runs the provided inference callable on one or more evaluation datasets,
49
+ computes metric scores, and returns the evaluation results. It supports batch processing,
50
+ parameter sweeping, and different result formatting options.
51
+
52
+ Args:
53
+ inference_callable: A callable function or object that takes (items, hyperparameters)
54
+ and returns predictions. Can be a regular function, async function,
55
+ or callable instance (like a class with __call__ method).
56
+ eval_datasets: One or more evaluation datasets to run evaluation on. Can be:
57
+ - A single EvalDataset instance
58
+ - A list of EvalDataset instances
59
+ - A string identifier (for future dataset registry support)
60
+ - A list of string identifiers
61
+ hyperparameters: Optional dictionary containing hyperparameter sweep configuration.
62
+ experiment_id: Optional string identifier for tracking multiple evaluation runs.
63
+ return_dict: If True, returns eval results as a dict
64
+ return_aggregates: If True, returns aggregate scores for each dataset
65
+ return_items: If True, returns individual items for each dataset
66
+ return_output: If True, returns model outputs for each dataset item evaluated
67
+ sample_size: If set, only return a sample of the dataset items (for debugging)
68
+ parallel: If True, run inference functions in parallel (requires all functions to be async)
69
+
70
+ Returns:
71
+ Dictionary mapping dataset names to their evaluation results. For each dataset,
72
+ returns a dictionary containing:
73
+ - items: List of EvalResult objects with predictions and ground truth
74
+ - metrics: Dictionary mapping metric names to their computed scores
75
+
76
+ Example:
77
+
78
+ python
79
+ dataset = EvalDataset.from_huggingface("dataset_name", label="answer", metrics=[Precision])
80
+ def inference_fn(items):
81
+ # Model inference logic here - process all items at once
82
+ return [prediction for item in items]
83
+
84
+ results = evaluate(inference_fn, dataset, item_limit=100)
85
+ """
86
+
87
+ logger.info(
88
+ "Starting evaluation: experiment_id=%s, project_id=%s, parallel=%s",
89
+ experiment_id,
90
+ project_id,
91
+ parallel,
92
+ )
93
+
94
+ return asyncio.run(
95
+ _evaluate_async(
96
+ inference_callable=inference_callable,
97
+ eval_datasets=eval_datasets,
98
+ hyperparameters=hyperparameters,
99
+ experiment_id=experiment_id,
100
+ project_id=project_id,
101
+ parallel=parallel,
102
+ return_dict=return_dict,
103
+ return_aggregates=return_aggregates,
104
+ return_items=return_items,
105
+ return_output=return_output,
106
+ sample_size=sample_size,
107
+ )
108
+ )
109
+
110
+
111
+ async def _evaluate_async(
112
+ inference_callable: Callable,
113
+ eval_datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
114
+ hyperparameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
115
+ experiment_id: Optional[str] = None,
116
+ project_id: Optional[str] = None,
117
+ return_dict: bool = True,
118
+ return_aggregates: bool = True,
119
+ return_items: bool = False,
120
+ return_output: bool = False,
121
+ parallel: bool = False,
122
+ sample_size: Optional[int] = None,
123
+ ) -> Union[Dict, List]:
124
+ _validate_parameters(locals())
125
+ datasets, adaptive_datasets = _prepare_datasets(eval_datasets, sample_size)
126
+ hyperparameters = _prepare_hyperparameters(hyperparameters)
127
+
128
+ logger.info(
129
+ "Prepared %d datasets and %d hyperparameter configurations",
130
+ len(datasets),
131
+ len(hyperparameters),
132
+ )
133
+
134
+ runs = _build_runs(datasets, hyperparameters)
135
+ runs.sort(key=lambda run: (run.dataset_idx, run.hp_idx))
136
+
137
+ logger.info("Created %d evaluation runs", len(runs))
138
+
139
+ with evaluation_progress(datasets, len(hyperparameters), parallel, len(runs)) as progress_bars:
140
+ if parallel:
141
+ eval_results = await _run_parallel(inference_callable, runs, progress_bars)
142
+ else:
143
+ eval_results = await _run_sequential(inference_callable, runs, progress_bars)
144
+
145
+ logger.info("Evaluation completed successfully")
146
+
147
+ return _format_results(
148
+ eval_results, return_dict, return_aggregates, return_items, return_output
149
+ )
150
+
151
+
152
+ # ===== ORCHESTRATION PATHS =====
153
+
154
+
155
+ async def _run_parallel(
156
+ inference_callable: Callable,
157
+ runs: List[EvalRunSpec],
158
+ progress_bars: Any,
159
+ ) -> List[EvalResult]:
160
+ logger.debug("Running inference in parallel")
161
+
162
+ async def worker(run: EvalRunSpec) -> Tuple[EvalRunSpec, EvalResult]:
163
+ er = await _execute_run(inference_callable, run)
164
+ progress_bars.on_eval_run_completed(run.dataset_idx)
165
+ return run, er
166
+
167
+ pairs = await asyncio.gather(*[worker(r) for r in runs])
168
+ # Return in canonical (dataset_idx, hp_idx) order for stability
169
+ pairs.sort(key=lambda p: (p[0].dataset_idx, p[0].hp_idx))
170
+ return [er for _, er in pairs]
171
+
172
+
173
+ async def _run_sequential(
174
+ inference_callable: Callable,
175
+ runs: List[EvalRunSpec],
176
+ progress_bars: Any,
177
+ ) -> List[EvalResult]:
178
+ logger.debug("Running inference sequentially")
179
+ results: List[EvalResult] = []
180
+ for run in runs:
181
+ er = await _execute_run(inference_callable, run)
182
+ results.append(er)
183
+ progress_bars.on_hyperparam_completed(run.dataset_idx)
184
+ return results
185
+
186
+
187
+ # ===== EVALUATION EXECUTIONS =====
188
+
189
+
190
+ async def _execute_run(inference_callable: Callable, run: EvalRunSpec) -> EvalResult:
191
+ logger.debug("Executing run for %s", run)
192
+
193
+ outputs = await _run_inference_callable(inference_callable, run.items, run.hyperparams)
194
+ logger.debug("Inference completed for run %s", run)
195
+
196
+ metric_scores = _score_metrics(run.eval_dataset, outputs, run.labels)
197
+ logger.debug("Metrics computed for run %s. - scores: %s", run, list(metric_scores.keys()))
198
+
199
+ return EvalResult(run.eval_dataset, outputs, metric_scores, run.hyperparams)
200
+
201
+
202
+ # ===== HELPER FUNCTIONS =====
203
+
204
+
205
+ def _validate_parameters(params: Dict[str, Any]) -> None:
206
+ """Validate all parameters for evaluation."""
207
+
208
+ if params["return_dict"] and not params["return_aggregates"] and not params["return_items"]:
209
+ raise ParameterValidationError(
210
+ "When return_dict=True, at least one of return_aggregates or return_items must be True"
211
+ )
212
+
213
+ if params["parallel"] and not is_awaitable(params["inference_callable"]):
214
+ raise ParallelExecutionError(
215
+ "parallel=True requires the inference_callable to be async. "
216
+ "Please make your inference function async or set parallel=False."
217
+ )
218
+
219
+
220
+ def _prepare_datasets(
221
+ datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
222
+ sample_size: Optional[int] = None,
223
+ ) -> Tuple[List[EvalDataset], List[str]]:
224
+ """Prepare and separate input datasets into classic and adaptive evaluation datasets."""
225
+
226
+ # Ensure datasets is always a list for consistent processing
227
+ if not isinstance(datasets, list):
228
+ datasets = [datasets]
229
+
230
+ # Extract classical datasets TODO: handle other types (string registry)
231
+ classic_eval_datasets = [dataset for dataset in datasets if isinstance(dataset, EvalDataset)]
232
+
233
+ # Reduce datasets to a random sample
234
+ if sample_size:
235
+ logger.info("Sampling datasets to %d items each", sample_size)
236
+ for dataset in classic_eval_datasets:
237
+ dataset.shuffle()
238
+ if len(dataset) > sample_size:
239
+ original_size = len(dataset)
240
+ dataset._hf_dataset = dataset._hf_dataset.select(range(sample_size))
241
+ logger.debug(
242
+ "Sampled dataset '%s' from %d to %d items",
243
+ dataset.name,
244
+ original_size,
245
+ sample_size,
246
+ )
247
+
248
+ # Extract adaptive dataset strings
249
+ adaptive_eval_datasets = [
250
+ dataset.replace(":adaptive", "")
251
+ for dataset in datasets
252
+ if isinstance(dataset, str) and dataset.endswith(":adaptive")
253
+ ]
254
+
255
+ logger.info("Evaluating on classic datasets: %s", [ds.name for ds in classic_eval_datasets])
256
+ logger.info("Evaluating on adaptive datasets: %s", adaptive_eval_datasets)
257
+
258
+ return classic_eval_datasets, adaptive_eval_datasets
259
+
260
+
261
+ def _prepare_hyperparameters(
262
+ hyperparameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]
263
+ ) -> List[Dict[str, Any]]:
264
+ """Prepare hyperparameters for evaluation by returning a list of hyper-param configs."""
265
+ if hyperparameters is None:
266
+ return [{}]
267
+ if not isinstance(hyperparameters, list):
268
+ expanded: List[Dict[str, Any]] = expand_dict(hyperparameters or {})
269
+ return expanded
270
+
271
+ logger.info("Evaluating with hyperparameters: %s", hyperparameters)
272
+
273
+ return hyperparameters
274
+
275
+
276
+ async def _run_inference_callable(
277
+ inference_callable: Callable,
278
+ items: List[Dict[str, Any]],
279
+ hyperparams: Dict[str, Any],
280
+ ) -> Any:
281
+ if is_awaitable(inference_callable):
282
+ return await inference_callable(items, **hyperparams)
283
+ else:
284
+ return inference_callable(items, **hyperparams)
285
+
286
+
287
+ def _build_runs(
288
+ datasets: List[EvalDataset],
289
+ hyperparameters: List[Dict[str, Any]],
290
+ ) -> List[EvalRunSpec]:
291
+ """Build RunSpec objects for each dataset/hyperparameter combination."""
292
+ runs: List[EvalRunSpec] = []
293
+ for d_idx, ds in enumerate(datasets):
294
+ items = ds.items
295
+ labels = [item.get(ds.label) for item in items]
296
+ for hp_idx, hp in enumerate(hyperparameters):
297
+ run_spec = EvalRunSpec(d_idx, ds, items, labels, hp, hp_idx)
298
+ logger.debug("Built RunSpec: %s", run_spec)
299
+ runs.append(run_spec)
300
+ return runs
301
+
302
+
303
+ def _score_metrics(
304
+ eval_dataset: EvalDataset, outputs: List[Any], labels: List[Any]
305
+ ) -> Dict[str, Dict[str, Any]]:
306
+ """Compute metric scores for a given dataset and inference outputs."""
307
+ metric_scores: Dict[str, Dict[str, Any]] = {}
308
+
309
+ if len(outputs) != len(labels):
310
+ raise DataMismatchError(len(outputs), len(labels), eval_dataset.name)
311
+
312
+ for metric in eval_dataset.metrics:
313
+ try:
314
+ aggregate_scores, item_scores = metric.score(outputs, labels)
315
+ metric_scores[metric.name] = {
316
+ "aggregate_scores": aggregate_scores,
317
+ "item_scores": item_scores,
318
+ }
319
+ except Exception as e:
320
+ logger.error(
321
+ "Failed to compute metric '%s' for dataset '%s': %s",
322
+ metric.name,
323
+ eval_dataset.name,
324
+ str(e),
325
+ )
326
+ raise MetricComputationError(metric.name, eval_dataset.name, e)
327
+
328
+ return metric_scores
329
+
330
+
331
+ def _format_results(
332
+ eval_results: List[EvalResult],
333
+ return_dict: bool,
334
+ return_aggregates: bool,
335
+ return_items: bool,
336
+ return_output: bool,
337
+ ) -> Union[Dict, List]:
338
+
339
+ # Return results as a dict
340
+ if return_dict:
341
+
342
+ # Include both aggregate and item scores in dict returned
343
+ if return_aggregates and return_items:
344
+ results: Dict[str, List[Dict[str, Any]]] = {"aggregate_results": [], "item_results": []}
345
+ for eval_result in eval_results:
346
+ eval_result_dict = eval_result.to_dict()
347
+ results["aggregate_results"].extend(eval_result_dict["aggregate_results"])
348
+ if return_output:
349
+ results["item_results"].extend(eval_result_dict["item_results"])
350
+ else:
351
+ results["item_results"].extend(
352
+ [
353
+ {k: v for k, v in item.items() if k != "inference_output"}
354
+ for item in eval_result_dict["item_results"]
355
+ ]
356
+ )
357
+ return results
358
+
359
+ # Include only aggregate scores in dict returned
360
+ elif return_aggregates:
361
+ return [eval_result.aggregate_scores for eval_result in eval_results]
362
+
363
+ # Include only item scores in dict returned
364
+ else:
365
+ if return_output:
366
+ return [item for eval_result in eval_results for item in eval_result.item_scores]
367
+ else:
368
+ return [
369
+ {k: v for k, v in item.items() if k != "inference_output"}
370
+ for eval_result in eval_results
371
+ for item in eval_result.item_scores
372
+ ]
373
+
374
+ # Return results as an EvalResult object
375
+ else:
376
+ out: Dict[str, List[EvalResult]] = {}
377
+ for er in eval_results:
378
+ out.setdefault(er.eval_dataset.name, []).append(er)
379
+ return out
@@ -0,0 +1,54 @@
1
+ """
2
+ Custom exceptions for the Scorebook framework.
3
+
4
+ This module defines specific exception types used throughout the Scorebook
5
+ evaluation framework to provide clear error handling and debugging information.
6
+ """
7
+
8
+
9
+ class ScoreBookError(Exception):
10
+ """Base exception class for all Scorebook-related errors."""
11
+
12
+
13
+ class EvaluationError(ScoreBookError):
14
+ """Raised when there are errors during model evaluation."""
15
+
16
+
17
+ class ParameterValidationError(ScoreBookError):
18
+ """Raised when invalid parameters are provided to evaluation functions."""
19
+
20
+
21
+ class InferenceError(EvaluationError):
22
+ """Raised when there are errors during model inference."""
23
+
24
+
25
+ class MetricComputationError(EvaluationError):
26
+ """Raised when metric computation fails."""
27
+
28
+ def __init__(self, metric_name: str, dataset_name: str, original_error: Exception):
29
+ """Initialize metric computation error."""
30
+ self.metric_name = metric_name
31
+ self.dataset_name = dataset_name
32
+ self.original_error = original_error
33
+ super().__init__(
34
+ f"Failed to compute metric '{metric_name}' for dataset "
35
+ f"'{dataset_name}': {original_error}"
36
+ )
37
+
38
+
39
+ class DataMismatchError(EvaluationError):
40
+ """Raised when there's a mismatch between outputs and expected labels."""
41
+
42
+ def __init__(self, outputs_count: int, labels_count: int, dataset_name: str):
43
+ """Initialize data mismatch error."""
44
+ self.outputs_count = outputs_count
45
+ self.labels_count = labels_count
46
+ self.dataset_name = dataset_name
47
+ super().__init__(
48
+ f"Output count ({outputs_count}) doesn't match label count ({labels_count}) "
49
+ f"for dataset '{dataset_name}'"
50
+ )
51
+
52
+
53
+ class ParallelExecutionError(ScoreBookError):
54
+ """Raised when parallel execution requirements are not met."""
@@ -8,17 +8,19 @@ API communication, request formatting, and response processing.
8
8
 
9
9
  import asyncio
10
10
  import json
11
+ import logging
11
12
  import tempfile
12
13
  from typing import Any, List
13
14
 
14
- from openai import OpenAI
15
- from tqdm.asyncio import tqdm
15
+ from openai import AsyncOpenAI
16
+
17
+ logger = logging.getLogger(__name__)
16
18
 
17
19
 
18
20
  async def responses(
19
21
  items: List[Any], model: str = "gpt-4.1-nano", client: Any = None, **hyperparameters: Any
20
22
  ) -> List[Any]:
21
- """Process multiple inference requests using OpenAI's API.
23
+ """Process multiple inference requests using OpenAI's Async API.
22
24
 
23
25
  This asynchronous function handles multiple inference requests,
24
26
  manages the API communication, and processes the responses.
@@ -35,13 +37,67 @@ async def responses(
35
37
  Raises:
36
38
  NotImplementedError: Currently not implemented.
37
39
  """
38
- if client is None:
39
- client = OpenAI()
40
+ logger.debug("OpenAI responses function called with %d items", len(items))
41
+ logger.debug("Using model: %s", model)
42
+ logger.debug("Hyperparameters: %s", hyperparameters)
40
43
 
41
- results = []
42
- for item in items:
43
- response = client.responses.create(model=model, input=item)
44
- results.append(response)
44
+ if client is None:
45
+ logger.debug("Creating new AsyncOpenAI client")
46
+ client = AsyncOpenAI()
47
+
48
+ # Create all tasks concurrently for true parallelism
49
+ tasks = []
50
+ for i, item in enumerate(items):
51
+ logger.debug(
52
+ "Processing item %d: %s",
53
+ i,
54
+ str(item)[:100] + "..." if len(str(item)) > 100 else str(item),
55
+ )
56
+
57
+ # Handle string input from preprocessor - convert to proper messages format
58
+ if isinstance(item, str):
59
+ # Convert the string format to proper OpenAI messages array
60
+ messages = [{"role": "user", "content": item}]
61
+ logger.debug(
62
+ "Converted string to messages format: %s",
63
+ (
64
+ messages[0]["content"][:100] + "..."
65
+ if len(messages[0]["content"]) > 100
66
+ else messages[0]["content"]
67
+ ),
68
+ )
69
+ elif isinstance(item, list):
70
+ # Already in proper messages format
71
+ messages = item
72
+ logger.debug("Item %d already in messages format", i)
73
+ else:
74
+ # Fallback: treat as user message
75
+ messages = [{"role": "user", "content": str(item)}]
76
+ logger.debug("Item %d converted to fallback format", i)
77
+
78
+ logger.debug("Creating OpenAI task %d with messages: %s", i, messages)
79
+ task = client.chat.completions.create(model=model, messages=messages, **hyperparameters)
80
+ tasks.append(task)
81
+
82
+ logger.debug("Created %d tasks, waiting for OpenAI responses...", len(tasks))
83
+ # Wait for all requests to complete in parallel
84
+ results = await asyncio.gather(*tasks)
85
+ logger.debug("Received %d responses from OpenAI", len(results))
86
+
87
+ for i, result in enumerate(results):
88
+ logger.debug("Response %d type: %s", i, type(result))
89
+ try:
90
+ if hasattr(result, "choices") and result.choices:
91
+ content = result.choices[0].message.content
92
+ logger.debug(
93
+ "Response %d content: %s",
94
+ i,
95
+ content[:100] + "..." if content and len(content) > 100 else content,
96
+ )
97
+ else:
98
+ logger.debug("Response %d has no choices or unexpected format", i)
99
+ except Exception as e:
100
+ logger.error("Error logging response %d: %s", i, e)
45
101
 
46
102
  return results
47
103
 
@@ -70,40 +126,23 @@ async def batch(
70
126
  NotImplementedError: Currently not implemented.
71
127
  """
72
128
  if client is None:
73
- client = OpenAI()
129
+ client = AsyncOpenAI()
74
130
 
75
- file_id = _upload_batch(items, client)
76
- batch_id = _start_batch(file_id, client)
77
-
78
- # Initialize progress bar
79
- pbar = tqdm(total=len(items), desc="Batch processing", unit="requests")
131
+ file_id = await _upload_batch(items, client)
132
+ batch_id = await _start_batch(file_id, client)
80
133
 
81
134
  awaiting_batch = True
82
135
  while awaiting_batch:
83
136
  batch_object = await _get_batch(batch_id, client)
84
137
  batch_status = batch_object.status
85
138
 
86
- if hasattr(batch_object, "request_counts") and batch_object.request_counts:
87
- completed = batch_object.request_counts.completed
88
- total = batch_object.request_counts.total
89
- pbar.n = completed
90
- pbar.set_postfix(status=batch_status, completed=f"{completed}/{total}")
91
- else:
92
- pbar.set_postfix(status=batch_status)
93
-
94
- pbar.refresh()
95
-
96
139
  if batch_status == "completed":
97
140
  awaiting_batch = False
98
- pbar.n = pbar.total
99
- pbar.set_postfix(status="completed")
100
141
  elif batch_status == "failed":
101
142
  raise Exception("Batch processing failed")
102
143
  else:
103
144
  await asyncio.sleep(60)
104
145
 
105
- pbar.close()
106
-
107
146
  # Get the final batch object to access output_file_id
108
147
  final_batch_object = await _get_batch(batch_id, client)
109
148
  output_file_id = final_batch_object.output_file_id
@@ -112,7 +151,7 @@ async def batch(
112
151
  return batch_result
113
152
 
114
153
 
115
- def _upload_batch(items: List[Any], client: Any) -> str:
154
+ async def _upload_batch(items: List[Any], client: Any) -> str:
116
155
  """Create a .jsonl file from preprocessed items and upload to OpenAI for batch processing.
117
156
 
118
157
  Args:
@@ -121,10 +160,9 @@ def _upload_batch(items: List[Any], client: Any) -> str:
121
160
  Returns:
122
161
  The file ID returned by OpenAI after uploading.
123
162
  """
124
- print("Uploading batch...")
125
163
  # Instantiate OpenAI client
126
164
  if client is None:
127
- client = OpenAI()
165
+ client = AsyncOpenAI()
128
166
 
129
167
  # Create temp .jsonl file
130
168
  with tempfile.NamedTemporaryFile(mode="w+", suffix=".jsonl", delete=False) as f:
@@ -141,13 +179,13 @@ def _upload_batch(items: List[Any], client: Any) -> str:
141
179
 
142
180
  # Upload file to OpenAI
143
181
  with open(file_path, "rb") as upload_file:
144
- response = client.files.create(file=upload_file, purpose="batch")
182
+ response = await client.files.create(file=upload_file, purpose="batch")
145
183
 
146
184
  return str(response.id)
147
185
 
148
186
 
149
- def _start_batch(file_id: str, client: Any) -> str:
150
- batch_response = client.batches.create(
187
+ async def _start_batch(file_id: str, client: Any) -> str:
188
+ batch_response = await client.batches.create(
151
189
  input_file_id=file_id,
152
190
  endpoint="/v1/chat/completions",
153
191
  completion_window="24h",
@@ -156,13 +194,13 @@ def _start_batch(file_id: str, client: Any) -> str:
156
194
 
157
195
 
158
196
  async def _get_batch(batch_id: str, client: Any) -> Any:
159
- batch_object = client.batches.retrieve(batch_id)
197
+ batch_object = await client.batches.retrieve(batch_id)
160
198
  return batch_object
161
199
 
162
200
 
163
201
  async def _get_results_file(output_file_id: str, client: Any) -> List[str]:
164
202
  """Download and parse the batch results file from OpenAI."""
165
- response = client.files.content(output_file_id)
203
+ response = await client.files.content(output_file_id)
166
204
 
167
205
  # Parse the JSONL content
168
206
  content = response.content.decode("utf-8")
@@ -7,5 +7,6 @@ and evaluation results.
7
7
 
8
8
  from scorebook.types.eval_dataset import EvalDataset
9
9
  from scorebook.types.eval_result import EvalResult
10
+ from scorebook.types.eval_run_spec import EvalRunSpec
10
11
 
11
- __all__ = ["EvalDataset", "EvalResult"]
12
+ __all__ = ["EvalDataset", "EvalResult", "EvalRunSpec"]
@@ -86,6 +86,12 @@ class EvalDataset:
86
86
  raise ValueError("Dataset is not initialized")
87
87
  return iter(self._hf_dataset)
88
88
 
89
+ def shuffle(self) -> None:
90
+ """Randomly shuffle the dataset items."""
91
+ if self._hf_dataset is None:
92
+ raise ValueError("Dataset is not initialized")
93
+ self._hf_dataset.shuffle()
94
+
89
95
  @property
90
96
  def items(self) -> List[Any]:
91
97
  """Return a list of all examples in the dataset."""
@@ -0,0 +1,28 @@
1
+ """Evaluation run specification types for Scorebook."""
2
+
3
+ from typing import Any, Dict, List, NamedTuple
4
+
5
+ from scorebook.types import EvalDataset
6
+
7
+
8
+ class EvalRunSpec(NamedTuple):
9
+ """Represents a single evaluation run configuration."""
10
+
11
+ dataset_idx: int
12
+ eval_dataset: EvalDataset
13
+ items: List[Dict[str, Any]]
14
+ labels: List[Any]
15
+ hyperparams: Dict[str, Any]
16
+ hp_idx: int
17
+
18
+ def __str__(self) -> str:
19
+ """Return a formatted string summary of the evaluation run specification."""
20
+ hyperparams_str = ", ".join([f"{k}={v}" for k, v in self.hyperparams.items()])
21
+
22
+ return (
23
+ f"EvalRunSpec(dataset_idx={self.dataset_idx},"
24
+ f" hp_idx={self.hp_idx},"
25
+ f" dataset_name='{self.eval_dataset.name}',"
26
+ f" hyperparams=[{hyperparams_str}]"
27
+ f")"
28
+ )
@@ -57,7 +57,7 @@ class InferencePipeline:
57
57
  List of processed outputs after running through the complete pipeline
58
58
  """
59
59
  if self.preprocessor:
60
- input_items = [self.preprocessor(item, hyperparameters) for item in items]
60
+ input_items = [self.preprocessor(item, **hyperparameters) for item in items]
61
61
  else:
62
62
  input_items = items
63
63
 
@@ -68,7 +68,7 @@ class InferencePipeline:
68
68
 
69
69
  if self.postprocessor:
70
70
  return [
71
- self.postprocessor(inference_output, hyperparameters)
71
+ self.postprocessor(inference_output, **hyperparameters)
72
72
  for inference_output in inference_outputs
73
73
  ]
74
74
  else:
@@ -0,0 +1 @@
1
+ """Logging utilities for Scorebook evaluation framework."""
@@ -0,0 +1,146 @@
1
+ """Progress bar utilities for evaluation tracking."""
2
+
3
+ from contextlib import contextmanager
4
+ from typing import Any, Generator, List, Optional
5
+
6
+ from tqdm import tqdm
7
+
8
+
9
+ class EvaluationProgressBars:
10
+ """Manages nested progress bars for evaluation tracking."""
11
+
12
+ def __init__(
13
+ self, datasets: List[Any], hyperparam_count: int, parallel: bool, total_eval_runs: int
14
+ ) -> None:
15
+ """Initialize progress bar manager.
16
+
17
+ Args:
18
+ datasets: List of datasets being evaluated
19
+ hyperparam_count: Number of hyperparameter configurations per dataset
20
+ parallel: Whether running in parallel mode
21
+ total_eval_runs: Total number of EvalRunSpecs (dataset_count * hyperparam_count)
22
+ """
23
+ self.datasets = datasets
24
+ self.hyperparam_count = hyperparam_count
25
+ self.parallel = parallel
26
+ self.total_eval_runs = total_eval_runs
27
+
28
+ self.dataset_pbar: Optional[tqdm] = None
29
+ self.hyperparam_pbar: Optional[tqdm] = None
30
+
31
+ # Track progress per dataset
32
+ self.current_dataset_idx = 0
33
+ self.completed_hyperparams_per_dataset: dict[int, int] = {}
34
+ self.completed_eval_runs = 0
35
+
36
+ def start_progress_bars(self) -> None:
37
+ """Start both progress bars."""
38
+ # Top level: Datasets
39
+ self.dataset_pbar = tqdm(
40
+ total=len(self.datasets),
41
+ desc="Datasets ",
42
+ unit="dataset",
43
+ position=0,
44
+ leave=True,
45
+ ncols=80,
46
+ bar_format="{desc} {percentage:3.0f}%|{bar:40}| {n_fmt}/{total_fmt}",
47
+ )
48
+
49
+ # Bottom level: Hyperparameters/Eval runs
50
+ if self.parallel:
51
+ # In parallel mode: show eval runs completed out of total
52
+ self.hyperparam_pbar = tqdm(
53
+ total=self.total_eval_runs,
54
+ desc="Eval Runs ",
55
+ unit="run",
56
+ position=1,
57
+ leave=False,
58
+ ncols=80,
59
+ bar_format="{desc} {percentage:3.0f}%|{bar:40}| {n_fmt}/{total_fmt}",
60
+ )
61
+ else:
62
+ # In sequential mode: show hyperparams per dataset
63
+ self.hyperparam_pbar = tqdm(
64
+ total=self.hyperparam_count,
65
+ desc="Hyperparams",
66
+ unit="config",
67
+ position=1,
68
+ leave=False,
69
+ ncols=80,
70
+ bar_format="{desc} {percentage:3.0f}%|{bar:40}| {n_fmt}/{total_fmt}",
71
+ )
72
+
73
+ def on_eval_run_completed(self, dataset_idx: int) -> None:
74
+ """Update progress when an eval run (EvalRunSpec) completes in parallel mode."""
75
+ if not self.parallel:
76
+ return
77
+
78
+ self.completed_eval_runs += 1
79
+ if self.hyperparam_pbar:
80
+ self.hyperparam_pbar.update(1)
81
+
82
+ # Track how many runs completed for this dataset
83
+ self.completed_hyperparams_per_dataset[dataset_idx] = (
84
+ self.completed_hyperparams_per_dataset.get(dataset_idx, 0) + 1
85
+ )
86
+
87
+ # Check if this dataset is complete
88
+ if self.completed_hyperparams_per_dataset[dataset_idx] == self.hyperparam_count:
89
+ if self.dataset_pbar:
90
+ self.dataset_pbar.update(1)
91
+
92
+ def on_hyperparam_completed(self, dataset_idx: int) -> None:
93
+ """Update progress when a hyperparameter config completes in sequential mode."""
94
+ if self.parallel:
95
+ return
96
+
97
+ if self.hyperparam_pbar:
98
+ self.hyperparam_pbar.update(1)
99
+
100
+ # Track completed hyperparams for this dataset
101
+ self.completed_hyperparams_per_dataset[dataset_idx] = (
102
+ self.completed_hyperparams_per_dataset.get(dataset_idx, 0) + 1
103
+ )
104
+
105
+ # Check if this dataset is complete
106
+ if self.completed_hyperparams_per_dataset[dataset_idx] == self.hyperparam_count:
107
+ # Update dataset progress
108
+ if self.dataset_pbar:
109
+ self.dataset_pbar.update(1)
110
+
111
+ # Reset hyperparameter progress for next dataset (if any)
112
+ if dataset_idx < len(self.datasets) - 1:
113
+ if self.hyperparam_pbar:
114
+ self.hyperparam_pbar.reset()
115
+
116
+ def close_progress_bars(self) -> None:
117
+ """Close both progress bars."""
118
+ if self.hyperparam_pbar:
119
+ self.hyperparam_pbar.close()
120
+ self.hyperparam_pbar = None
121
+ if self.dataset_pbar:
122
+ self.dataset_pbar.close()
123
+ self.dataset_pbar = None
124
+
125
+
126
+ @contextmanager
127
+ def evaluation_progress(
128
+ datasets: List[Any], hyperparam_count: int, parallel: bool, total_eval_runs: int
129
+ ) -> Generator[EvaluationProgressBars, None, None]:
130
+ """Context manager for evaluation progress bars.
131
+
132
+ Args:
133
+ datasets: List of datasets being evaluated
134
+ hyperparam_count: Number of hyperparameter configurations per dataset
135
+ parallel: Whether running in parallel mode
136
+ total_eval_runs: Total number of EvalRunSpecs
137
+
138
+ Yields:
139
+ EvaluationProgressBars: Progress bar manager instance
140
+ """
141
+ progress_bars = EvaluationProgressBars(datasets, hyperparam_count, parallel, total_eval_runs)
142
+ progress_bars.start_progress_bars()
143
+ try:
144
+ yield progress_bars
145
+ finally:
146
+ progress_bars.close_progress_bars()
@@ -1,271 +0,0 @@
1
- """
2
- Model evaluation functionality for the Scorebook framework.
3
-
4
- This module provides the core evaluation logic to assess model predictions
5
- against ground truth labels using configurable metrics. It supports:
6
-
7
- - Batch evaluation of models across multiple datasets
8
- - Flexible metric computation and aggregation
9
- - Optional parameter sweeping and experiment tracking
10
- - Customizable inference functions
11
-
12
- The main entry point is the `evaluate()` function which handles running
13
- models on datasets and computing metric scores.
14
- """
15
-
16
- import asyncio
17
- from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
18
-
19
- from scorebook.types.eval_dataset import EvalDataset
20
- from scorebook.types.eval_result import EvalResult
21
- from scorebook.utils import evaluation_progress, expand_dict, is_awaitable
22
-
23
-
24
- async def _evaluate_async(
25
- inference_callable: Callable,
26
- eval_datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
27
- hyperparameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
28
- experiment_id: Optional[str] = None,
29
- return_dict: bool = True,
30
- return_aggregates: bool = True,
31
- return_items: bool = False,
32
- return_output: bool = False,
33
- sample_size: Optional[int] = None,
34
- ) -> Union[Dict, List]:
35
- """Run inference across datasets/hyperparams, compute metrics, and format results."""
36
-
37
- # Validate parameters
38
- if return_dict and not return_aggregates and not return_items:
39
- raise ValueError(
40
- "When return_dict=True, at least one of return_aggregates or return_items must be True"
41
- )
42
-
43
- normalized_datasets = _normalize_datasets(eval_datasets)
44
-
45
- if hyperparameters is None:
46
- hyperparam_grid: List[Dict[str, Any]] = [{}]
47
- elif not isinstance(hyperparameters, list):
48
- hyperparam_grid = _expand_hyperparams(hyperparameters)
49
- else:
50
- hyperparam_grid = hyperparameters
51
-
52
- eval_results: List[EvalResult] = []
53
-
54
- with evaluation_progress(normalized_datasets, len(hyperparam_grid)) as progress_bars:
55
- # Loop through datasets, then hyperparameters for clear progress tracking
56
- for dataset_idx, eval_dataset in enumerate(normalized_datasets):
57
- with progress_bars.hyperparam_progress_context():
58
- # Run inference for each hyperparameter configuration on this dataset
59
- for hp_idx, hyperparam_config in enumerate(hyperparam_grid):
60
-
61
- if sample_size:
62
- items = _get_items_sample(eval_dataset.items, sample_size)
63
- else:
64
- items = eval_dataset.items
65
-
66
- labels = _get_labels_for_items(items, eval_dataset.label)
67
-
68
- # 1) Run inference
69
- outputs = await _run_inference_callable(
70
- inference_callable, items, hyperparam_config
71
- )
72
-
73
- # 2) Score metrics
74
- metric_scores = _score_metrics(eval_dataset, outputs, labels)
75
-
76
- # 3) Wrap into EvalResult
77
- eval_results.append(
78
- EvalResult(eval_dataset, outputs, metric_scores, hyperparam_config)
79
- )
80
-
81
- # Update inner progress bar
82
- progress_bars.update_hyperparam_progress()
83
-
84
- # Update the outer progress bar
85
- progress_bars.update_dataset_progress()
86
-
87
- # TODO: experiment_id handling (left as passthrough to preserve behavior)
88
- if experiment_id:
89
- pass
90
-
91
- # 4) Format as requested
92
- return _format_results(
93
- eval_results, return_dict, return_aggregates, return_items, return_output
94
- )
95
-
96
-
97
- def evaluate(
98
- inference_callable: Callable,
99
- eval_datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
100
- hyperparameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
101
- experiment_id: Optional[str] = None,
102
- return_dict: bool = True,
103
- return_aggregates: bool = True,
104
- return_items: bool = False,
105
- return_output: bool = False,
106
- sample_size: Optional[int] = None,
107
- ) -> Union[Dict, List]:
108
- """
109
- Evaluate model predictions using specified metrics on given datasets.
110
-
111
- This function runs the provided inference callable on one or more evaluation datasets,
112
- computes metric scores, and returns the evaluation results. It supports batch processing,
113
- parameter sweeping, and different result formatting options.
114
-
115
- Args:
116
- inference_callable: A callable function or object that takes (items, hyperparameters)
117
- and returns predictions. Can be a regular function, async function,
118
- or callable instance (like a class with __call__ method).
119
- eval_datasets: One or more evaluation datasets to run evaluation on. Can be:
120
- - A single EvalDataset instance
121
- - A list of EvalDataset instances
122
- - A string identifier (for future dataset registry support)
123
- - A list of string identifiers
124
- hyperparameters: Optional dictionary containing hyperparameter sweep configuration.
125
- experiment_id: Optional string identifier for tracking multiple evaluation runs.
126
- return_dict: If True, returns eval results as a dict
127
- return_aggregates: If True, returns aggregate scores for each dataset
128
- return_items: If True, returns individual items for each dataset
129
- return_output: If True, returns model outputs for each dataset item evaluated
130
- sample_size: If set, only return a sample of the dataset items (for debugging)
131
-
132
- Returns:
133
- Dictionary mapping dataset names to their evaluation results. For each dataset,
134
- returns a dictionary containing:
135
- - items: List of EvalResult objects with predictions and ground truth
136
- - metrics: Dictionary mapping metric names to their computed scores
137
-
138
- Example:
139
-
140
- python
141
- dataset = EvalDataset.from_huggingface("dataset_name", label="answer", metrics=[Precision])
142
- def inference_fn(items):
143
- # Model inference logic here - process all items at once
144
- return [prediction for item in items]
145
-
146
- results = evaluate(inference_fn, dataset, item_limit=100)
147
- """
148
- return asyncio.run(
149
- _evaluate_async(
150
- inference_callable=inference_callable,
151
- eval_datasets=eval_datasets,
152
- hyperparameters=hyperparameters,
153
- experiment_id=experiment_id,
154
- return_dict=return_dict,
155
- return_aggregates=return_aggregates,
156
- return_items=return_items,
157
- return_output=return_output,
158
- sample_size=sample_size,
159
- )
160
- )
161
-
162
-
163
- # ===== Helper Functions =====
164
-
165
-
166
- def _normalize_datasets(
167
- datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]]
168
- ) -> List[EvalDataset]:
169
- if not isinstance(datasets, list):
170
- datasets = [datasets]
171
- # TODO: handle other types (string registry, etc.)
172
- return [d for d in datasets if isinstance(d, EvalDataset)]
173
-
174
-
175
- def _expand_hyperparams(hyperparameters: Optional[Dict[str, Any]]) -> Any:
176
- return expand_dict(hyperparameters or {})
177
-
178
-
179
- def _get_items_sample(
180
- items: List[Dict[str, Any]], item_limit: Optional[int]
181
- ) -> List[Dict[str, Any]]:
182
- return items[:item_limit] if item_limit else items
183
-
184
-
185
- def _get_labels_for_items(items: List[Dict[str, Any]], label_key: str) -> List[Any]:
186
- return [item.get(label_key) for item in items]
187
-
188
-
189
- async def _run_inference_callable(
190
- inference_callable: Callable,
191
- items: List[Dict[str, Any]],
192
- hyperparams: Dict[str, Any],
193
- ) -> Any:
194
- if is_awaitable(inference_callable):
195
- return await inference_callable(items, **hyperparams)
196
- else:
197
- return inference_callable(items, **hyperparams)
198
-
199
-
200
- # Yields (eval_dataset, items, labels, hyperparams) for every dataset x hyperparam combo.
201
- def _iter_dataset_jobs(
202
- datasets: List[EvalDataset],
203
- hyperparam_grid: List[Dict[str, Any]],
204
- sample_size: Optional[int],
205
- ) -> Iterable[Tuple[EvalDataset, List[Dict[str, Any]], List[Any], Dict[str, Any]]]:
206
- for eval_dataset in datasets:
207
- for hp in hyperparam_grid:
208
- items = _get_items_sample(eval_dataset.items, sample_size)
209
- labels = _get_labels_for_items(items, eval_dataset.label)
210
- yield eval_dataset, items, labels, hp
211
-
212
-
213
- def _score_metrics(
214
- eval_dataset: EvalDataset, outputs: List[Any], labels: List[Any]
215
- ) -> Dict[str, Dict[str, Any]]:
216
- metric_scores: Dict[str, Dict[str, Any]] = {}
217
- for metric in eval_dataset.metrics:
218
- aggregate_scores, item_scores = metric.score(outputs, labels)
219
- metric_scores[metric.name] = {
220
- "aggregate_scores": aggregate_scores,
221
- "item_scores": item_scores,
222
- }
223
- return metric_scores
224
-
225
-
226
- def _format_results(
227
- eval_results: List[EvalResult],
228
- return_dict: bool,
229
- return_aggregates: bool,
230
- return_items: bool,
231
- return_output: bool,
232
- ) -> Union[Dict, List]:
233
-
234
- # Return results as a dict
235
- if return_dict:
236
-
237
- # Include both aggregate and item scores in dict returned
238
- if return_aggregates and return_items:
239
- results: Dict[str, List[Dict[str, Any]]] = {"aggregate_results": [], "item_results": []}
240
- for eval_result in eval_results:
241
- eval_result_dict = eval_result.to_dict()
242
- results["aggregate_results"].extend(eval_result_dict["aggregate_results"])
243
- if return_output:
244
- results["item_results"].extend(eval_result_dict["item_results"])
245
- else:
246
- results["item_results"].extend(
247
- [
248
- {k: v for k, v in item.items() if k != "inference_output"}
249
- for item in eval_result_dict["item_results"]
250
- ]
251
- )
252
- return results
253
-
254
- # Include only aggregate scores in dict returned
255
- elif return_aggregates:
256
- return [eval_result.aggregate_scores for eval_result in eval_results]
257
-
258
- # Include only item scores in dict returned
259
- else:
260
- if return_output:
261
- return [item for eval_result in eval_results for item in eval_result.item_scores]
262
- else:
263
- return [
264
- {k: v for k, v in item.items() if k != "inference_output"}
265
- for eval_result in eval_results
266
- for item in eval_result.item_scores
267
- ]
268
-
269
- # Return results as an EvalResult object
270
- else:
271
- return {er.eval_dataset.name: er for er in eval_results}
@@ -1,89 +0,0 @@
1
- """Progress bar utilities for evaluation tracking."""
2
-
3
- from contextlib import contextmanager
4
- from typing import Any, Generator, List, Optional
5
-
6
- from tqdm import tqdm
7
-
8
-
9
- class EvaluationProgressBars:
10
- """Manages nested progress bars for evaluation tracking."""
11
-
12
- def __init__(self, datasets: List[Any], hyperparam_count: int) -> None:
13
- """Initialize progress bar manager.
14
-
15
- Args:
16
- datasets: List of datasets being evaluated
17
- hyperparam_count: Number of hyperparameter configurations per dataset
18
- """
19
- self.datasets = datasets
20
- self.hyperparam_count = hyperparam_count
21
- self.dataset_pbar: Optional[tqdm] = None
22
- self.hyperparam_pbar: Optional[tqdm] = None
23
-
24
- def start_dataset_progress(self) -> None:
25
- """Start the outer progress bar for datasets."""
26
- self.dataset_pbar = tqdm(
27
- total=len(self.datasets),
28
- desc="Datasets ",
29
- unit="dataset",
30
- position=0,
31
- leave=True,
32
- ncols=80,
33
- bar_format="{desc} {percentage:3.0f}%|{bar:40}| {n_fmt}/{total_fmt}",
34
- )
35
-
36
- def update_dataset_progress(self) -> None:
37
- """Update the dataset progress bar."""
38
- if self.dataset_pbar:
39
- self.dataset_pbar.update(1)
40
-
41
- def close_dataset_progress(self) -> None:
42
- """Close the dataset progress bar."""
43
- if self.dataset_pbar:
44
- self.dataset_pbar.close()
45
- self.dataset_pbar = None
46
-
47
- @contextmanager
48
- def hyperparam_progress_context(self) -> Generator[tqdm, None, None]:
49
- """Context manager for hyperparameter progress bar."""
50
- self.hyperparam_pbar = tqdm(
51
- total=self.hyperparam_count,
52
- desc="Hyperparams",
53
- unit="config",
54
- position=1,
55
- leave=False,
56
- ncols=80,
57
- bar_format="{desc} {percentage:3.0f}%|{bar:40}| {n_fmt}/{total_fmt}",
58
- )
59
- try:
60
- yield self.hyperparam_pbar
61
- finally:
62
- self.hyperparam_pbar.close()
63
- self.hyperparam_pbar = None
64
-
65
- def update_hyperparam_progress(self) -> None:
66
- """Update the hyperparameter progress bar."""
67
- if self.hyperparam_pbar:
68
- self.hyperparam_pbar.update(1)
69
-
70
-
71
- @contextmanager
72
- def evaluation_progress(
73
- datasets: List[Any], hyperparam_count: int
74
- ) -> Generator[EvaluationProgressBars, None, None]:
75
- """Context manager for evaluation progress bars.
76
-
77
- Args:
78
- datasets: List of datasets being evaluated
79
- hyperparam_count: Number of hyperparameter configurations per dataset
80
-
81
- Yields:
82
- EvaluationProgressBars: Progress bar manager instance
83
- """
84
- progress_bars = EvaluationProgressBars(datasets, hyperparam_count)
85
- progress_bars.start_dataset_progress()
86
- try:
87
- yield progress_bars
88
- finally:
89
- progress_bars.close_dataset_progress()
File without changes
File without changes