scorebook 0.0.9__py3-none-any.whl → 0.0.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. scorebook/__init__.py +14 -6
  2. scorebook/cli/auth.py +1 -1
  3. scorebook/eval_datasets/__init__.py +5 -0
  4. scorebook/eval_datasets/eval_dataset.py +719 -0
  5. scorebook/evaluate/__init__.py +15 -0
  6. scorebook/evaluate/_async/__init__.py +0 -0
  7. scorebook/evaluate/_async/evaluate_async.py +443 -0
  8. scorebook/evaluate/_sync/__init__.py +0 -0
  9. scorebook/evaluate/_sync/evaluate.py +443 -0
  10. scorebook/evaluate/evaluate_helpers.py +388 -0
  11. scorebook/exceptions.py +48 -0
  12. scorebook/inference/__init__.py +4 -0
  13. scorebook/inference/clients/__init__.py +8 -0
  14. scorebook/inference/{bedrock.py → clients/bedrock.py} +1 -1
  15. scorebook/inference/{openai.py → clients/openai.py} +35 -23
  16. scorebook/inference/{portkey.py → clients/portkey.py} +1 -1
  17. scorebook/inference/{vertex.py → clients/vertex.py} +1 -1
  18. scorebook/{inference_pipeline.py → inference/inference_pipeline.py} +66 -4
  19. scorebook/settings.py +21 -0
  20. scorebook/trismik/__init__.py +10 -0
  21. scorebook/types.py +8 -5
  22. scorebook/utils/__init__.py +11 -4
  23. scorebook/utils/async_utils.py +20 -1
  24. scorebook/utils/io_helpers.py +18 -5
  25. scorebook/utils/progress_bars.py +739 -96
  26. scorebook/utils/{build_prompt.py → render_template.py} +13 -12
  27. {scorebook-0.0.9.dist-info → scorebook-0.0.11.dist-info}/METADATA +4 -4
  28. scorebook-0.0.11.dist-info/RECORD +42 -0
  29. scorebook/eval_dataset.py +0 -404
  30. scorebook/evaluate.py +0 -623
  31. scorebook/trismik_services/__init__.py +0 -6
  32. scorebook/trismik_services/adaptive_testing_service.py +0 -141
  33. scorebook/trismik_services/upload_classic_eval_run.py +0 -102
  34. scorebook-0.0.9.dist-info/RECORD +0 -36
  35. /scorebook/{trismik_services/login.py → trismik/credentials.py} +0 -0
  36. {scorebook-0.0.9.dist-info → scorebook-0.0.11.dist-info}/WHEEL +0 -0
  37. {scorebook-0.0.9.dist-info → scorebook-0.0.11.dist-info}/entry_points.txt +0 -0
  38. {scorebook-0.0.9.dist-info → scorebook-0.0.11.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,388 @@
1
+ """Helper utilities shared by synchronous and asynchronous evaluation flows."""
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import inspect
6
+ import logging
7
+ from typing import Any, Callable, Dict, Iterable, List, Literal, Mapping, Optional, Union
8
+
9
+ from trismik._async.client import TrismikAsyncClient
10
+ from trismik._sync.client import TrismikClient
11
+ from trismik.types import TrismikMultipleChoiceTextItem
12
+
13
+ from scorebook import EvalDataset
14
+ from scorebook.exceptions import (
15
+ DataMismatchError,
16
+ MetricComputationError,
17
+ ParameterValidationError,
18
+ ScoreBookError,
19
+ )
20
+ from scorebook.settings import TRISMIK_SERVICE_URL
21
+ from scorebook.trismik.credentials import get_token
22
+ from scorebook.types import AdaptiveEvalDataset, AdaptiveEvalRunSpec, EvalResult, EvalRunSpec
23
+ from scorebook.utils import expand_dict, is_awaitable
24
+
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ def resolve_upload_results(upload_results: Union[Literal["auto"], bool]) -> bool:
29
+ """Resolve the upload_results parameter based on trismik login status."""
30
+
31
+ if upload_results == "auto":
32
+ upload_results = get_token() is not None
33
+ logger.debug("Auto upload results resolved to: %s", upload_results)
34
+
35
+ return upload_results
36
+
37
+
38
+ def resolve_show_progress(show_progress: Optional[bool]) -> bool:
39
+ """Resolve whether to show progress bars.
40
+
41
+ Args:
42
+ show_progress: Explicit setting (None uses default from settings)
43
+
44
+ Returns:
45
+ bool: Whether to show progress bars
46
+ """
47
+ if show_progress is None:
48
+ from scorebook.settings import SHOW_PROGRESS_BARS
49
+
50
+ return bool(SHOW_PROGRESS_BARS)
51
+ return show_progress
52
+
53
+
54
+ def validate_parameters(params: Dict[str, Any], caller: Callable[..., Any]) -> None:
55
+ """Validate all parameters for evaluation."""
56
+
57
+ caller_is_async = is_awaitable(caller)
58
+
59
+ # Sync evaluate() should only accept sync inference functions
60
+ if not caller_is_async and is_awaitable(params.get("inference")):
61
+ raise ParameterValidationError(
62
+ "evaluate() only accepts synchronous inference functions. "
63
+ "Use evaluate_async() for async inference functions."
64
+ )
65
+
66
+ # Async evaluate_async() should only accept async inference functions
67
+ if caller_is_async and not is_awaitable(params.get("inference")):
68
+ raise ParameterValidationError(
69
+ "evaluate_async() only accepts asynchronous inference functions. "
70
+ "Use evaluate() for sync inference functions."
71
+ )
72
+
73
+ # If returning a dict, it must contain items and/or aggregates
74
+ if params["return_dict"] and not params["return_aggregates"] and not params["return_items"]:
75
+ raise ParameterValidationError(
76
+ "When return_dict=True, at least one of return_aggregates or return_items must be True"
77
+ )
78
+
79
+ # If uploading results, experiment_id and project_id must be specified
80
+ if params["upload_results"]:
81
+ if params["experiment_id"] is None or params["project_id"] is None:
82
+ raise ParameterValidationError(
83
+ "experiment_id and project_id are required for upload_results=True"
84
+ )
85
+
86
+ logger.debug("Parameter validation successful")
87
+
88
+
89
+ def prepare_datasets(
90
+ datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
91
+ sample_size: Optional[int] = None,
92
+ ) -> List[Union[EvalDataset, AdaptiveEvalDataset]]:
93
+ """Prepare and separate input datasets into classic and adaptive evaluation datasets."""
94
+
95
+ # Ensure datasets is always a list for consistent processing
96
+ if not isinstance(datasets, list):
97
+ datasets = [datasets]
98
+
99
+ datasets_out: List[Union[EvalDataset, AdaptiveEvalDataset]] = []
100
+ for dataset in datasets:
101
+
102
+ # Prepare classic datasets
103
+ if isinstance(dataset, EvalDataset):
104
+
105
+ if sample_size is not None:
106
+ dataset = dataset.sample(sample_size)
107
+
108
+ datasets_out.append(dataset)
109
+
110
+ # Prepare adaptive datasets
111
+ elif isinstance(dataset, str) and dataset.endswith(":adaptive"):
112
+ datasets_out.append(AdaptiveEvalDataset(dataset.replace(":adaptive", "")))
113
+
114
+ # TODO: dataset name string registry
115
+ elif isinstance(dataset, str):
116
+ pass
117
+
118
+ else:
119
+ raise ParameterValidationError(f"Unrecognized dataset type: {type(dataset)}")
120
+
121
+ return datasets_out
122
+
123
+
124
+ def prepare_hyperparameter_configs(
125
+ hyperparameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]]
126
+ ) -> List[Dict[str, Any]]:
127
+ """Prepare hyperparameters for evaluation by returning a list of hyper-param configs."""
128
+ if hyperparameters is None:
129
+ return [{}]
130
+ if not isinstance(hyperparameters, list):
131
+ expanded: List[Dict[str, Any]] = expand_dict(hyperparameters or {})
132
+ return expanded
133
+
134
+ logger.info("Evaluating with hyperparameters: %s", hyperparameters)
135
+
136
+ return hyperparameters
137
+
138
+
139
+ def build_eval_run_specs(
140
+ datasets: List[Union[EvalDataset, str]],
141
+ hyperparameters: Any,
142
+ experiment_id: Optional[str],
143
+ project_id: Optional[str],
144
+ metadata: Optional[Dict[str, Any]] = None,
145
+ ) -> List[Union[EvalRunSpec, AdaptiveEvalRunSpec]]:
146
+ """Build All RunSpec objects for each dataset/hyperparameter combination."""
147
+
148
+ eval_run_specs: List[Union[EvalRunSpec, AdaptiveEvalRunSpec]] = []
149
+ for dataset_index, dataset in enumerate(datasets):
150
+ for hyperparameters_index, hyperparameter_config in enumerate(hyperparameters):
151
+
152
+ # Create classic eval run spec
153
+ if isinstance(dataset, EvalDataset):
154
+ eval_run_specs.append(
155
+ build_classic_eval_run_spec(
156
+ dataset, dataset_index, hyperparameter_config, hyperparameters_index
157
+ )
158
+ )
159
+
160
+ # Create adaptive eval run spec from string
161
+ elif isinstance(dataset, AdaptiveEvalDataset):
162
+ if not experiment_id or not project_id:
163
+ raise ScoreBookError(
164
+ "experiment_id and project_id are required for adaptive evaluations"
165
+ )
166
+ eval_run_specs.append(
167
+ build_adaptive_eval_run_spec(
168
+ dataset.name,
169
+ dataset_index,
170
+ hyperparameter_config,
171
+ hyperparameters_index,
172
+ experiment_id,
173
+ project_id,
174
+ metadata,
175
+ )
176
+ )
177
+
178
+ # Log warning - should never happen
179
+ else:
180
+ logger.warning("Unrecognized dataset type: %s", dataset)
181
+
182
+ return eval_run_specs
183
+
184
+
185
+ def build_classic_eval_run_spec(
186
+ dataset: EvalDataset,
187
+ dataset_index: int,
188
+ hyperparameters: Dict[str, Any],
189
+ hyperparameters_index: int,
190
+ ) -> EvalRunSpec:
191
+ """Build EvalRunSpec objects for a classic dataset and hyperparameter combination.
192
+
193
+ Extracts input and label values from the appropriate columns in the dataset.
194
+ The column names are determined by dataset.input and dataset.label,
195
+ which may be original field names (e.g., "question", "answer") or computed
196
+ column names (e.g., "*input", "*label") if templates were used.
197
+ """
198
+ # Extract inputs and labels using the dataset's column specifications
199
+ inputs = dataset[dataset.input] # Returns List[Any]
200
+ labels = dataset[dataset.label] # Returns List[Any]
201
+ eval_run_spec = EvalRunSpec(
202
+ dataset,
203
+ dataset_index,
204
+ hyperparameters,
205
+ hyperparameters_index,
206
+ inputs,
207
+ labels,
208
+ )
209
+ logger.debug("Built EvalRunSpec: %s", eval_run_spec)
210
+ return eval_run_spec
211
+
212
+
213
+ def build_adaptive_eval_run_spec(
214
+ adaptive_dataset: str,
215
+ dataset_index: int,
216
+ hyperparameter_config: Dict[str, Any],
217
+ hyperparameter_config_index: int,
218
+ experiment_id: str,
219
+ project_id: str,
220
+ metadata: Optional[Dict[str, Any]] = None,
221
+ ) -> AdaptiveEvalRunSpec:
222
+ """Build AdaptiveEvalRunSpec objects for a dataset/hyperparameter combination."""
223
+ dataset = adaptive_dataset.replace(":adaptive", "")
224
+ adaptive_eval_run_spec = AdaptiveEvalRunSpec(
225
+ dataset,
226
+ dataset_index,
227
+ hyperparameter_config,
228
+ hyperparameter_config_index,
229
+ experiment_id,
230
+ project_id,
231
+ metadata,
232
+ )
233
+ logger.debug("Built AdaptiveEvalRunSpec: %s", adaptive_eval_run_spec)
234
+ return adaptive_eval_run_spec
235
+
236
+
237
+ def score_metrics(
238
+ eval_dataset: EvalDataset, outputs: List[Any], labels: List[Any]
239
+ ) -> Dict[str, Dict[str, Any]]:
240
+ """Compute metric scores for a given dataset and inference outputs."""
241
+ metric_scores: Dict[str, Dict[str, Any]] = {}
242
+
243
+ if len(outputs) != len(labels):
244
+ raise DataMismatchError(len(outputs), len(labels), eval_dataset.name)
245
+
246
+ for metric in eval_dataset.metrics:
247
+ try:
248
+ aggregate_scores, item_scores = metric.score(outputs, labels)
249
+ metric_scores[metric.name] = {
250
+ "aggregate_scores": aggregate_scores,
251
+ "item_scores": item_scores,
252
+ }
253
+ except Exception as e:
254
+ logger.error(
255
+ "Failed to compute metric '%s' for dataset '%s': %s",
256
+ metric.name,
257
+ eval_dataset.name,
258
+ str(e),
259
+ )
260
+ raise MetricComputationError(metric.name, eval_dataset.name, e)
261
+
262
+ return metric_scores
263
+
264
+
265
+ def create_trismik_async_client() -> TrismikAsyncClient:
266
+ """Create a new async Trismik client instance."""
267
+ api_key = get_token()
268
+ logger.debug("Creating new async Trismik client")
269
+ return TrismikAsyncClient(service_url=TRISMIK_SERVICE_URL, api_key=api_key)
270
+
271
+
272
+ def create_trismik_sync_client() -> TrismikClient:
273
+ """Create a new sync Trismik client instance."""
274
+ api_key = get_token()
275
+ logger.debug("Creating new sync Trismik client")
276
+ return TrismikClient(service_url=TRISMIK_SERVICE_URL, api_key=api_key)
277
+
278
+
279
+ def get_model_name(
280
+ inference_callable: Optional[Callable] = None, metadata: Optional[Dict[str, Any]] = None
281
+ ) -> str:
282
+ """Determine a model's name with the fallback "Model"."""
283
+
284
+ # First priority: metadata.model
285
+ if metadata and "model" in metadata:
286
+ return str(metadata["model"])
287
+
288
+ # Second priority: inference_pipeline.model (if callable is an InferencePipeline)
289
+ if inference_callable and hasattr(inference_callable, "model"):
290
+ return str(inference_callable.model)
291
+
292
+ # Fallback: "Model"
293
+ return "Model"
294
+
295
+
296
+ def format_results(
297
+ eval_result: EvalResult,
298
+ return_dict: bool,
299
+ return_aggregates: bool,
300
+ return_items: bool,
301
+ return_output: bool,
302
+ ) -> Union[EvalResult, Dict, List]:
303
+ """Format an `EvalResult` into the requested output structure."""
304
+
305
+ # Return results as a dict
306
+ if return_dict:
307
+ results = {}
308
+
309
+ if return_aggregates:
310
+ results["aggregate_results"] = eval_result.aggregate_scores
311
+
312
+ if return_items:
313
+ item_scores = eval_result.item_scores
314
+
315
+ # Remove inference output if not requested
316
+ if not return_output:
317
+ for item in item_scores:
318
+ item.pop("output", None)
319
+
320
+ results["item_results"] = item_scores
321
+
322
+ # If both are requested, return the combined structure
323
+ if return_aggregates and return_items:
324
+ return results
325
+ # If only aggregates requested, return just the list
326
+ elif return_aggregates:
327
+ return results["aggregate_results"]
328
+ # If only items requested, return just the list
329
+ else:
330
+ return results["item_results"]
331
+
332
+ # Return results as an EvalResult object
333
+ else:
334
+ return eval_result
335
+
336
+
337
+ def make_trismik_inference(
338
+ inference_function: Callable[..., Any],
339
+ return_list: bool = False,
340
+ ) -> Callable[[Any], Any]:
341
+ """Wrap an inference function for flexible input handling.
342
+
343
+ Takes a function expecting list[dict] and makes it accept single dict
344
+ or TrismikMultipleChoiceTextItem.
345
+ """
346
+
347
+ # Check if the inference function is async
348
+ is_async = inspect.iscoroutinefunction(inference_function) or (
349
+ hasattr(inference_function, "__call__")
350
+ and inspect.iscoroutinefunction(inference_function.__call__)
351
+ )
352
+
353
+ def sync_trismik_inference_function(eval_items: Any, **kwargs: Any) -> Any:
354
+ # Single TrismikMultipleChoiceTextItem dataclass
355
+ if isinstance(eval_items, TrismikMultipleChoiceTextItem):
356
+ eval_item_dict = dataclasses.asdict(eval_items)
357
+ results = inference_function([eval_item_dict], **kwargs)
358
+ if is_async:
359
+ results = asyncio.run(results)
360
+ return results if return_list else results[0]
361
+
362
+ # Single item (a mapping)
363
+ if isinstance(eval_items, Mapping):
364
+ results = inference_function([eval_items], **kwargs)
365
+ if is_async:
366
+ results = asyncio.run(results)
367
+ return results if return_list else results[0]
368
+
369
+ # Iterable of items (but not a string/bytes)
370
+ if isinstance(eval_items, Iterable) and not isinstance(eval_items, (str, bytes)):
371
+ # Convert any TrismikMultipleChoiceTextItem instances to dicts
372
+ converted_items = []
373
+ for item in eval_items:
374
+ if isinstance(item, TrismikMultipleChoiceTextItem):
375
+ converted_items.append(dataclasses.asdict(item))
376
+ else:
377
+ converted_items.append(item)
378
+ results = inference_function(converted_items, **kwargs)
379
+ if is_async:
380
+ results = asyncio.run(results)
381
+ return results
382
+
383
+ raise TypeError(
384
+ "Expected a single item (Mapping[str, Any] or TrismikMultipleChoiceTextItem) "
385
+ "or an iterable of such items."
386
+ )
387
+
388
+ return sync_trismik_inference_function
scorebook/exceptions.py CHANGED
@@ -10,6 +10,54 @@ class ScoreBookError(Exception):
10
10
  """Base exception class for all Scorebook-related errors."""
