scorebook 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scorebook/__init__.py +2 -1
- scorebook/evaluator.py +269 -118
- scorebook/exceptions.py +54 -0
- scorebook/inference/__init__.py +0 -4
- scorebook/inference/bedrock.py +305 -0
- scorebook/inference/openai.py +75 -37
- scorebook/inference/vertex.py +295 -0
- scorebook/types/__init__.py +2 -1
- scorebook/types/eval_dataset.py +56 -0
- scorebook/types/eval_result.py +7 -3
- scorebook/types/eval_run_spec.py +28 -0
- scorebook/types/inference_pipeline.py +5 -2
- scorebook/utils/__init__.py +2 -1
- scorebook/utils/build_prompt.py +52 -0
- scorebook/utils/jinja_helpers.py +146 -0
- scorebook/utils/logging_utils.py +1 -0
- scorebook/utils/progress_bars.py +91 -34
- {scorebook-0.0.1.dist-info → scorebook-0.0.3.dist-info}/METADATA +11 -1
- scorebook-0.0.3.dist-info/RECORD +31 -0
- scorebook-0.0.1.dist-info/RECORD +0 -24
- {scorebook-0.0.1.dist-info → scorebook-0.0.3.dist-info}/LICENSE +0 -0
- {scorebook-0.0.1.dist-info → scorebook-0.0.3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Google Cloud Vertex AI batch inference implementation for Scorebook.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for running batch inference using Google Cloud
|
|
5
|
+
Vertex AI Gemini models, supporting large-scale asynchronous processing. It handles
|
|
6
|
+
API communication, request formatting, response processing, and Cloud Storage operations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import tempfile
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
from typing import Any, Dict, List, Optional, Union
|
|
15
|
+
|
|
16
|
+
import fsspec
|
|
17
|
+
import pandas as pd
|
|
18
|
+
from google import genai
|
|
19
|
+
from google.cloud import storage
|
|
20
|
+
from google.genai import types
|
|
21
|
+
from tqdm.asyncio import tqdm
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
async def responses(
|
|
25
|
+
items: List[
|
|
26
|
+
Union[
|
|
27
|
+
str,
|
|
28
|
+
List[str],
|
|
29
|
+
types.Content,
|
|
30
|
+
List[types.Content],
|
|
31
|
+
types.FunctionCall,
|
|
32
|
+
List[types.FunctionCall],
|
|
33
|
+
types.Part,
|
|
34
|
+
List[types.Part],
|
|
35
|
+
]
|
|
36
|
+
],
|
|
37
|
+
model: str,
|
|
38
|
+
client: Optional[genai.Client] = None,
|
|
39
|
+
project_id: Optional[str] = None,
|
|
40
|
+
location: str = "us-central1",
|
|
41
|
+
system_instruction: Optional[str] = None,
|
|
42
|
+
**hyperparameters: Any,
|
|
43
|
+
) -> List[types.GenerateContentResponse]:
|
|
44
|
+
"""Process multiple inference requests using Google Cloud Vertex AI.
|
|
45
|
+
|
|
46
|
+
This asynchronous function handles multiple inference requests,
|
|
47
|
+
manages the API communication, and processes the responses.
|
|
48
|
+
|
|
49
|
+
Args:
|
|
50
|
+
items: List of preprocessed items to process.
|
|
51
|
+
model: Gemini model ID to use (e.g., 'gemini-2.0-flash-001').
|
|
52
|
+
client: Optional Vertex AI client instance.
|
|
53
|
+
project_id: Google Cloud Project ID. If None, uses GOOGLE_CLOUD_PROJECT env var.
|
|
54
|
+
location: Google Cloud region (default: 'us-central1').
|
|
55
|
+
system_instruction: Optional system instruction to guide model behavior.
|
|
56
|
+
hyperparameters: Additional parameters for the requests.
|
|
57
|
+
|
|
58
|
+
Returns:
|
|
59
|
+
List of raw model responses.
|
|
60
|
+
"""
|
|
61
|
+
if client is None:
|
|
62
|
+
if project_id is None:
|
|
63
|
+
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
64
|
+
if not project_id:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
"Project ID must be provided or set in GOOGLE_CLOUD_PROJECT "
|
|
67
|
+
"environment variable"
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
client = genai.Client(
|
|
71
|
+
vertexai=True,
|
|
72
|
+
project=project_id,
|
|
73
|
+
location=location,
|
|
74
|
+
http_options=types.HttpOptions(api_version="v1"),
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Create config if system_instruction or hyperparameters are provided
|
|
78
|
+
config = None
|
|
79
|
+
if system_instruction or hyperparameters:
|
|
80
|
+
config_dict = {}
|
|
81
|
+
if system_instruction:
|
|
82
|
+
config_dict["system_instruction"] = system_instruction
|
|
83
|
+
if hyperparameters:
|
|
84
|
+
config_dict.update(hyperparameters)
|
|
85
|
+
config = types.GenerateContentConfig(**config_dict)
|
|
86
|
+
|
|
87
|
+
results = []
|
|
88
|
+
for item in items:
|
|
89
|
+
response = client.models.generate_content(
|
|
90
|
+
model=model,
|
|
91
|
+
contents=item,
|
|
92
|
+
config=config,
|
|
93
|
+
)
|
|
94
|
+
results.append(response)
|
|
95
|
+
|
|
96
|
+
return results
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
async def batch(
|
|
100
|
+
items: List[Any],
|
|
101
|
+
model: str,
|
|
102
|
+
project_id: Optional[str] = None,
|
|
103
|
+
location: str = "us-central1",
|
|
104
|
+
input_bucket: Optional[str] = None,
|
|
105
|
+
output_bucket: Optional[str] = None,
|
|
106
|
+
**hyperparameters: Any,
|
|
107
|
+
) -> List[Any]:
|
|
108
|
+
"""Process multiple inference requests in batch using Google Cloud Vertex AI.
|
|
109
|
+
|
|
110
|
+
This asynchronous function handles batch processing of inference requests,
|
|
111
|
+
optimizing for cost and throughput using Google Cloud's batch prediction API.
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
items: List of preprocessed items to process.
|
|
115
|
+
model: Gemini model ID to use (e.g., 'gemini-2.0-flash-001').
|
|
116
|
+
project_id: Google Cloud Project ID. If None, uses GOOGLE_CLOUD_PROJECT env var.
|
|
117
|
+
location: Google Cloud region (default: 'us-central1').
|
|
118
|
+
input_bucket: GCS bucket for input data (required).
|
|
119
|
+
output_bucket: GCS bucket for output data (required).
|
|
120
|
+
hyperparameters: Additional parameters for the batch requests.
|
|
121
|
+
|
|
122
|
+
Returns:
|
|
123
|
+
A list of raw model responses.
|
|
124
|
+
"""
|
|
125
|
+
# Set up project ID
|
|
126
|
+
if project_id is None:
|
|
127
|
+
project_id = os.getenv("GOOGLE_CLOUD_PROJECT")
|
|
128
|
+
if not project_id:
|
|
129
|
+
raise ValueError(
|
|
130
|
+
"Project ID must be provided or set in GOOGLE_CLOUD_PROJECT " "environment variable"
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
if not input_bucket or not output_bucket:
|
|
134
|
+
raise ValueError("Both input_bucket and output_bucket must be provided")
|
|
135
|
+
|
|
136
|
+
# Initialize client
|
|
137
|
+
client = genai.Client(vertexai=True, project=project_id, location=location)
|
|
138
|
+
|
|
139
|
+
# Upload batch data
|
|
140
|
+
input_uri = await _upload_batch(items, input_bucket, model, project_id, **hyperparameters)
|
|
141
|
+
|
|
142
|
+
# Start batch job
|
|
143
|
+
batch_job = await _start_batch_job(client, model, input_uri, output_bucket)
|
|
144
|
+
|
|
145
|
+
# Wait for completion with progress tracking
|
|
146
|
+
await _wait_for_completion(client, batch_job, len(items))
|
|
147
|
+
|
|
148
|
+
# Retrieve results
|
|
149
|
+
results = await _get_batch_results(batch_job)
|
|
150
|
+
|
|
151
|
+
return results
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
async def _upload_batch(
|
|
155
|
+
items: List[Any], input_bucket: str, model: str, project_id: str, **hyperparameters: Any
|
|
156
|
+
) -> str:
|
|
157
|
+
"""Create a JSONL file from preprocessed items and upload to GCS for batch processing."""
|
|
158
|
+
|
|
159
|
+
# Create temp JSONL file
|
|
160
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
|
161
|
+
for item in items:
|
|
162
|
+
# Construct batch request in Vertex AI format
|
|
163
|
+
request_data: Dict[str, Any] = {
|
|
164
|
+
"request": {
|
|
165
|
+
"contents": [
|
|
166
|
+
{
|
|
167
|
+
"role": "user",
|
|
168
|
+
"parts": [
|
|
169
|
+
{
|
|
170
|
+
"text": (
|
|
171
|
+
str(item)
|
|
172
|
+
if not isinstance(item, list)
|
|
173
|
+
else item[0]["content"]
|
|
174
|
+
)
|
|
175
|
+
}
|
|
176
|
+
],
|
|
177
|
+
}
|
|
178
|
+
]
|
|
179
|
+
}
|
|
180
|
+
}
|
|
181
|
+
|
|
182
|
+
# Only add generationConfig if hyperparameters are provided
|
|
183
|
+
if hyperparameters:
|
|
184
|
+
request_data["request"]["generationConfig"] = hyperparameters
|
|
185
|
+
f.write(json.dumps(request_data) + "\n")
|
|
186
|
+
file_path = f.name
|
|
187
|
+
|
|
188
|
+
# Upload to GCS using Python client
|
|
189
|
+
timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
|
|
190
|
+
|
|
191
|
+
# Parse bucket name and path from input_bucket
|
|
192
|
+
if input_bucket.startswith("gs://"):
|
|
193
|
+
bucket_path = input_bucket[5:] # Remove 'gs://' prefix
|
|
194
|
+
else:
|
|
195
|
+
bucket_path = input_bucket
|
|
196
|
+
|
|
197
|
+
# Split bucket name and path
|
|
198
|
+
bucket_name = bucket_path.split("/")[0]
|
|
199
|
+
bucket_prefix = "/".join(bucket_path.split("/")[1:]) if "/" in bucket_path else ""
|
|
200
|
+
|
|
201
|
+
# Create blob path
|
|
202
|
+
blob_name = (
|
|
203
|
+
f"{bucket_prefix}/batch_input_{timestamp}.jsonl"
|
|
204
|
+
if bucket_prefix
|
|
205
|
+
else f"batch_input_{timestamp}.jsonl"
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
# Upload using Google Cloud Storage client
|
|
209
|
+
try:
|
|
210
|
+
gcs_client = storage.Client(project=project_id)
|
|
211
|
+
bucket = gcs_client.bucket(bucket_name)
|
|
212
|
+
blob = bucket.blob(blob_name)
|
|
213
|
+
|
|
214
|
+
with open(file_path, "rb") as f:
|
|
215
|
+
blob.upload_from_file(f)
|
|
216
|
+
|
|
217
|
+
input_uri = f"gs://{bucket_name}/{blob_name}"
|
|
218
|
+
|
|
219
|
+
except Exception as e:
|
|
220
|
+
raise Exception(f"Failed to upload file to GCS: {e}")
|
|
221
|
+
finally:
|
|
222
|
+
# Clean up temp file
|
|
223
|
+
os.unlink(file_path)
|
|
224
|
+
|
|
225
|
+
return input_uri
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
async def _start_batch_job(
|
|
229
|
+
client: genai.Client, model: str, input_uri: str, output_bucket: str
|
|
230
|
+
) -> Any:
|
|
231
|
+
"""Start a batch prediction job."""
|
|
232
|
+
batch_job = client.batches.create(
|
|
233
|
+
model=model,
|
|
234
|
+
src=input_uri,
|
|
235
|
+
config=types.CreateBatchJobConfig(dest=output_bucket),
|
|
236
|
+
)
|
|
237
|
+
return batch_job
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
async def _wait_for_completion(client: genai.Client, batch_job: Any, total_items: int) -> None:
|
|
241
|
+
"""Wait for batch job completion with progress tracking."""
|
|
242
|
+
# Initialize progress bar
|
|
243
|
+
pbar = tqdm(total=total_items, desc="Batch processing", unit="requests")
|
|
244
|
+
|
|
245
|
+
while True:
|
|
246
|
+
# Refresh job status
|
|
247
|
+
batch_job = client.batches.get(name=batch_job.name)
|
|
248
|
+
state = batch_job.state
|
|
249
|
+
|
|
250
|
+
# Update progress bar
|
|
251
|
+
pbar.set_postfix(status=str(state).replace("JobState.JOB_STATE_", ""))
|
|
252
|
+
pbar.refresh()
|
|
253
|
+
|
|
254
|
+
if state.name == "JOB_STATE_SUCCEEDED":
|
|
255
|
+
pbar.n = pbar.total
|
|
256
|
+
pbar.set_postfix(status="COMPLETED")
|
|
257
|
+
break
|
|
258
|
+
elif state.name == "JOB_STATE_FAILED":
|
|
259
|
+
pbar.close()
|
|
260
|
+
error_msg = getattr(batch_job, "error", "Unknown error")
|
|
261
|
+
raise Exception(f"Batch job failed: {error_msg}")
|
|
262
|
+
elif state.name in ["JOB_STATE_CANCELLED", "JOB_STATE_PAUSED"]:
|
|
263
|
+
pbar.close()
|
|
264
|
+
raise Exception(f"Batch job was {state.name}")
|
|
265
|
+
|
|
266
|
+
# Wait before checking again
|
|
267
|
+
await asyncio.sleep(30)
|
|
268
|
+
|
|
269
|
+
pbar.close()
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
async def _get_batch_results(batch_job: Any) -> List[str]:
|
|
273
|
+
"""Download and parse batch results from GCS."""
|
|
274
|
+
|
|
275
|
+
# Set up GCS filesystem
|
|
276
|
+
fs = fsspec.filesystem("gcs")
|
|
277
|
+
|
|
278
|
+
# Find predictions file - the pattern is: dest_uri/prediction-model-*/predictions.jsonl
|
|
279
|
+
output_uri = batch_job.dest.gcs_uri.rstrip("/")
|
|
280
|
+
file_paths = fs.glob(f"{output_uri}/prediction-model-*/predictions.jsonl")
|
|
281
|
+
|
|
282
|
+
if not file_paths:
|
|
283
|
+
raise Exception("No predictions file found in output bucket")
|
|
284
|
+
|
|
285
|
+
# Load and parse results
|
|
286
|
+
df = pd.read_json(f"gs://{file_paths[0]}", lines=True)
|
|
287
|
+
|
|
288
|
+
results = []
|
|
289
|
+
for _, row in df.iterrows():
|
|
290
|
+
# Extract text content from successful responses
|
|
291
|
+
response = row["response"]
|
|
292
|
+
text_content = response["candidates"][0]["content"]["parts"][0]["text"]
|
|
293
|
+
results.append(text_content)
|
|
294
|
+
|
|
295
|
+
return results
|
scorebook/types/__init__.py
CHANGED
|
@@ -7,5 +7,6 @@ and evaluation results.
|
|
|
7
7
|
|
|
8
8
|
from scorebook.types.eval_dataset import EvalDataset
|
|
9
9
|
from scorebook.types.eval_result import EvalResult
|
|
10
|
+
from scorebook.types.eval_run_spec import EvalRunSpec
|
|
10
11
|
|
|
11
|
-
__all__ = ["EvalDataset", "EvalResult"]
|
|
12
|
+
__all__ = ["EvalDataset", "EvalResult", "EvalRunSpec"]
|
scorebook/types/eval_dataset.py
CHANGED
|
@@ -4,6 +4,7 @@ import csv
|
|
|
4
4
|
import json
|
|
5
5
|
from typing import Any, Dict, Iterator, List, Optional, Type, Union
|
|
6
6
|
|
|
7
|
+
import yaml
|
|
7
8
|
from datasets import Dataset as HuggingFaceDataset
|
|
8
9
|
from datasets import DatasetDict as HuggingFaceDatasetDict
|
|
9
10
|
from datasets import load_dataset
|
|
@@ -21,6 +22,7 @@ class EvalDataset:
|
|
|
21
22
|
label: str,
|
|
22
23
|
metrics: Union[str, Type[MetricBase], List[Union[str, Type[MetricBase]]]],
|
|
23
24
|
hf_dataset: HuggingFaceDataset,
|
|
25
|
+
prompt_template: Optional[str] = None,
|
|
24
26
|
):
|
|
25
27
|
"""
|
|
26
28
|
Create a new scorebook evaluation dataset instance.
|
|
@@ -29,11 +31,13 @@ class EvalDataset:
|
|
|
29
31
|
:param label: The label field of the dataset.
|
|
30
32
|
:param metrics: The specified metrics associated with the dataset.
|
|
31
33
|
:param hf_dataset: The dataset as a hugging face dataset object.
|
|
34
|
+
:param prompt_template: Optional prompt template for building prompts from dataset items.
|
|
32
35
|
"""
|
|
33
36
|
self.name: str = name
|
|
34
37
|
self.label: str = label
|
|
35
38
|
self.metrics: List[MetricBase] = self._resolve_metrics(metrics)
|
|
36
39
|
self._hf_dataset: Optional[HuggingFaceDataset] = hf_dataset
|
|
40
|
+
self.prompt_template: Optional[str] = prompt_template
|
|
37
41
|
|
|
38
42
|
def __len__(self) -> int:
|
|
39
43
|
"""Return the number of items in the dataset."""
|
|
@@ -82,6 +86,12 @@ class EvalDataset:
|
|
|
82
86
|
raise ValueError("Dataset is not initialized")
|
|
83
87
|
return iter(self._hf_dataset)
|
|
84
88
|
|
|
89
|
+
def shuffle(self) -> None:
|
|
90
|
+
"""Randomly shuffle the dataset items."""
|
|
91
|
+
if self._hf_dataset is None:
|
|
92
|
+
raise ValueError("Dataset is not initialized")
|
|
93
|
+
self._hf_dataset.shuffle()
|
|
94
|
+
|
|
85
95
|
@property
|
|
86
96
|
def items(self) -> List[Any]:
|
|
87
97
|
"""Return a list of all examples in the dataset."""
|
|
@@ -286,6 +296,52 @@ class EvalDataset:
|
|
|
286
296
|
|
|
287
297
|
return cls(name=path, label=label, metrics=metrics, hf_dataset=hf_dataset)
|
|
288
298
|
|
|
299
|
+
@classmethod
|
|
300
|
+
def from_yaml(cls, file_path: str) -> "EvalDataset":
|
|
301
|
+
"""Instantiate an EvalDataset from a YAML file.
|
|
302
|
+
|
|
303
|
+
The YAML file should contain configuration for loading a dataset, including:
|
|
304
|
+
- name: Name of the dataset or Hugging Face dataset path
|
|
305
|
+
- label: The field used as the evaluation label
|
|
306
|
+
- metrics: List of metrics to evaluate
|
|
307
|
+
- split: Optional split name to load
|
|
308
|
+
- template: Optional prompt template
|
|
309
|
+
|
|
310
|
+
Returns:
|
|
311
|
+
An EvalDataset instance configured according to the YAML file.
|
|
312
|
+
|
|
313
|
+
Raises:
|
|
314
|
+
ValueError: If the YAML file is invalid or missing required fields.
|
|
315
|
+
"""
|
|
316
|
+
path = validate_path(file_path, expected_suffix=".yaml")
|
|
317
|
+
|
|
318
|
+
try:
|
|
319
|
+
with path.open("r", encoding="utf-8") as f:
|
|
320
|
+
config = yaml.safe_load(f)
|
|
321
|
+
except yaml.YAMLError as e:
|
|
322
|
+
raise ValueError(f"Invalid YAML in {file_path}: {e}") from e
|
|
323
|
+
|
|
324
|
+
# Validate required fields
|
|
325
|
+
required_fields = ["name", "label", "metrics"]
|
|
326
|
+
missing_fields = [field for field in required_fields if field not in config]
|
|
327
|
+
if missing_fields:
|
|
328
|
+
raise ValueError(f"Missing required fields in YAML config: {', '.join(missing_fields)}")
|
|
329
|
+
|
|
330
|
+
# Load the dataset from Hugging Face
|
|
331
|
+
dataset = cls.from_huggingface(
|
|
332
|
+
path=config["name"],
|
|
333
|
+
label=config["label"],
|
|
334
|
+
metrics=config["metrics"],
|
|
335
|
+
split=config.get("split"), # Optional field
|
|
336
|
+
name=config.get("config"), # Optional field
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# Add template if provided
|
|
340
|
+
if "template" in config:
|
|
341
|
+
dataset.prompt_template = config["template"]
|
|
342
|
+
|
|
343
|
+
return dataset
|
|
344
|
+
|
|
289
345
|
@staticmethod
|
|
290
346
|
def _resolve_metrics(
|
|
291
347
|
metrics: Union[
|
scorebook/types/eval_result.py
CHANGED
|
@@ -42,6 +42,7 @@ class EvalResult:
|
|
|
42
42
|
result = {
|
|
43
43
|
"item_id": idx,
|
|
44
44
|
"dataset_name": self.eval_dataset.name,
|
|
45
|
+
"inference_output": self.inference_outputs[idx],
|
|
45
46
|
**{
|
|
46
47
|
metric: self.metric_scores[metric]["item_scores"][idx]
|
|
47
48
|
for metric in metric_names
|
|
@@ -74,13 +75,13 @@ class EvalResult:
|
|
|
74
75
|
def to_dict(self) -> Dict[str, Any]:
|
|
75
76
|
"""Return a dictionary representing the evaluation results."""
|
|
76
77
|
return {
|
|
77
|
-
"
|
|
78
|
+
"aggregate_results": [
|
|
78
79
|
{
|
|
79
80
|
**getattr(self.eval_dataset, "hyperparams", {}),
|
|
80
81
|
**self.aggregate_scores,
|
|
81
82
|
}
|
|
82
83
|
],
|
|
83
|
-
"
|
|
84
|
+
"item_results": [item for item in self.item_scores],
|
|
84
85
|
}
|
|
85
86
|
|
|
86
87
|
def to_csv(self, file_path: str) -> None:
|
|
@@ -113,7 +114,10 @@ class EvalResult:
|
|
|
113
114
|
writer.writerow(row)
|
|
114
115
|
|
|
115
116
|
def to_json(self, file_path: str) -> None:
|
|
116
|
-
"""Save evaluation results to a JSON file in structured format
|
|
117
|
+
"""Save evaluation results to a JSON file in a structured format.
|
|
118
|
+
|
|
119
|
+
The JSON file will contain both aggregate & item results, produced by the to_dict() method.
|
|
120
|
+
"""
|
|
117
121
|
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
|
|
118
122
|
with open(file_path, "w") as f:
|
|
119
123
|
json.dump(self.to_dict(), f, indent=2)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
"""Evaluation run specification types for Scorebook."""
|
|
2
|
+
|
|
3
|
+
from typing import Any, Dict, List, NamedTuple
|
|
4
|
+
|
|
5
|
+
from scorebook.types import EvalDataset
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class EvalRunSpec(NamedTuple):
|
|
9
|
+
"""Represents a single evaluation run configuration."""
|
|
10
|
+
|
|
11
|
+
dataset_idx: int
|
|
12
|
+
eval_dataset: EvalDataset
|
|
13
|
+
items: List[Dict[str, Any]]
|
|
14
|
+
labels: List[Any]
|
|
15
|
+
hyperparams: Dict[str, Any]
|
|
16
|
+
hp_idx: int
|
|
17
|
+
|
|
18
|
+
def __str__(self) -> str:
|
|
19
|
+
"""Return a formatted string summary of the evaluation run specification."""
|
|
20
|
+
hyperparams_str = ", ".join([f"{k}={v}" for k, v in self.hyperparams.items()])
|
|
21
|
+
|
|
22
|
+
return (
|
|
23
|
+
f"EvalRunSpec(dataset_idx={self.dataset_idx},"
|
|
24
|
+
f" hp_idx={self.hp_idx},"
|
|
25
|
+
f" dataset_name='{self.eval_dataset.name}',"
|
|
26
|
+
f" hyperparams=[{hyperparams_str}]"
|
|
27
|
+
f")"
|
|
28
|
+
)
|
|
@@ -57,7 +57,7 @@ class InferencePipeline:
|
|
|
57
57
|
List of processed outputs after running through the complete pipeline
|
|
58
58
|
"""
|
|
59
59
|
if self.preprocessor:
|
|
60
|
-
input_items = [self.preprocessor(item) for item in items]
|
|
60
|
+
input_items = [self.preprocessor(item, **hyperparameters) for item in items]
|
|
61
61
|
else:
|
|
62
62
|
input_items = items
|
|
63
63
|
|
|
@@ -67,7 +67,10 @@ class InferencePipeline:
|
|
|
67
67
|
inference_outputs = self.inference_function(input_items, **hyperparameters)
|
|
68
68
|
|
|
69
69
|
if self.postprocessor:
|
|
70
|
-
return [
|
|
70
|
+
return [
|
|
71
|
+
self.postprocessor(inference_output, **hyperparameters)
|
|
72
|
+
for inference_output in inference_outputs
|
|
73
|
+
]
|
|
71
74
|
else:
|
|
72
75
|
return cast(List[Any], inference_outputs)
|
|
73
76
|
|
scorebook/utils/__init__.py
CHANGED
|
@@ -1,8 +1,9 @@
|
|
|
1
1
|
"""Utility functions and common helpers for the Scorebook framework."""
|
|
2
2
|
|
|
3
3
|
from scorebook.utils.async_utils import is_awaitable
|
|
4
|
+
from scorebook.utils.build_prompt import build_prompt
|
|
4
5
|
from scorebook.utils.io_helpers import validate_path
|
|
5
6
|
from scorebook.utils.progress_bars import evaluation_progress
|
|
6
7
|
from scorebook.utils.transform_helpers import expand_dict
|
|
7
8
|
|
|
8
|
-
__all__ = ["is_awaitable", "validate_path", "expand_dict", "evaluation_progress"]
|
|
9
|
+
__all__ = ["is_awaitable", "validate_path", "expand_dict", "evaluation_progress", "build_prompt"]
|
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Module for building prompt strings using Jinja2 templating.
|
|
3
|
+
|
|
4
|
+
Provides functionality to render prompts from templates with custom filters
|
|
5
|
+
and global variables, using strict undefined handling for better error detection.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
from typing import Any, Dict, Optional
|
|
9
|
+
|
|
10
|
+
from jinja2 import BaseLoader, Environment, StrictUndefined
|
|
11
|
+
|
|
12
|
+
from scorebook.utils.jinja_helpers import default_jinja_filters, default_jinja_globals
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def build_prompt(
|
|
16
|
+
prompt_template: str,
|
|
17
|
+
prompt_args: Dict[str, Any],
|
|
18
|
+
filters: Optional[Dict[str, Any]] = None,
|
|
19
|
+
globals_dict: Optional[Dict[str, Any]] = None,
|
|
20
|
+
) -> str:
|
|
21
|
+
"""
|
|
22
|
+
Build a prompt string from a template and arguments.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
prompt_template: Jinja2 template string
|
|
26
|
+
prompt_args: Dictionary of arguments to pass to the template
|
|
27
|
+
filters: Dictionary of Jinja2 filters. Defaults to default_jinja_filters().
|
|
28
|
+
globals_dict: Dictionary of global functions/variables. Defaults to default_jinja_globals().
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
str: Rendered prompt string
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
# Use defaults if not provided
|
|
35
|
+
filters = filters or default_jinja_filters()
|
|
36
|
+
globals_dict = globals_dict or default_jinja_globals()
|
|
37
|
+
|
|
38
|
+
# Create a Jinja2 environment with strict undefined handling
|
|
39
|
+
env = Environment(
|
|
40
|
+
loader=BaseLoader(),
|
|
41
|
+
undefined=StrictUndefined,
|
|
42
|
+
trim_blocks=True,
|
|
43
|
+
lstrip_blocks=True,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
# Add filters and globals
|
|
47
|
+
env.filters.update(filters)
|
|
48
|
+
env.globals.update(globals_dict)
|
|
49
|
+
|
|
50
|
+
# Render the template
|
|
51
|
+
template = env.from_string(prompt_template)
|
|
52
|
+
return str(template.render(**prompt_args))
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
"""Jinja2 template helper functions for Scorebook."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any, Dict, List
|
|
6
|
+
|
|
7
|
+
# Helper functions for use in Jinja templates
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def number_to_letter(index: int, uppercase: bool = True) -> str:
|
|
11
|
+
"""Convert a number to a letter (0->A, 1->B, etc.).
|
|
12
|
+
|
|
13
|
+
Args:
|
|
14
|
+
index: The number to convert to a letter (0-based index, must be 0-25)
|
|
15
|
+
uppercase: If True, returns uppercase letter; if False, returns lowercase
|
|
16
|
+
|
|
17
|
+
Returns:
|
|
18
|
+
str: A letter from A-Z (or a-z if uppercase is False)
|
|
19
|
+
|
|
20
|
+
Raises:
|
|
21
|
+
ValueError: If index is less than 0 or greater than 25
|
|
22
|
+
"""
|
|
23
|
+
if not 0 <= index <= 25:
|
|
24
|
+
raise ValueError("Index must be between 0 and 25 inclusive")
|
|
25
|
+
letter = chr(65 + index)
|
|
26
|
+
return letter if uppercase else letter.lower()
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def letter_to_number(letter: str) -> int:
|
|
30
|
+
"""Convert a letter to a number (A->0, B->1, etc.).
|
|
31
|
+
|
|
32
|
+
Args:
|
|
33
|
+
letter: A single letter character (A-Z or a-z)
|
|
34
|
+
|
|
35
|
+
Returns:
|
|
36
|
+
int: The zero-based position of the letter in the alphabet
|
|
37
|
+
|
|
38
|
+
Raises:
|
|
39
|
+
ValueError: If the input is not a single letter character
|
|
40
|
+
"""
|
|
41
|
+
if not letter.isalpha() or len(letter) != 1:
|
|
42
|
+
raise ValueError("Input must be a single letter character")
|
|
43
|
+
return ord(letter.upper()) - 65
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
def format_list(items: List[Any], separator: str = ", ", last_separator: str = " and ") -> str:
|
|
47
|
+
"""Format a list with proper separators and conjunction.
|
|
48
|
+
|
|
49
|
+
Examples:
|
|
50
|
+
format_list(["a", "b", "c"]) -> "a, b and c"
|
|
51
|
+
format_list(["a", "b"]) -> "a and b"
|
|
52
|
+
format_list(["a"]) -> "a"
|
|
53
|
+
"""
|
|
54
|
+
if not items:
|
|
55
|
+
return ""
|
|
56
|
+
if len(items) == 1:
|
|
57
|
+
return str(items[0])
|
|
58
|
+
if len(items) == 2:
|
|
59
|
+
return f"{items[0]}{last_separator}{items[1]}"
|
|
60
|
+
return f"{separator.join(str(item) for item in items[:-1])}{last_separator}{items[-1]}"
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def truncate_text(text: str, max_length: int, suffix: str = "...") -> str:
|
|
64
|
+
"""Truncate text to a maximum length with optional suffix."""
|
|
65
|
+
if len(text) <= max_length:
|
|
66
|
+
return text
|
|
67
|
+
return text[: max_length - len(suffix)] + suffix
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def format_number(number: float, precision: int = 2) -> str:
|
|
71
|
+
"""Format a number with specified decimal places."""
|
|
72
|
+
return f"{number:.{precision}f}"
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def extract_initials(text: str) -> str:
|
|
76
|
+
"""Extract initials from a text string.
|
|
77
|
+
|
|
78
|
+
Examples:
|
|
79
|
+
extract_initials("John Doe") -> "JD"
|
|
80
|
+
extract_initials("Machine Learning Model") -> "MLM"
|
|
81
|
+
"""
|
|
82
|
+
words = re.findall(r"\b[A-Za-z]", text)
|
|
83
|
+
return "".join(words).upper()
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
def json_pretty(obj: Any, indent: int = 2) -> str:
|
|
87
|
+
"""Pretty-print an object as JSON."""
|
|
88
|
+
return json.dumps(obj, indent=indent, ensure_ascii=False)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def percentage(value: float, total: float, precision: int = 1) -> str:
|
|
92
|
+
"""Calculate and format a percentage.
|
|
93
|
+
|
|
94
|
+
Examples:
|
|
95
|
+
percentage(25, 100) -> "25.0%"
|
|
96
|
+
percentage(1, 3, 2) -> "33.33%"
|
|
97
|
+
"""
|
|
98
|
+
if total == 0:
|
|
99
|
+
return "0.0%"
|
|
100
|
+
pct = (value / total) * 100
|
|
101
|
+
return f"{pct:.{precision}f}%"
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def ordinal(n: int) -> str:
|
|
105
|
+
"""Convert number to ordinal format like 1st, 2nd, 3rd, etc."""
|
|
106
|
+
if 10 <= n % 100 <= 20:
|
|
107
|
+
suffix = "th"
|
|
108
|
+
else:
|
|
109
|
+
suffix = {1: "st", 2: "nd", 3: "rd"}.get(n % 10, "th")
|
|
110
|
+
return f"{n}{suffix}"
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
def default_jinja_globals() -> Dict[str, Any]:
|
|
114
|
+
"""Get default global functions for Jinja templates."""
|
|
115
|
+
return {
|
|
116
|
+
"number_to_letter": number_to_letter,
|
|
117
|
+
"letter_to_number": letter_to_number,
|
|
118
|
+
"format_list": format_list,
|
|
119
|
+
"truncate_text": truncate_text,
|
|
120
|
+
"format_number": format_number,
|
|
121
|
+
"extract_initials": extract_initials,
|
|
122
|
+
"json_pretty": json_pretty,
|
|
123
|
+
"percentage": percentage,
|
|
124
|
+
"ordinal": ordinal,
|
|
125
|
+
"max": max,
|
|
126
|
+
"min": min,
|
|
127
|
+
"len": len,
|
|
128
|
+
"abs": abs,
|
|
129
|
+
"round": round,
|
|
130
|
+
"sum": sum,
|
|
131
|
+
"sorted": sorted,
|
|
132
|
+
"enumerate": enumerate,
|
|
133
|
+
"zip": zip,
|
|
134
|
+
"range": range,
|
|
135
|
+
}
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
def default_jinja_filters() -> Dict[str, Any]:
|
|
139
|
+
"""Get default filters for Jinja templates."""
|
|
140
|
+
return {
|
|
141
|
+
"chr": chr,
|
|
142
|
+
"ord": ord,
|
|
143
|
+
"abs": abs,
|
|
144
|
+
"round": round,
|
|
145
|
+
"len": len,
|
|
146
|
+
}
|