scorebook 0.0.1__tar.gz → 0.0.2__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {scorebook-0.0.1 → scorebook-0.0.2}/PKG-INFO +11 -1
- {scorebook-0.0.1 → scorebook-0.0.2}/pyproject.toml +5 -1
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/__init__.py +2 -1
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/evaluator.py +94 -51
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/inference/__init__.py +0 -4
- scorebook-0.0.2/src/scorebook/inference/bedrock.py +305 -0
- scorebook-0.0.2/src/scorebook/inference/vertex.py +295 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/types/eval_dataset.py +50 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/types/eval_result.py +7 -3
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/types/inference_pipeline.py +5 -2
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/utils/__init__.py +2 -1
- scorebook-0.0.2/src/scorebook/utils/build_prompt.py +52 -0
- scorebook-0.0.2/src/scorebook/utils/jinja_helpers.py +146 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/LICENSE +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/README.md +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/inference/openai.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/inference/portkey.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/metrics/__init__.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/metrics/accuracy.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/metrics/metric_base.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/metrics/metric_registry.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/metrics/precision.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/types/__init__.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/utils/async_utils.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/utils/io_helpers.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/utils/mappers.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/utils/progress_bars.py +0 -0
- {scorebook-0.0.1 → scorebook-0.0.2}/src/scorebook/utils/transform_helpers.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: scorebook
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.2
|
|
4
4
|
Summary: A Python project for LLM evaluation.
|
|
5
5
|
Author: Euan Campbell
|
|
6
6
|
Author-email: euan@trismik.com
|
|
@@ -11,16 +11,26 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
11
11
|
Classifier: Programming Language :: Python :: 3.11
|
|
12
12
|
Classifier: Programming Language :: Python :: 3.12
|
|
13
13
|
Classifier: Programming Language :: Python :: 3.13
|
|
14
|
+
Provides-Extra: bedrock
|
|
14
15
|
Provides-Extra: examples
|
|
15
16
|
Provides-Extra: openai
|
|
16
17
|
Provides-Extra: portkey
|
|
18
|
+
Provides-Extra: vertex
|
|
17
19
|
Requires-Dist: accelerate ; extra == "examples"
|
|
20
|
+
Requires-Dist: boto3 (==1.40.0) ; extra == "bedrock"
|
|
18
21
|
Requires-Dist: datasets (>=3.6.0)
|
|
22
|
+
Requires-Dist: fsspec[gcs] ; extra == "vertex"
|
|
23
|
+
Requires-Dist: google-cloud-storage ; extra == "vertex"
|
|
24
|
+
Requires-Dist: google-genai ; extra == "vertex"
|
|
25
|
+
Requires-Dist: notebook (>=7.4.5,<8.0.0)
|
|
19
26
|
Requires-Dist: notebook ; extra == "examples"
|
|
20
27
|
Requires-Dist: openai ; extra == "openai"
|
|
28
|
+
Requires-Dist: pandas ; extra == "vertex"
|
|
21
29
|
Requires-Dist: portkey-ai ; extra == "portkey"
|
|
30
|
+
Requires-Dist: python-dotenv ; extra == "bedrock"
|
|
22
31
|
Requires-Dist: python-dotenv ; extra == "openai"
|
|
23
32
|
Requires-Dist: python-dotenv ; extra == "portkey"
|
|
33
|
+
Requires-Dist: python-dotenv ; extra == "vertex"
|
|
24
34
|
Requires-Dist: torch ; extra == "examples"
|
|
25
35
|
Requires-Dist: torchaudio ; extra == "examples"
|
|
26
36
|
Requires-Dist: torchvision ; extra == "examples"
|
|
@@ -10,10 +10,11 @@ readme = "README.md"
|
|
|
10
10
|
requires-python = ">=3.9"
|
|
11
11
|
dependencies = [
|
|
12
12
|
"datasets>=3.6.0",
|
|
13
|
+
"notebook (>=7.4.5,<8.0.0)",
|
|
13
14
|
]
|
|
14
15
|
|
|
15
16
|
[tool.poetry]
|
|
16
|
-
version = "0.0.
|
|
17
|
+
version = "0.0.2" # base version
|
|
17
18
|
packages = [{ include = "scorebook", from = "src" }]
|
|
18
19
|
|
|
19
20
|
|
|
@@ -28,10 +29,13 @@ flake8 = "^7.0.0"
|
|
|
28
29
|
mypy = "^1.15.0"
|
|
29
30
|
autoflake = "^2.3.1"
|
|
30
31
|
toml = "^0.10.2"
|
|
32
|
+
types-pyyaml = "^6.0.12.20250822"
|
|
31
33
|
|
|
32
34
|
[project.optional-dependencies]
|
|
33
35
|
openai = ["openai", "python-dotenv"]
|
|
34
36
|
portkey = ["portkey-ai", "python-dotenv"]
|
|
37
|
+
bedrock = ["boto3==1.40.0", "python-dotenv"]
|
|
38
|
+
vertex = ["google-genai", "pandas", "google-cloud-storage", "fsspec[gcs]", "python-dotenv"]
|
|
35
39
|
examples = ["transformers", "torch", "torchvision", "torchaudio", "accelerate", "notebook"]
|
|
36
40
|
|
|
37
41
|
|
|
@@ -11,5 +11,6 @@ __version__ = importlib.metadata.version(__package__ or __name__)
|
|
|
11
11
|
|
|
12
12
|
from scorebook.evaluator import evaluate
|
|
13
13
|
from scorebook.types.eval_dataset import EvalDataset
|
|
14
|
+
from scorebook.utils.build_prompt import build_prompt
|
|
14
15
|
|
|
15
|
-
__all__ = ["EvalDataset", "evaluate"]
|
|
16
|
+
__all__ = ["EvalDataset", "evaluate", "build_prompt"]
|
|
@@ -24,17 +24,30 @@ from scorebook.utils import evaluation_progress, expand_dict, is_awaitable
|
|
|
24
24
|
async def _evaluate_async(
|
|
25
25
|
inference_callable: Callable,
|
|
26
26
|
eval_datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
|
|
27
|
-
hyperparameters: Optional[Dict[str, Any]] = None,
|
|
27
|
+
hyperparameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
28
28
|
experiment_id: Optional[str] = None,
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
29
|
+
return_dict: bool = True,
|
|
30
|
+
return_aggregates: bool = True,
|
|
31
|
+
return_items: bool = False,
|
|
32
|
+
return_output: bool = False,
|
|
33
|
+
sample_size: Optional[int] = None,
|
|
32
34
|
) -> Union[Dict, List]:
|
|
33
35
|
"""Run inference across datasets/hyperparams, compute metrics, and format results."""
|
|
34
|
-
|
|
36
|
+
|
|
37
|
+
# Validate parameters
|
|
38
|
+
if return_dict and not return_aggregates and not return_items:
|
|
39
|
+
raise ValueError(
|
|
40
|
+
"When return_dict=True, at least one of return_aggregates or return_items must be True"
|
|
41
|
+
)
|
|
35
42
|
|
|
36
43
|
normalized_datasets = _normalize_datasets(eval_datasets)
|
|
37
|
-
|
|
44
|
+
|
|
45
|
+
if hyperparameters is None:
|
|
46
|
+
hyperparam_grid: List[Dict[str, Any]] = [{}]
|
|
47
|
+
elif not isinstance(hyperparameters, list):
|
|
48
|
+
hyperparam_grid = _expand_hyperparams(hyperparameters)
|
|
49
|
+
else:
|
|
50
|
+
hyperparam_grid = hyperparameters
|
|
38
51
|
|
|
39
52
|
eval_results: List[EvalResult] = []
|
|
40
53
|
|
|
@@ -44,8 +57,13 @@ async def _evaluate_async(
|
|
|
44
57
|
with progress_bars.hyperparam_progress_context():
|
|
45
58
|
# Run inference for each hyperparameter configuration on this dataset
|
|
46
59
|
for hp_idx, hyperparam_config in enumerate(hyperparam_grid):
|
|
47
|
-
|
|
48
|
-
|
|
60
|
+
|
|
61
|
+
if sample_size:
|
|
62
|
+
items = _get_items_sample(eval_dataset.items, sample_size)
|
|
63
|
+
else:
|
|
64
|
+
items = eval_dataset.items
|
|
65
|
+
|
|
66
|
+
labels = _get_labels_for_items(items, eval_dataset.label)
|
|
49
67
|
|
|
50
68
|
# 1) Run inference
|
|
51
69
|
outputs = await _run_inference_callable(
|
|
@@ -71,17 +89,21 @@ async def _evaluate_async(
|
|
|
71
89
|
pass
|
|
72
90
|
|
|
73
91
|
# 4) Format as requested
|
|
74
|
-
return _format_results(
|
|
92
|
+
return _format_results(
|
|
93
|
+
eval_results, return_dict, return_aggregates, return_items, return_output
|
|
94
|
+
)
|
|
75
95
|
|
|
76
96
|
|
|
77
97
|
def evaluate(
|
|
78
98
|
inference_callable: Callable,
|
|
79
99
|
eval_datasets: Union[str, EvalDataset, List[Union[str, EvalDataset]]],
|
|
80
|
-
hyperparameters: Optional[Dict[str, Any]] = None,
|
|
100
|
+
hyperparameters: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None,
|
|
81
101
|
experiment_id: Optional[str] = None,
|
|
82
|
-
|
|
83
|
-
|
|
84
|
-
|
|
102
|
+
return_dict: bool = True,
|
|
103
|
+
return_aggregates: bool = True,
|
|
104
|
+
return_items: bool = False,
|
|
105
|
+
return_output: bool = False,
|
|
106
|
+
sample_size: Optional[int] = None,
|
|
85
107
|
) -> Union[Dict, List]:
|
|
86
108
|
"""
|
|
87
109
|
Evaluate model predictions using specified metrics on given datasets.
|
|
@@ -101,12 +123,11 @@ def evaluate(
|
|
|
101
123
|
- A list of string identifiers
|
|
102
124
|
hyperparameters: Optional dictionary containing hyperparameter sweep configuration.
|
|
103
125
|
experiment_id: Optional string identifier for tracking multiple evaluation runs.
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
107
|
-
|
|
108
|
-
|
|
109
|
-
- "all": Return both aggregate and per-item scores
|
|
126
|
+
return_dict: If True, returns eval results as a dict
|
|
127
|
+
return_aggregates: If True, returns aggregate scores for each dataset
|
|
128
|
+
return_items: If True, returns individual items for each dataset
|
|
129
|
+
return_output: If True, returns model outputs for each dataset item evaluated
|
|
130
|
+
sample_size: If set, only return a sample of the dataset items (for debugging)
|
|
110
131
|
|
|
111
132
|
Returns:
|
|
112
133
|
Dictionary mapping dataset names to their evaluation results. For each dataset,
|
|
@@ -130,9 +151,11 @@ def evaluate(
|
|
|
130
151
|
eval_datasets=eval_datasets,
|
|
131
152
|
hyperparameters=hyperparameters,
|
|
132
153
|
experiment_id=experiment_id,
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
154
|
+
return_dict=return_dict,
|
|
155
|
+
return_aggregates=return_aggregates,
|
|
156
|
+
return_items=return_items,
|
|
157
|
+
return_output=return_output,
|
|
158
|
+
sample_size=sample_size,
|
|
136
159
|
)
|
|
137
160
|
)
|
|
138
161
|
|
|
@@ -149,20 +172,17 @@ def _normalize_datasets(
|
|
|
149
172
|
return [d for d in datasets if isinstance(d, EvalDataset)]
|
|
150
173
|
|
|
151
174
|
|
|
152
|
-
def _validate_score_type(score_type: str) -> None:
|
|
153
|
-
if score_type not in {"aggregate", "item", "all"}:
|
|
154
|
-
raise ValueError("score_type must be 'aggregate', 'item', or 'all'")
|
|
155
|
-
|
|
156
|
-
|
|
157
175
|
def _expand_hyperparams(hyperparameters: Optional[Dict[str, Any]]) -> Any:
|
|
158
176
|
return expand_dict(hyperparameters or {})
|
|
159
177
|
|
|
160
178
|
|
|
161
|
-
def
|
|
179
|
+
def _get_items_sample(
|
|
180
|
+
items: List[Dict[str, Any]], item_limit: Optional[int]
|
|
181
|
+
) -> List[Dict[str, Any]]:
|
|
162
182
|
return items[:item_limit] if item_limit else items
|
|
163
183
|
|
|
164
184
|
|
|
165
|
-
def
|
|
185
|
+
def _get_labels_for_items(items: List[Dict[str, Any]], label_key: str) -> List[Any]:
|
|
166
186
|
return [item.get(label_key) for item in items]
|
|
167
187
|
|
|
168
188
|
|
|
@@ -181,12 +201,12 @@ async def _run_inference_callable(
|
|
|
181
201
|
def _iter_dataset_jobs(
|
|
182
202
|
datasets: List[EvalDataset],
|
|
183
203
|
hyperparam_grid: List[Dict[str, Any]],
|
|
184
|
-
|
|
204
|
+
sample_size: Optional[int],
|
|
185
205
|
) -> Iterable[Tuple[EvalDataset, List[Dict[str, Any]], List[Any], Dict[str, Any]]]:
|
|
186
206
|
for eval_dataset in datasets:
|
|
187
207
|
for hp in hyperparam_grid:
|
|
188
|
-
items =
|
|
189
|
-
labels =
|
|
208
|
+
items = _get_items_sample(eval_dataset.items, sample_size)
|
|
209
|
+
labels = _get_labels_for_items(items, eval_dataset.label)
|
|
190
210
|
yield eval_dataset, items, labels, hp
|
|
191
211
|
|
|
192
212
|
|
|
@@ -204,25 +224,48 @@ def _score_metrics(
|
|
|
204
224
|
|
|
205
225
|
|
|
206
226
|
def _format_results(
|
|
207
|
-
eval_results: List[EvalResult],
|
|
227
|
+
eval_results: List[EvalResult],
|
|
228
|
+
return_dict: bool,
|
|
229
|
+
return_aggregates: bool,
|
|
230
|
+
return_items: bool,
|
|
231
|
+
return_output: bool,
|
|
208
232
|
) -> Union[Dict, List]:
|
|
209
233
|
|
|
210
|
-
|
|
234
|
+
# Return results as a dict
|
|
235
|
+
if return_dict:
|
|
236
|
+
|
|
237
|
+
# Include both aggregate and item scores in dict returned
|
|
238
|
+
if return_aggregates and return_items:
|
|
239
|
+
results: Dict[str, List[Dict[str, Any]]] = {"aggregate_results": [], "item_results": []}
|
|
240
|
+
for eval_result in eval_results:
|
|
241
|
+
eval_result_dict = eval_result.to_dict()
|
|
242
|
+
results["aggregate_results"].extend(eval_result_dict["aggregate_results"])
|
|
243
|
+
if return_output:
|
|
244
|
+
results["item_results"].extend(eval_result_dict["item_results"])
|
|
245
|
+
else:
|
|
246
|
+
results["item_results"].extend(
|
|
247
|
+
[
|
|
248
|
+
{k: v for k, v in item.items() if k != "inference_output"}
|
|
249
|
+
for item in eval_result_dict["item_results"]
|
|
250
|
+
]
|
|
251
|
+
)
|
|
252
|
+
return results
|
|
253
|
+
|
|
254
|
+
# Include only aggregate scores in dict returned
|
|
255
|
+
elif return_aggregates:
|
|
256
|
+
return [eval_result.aggregate_scores for eval_result in eval_results]
|
|
257
|
+
|
|
258
|
+
# Include only item scores in dict returned
|
|
259
|
+
else:
|
|
260
|
+
if return_output:
|
|
261
|
+
return [item for eval_result in eval_results for item in eval_result.item_scores]
|
|
262
|
+
else:
|
|
263
|
+
return [
|
|
264
|
+
{k: v for k, v in item.items() if k != "inference_output"}
|
|
265
|
+
for eval_result in eval_results
|
|
266
|
+
for item in eval_result.item_scores
|
|
267
|
+
]
|
|
268
|
+
|
|
269
|
+
# Return results as an EvalResult object
|
|
270
|
+
else:
|
|
211
271
|
return {er.eval_dataset.name: er for er in eval_results}
|
|
212
|
-
|
|
213
|
-
if score_type == "all":
|
|
214
|
-
combined: Dict[str, List[Dict[str, Any]]] = {"aggregate": [], "per_sample": []}
|
|
215
|
-
for er in eval_results:
|
|
216
|
-
d = er.to_dict()
|
|
217
|
-
combined["aggregate"].extend(d["aggregate"])
|
|
218
|
-
combined["per_sample"].extend(d["per_sample"])
|
|
219
|
-
return combined
|
|
220
|
-
|
|
221
|
-
if score_type == "aggregate":
|
|
222
|
-
return [er.aggregate_scores for er in eval_results]
|
|
223
|
-
|
|
224
|
-
if score_type == "item":
|
|
225
|
-
return [item for er in eval_results for item in er.item_scores]
|
|
226
|
-
|
|
227
|
-
# Should be unreachable due to validation
|
|
228
|
-
return {}
|
|
@@ -5,7 +5,3 @@ This module provides functionality for running inference with various models
|
|
|
5
5
|
and processing their responses. It includes utilities for both single and
|
|
6
6
|
batch inference operations.
|
|
7
7
|
"""
|
|
8
|
-
|
|
9
|
-
from scorebook.inference.openai import batch, responses
|
|
10
|
-
|
|
11
|
-
__all__ = ["responses", "batch"]
|
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AWS Bedrock batch inference implementation for Scorebook.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for running batch inference using AWS Bedrock's
|
|
5
|
+
Model Invocation Jobs, supporting large-scale asynchronous processing. It handles
|
|
6
|
+
API communication, request formatting, response processing, and S3 operations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import tempfile
|
|
13
|
+
import uuid
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import boto3
|
|
18
|
+
from botocore.config import Config
|
|
19
|
+
from botocore.exceptions import ClientError
|
|
20
|
+
from tqdm.asyncio import tqdm
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
async def batch(
|
|
24
|
+
items: List[Any],
|
|
25
|
+
model: Optional[str] = None,
|
|
26
|
+
aws_region: Optional[str] = None,
|
|
27
|
+
aws_profile: Optional[str] = None,
|
|
28
|
+
bucket: Optional[str] = None,
|
|
29
|
+
input_prefix: Optional[str] = None,
|
|
30
|
+
output_prefix: Optional[str] = None,
|
|
31
|
+
role_arn: Optional[str] = None,
|
|
32
|
+
**hyperparameters: Any,
|
|
33
|
+
) -> List[Any]:
|
|
34
|
+
"""Process multiple inference requests in batch using AWS Bedrock.
|
|
35
|
+
|
|
36
|
+
This asynchronous function handles batch processing of inference requests,
|
|
37
|
+
optimizing for cost and throughput using AWS Bedrock's Model Invocation Jobs.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
items: List of preprocessed items to process.
|
|
41
|
+
model: Bedrock model ID (e.g., 'us.anthropic.claude-3-5-sonnet-20241022-v2:0').
|
|
42
|
+
aws_region: AWS region for Bedrock and S3.
|
|
43
|
+
aws_profile: AWS profile name for authentication.
|
|
44
|
+
bucket: S3 bucket name for input/output data.
|
|
45
|
+
input_prefix: S3 prefix for input data.
|
|
46
|
+
output_prefix: S3 prefix for output data.
|
|
47
|
+
role_arn: IAM role ARN for Bedrock execution.
|
|
48
|
+
hyperparameters: Additional parameters for the batch requests.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A list of raw model responses.
|
|
52
|
+
"""
|
|
53
|
+
# Set up AWS session and clients
|
|
54
|
+
session_kwargs = {}
|
|
55
|
+
if aws_profile:
|
|
56
|
+
session_kwargs["profile_name"] = aws_profile
|
|
57
|
+
if aws_region:
|
|
58
|
+
session_kwargs["region_name"] = aws_region
|
|
59
|
+
|
|
60
|
+
session = boto3.Session(**session_kwargs)
|
|
61
|
+
|
|
62
|
+
boto_config = Config(region_name=aws_region, retries={"max_attempts": 10, "mode": "adaptive"})
|
|
63
|
+
|
|
64
|
+
s3_client = session.client("s3", config=boto_config)
|
|
65
|
+
bedrock_client = session.client("bedrock", config=boto_config)
|
|
66
|
+
|
|
67
|
+
# Upload batch data
|
|
68
|
+
input_uri = await _upload_batch(
|
|
69
|
+
items, s3_client, bucket, input_prefix, model, **hyperparameters
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Start batch job
|
|
73
|
+
job_arn = await _start_batch_job(
|
|
74
|
+
bedrock_client, model, input_uri, bucket, output_prefix, role_arn
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Wait for completion with progress tracking
|
|
78
|
+
await _wait_for_completion(bedrock_client, job_arn, len(items))
|
|
79
|
+
|
|
80
|
+
# Retrieve results
|
|
81
|
+
results = await _get_batch_results(s3_client, bedrock_client, job_arn)
|
|
82
|
+
|
|
83
|
+
return results
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
async def _upload_batch(
|
|
87
|
+
items: List[Any],
|
|
88
|
+
s3_client: Any,
|
|
89
|
+
bucket: Optional[str],
|
|
90
|
+
input_prefix: Optional[str],
|
|
91
|
+
model: Optional[str],
|
|
92
|
+
**hyperparameters: Any,
|
|
93
|
+
) -> str:
|
|
94
|
+
"""Create a JSONL file from preprocessed items and upload to S3 for batch processing."""
|
|
95
|
+
|
|
96
|
+
# Generate unique run ID and key
|
|
97
|
+
run_id = datetime.utcnow().strftime("%Y%m%dT%H%M%S") + "-" + uuid.uuid4().hex[:8]
|
|
98
|
+
|
|
99
|
+
if input_prefix:
|
|
100
|
+
input_key = f"{input_prefix.rstrip('/')}/inputs-{run_id}.jsonl"
|
|
101
|
+
else:
|
|
102
|
+
input_key = f"inputs-{run_id}.jsonl"
|
|
103
|
+
|
|
104
|
+
# Create temp JSONL file
|
|
105
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
|
106
|
+
for i, item in enumerate(items):
|
|
107
|
+
# Construct batch request in Bedrock format
|
|
108
|
+
record = {
|
|
109
|
+
"recordId": f"rec-{i:04d}",
|
|
110
|
+
"modelInput": _build_claude_messages_payload(item, **hyperparameters),
|
|
111
|
+
}
|
|
112
|
+
f.write(json.dumps(record, separators=(",", ":")) + "\n")
|
|
113
|
+
file_path = f.name
|
|
114
|
+
|
|
115
|
+
# Upload to S3
|
|
116
|
+
try:
|
|
117
|
+
body = open(file_path, "rb").read()
|
|
118
|
+
s3_client.put_object(
|
|
119
|
+
Bucket=bucket,
|
|
120
|
+
Key=input_key,
|
|
121
|
+
Body=body,
|
|
122
|
+
StorageClass="INTELLIGENT_TIERING",
|
|
123
|
+
ContentType="application/json",
|
|
124
|
+
)
|
|
125
|
+
input_uri = f"s3://{bucket}/{input_key}"
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise Exception(f"Failed to upload file to S3: {e}")
|
|
128
|
+
finally:
|
|
129
|
+
# Clean up temp file
|
|
130
|
+
os.unlink(file_path)
|
|
131
|
+
|
|
132
|
+
return input_uri
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _build_claude_messages_payload(item: Any, **hyperparameters: Any) -> Dict[str, Any]:
|
|
136
|
+
"""Build Claude messages payload for Bedrock batch processing."""
|
|
137
|
+
|
|
138
|
+
# item is a list of messages from our preprocessor
|
|
139
|
+
messages = item
|
|
140
|
+
|
|
141
|
+
# Convert to Bedrock format and extract system message
|
|
142
|
+
bedrock_messages = []
|
|
143
|
+
system_content = None
|
|
144
|
+
|
|
145
|
+
for msg in messages:
|
|
146
|
+
if msg["role"] == "system":
|
|
147
|
+
system_content = msg["content"]
|
|
148
|
+
else:
|
|
149
|
+
bedrock_messages.append(
|
|
150
|
+
{"role": msg["role"], "content": [{"type": "text", "text": msg["content"]}]}
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
payload = {
|
|
154
|
+
"anthropic_version": "bedrock-2023-05-31",
|
|
155
|
+
"max_tokens": 256,
|
|
156
|
+
"messages": bedrock_messages,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if system_content:
|
|
160
|
+
payload["system"] = system_content
|
|
161
|
+
|
|
162
|
+
payload.update(hyperparameters)
|
|
163
|
+
return payload
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
async def _start_batch_job(
|
|
167
|
+
bedrock_client: Any,
|
|
168
|
+
model: Optional[str],
|
|
169
|
+
input_uri: str,
|
|
170
|
+
bucket: Optional[str],
|
|
171
|
+
output_prefix: Optional[str],
|
|
172
|
+
role_arn: Optional[str],
|
|
173
|
+
) -> str:
|
|
174
|
+
"""Start a Bedrock Model Invocation Job."""
|
|
175
|
+
|
|
176
|
+
# Generate unique job name and output URI
|
|
177
|
+
run_id = datetime.utcnow().strftime("%Y%m%dT%H%M%S") + "-" + uuid.uuid4().hex[:8]
|
|
178
|
+
job_name = f"bedrock-batch-{run_id}"
|
|
179
|
+
|
|
180
|
+
if output_prefix:
|
|
181
|
+
output_uri = f"s3://{bucket}/{output_prefix.rstrip('/')}/job-{run_id}/"
|
|
182
|
+
else:
|
|
183
|
+
output_uri = f"s3://{bucket}/job-{run_id}/"
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
response = bedrock_client.create_model_invocation_job(
|
|
187
|
+
jobName=job_name,
|
|
188
|
+
modelId=model,
|
|
189
|
+
roleArn=role_arn,
|
|
190
|
+
inputDataConfig={"s3InputDataConfig": {"s3Uri": input_uri}},
|
|
191
|
+
outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_uri}},
|
|
192
|
+
tags=[{"key": "project", "value": "scorebook-batch"}],
|
|
193
|
+
)
|
|
194
|
+
job_arn: str = response["jobArn"]
|
|
195
|
+
return job_arn
|
|
196
|
+
except ClientError as e:
|
|
197
|
+
error_info = e.response.get("Error", {})
|
|
198
|
+
raise Exception(f"Failed to create batch job: {error_info}")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
async def _wait_for_completion(bedrock_client: Any, job_arn: str, total_items: int) -> None:
|
|
202
|
+
"""Wait for batch job completion with progress tracking."""
|
|
203
|
+
|
|
204
|
+
# Initialize progress bar
|
|
205
|
+
pbar = tqdm(total=total_items, desc="Batch processing", unit="requests")
|
|
206
|
+
|
|
207
|
+
terminal_states = {"Completed", "Failed", "Stopped"}
|
|
208
|
+
sleep_time = 15
|
|
209
|
+
|
|
210
|
+
while True:
|
|
211
|
+
try:
|
|
212
|
+
desc = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)
|
|
213
|
+
status = desc["status"]
|
|
214
|
+
|
|
215
|
+
# Get progress if available
|
|
216
|
+
job_state = desc.get("jobState", {})
|
|
217
|
+
progress = job_state.get("percentComplete")
|
|
218
|
+
|
|
219
|
+
# Update progress bar
|
|
220
|
+
if progress is not None:
|
|
221
|
+
completed = int((progress / 100) * total_items)
|
|
222
|
+
pbar.n = completed
|
|
223
|
+
pbar.set_postfix(status=status, progress=f"{progress}%")
|
|
224
|
+
else:
|
|
225
|
+
pbar.set_postfix(status=status)
|
|
226
|
+
|
|
227
|
+
pbar.refresh()
|
|
228
|
+
|
|
229
|
+
if status in terminal_states:
|
|
230
|
+
if status == "Completed":
|
|
231
|
+
pbar.n = pbar.total
|
|
232
|
+
pbar.set_postfix(status="COMPLETED")
|
|
233
|
+
else:
|
|
234
|
+
pbar.close()
|
|
235
|
+
error_msg = desc.get("failureMessage", f"Job ended with status {status}")
|
|
236
|
+
raise Exception(f"Batch job failed: {error_msg}")
|
|
237
|
+
break
|
|
238
|
+
|
|
239
|
+
# Wait before checking again
|
|
240
|
+
await asyncio.sleep(sleep_time)
|
|
241
|
+
|
|
242
|
+
except Exception as e:
|
|
243
|
+
pbar.close()
|
|
244
|
+
raise e
|
|
245
|
+
|
|
246
|
+
pbar.close()
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
async def _get_batch_results(s3_client: Any, bedrock_client: Any, job_arn: str) -> List[str]:
|
|
250
|
+
"""Download and parse batch results from S3."""
|
|
251
|
+
|
|
252
|
+
# Get job details to find output location
|
|
253
|
+
desc = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)
|
|
254
|
+
output_uri = desc["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
|
255
|
+
|
|
256
|
+
bucket_name, prefix = s3_uri_to_bucket_and_prefix(output_uri)
|
|
257
|
+
|
|
258
|
+
# Find the output JSONL file
|
|
259
|
+
try:
|
|
260
|
+
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
|
261
|
+
contents = response.get("Contents", [])
|
|
262
|
+
|
|
263
|
+
# Look for the output JSONL file
|
|
264
|
+
output_key = None
|
|
265
|
+
for obj in contents:
|
|
266
|
+
if obj["Key"].endswith(".jsonl.out"):
|
|
267
|
+
output_key = obj["Key"]
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
if not output_key:
|
|
271
|
+
raise Exception("No output JSONL file found in S3")
|
|
272
|
+
|
|
273
|
+
# Download and parse results
|
|
274
|
+
obj_response = s3_client.get_object(Bucket=bucket_name, Key=output_key)
|
|
275
|
+
content = obj_response["Body"].read().decode("utf-8")
|
|
276
|
+
|
|
277
|
+
results = []
|
|
278
|
+
for line in content.strip().split("\n"):
|
|
279
|
+
if line.strip():
|
|
280
|
+
result_obj = json.loads(line)
|
|
281
|
+
# Extract text from Claude response format
|
|
282
|
+
model_output = result_obj.get("modelOutput", {})
|
|
283
|
+
content_list = model_output.get("content", [])
|
|
284
|
+
if content_list and len(content_list) > 0:
|
|
285
|
+
text = content_list[0].get("text", "")
|
|
286
|
+
results.append(text)
|
|
287
|
+
else:
|
|
288
|
+
results.append("")
|
|
289
|
+
|
|
290
|
+
return results
|
|
291
|
+
|
|
292
|
+
except Exception as e:
|
|
293
|
+
raise Exception(f"Failed to retrieve batch results: {e}")
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def s3_uri_to_bucket_and_prefix(s3_uri: str) -> Tuple[str, str]:
|
|
297
|
+
"""Parse S3 URI to bucket and prefix."""
|
|
298
|
+
# Parse S3 URI
|
|
299
|
+
if s3_uri.startswith("s3://"):
|
|
300
|
+
uri_parts = s3_uri[5:].split("/", 1)
|
|
301
|
+
bucket_name = uri_parts[0]
|
|
302
|
+
prefix = uri_parts[1] if len(uri_parts) > 1 else ""
|
|
303
|
+
else:
|
|
304
|
+
raise ValueError(f"Invalid S3 URI: {s3_uri}")
|
|
305
|
+
return bucket_name, prefix
|
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google Cloud Vertex AI batch inference implementation for Scorebook.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for running batch inference using Google Cloud
|
|
5
|
+
Vertex AI Gemini models, supporting large-scale asynchronous processing. It handles
|
|
6
|
+
API communication, request formatting, response processing, and Cloud Storage operations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import tempfile
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from typing import Any, Dict, List, Optional, Union
|
|
15
|
+
|
|
16
|
+
import fsspec
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from google import genai
|
|
19
|
+
from google.cloud import storage
|
|
20
|
+
from google.genai import types
|
|
21
|
+
from tqdm.asyncio import tqdm
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def responses(
|
|
25
|
+
items: List[
|
|
26
|
+
Union[
|
|
27
|
+
str,
|
|
28
|
+
List[str],
|
|
29
|
+
types.Content,
|
|
30
|
+
List[types.Content],
|
|
31
|
+
types.FunctionCall,
|
|
32
|
+
List[types.FunctionCall],
|
|
33
|
+
types.Part,
|
|
34
|
+
List[types.Part],
|
|
35
|
+
]
|
|
36
|
+
],
|
|
37
|
+
model: str,
|
|
38
|
+
client: Optional[genai.Client] = None,
|
|
39
|
+
project_id: Optional[str] = None,
|
|
40
|
+
location: str = "us-central1",
|
|
41
|
+
system_instruction: Optional[str] = None,
|
|
42
|
+
**hyperparameters: Any,
|
|
43
|
+
) -> List[types.GenerateContentResponse]:
|
|
44
|
+
"""Process multiple inference requests using Google Cloud Vertex AI.
|
|
45
|
+
|
|
46
|
+
This asynchronous function handles multiple inference requests,
|
|
47
|
+
manages the API communication, and processes the responses.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
items: List of preprocessed items to process.
|
|
51
|
+
model: Gemini model ID to use (e.g., 'gemini-2.0-flash-001').
|
|
52
|
+
client: Optional Vertex AI client instance.
|
|
53
|
+
project_id: Google Cloud Project ID. If None, uses GOOGLE_CLOUD_PROJECT env var.
|
|
54
|
+
location: Google Cloud region (default: 'us-central1').
|
|
55
|
+
system_instruction: Optional system instruction to guide model behavior.
|
|
56
|
+
hyperparameters: Additional parameters for the requests.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of raw model responses.
|
|
60
|
+
"""
|
|
61
|
+
if client is None:
|
|
62
|
+
if project_id is None:
|
|
63
|
+
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
64
|
+
if not project_id:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"Project ID must be provided or set in GOOGLE_CLOUD_PROJECT "
|
|
67
|
+
"environment variable"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
client = genai.Client(
|
|
71
|
+
vertexai=True,
|
|
72
|
+
project=project_id,
|
|
73
|
+
location=location,
|
|
74
|
+
http_options=types.HttpOptions(api_version="v1"),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Create config if system_instruction or hyperparameters are provided
|
|
78
|
+
config = None
|
|
79
|
+
if system_instruction or hyperparameters:
|
|
80
|
+
config_dict = {}
|
|
81
|
+
if system_instruction:
|
|
82
|
+
config_dict["system_instruction"] = system_instruction
|
|
83
|
+
if hyperparameters:
|
|
84
|
+
config_dict.update(hyperparameters)
|
|
85
|
+
config = types.GenerateContentConfig(**config_dict)
|
|
86
|
+
|
|
87
|
+
results = []
|
|
88
|
+
for item in items:
|
|
89
|
+
response = client.models.generate_content(
|
|
90
|
+
model=model,
|
|
91
|
+
contents=item,
|
|
92
|
+
config=config,
|
|
93
|
+
)
|
|
94
|
+
results.append(response)
|
|
95
|
+
|
|
96
|
+
return results
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def batch(
|
|
100
|
+
items: List[Any],
|
|
101
|
+
model: str,
|
|
102
|
+
project_id: Optional[str] = None,
|
|
103
|
+
location: str = "us-central1",
|
|
104
|
+
input_bucket: Optional[str] = None,
|
|
105
|
+
output_bucket: Optional[str] = None,
|
|
106
|
+
**hyperparameters: Any,
|
|
107
|
+
) -> List[Any]:
|
|
108
|
+
"""Process multiple inference requests in batch using Google Cloud Vertex AI.
|
|
109
|
+
|
|
110
|
+
This asynchronous function handles batch processing of inference requests,
|
|
111
|
+
optimizing for cost and throughput using Google Cloud's batch prediction API.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
items: List of preprocessed items to process.
|
|
115
|
+
model: Gemini model ID to use (e.g., 'gemini-2.0-flash-001').
|
|
116
|
+
project_id: Google Cloud Project ID. If None, uses GOOGLE_CLOUD_PROJECT env var.
|
|
117
|
+
location: Google Cloud region (default: 'us-central1').
|
|
118
|
+
input_bucket: GCS bucket for input data (required).
|
|
119
|
+
output_bucket: GCS bucket for output data (required).
|
|
120
|
+
hyperparameters: Additional parameters for the batch requests.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
A list of raw model responses.
|
|
124
|
+
"""
|
|
125
|
+
# Set up project ID
|
|
126
|
+
if project_id is None:
|
|
127
|
+
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
128
|
+
if not project_id:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"Project ID must be provided or set in GOOGLE_CLOUD_PROJECT " "environment variable"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if not input_bucket or not output_bucket:
|
|
134
|
+
raise ValueError("Both input_bucket and output_bucket must be provided")
|
|
135
|
+
|
|
136
|
+
# Initialize client
|
|
137
|
+
client = genai.Client(vertexai=True, project=project_id, location=location)
|
|
138
|
+
|
|
139
|
+
# Upload batch data
|
|
140
|
+
input_uri = await _upload_batch(items, input_bucket, model, project_id, **hyperparameters)
|
|
141
|
+
|
|
142
|
+
# Start batch job
|
|
143
|
+
batch_job = await _start_batch_job(client, model, input_uri, output_bucket)
|
|
144
|
+
|
|
145
|
+
# Wait for completion with progress tracking
|
|
146
|
+
await _wait_for_completion(client, batch_job, len(items))
|
|
147
|
+
|
|
148
|
+
# Retrieve results
|
|
149
|
+
results = await _get_batch_results(batch_job)
|
|
150
|
+
|
|
151
|
+
return results
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def _upload_batch(
|
|
155
|
+
items: List[Any], input_bucket: str, model: str, project_id: str, **hyperparameters: Any
|
|
156
|
+
) -> str:
|
|
157
|
+
"""Create a JSONL file from preprocessed items and upload to GCS for batch processing."""
|
|
158
|
+
|
|
159
|
+
# Create temp JSONL file
|
|
160
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
|
161
|
+
for item in items:
|
|
162
|
+
# Construct batch request in Vertex AI format
|
|
163
|
+
request_data: Dict[str, Any] = {
|
|
164
|
+
"request": {
|
|
165
|
+
"contents": [
|
|
166
|
+
{
|
|
167
|
+
"role": "user",
|
|
168
|
+
"parts": [
|
|
169
|
+
{
|
|
170
|
+
"text": (
|
|
171
|
+
str(item)
|
|
172
|
+
if not isinstance(item, list)
|
|
173
|
+
else item[0]["content"]
|
|
174
|
+
)
|
|
175
|
+
}
|
|
176
|
+
],
|
|
177
|
+
}
|
|
178
|
+
]
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
# Only add generationConfig if hyperparameters are provided
|
|
183
|
+
if hyperparameters:
|
|
184
|
+
request_data["request"]["generationConfig"] = hyperparameters
|
|
185
|
+
f.write(json.dumps(request_data) + "\n")
|
|
186
|
+
file_path = f.name
|
|
187
|
+
|
|
188
|
+
# Upload to GCS using Python client
|
|
189
|
+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
190
|
+
|
|
191
|
+
# Parse bucket name and path from input_bucket
|
|
192
|
+
if input_bucket.startswith("gs://"):
|
|
193
|
+
bucket_path = input_bucket[5:] # Remove 'gs://' prefix
|
|
194
|
+
else:
|
|
195
|
+
bucket_path = input_bucket
|
|
196
|
+
|
|
197
|
+
# Split bucket name and path
|
|
198
|
+
bucket_name = bucket_path.split("/")[0]
|
|
199
|
+
bucket_prefix = "/".join(bucket_path.split("/")[1:]) if "/" in bucket_path else ""
|
|
200
|
+
|
|
201
|
+
# Create blob path
|
|
202
|
+
blob_name = (
|
|
203
|
+
f"{bucket_prefix}/batch_input_{timestamp}.jsonl"
|
|
204
|
+
if bucket_prefix
|
|
205
|
+
else f"batch_input_{timestamp}.jsonl"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Upload using Google Cloud Storage client
|
|
209
|
+
try:
|
|
210
|
+
gcs_client = storage.Client(project=project_id)
|
|
211
|
+
bucket = gcs_client.bucket(bucket_name)
|
|
212
|
+
blob = bucket.blob(blob_name)
|
|
213
|
+
|
|
214
|
+
with open(file_path, "rb") as f:
|
|
215
|
+
blob.upload_from_file(f)
|
|
216
|
+
|
|
217
|
+
input_uri = f"gs://{bucket_name}/{blob_name}"
|
|
218
|
+
|
|
219
|
+
except Exception as e:
|
|
220
|
+
raise Exception(f"Failed to upload file to GCS: {e}")
|
|
221
|
+
finally:
|
|
222
|
+
# Clean up temp file
|
|
223
|
+
os.unlink(file_path)
|
|
224
|
+
|
|
225
|
+
return input_uri
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
async def _start_batch_job(
|
|
229
|
+
client: genai.Client, model: str, input_uri: str, output_bucket: str
|
|
230
|
+
) -> Any:
|
|
231
|
+
"""Start a batch prediction job."""
|
|
232
|
+
batch_job = client.batches.create(
|
|
233
|
+
model=model,
|
|
234
|
+
src=input_uri,
|
|
235
|
+
config=types.CreateBatchJobConfig(dest=output_bucket),
|
|
236
|
+
)
|
|
237
|
+
return batch_job
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
async def _wait_for_completion(client: genai.Client, batch_job: Any, total_items: int) -> None:
|
|
241
|
+
"""Wait for batch job completion with progress tracking."""
|
|
242
|
+
# Initialize progress bar
|
|
243
|
+
pbar = tqdm(total=total_items, desc="Batch processing", unit="requests")
|
|
244
|
+
|
|
245
|
+
while True:
|
|
246
|
+
# Refresh job status
|
|
247
|
+
batch_job = client.batches.get(name=batch_job.name)
|
|
248
|
+
state = batch_job.state
|
|
249
|
+
|
|
250
|
+
# Update progress bar
|
|
251
|
+
pbar.set_postfix(status=str(state).replace("JobState.JOB_STATE_", ""))
|
|
252
|
+
pbar.refresh()
|
|
253
|
+
|
|
254
|
+
if state.name == "JOB_STATE_SUCCEEDED":
|
|
255
|
+
pbar.n = pbar.total
|
|
256
|
+
pbar.set_postfix(status="COMPLETED")
|
|
257
|
+
break
|
|
258
|
+
elif state.name == "JOB_STATE_FAILED":
|
|
259
|
+
pbar.close()
|
|
260
|
+
error_msg = getattr(batch_job, "error", "Unknown error")
|
|
261
|
+
raise Exception(f"Batch job failed: {error_msg}")
|
|
262
|
+
elif state.name in ["JOB_STATE_CANCELLED", "JOB_STATE_PAUSED"]:
|
|
263
|
+
pbar.close()
|
|
264
|
+
raise Exception(f"Batch job was {state.name}")
|
|
265
|
+
|
|
266
|
+
# Wait before checking again
|
|
267
|
+
await asyncio.sleep(30)
|
|
268
|
+
|
|
269
|
+
pbar.close()
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
async def _get_batch_results(batch_job: Any) -> List[str]:
|
|
273
|
+
"""Download and parse batch results from GCS."""
|
|
274
|
+
|
|
275
|
+
# Set up GCS filesystem
|
|
276
|
+
fs = fsspec.filesystem("gcs")
|
|
277
|
+
|
|
278
|
+
# Find predictions file - the pattern is: dest_uri/prediction-model-*/predictions.jsonl
|
|
279
|
+
output_uri = batch_job.dest.gcs_uri.rstrip("/")
|
|
280
|
+
file_paths = fs.glob(f"{output_uri}/prediction-model-*/predictions.jsonl")
|
|
281
|
+
|
|
282
|
+
if not file_paths:
|
|
283
|
+
raise Exception("No predictions file found in output bucket")
|
|
284
|
+
|
|
285
|
+
# Load and parse results
|
|
286
|
+
df = pd.read_json(f"gs://{file_paths[0]}", lines=True)
|
|
287
|
+
|
|
288
|
+
results = []
|
|
289
|
+
for _, row in df.iterrows():
|
|
290
|
+
# Extract text content from successful responses
|
|
291
|
+
response = row["response"]
|
|
292
|
+
text_content = response["candidates"][0]["content"]["parts"][0]["text"]
|
|
293
|
+
results.append(text_content)
|
|
294
|
+
|
|
295
|
+
return results
|
|
@@ -4,6 +4,7 @@ import csv
|
|
|
4
4
|
import json
|
|
5
5
|
from typing import Any, Dict, Iterator, List, Optional, Type, Union
|
|
6
6
|
|
|
7
|
+
import yaml
|
|
7
8
|
from datasets import Dataset as HuggingFaceDataset
|
|
8
9
|
from datasets import DatasetDict as HuggingFaceDatasetDict
|
|
9
10
|
from datasets import load_dataset
|
|
@@ -21,6 +22,7 @@ class EvalDataset:
|
|
|
21
22
|
label: str,
|
|
22
23
|
metrics: Union[str, Type[MetricBase], List[Union[str, Type[MetricBase]]]],
|
|
23
24
|
hf_dataset: HuggingFaceDataset,
|
|
25
|
+
prompt_template: Optional[str] = None,
|
|
24
26
|
):
|
|
25
27
|
"""
|
|
26
28
|
Create a new scorebook evaluation dataset instance.
|
|
@@ -29,11 +31,13 @@ class EvalDataset:
|
|
|
29
31
|
:param label: The label field of the dataset.
|
|
30
32
|
:param metrics: The specified metrics associated with the dataset.
|
|
31
33
|
:param hf_dataset: The dataset as a hugging face dataset object.
|
|
34
|
+
:param prompt_template: Optional prompt template for building prompts from dataset items.
|
|
32
35
|
"""
|
|
33
36
|
self.name: str = name
|
|
34
37
|
self.label: str = label
|
|
35
38
|
self.metrics: List[MetricBase] = self._resolve_metrics(metrics)
|
|
36
39
|
self._hf_dataset: Optional[HuggingFaceDataset] = hf_dataset
|
|
40
|
+
self.prompt_template: Optional[str] = prompt_template
|
|
37
41
|
|
|
38
42
|
def __len__(self) -> int:
|
|
39
43
|
"""Return the number of items in the dataset."""
|
|
@@ -286,6 +290,52 @@ class EvalDataset:
|
|
|
286
290
|
|
|
287
291
|
return cls(name=path, label=label, metrics=metrics, hf_dataset=hf_dataset)
|
|
288
292
|
|
|
293
|
+
@classmethod
|
|
294
|
+
def from_yaml(cls, file_path: str) -> "EvalDataset":
|
|
295
|
+
"""Instantiate an EvalDataset from a YAML file.
|
|
296
|
+
|
|
297
|
+
The YAML file should contain configuration for loading a dataset, including:
|
|
298
|
+
- name: Name of the dataset or Hugging Face dataset path
|
|
299
|
+
- label: The field used as the evaluation label
|
|
300
|
+
- metrics: List of metrics to evaluate
|
|
301
|
+
- split: Optional split name to load
|
|
302
|
+
- template: Optional prompt template
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
An EvalDataset instance configured according to the YAML file.
|
|
306
|
+
|
|
307
|
+
Raises:
|
|
308
|
+
ValueError: If the YAML file is invalid or missing required fields.
|
|
309
|
+
"""
|
|
310
|
+
path = validate_path(file_path, expected_suffix=".yaml")
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
with path.open("r", encoding="utf-8") as f:
|
|
314
|
+
config = yaml.safe_load(f)
|
|
315
|
+
except yaml.YAMLError as e:
|
|
316
|
+
raise ValueError(f"Invalid YAML in {file_path}: {e}") from e
|
|
317
|
+
|
|
318
|
+
# Validate required fields
|
|
319
|
+
required_fields = ["name", "label", "metrics"]
|
|
320
|
+
missing_fields = [field for field in required_fields if field not in config]
|
|
321
|
+
if missing_fields:
|
|
322
|
+
raise ValueError(f"Missing required fields in YAML config: {', '.join(missing_fields)}")
|
|
323
|
+
|
|
324
|
+
# Load the dataset from Hugging Face
|
|
325
|
+
dataset = cls.from_huggingface(
|
|
326
|
+
path=config["name"],
|
|
327
|
+
label=config["label"],
|
|
328
|
+
metrics=config["metrics"],
|
|
329
|
+
split=config.get("split"), # Optional field
|
|
330
|
+
name=config.get("config"), # Optional field
|
|
331
|
+
)
|
|
332
|
+
|
|
333
|
+
# Add template if provided
|
|
334
|
+
if "template" in config:
|
|
335
|
+
dataset.prompt_template = config["template"]
|
|
336
|
+
|
|
337
|
+
return dataset
|
|
338
|
+
|
|
289
339
|
@staticmethod
|
|
290
340
|
def _resolve_metrics(
|
|
291
341
|
metrics: Union[
|
|
@@ -42,6 +42,7 @@ class EvalResult:
|
|
|
42
42
|
result = {
|
|
43
43
|
"item_id": idx,
|
|
44
44
|
"dataset_name": self.eval_dataset.name,
|
|
45
|
+
"inference_output": self.inference_outputs[idx],
|
|
45
46
|
**{
|
|
46
47
|
metric: self.metric_scores[metric]["item_scores"][idx]
|
|
47
48
|
for metric in metric_names
|
|
@@ -74,13 +75,13 @@ class EvalResult:
|
|
|
74
75
|
def to_dict(self) -> Dict[str, Any]:
|
|
75
76
|
"""Return a dictionary representing the evaluation results."""
|
|
76
77
|
return {
|
|
77
|
-
"
|
|
78
|
+
"aggregate_results": [
|
|
78
79
|
{
|
|
79
80
|
**getattr(self.eval_dataset, "hyperparams", {}),
|
|
80
81
|
**self.aggregate_scores,
|
|
81
82
|
}
|
|
82
83
|
],
|
|
83
|
-
"
|
|
84
|
+
"item_results": [item for item in self.item_scores],
|
|
84
85
|
}
|
|
85
86
|
|
|
86
87
|
def to_csv(self, file_path: str) -> None:
|
|
@@ -113,7 +114,10 @@ class EvalResult:
|
|
|
113
114
|
writer.writerow(row)
|
|
114
115
|
|
|
115
116
|
def to_json(self, file_path: str) -> None:
|
|
116
|
-
"""Save evaluation results to a JSON file in structured format
|
|
117
|
+
"""Save evaluation results to a JSON file in a structured format.
|
|
118
|
+
|
|
119
|
+
The JSON file will contain both aggregate & item results, produced by the to_dict() method.
|
|
120
|
+
"""
|
|
117
121
|
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
|
118
122
|
with open(file_path, "w") as f:
|
|
119
123
|
json.dump(self.to_dict(), f, indent=2)
|
|
@@ -57,7 +57,7 @@ class InferencePipeline:
|
|
|
57
57
|
List of processed outputs after running through the complete pipeline
|
|
58
58
|
"""
|
|
59
59
|
if self.preprocessor:
|
|
60
|
-
input_items = [self.preprocessor(item) for item in items]
|
|
60
|
+
input_items = [self.preprocessor(item, hyperparameters) for item in items]
|
|
61
61
|
else:
|
|
62
62
|
input_items = items
|
|
63
63
|
|
|
@@ -67,7 +67,10 @@ class InferencePipeline:
|
|
|
67
67
|
inference_outputs = self.inference_function(input_items, **hyperparameters)
|
|
68
68
|
|
|
69
69
|
if self.postprocessor:
|
|
70
|
-
return [
|
|
70
|
+
return [
|
|
71
|
+
self.postprocessor(inference_output, hyperparameters)
|
|
72
|
+
for inference_output in inference_outputs
|
|
73
|
+
]
|
|
71
74
|
else:
|
|
72
75
|
return cast(List[Any], inference_outputs)
|
|
73
76
|
|
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Utility functions and common helpers for the Scorebook framework."""
|
|
2
2
|
|
|
3
3
|
from scorebook.utils.async_utils import is_awaitable
|
|
4
|
+
from scorebook.utils.build_prompt import build_prompt
|
|
4
5
|
from scorebook.utils.io_helpers import validate_path
|
|
5
6
|
from scorebook.utils.progress_bars import evaluation_progress
|
|
6
7
|
from scorebook.utils.transform_helpers import expand_dict
|
|
7
8
|
|
|
8
|
-
__all__ = ["is_awaitable", "validate_path", "expand_dict", "evaluation_progress"]
|
|
9
|
+
__all__ = ["is_awaitable", "validate_path", "expand_dict", "evaluation_progress", "build_prompt"]
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module for building prompt strings using Jinja2 templating.
|
|
3
|
+
|
|
4
|
+
Provides functionality to render prompts from templates with custom filters
|
|
5
|
+
and global variables, using strict undefined handling for better error detection.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
from jinja2 import BaseLoader, Environment, StrictUndefined
|
|
11
|
+
|
|
12
|
+
from scorebook.utils.jinja_helpers import default_jinja_filters, default_jinja_globals
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def build_prompt(
|
|
16
|
+
prompt_template: str,
|
|
17
|
+
prompt_args: Dict[str, Any],
|
|
18
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
19
|
+
globals_dict: Optional[Dict[str, Any]] = None,
|
|
20
|
+
) -> str:
|
|
21
|
+
"""
|
|
22
|
+
Build a prompt string from a template and arguments.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
prompt_template: Jinja2 template string
|
|
26
|
+
prompt_args: Dictionary of arguments to pass to the template
|
|
27
|
+
filters: Dictionary of Jinja2 filters. Defaults to default_jinja_filters().
|
|
28
|
+
globals_dict: Dictionary of global functions/variables. Defaults to default_jinja_globals().
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
str: Rendered prompt string
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Use defaults if not provided
|
|
35
|
+
filters = filters or default_jinja_filters()
|
|
36
|
+
globals_dict = globals_dict or default_jinja_globals()
|
|
37
|
+
|
|
38
|
+
# Create a Jinja2 environment with strict undefined handling
|
|
39
|
+
env = Environment(
|
|
40
|
+
loader=BaseLoader(),
|
|
41
|
+
undefined=StrictUndefined,
|
|
42
|
+
trim_blocks=True,
|
|
43
|
+
lstrip_blocks=True,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Add filters and globals
|
|
47
|
+
env.filters.update(filters)
|
|
48
|
+
env.globals.update(globals_dict)
|
|
49
|
+
|
|
50
|
+
# Render the template
|
|
51
|
+
template = env.from_string(prompt_template)
|
|
52
|
+
return str(template.render(**prompt_args))
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Jinja2 template helper functions for Scorebook."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
# Helper functions for use in Jinja templates
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def number_to_letter(index: int, uppercase: bool = True) -> str:
|
|
11
|
+
"""Convert a number to a letter (0->A, 1->B, etc.).
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
index: The number to convert to a letter (0-based index, must be 0-25)
|
|
15
|
+
uppercase: If True, returns uppercase letter; if False, returns lowercase
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
str: A letter from A-Z (or a-z if uppercase is False)
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ValueError: If index is less than 0 or greater than 25
|
|
22
|
+
"""
|
|
23
|
+
if not 0 <= index <= 25:
|
|
24
|
+
raise ValueError("Index must be between 0 and 25 inclusive")
|
|
25
|
+
letter = chr(65 + index)
|
|
26
|
+
return letter if uppercase else letter.lower()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def letter_to_number(letter: str) -> int:
|
|
30
|
+
"""Convert a letter to a number (A->0, B->1, etc.).
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
letter: A single letter character (A-Z or a-z)
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
int: The zero-based position of the letter in the alphabet
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If the input is not a single letter character
|
|
40
|
+
"""
|
|
41
|
+
if not letter.isalpha() or len(letter) != 1:
|
|
42
|
+
raise ValueError("Input must be a single letter character")
|
|
43
|
+
return ord(letter.upper()) - 65
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def format_list(items: List[Any], separator: str = ", ", last_separator: str = " and ") -> str:
|
|
47
|
+
"""Format a list with proper separators and conjunction.
|
|
48
|
+
|
|
49
|
+
Examples:
|
|
50
|
+
format_list(["a", "b", "c"]) -> "a, b and c"
|
|
51
|
+
format_list(["a", "b"]) -> "a and b"
|
|
52
|
+
format_list(["a"]) -> "a"
|
|
53
|
+
"""
|
|
54
|
+
if not items:
|
|
55
|
+
return ""
|
|
56
|
+
if len(items) == 1:
|
|
57
|
+
return str(items[0])
|
|
58
|
+
if len(items) == 2:
|
|
59
|
+
return f"{items[0]}{last_separator}{items[1]}"
|
|
60
|
+
return f"{separator.join(str(item) for item in items[:-1])}{last_separator}{items[-1]}"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def truncate_text(text: str, max_length: int, suffix: str = "...") -> str:
|
|
64
|
+
"""Truncate text to a maximum length with optional suffix."""
|
|
65
|
+
if len(text) <= max_length:
|
|
66
|
+
return text
|
|
67
|
+
return text[: max_length - len(suffix)] + suffix
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def format_number(number: float, precision: int = 2) -> str:
|
|
71
|
+
"""Format a number with specified decimal places."""
|
|
72
|
+
return f"{number:.{precision}f}"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def extract_initials(text: str) -> str:
|
|
76
|
+
"""Extract initials from a text string.
|
|
77
|
+
|
|
78
|
+
Examples:
|
|
79
|
+
extract_initials("John Doe") -> "JD"
|
|
80
|
+
extract_initials("Machine Learning Model") -> "MLM"
|
|
81
|
+
"""
|
|
82
|
+
words = re.findall(r"\b[A-Za-z]", text)
|
|
83
|
+
return "".join(words).upper()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def json_pretty(obj: Any, indent: int = 2) -> str:
|
|
87
|
+
"""Pretty-print an object as JSON."""
|
|
88
|
+
return json.dumps(obj, indent=indent, ensure_ascii=False)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def percentage(value: float, total: float, precision: int = 1) -> str:
|
|
92
|
+
"""Calculate and format a percentage.
|
|
93
|
+
|
|
94
|
+
Examples:
|
|
95
|
+
percentage(25, 100) -> "25.0%"
|
|
96
|
+
percentage(1, 3, 2) -> "33.33%"
|
|
97
|
+
"""
|
|
98
|
+
if total == 0:
|
|
99
|
+
return "0.0%"
|
|
100
|
+
pct = (value / total) * 100
|
|
101
|
+
return f"{pct:.{precision}f}%"
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def ordinal(n: int) -> str:
|
|
105
|
+
"""Convert number to ordinal format like 1st, 2nd, 3rd, etc."""
|
|
106
|
+
if 10 <= n % 100 <= 20:
|
|
107
|
+
suffix = "th"
|
|
108
|
+
else:
|
|
109
|
+
suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th")
|
|
110
|
+
return f"{n}{suffix}"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def default_jinja_globals() -> Dict[str, Any]:
|
|
114
|
+
"""Get default global functions for Jinja templates."""
|
|
115
|
+
return {
|
|
116
|
+
"number_to_letter": number_to_letter,
|
|
117
|
+
"letter_to_number": letter_to_number,
|
|
118
|
+
"format_list": format_list,
|
|
119
|
+
"truncate_text": truncate_text,
|
|
120
|
+
"format_number": format_number,
|
|
121
|
+
"extract_initials": extract_initials,
|
|
122
|
+
"json_pretty": json_pretty,
|
|
123
|
+
"percentage": percentage,
|
|
124
|
+
"ordinal": ordinal,
|
|
125
|
+
"max": max,
|
|
126
|
+
"min": min,
|
|
127
|
+
"len": len,
|
|
128
|
+
"abs": abs,
|
|
129
|
+
"round": round,
|
|
130
|
+
"sum": sum,
|
|
131
|
+
"sorted": sorted,
|
|
132
|
+
"enumerate": enumerate,
|
|
133
|
+
"zip": zip,
|
|
134
|
+
"range": range,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def default_jinja_filters() -> Dict[str, Any]:
|
|
139
|
+
"""Get default filters for Jinja templates."""
|
|
140
|
+
return {
|
|
141
|
+
"chr": chr,
|
|
142
|
+
"ord": ord,
|
|
143
|
+
"abs": abs,
|
|
144
|
+
"round": round,
|
|
145
|
+
"len": len,
|
|
146
|
+
}
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|