11
11
 
12
12
 
13
+ class EvalDatasetError(ScoreBookError):
14
+ """Base exception class for all EvalDataset errors."""
15
+
16
+
17
+ class DatasetConfigurationError(EvalDatasetError):
18
+ """Raised when dataset configuration is invalid (e.g., mutually exclusive parameters)."""
19
+
20
+
21
+ class MissingFieldError(EvalDatasetError):
22
+ """Raised when required field is missing from dataset."""
23
+
24
+ def __init__(self, field_name: str, field_type: str, available_fields: list[str]):
25
+ """Initialize missing field error with structured context."""
26
+ self.field_name = field_name
27
+ self.field_type = field_type # "input" or "label"
28
+ self.available_fields = available_fields
29
+ super().__init__(
30
+ f"{field_type.capitalize()} field '{field_name}' not found. "
31
+ f"Available fields: {', '.join(available_fields)}"
32
+ )
33
+
34
+
35
+ class DatasetLoadError(EvalDatasetError):
36
+ """Raised when dataset fails to load from source (file or remote)."""
37
+
38
+
39
+ class DatasetParseError(EvalDatasetError):
40
+ """Raised when dataset file cannot be parsed (CSV, JSON, YAML)."""
41
+
42
+
43
+ class DatasetNotInitializedError(EvalDatasetError):
44
+ """Raised when operations are attempted on uninitialized dataset."""
45
+
46
+
47
+ class DatasetSampleError(EvalDatasetError):
48
+ """Raised when sampling parameters are invalid."""
49
+
50
+ def __init__(self, sample_size: int, dataset_size: int, dataset_name: str):
51
+ """Initialize dataset sample error with structured context."""
52
+ self.sample_size = sample_size
53
+ self.dataset_size = dataset_size
54
+ self.dataset_name = dataset_name
55
+ super().__init__(
56
+ f"Sample size {sample_size} exceeds dataset size {dataset_size} "
57
+ f"for dataset '{dataset_name}'"
58
+ )
59
+
60
+
13
61
  class EvaluationError(ScoreBookError):
14
62
  """Raised when there are errors during model evaluation."""
15
63
 
@@ -5,3 +5,7 @@ This module provides functionality for running inference with various models
5
5
  and processing their responses. It includes utilities for both single and
6
6
  batch inference operations.
7
7
  """
8
+
9
+ from scorebook.inference.inference_pipeline import InferencePipeline
10
+
11
+ __all__ = ["InferencePipeline"]
@@ -0,0 +1,8 @@
1
+ """
2
+ Inference clients for various LLM providers.
3
+
4
+ This module provides client implementations for different LLM providers including
5
+ OpenAI, AWS Bedrock, Google Vertex AI, and Portkey.
6
+ """
7
+
8
+ __all__ = ["bedrock", "openai", "portkey", "vertex"]
@@ -17,7 +17,7 @@ from typing import Any, Dict, List, Optional, Tuple
17
17
  import boto3
18
18
  from botocore.config import Config
19
19
  from botocore.exceptions import ClientError
20
- from tqdm.asyncio import tqdm
20
+ from tqdm.auto import tqdm
21
21
 
22
22
 
23
23
  async def batch(
@@ -10,7 +10,7 @@ import asyncio
10
10
  import json
11
11
  import logging
12
12
  import tempfile
13
- from typing import Any, List
13
+ from typing import Any, List, Optional
14
14
 
15
15
  from openai import AsyncOpenAI
16
16
 
@@ -18,7 +18,10 @@ logger = logging.getLogger(__name__)
18
18
 
19
19
 
20
20
  async def responses(
21
- items: List[Any], model: str = "gpt-4.1-nano", client: Any = None, **hyperparameters: Any
21
+ items: List[Any],
22
+ model: str = "gpt-4.1-nano",
23
+ client: Optional[AsyncOpenAI] = None,
24
+ **hyperparameters: Any,
22
25
  ) -> List[Any]:
23
26
  """Process multiple inference requests using OpenAI's Async API.
24
27
 
@@ -28,23 +31,28 @@ async def responses(
28
31
  Args:
29
32
  items: List of preprocessed items to process.
30
33
  model: OpenAI model to use.
31
- client: Optional OpenAI client instance.
34
+ client: Optional OpenAI client instance. If not provided, creates a new client
35
+ with automatic cleanup using a context manager.
32
36
  hyperparameters: Dictionary of hyperparameters for inference.
33
37
 
34
38
  Returns:
35
39
  List of raw model responses.
36
-
37
- Raises:
38
- NotImplementedError: Currently not implemented.
39
40
  """
41
+ if client is None:
42
+ async with AsyncOpenAI() as client:
43
+ return await _do_responses(items, model, client, **hyperparameters)
44
+ else:
45
+ return await _do_responses(items, model, client, **hyperparameters)
46
+
47
+
48
+ async def _do_responses(
49
+ items: List[Any], model: str, client: AsyncOpenAI, **hyperparameters: Any
50
+ ) -> List[Any]:
51
+ """Process responses internally with provided client."""
40
52
  logger.debug("OpenAI responses function called with %d items", len(items))
41
53
  logger.debug("Using model: %s", model)
42
54
  logger.debug("Hyperparameters: %s", hyperparameters)
43
55
 
44
- if client is None:
45
- logger.debug("Creating new AsyncOpenAI client")
46
- client = AsyncOpenAI()
47
-
48
56
  # Create all tasks concurrently for true parallelism
49
57
  tasks = []
50
58
  for i, item in enumerate(items):
@@ -127,7 +135,7 @@ async def responses(
127
135
  async def batch(
128
136
  items: List[Any],
129
137
  model: str = "gpt-4.1-nano",
130
- client: Any = None,
138
+ client: Optional[AsyncOpenAI] = None,
131
139
  **hyperparameters: Any,
132
140
  ) -> List[Any]:
133
141
  """Process multiple inference requests in batch using OpenAI's API.
@@ -138,18 +146,24 @@ async def batch(
138
146
  Args:
139
147
  items: List of preprocessed items to process.
140
148
  model: OpenAI model to use.
141
- client: Optional OpenAI client instance.
149
+ client: Optional OpenAI client instance. If not provided, creates a new client
150
+ with automatic cleanup using a context manager.
142
151
  hyperparameters: Dictionary of hyperparameters for inference.
143
152
 
144
153
  Returns:
145
154
  A list of raw model responses.
146
-
147
- Raises:
148
- NotImplementedError: Currently not implemented.
149
155
  """
150
156
  if client is None:
151
- client = AsyncOpenAI()
157
+ async with AsyncOpenAI() as client:
158
+ return await _do_batch(items, model, client, **hyperparameters)
159
+ else:
160
+ return await _do_batch(items, model, client, **hyperparameters)
152
161
 
162
+
163
+ async def _do_batch(
164
+ items: List[Any], model: str, client: AsyncOpenAI, **hyperparameters: Any
165
+ ) -> List[Any]:
166
+ """Process batch internally with provided client."""
153
167
  file_id = await _upload_batch(items, client)
154
168
  batch_id = await _start_batch(file_id, client)
155
169
 
@@ -173,18 +187,16 @@ async def batch(
173
187
  return batch_result
174
188
 
175
189
 
176
- async def _upload_batch(items: List[Any], client: Any) -> str:
190
+ async def _upload_batch(items: List[Any], client: AsyncOpenAI) -> str:
177
191
  """Create a .jsonl file from preprocessed items and upload to OpenAI for batch processing.
178
192
 
179
193
  Args:
180
194
  items: A list of preprocessed items, each representing a single dataset eval item.
195
+ client: OpenAI client instance.
181
196
 
182
197
  Returns:
183
198
  The file ID returned by OpenAI after uploading.
184
199
  """
185
- # Instantiate OpenAI client
186
- if client is None:
187
- client = AsyncOpenAI()
188
200
 
189
201
  # Create temp .jsonl file
190
202
  with tempfile.NamedTemporaryFile(mode="w+", suffix=".jsonl", delete=False) as f:
@@ -206,7 +218,7 @@ async def _upload_batch(items: List[Any], client: Any) -> str:
206
218
  return str(response.id)
207
219
 
208
220
 
209
- async def _start_batch(file_id: str, client: Any) -> str:
221
+ async def _start_batch(file_id: str, client: AsyncOpenAI) -> str:
210
222
  batch_response = await client.batches.create(
211
223
  input_file_id=file_id,
212
224
  endpoint="/v1/chat/completions",
@@ -215,12 +227,12 @@ async def _start_batch(file_id: str, client: Any) -> str:
215
227
  return str(batch_response.id)
216
228
 
217
229
 
218
- async def _get_batch(batch_id: str, client: Any) -> Any:
230
+ async def _get_batch(batch_id: str, client: AsyncOpenAI) -> Any:
219
231
  batch_object = await client.batches.retrieve(batch_id)
220
232
  return batch_object
221
233
 
222
234
 
223
- async def _get_results_file(output_file_id: str, client: Any) -> List[str]:
235
+ async def _get_results_file(output_file_id: str, client: AsyncOpenAI) -> List[str]:
224
236
  """Download and parse the batch results file from OpenAI."""
225
237
  response = await client.files.content(output_file_id)
226
238
 
@@ -13,7 +13,7 @@ import tempfile
13
13
  from typing import Any, List, Optional
14
14
 
15
15
  from portkey_ai import AsyncPortkey
16
- from tqdm.asyncio import tqdm
16
+ from tqdm.auto import tqdm
17
17
 
18
18
 
19
19
  async def responses(
@@ -18,7 +18,7 @@ import pandas as pd
18
18
  from google import genai
19
19
  from google.cloud import storage
20
20
  from google.genai import types
21
- from tqdm.asyncio import tqdm
21
+ from tqdm.auto import tqdm
22
22
 
23
23
 
24
24
  async def responses(