scorebook 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- scorebook/__init__.py +2 -1
- scorebook/evaluator.py +269 -118
- scorebook/exceptions.py +54 -0
- scorebook/inference/__init__.py +0 -4
- scorebook/inference/bedrock.py +305 -0
- scorebook/inference/openai.py +75 -37
- scorebook/inference/vertex.py +295 -0
- scorebook/types/__init__.py +2 -1
- scorebook/types/eval_dataset.py +56 -0
- scorebook/types/eval_result.py +7 -3
- scorebook/types/eval_run_spec.py +28 -0
- scorebook/types/inference_pipeline.py +5 -2
- scorebook/utils/__init__.py +2 -1
- scorebook/utils/build_prompt.py +52 -0
- scorebook/utils/jinja_helpers.py +146 -0
- scorebook/utils/logging_utils.py +1 -0
- scorebook/utils/progress_bars.py +91 -34
- {scorebook-0.0.1.dist-info → scorebook-0.0.3.dist-info}/METADATA +11 -1
- scorebook-0.0.3.dist-info/RECORD +31 -0
- scorebook-0.0.1.dist-info/RECORD +0 -24
- {scorebook-0.0.1.dist-info → scorebook-0.0.3.dist-info}/LICENSE +0 -0
- {scorebook-0.0.1.dist-info → scorebook-0.0.3.dist-info}/WHEEL +0 -0
|
@@ -0,0 +1,305 @@
|
|
|
1
|
+
"""
|
|
2
|
+
AWS Bedrock batch inference implementation for Scorebook.
|
|
3
|
+
|
|
4
|
+
This module provides utilities for running batch inference using AWS Bedrock's
|
|
5
|
+
Model Invocation Jobs, supporting large-scale asynchronous processing. It handles
|
|
6
|
+
API communication, request formatting, response processing, and S3 operations.
|
|
7
|
+
"""
|
|
8
|
+
|
|
9
|
+
import asyncio
|
|
10
|
+
import json
|
|
11
|
+
import os
|
|
12
|
+
import tempfile
|
|
13
|
+
import uuid
|
|
14
|
+
from datetime import datetime
|
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple
|
|
16
|
+
|
|
17
|
+
import boto3
|
|
18
|
+
from botocore.config import Config
|
|
19
|
+
from botocore.exceptions import ClientError
|
|
20
|
+
from tqdm.asyncio import tqdm
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
async def batch(
|
|
24
|
+
items: List[Any],
|
|
25
|
+
model: Optional[str] = None,
|
|
26
|
+
aws_region: Optional[str] = None,
|
|
27
|
+
aws_profile: Optional[str] = None,
|
|
28
|
+
bucket: Optional[str] = None,
|
|
29
|
+
input_prefix: Optional[str] = None,
|
|
30
|
+
output_prefix: Optional[str] = None,
|
|
31
|
+
role_arn: Optional[str] = None,
|
|
32
|
+
**hyperparameters: Any,
|
|
33
|
+
) -> List[Any]:
|
|
34
|
+
"""Process multiple inference requests in batch using AWS Bedrock.
|
|
35
|
+
|
|
36
|
+
This asynchronous function handles batch processing of inference requests,
|
|
37
|
+
optimizing for cost and throughput using AWS Bedrock's Model Invocation Jobs.
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
items: List of preprocessed items to process.
|
|
41
|
+
model: Bedrock model ID (e.g., 'us.anthropic.claude-3-5-sonnet-20241022-v2:0').
|
|
42
|
+
aws_region: AWS region for Bedrock and S3.
|
|
43
|
+
aws_profile: AWS profile name for authentication.
|
|
44
|
+
bucket: S3 bucket name for input/output data.
|
|
45
|
+
input_prefix: S3 prefix for input data.
|
|
46
|
+
output_prefix: S3 prefix for output data.
|
|
47
|
+
role_arn: IAM role ARN for Bedrock execution.
|
|
48
|
+
hyperparameters: Additional parameters for the batch requests.
|
|
49
|
+
|
|
50
|
+
Returns:
|
|
51
|
+
A list of raw model responses.
|
|
52
|
+
"""
|
|
53
|
+
# Set up AWS session and clients
|
|
54
|
+
session_kwargs = {}
|
|
55
|
+
if aws_profile:
|
|
56
|
+
session_kwargs["profile_name"] = aws_profile
|
|
57
|
+
if aws_region:
|
|
58
|
+
session_kwargs["region_name"] = aws_region
|
|
59
|
+
|
|
60
|
+
session = boto3.Session(**session_kwargs)
|
|
61
|
+
|
|
62
|
+
boto_config = Config(region_name=aws_region, retries={"max_attempts": 10, "mode": "adaptive"})
|
|
63
|
+
|
|
64
|
+
s3_client = session.client("s3", config=boto_config)
|
|
65
|
+
bedrock_client = session.client("bedrock", config=boto_config)
|
|
66
|
+
|
|
67
|
+
# Upload batch data
|
|
68
|
+
input_uri = await _upload_batch(
|
|
69
|
+
items, s3_client, bucket, input_prefix, model, **hyperparameters
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
# Start batch job
|
|
73
|
+
job_arn = await _start_batch_job(
|
|
74
|
+
bedrock_client, model, input_uri, bucket, output_prefix, role_arn
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# Wait for completion with progress tracking
|
|
78
|
+
await _wait_for_completion(bedrock_client, job_arn, len(items))
|
|
79
|
+
|
|
80
|
+
# Retrieve results
|
|
81
|
+
results = await _get_batch_results(s3_client, bedrock_client, job_arn)
|
|
82
|
+
|
|
83
|
+
return results
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
async def _upload_batch(
|
|
87
|
+
items: List[Any],
|
|
88
|
+
s3_client: Any,
|
|
89
|
+
bucket: Optional[str],
|
|
90
|
+
input_prefix: Optional[str],
|
|
91
|
+
model: Optional[str],
|
|
92
|
+
**hyperparameters: Any,
|
|
93
|
+
) -> str:
|
|
94
|
+
"""Create a JSONL file from preprocessed items and upload to S3 for batch processing."""
|
|
95
|
+
|
|
96
|
+
# Generate unique run ID and key
|
|
97
|
+
run_id = datetime.utcnow().strftime("%Y%m%dT%H%M%S") + "-" + uuid.uuid4().hex[:8]
|
|
98
|
+
|
|
99
|
+
if input_prefix:
|
|
100
|
+
input_key = f"{input_prefix.rstrip('/')}/inputs-{run_id}.jsonl"
|
|
101
|
+
else:
|
|
102
|
+
input_key = f"inputs-{run_id}.jsonl"
|
|
103
|
+
|
|
104
|
+
# Create temp JSONL file
|
|
105
|
+
with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
|
|
106
|
+
for i, item in enumerate(items):
|
|
107
|
+
# Construct batch request in Bedrock format
|
|
108
|
+
record = {
|
|
109
|
+
"recordId": f"rec-{i:04d}",
|
|
110
|
+
"modelInput": _build_claude_messages_payload(item, **hyperparameters),
|
|
111
|
+
}
|
|
112
|
+
f.write(json.dumps(record, separators=(",", ":")) + "\n")
|
|
113
|
+
file_path = f.name
|
|
114
|
+
|
|
115
|
+
# Upload to S3
|
|
116
|
+
try:
|
|
117
|
+
body = open(file_path, "rb").read()
|
|
118
|
+
s3_client.put_object(
|
|
119
|
+
Bucket=bucket,
|
|
120
|
+
Key=input_key,
|
|
121
|
+
Body=body,
|
|
122
|
+
StorageClass="INTELLIGENT_TIERING",
|
|
123
|
+
ContentType="application/json",
|
|
124
|
+
)
|
|
125
|
+
input_uri = f"s3://{bucket}/{input_key}"
|
|
126
|
+
except Exception as e:
|
|
127
|
+
raise Exception(f"Failed to upload file to S3: {e}")
|
|
128
|
+
finally:
|
|
129
|
+
# Clean up temp file
|
|
130
|
+
os.unlink(file_path)
|
|
131
|
+
|
|
132
|
+
return input_uri
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def _build_claude_messages_payload(item: Any, **hyperparameters: Any) -> Dict[str, Any]:
|
|
136
|
+
"""Build Claude messages payload for Bedrock batch processing."""
|
|
137
|
+
|
|
138
|
+
# item is a list of messages from our preprocessor
|
|
139
|
+
messages = item
|
|
140
|
+
|
|
141
|
+
# Convert to Bedrock format and extract system message
|
|
142
|
+
bedrock_messages = []
|
|
143
|
+
system_content = None
|
|
144
|
+
|
|
145
|
+
for msg in messages:
|
|
146
|
+
if msg["role"] == "system":
|
|
147
|
+
system_content = msg["content"]
|
|
148
|
+
else:
|
|
149
|
+
bedrock_messages.append(
|
|
150
|
+
{"role": msg["role"], "content": [{"type": "text", "text": msg["content"]}]}
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
payload = {
|
|
154
|
+
"anthropic_version": "bedrock-2023-05-31",
|
|
155
|
+
"max_tokens": 256,
|
|
156
|
+
"messages": bedrock_messages,
|
|
157
|
+
}
|
|
158
|
+
|
|
159
|
+
if system_content:
|
|
160
|
+
payload["system"] = system_content
|
|
161
|
+
|
|
162
|
+
payload.update(hyperparameters)
|
|
163
|
+
return payload
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
async def _start_batch_job(
|
|
167
|
+
bedrock_client: Any,
|
|
168
|
+
model: Optional[str],
|
|
169
|
+
input_uri: str,
|
|
170
|
+
bucket: Optional[str],
|
|
171
|
+
output_prefix: Optional[str],
|
|
172
|
+
role_arn: Optional[str],
|
|
173
|
+
) -> str:
|
|
174
|
+
"""Start a Bedrock Model Invocation Job."""
|
|
175
|
+
|
|
176
|
+
# Generate unique job name and output URI
|
|
177
|
+
run_id = datetime.utcnow().strftime("%Y%m%dT%H%M%S") + "-" + uuid.uuid4().hex[:8]
|
|
178
|
+
job_name = f"bedrock-batch-{run_id}"
|
|
179
|
+
|
|
180
|
+
if output_prefix:
|
|
181
|
+
output_uri = f"s3://{bucket}/{output_prefix.rstrip('/')}/job-{run_id}/"
|
|
182
|
+
else:
|
|
183
|
+
output_uri = f"s3://{bucket}/job-{run_id}/"
|
|
184
|
+
|
|
185
|
+
try:
|
|
186
|
+
response = bedrock_client.create_model_invocation_job(
|
|
187
|
+
jobName=job_name,
|
|
188
|
+
modelId=model,
|
|
189
|
+
roleArn=role_arn,
|
|
190
|
+
inputDataConfig={"s3InputDataConfig": {"s3Uri": input_uri}},
|
|
191
|
+
outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_uri}},
|
|
192
|
+
tags=[{"key": "project", "value": "scorebook-batch"}],
|
|
193
|
+
)
|
|
194
|
+
job_arn: str = response["jobArn"]
|
|
195
|
+
return job_arn
|
|
196
|
+
except ClientError as e:
|
|
197
|
+
error_info = e.response.get("Error", {})
|
|
198
|
+
raise Exception(f"Failed to create batch job: {error_info}")
|
|
199
|
+
|
|
200
|
+
|
|
201
|
+
async def _wait_for_completion(bedrock_client: Any, job_arn: str, total_items: int) -> None:
|
|
202
|
+
"""Wait for batch job completion with progress tracking."""
|
|
203
|
+
|
|
204
|
+
# Initialize progress bar
|
|
205
|
+
pbar = tqdm(total=total_items, desc="Batch processing", unit="requests")
|
|
206
|
+
|
|
207
|
+
terminal_states = {"Completed", "Failed", "Stopped"}
|
|
208
|
+
sleep_time = 15
|
|
209
|
+
|
|
210
|
+
while True:
|
|
211
|
+
try:
|
|
212
|
+
desc = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)
|
|
213
|
+
status = desc["status"]
|
|
214
|
+
|
|
215
|
+
# Get progress if available
|
|
216
|
+
job_state = desc.get("jobState", {})
|
|
217
|
+
progress = job_state.get("percentComplete")
|
|
218
|
+
|
|
219
|
+
# Update progress bar
|
|
220
|
+
if progress is not None:
|
|
221
|
+
completed = int((progress / 100) * total_items)
|
|
222
|
+
pbar.n = completed
|
|
223
|
+
pbar.set_postfix(status=status, progress=f"{progress}%")
|
|
224
|
+
else:
|
|
225
|
+
pbar.set_postfix(status=status)
|
|
226
|
+
|
|
227
|
+
pbar.refresh()
|
|
228
|
+
|
|
229
|
+
if status in terminal_states:
|
|
230
|
+
if status == "Completed":
|
|
231
|
+
pbar.n = pbar.total
|
|
232
|
+
pbar.set_postfix(status="COMPLETED")
|
|
233
|
+
else:
|
|
234
|
+
pbar.close()
|
|
235
|
+
error_msg = desc.get("failureMessage", f"Job ended with status {status}")
|
|
236
|
+
raise Exception(f"Batch job failed: {error_msg}")
|
|
237
|
+
break
|
|
238
|
+
|
|
239
|
+
# Wait before checking again
|
|
240
|
+
await asyncio.sleep(sleep_time)
|
|
241
|
+
|
|
242
|
+
except Exception as e:
|
|
243
|
+
pbar.close()
|
|
244
|
+
raise e
|
|
245
|
+
|
|
246
|
+
pbar.close()
|
|
247
|
+
|
|
248
|
+
|
|
249
|
+
async def _get_batch_results(s3_client: Any, bedrock_client: Any, job_arn: str) -> List[str]:
|
|
250
|
+
"""Download and parse batch results from S3."""
|
|
251
|
+
|
|
252
|
+
# Get job details to find output location
|
|
253
|
+
desc = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)
|
|
254
|
+
output_uri = desc["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
|
|
255
|
+
|
|
256
|
+
bucket_name, prefix = s3_uri_to_bucket_and_prefix(output_uri)
|
|
257
|
+
|
|
258
|
+
# Find the output JSONL file
|
|
259
|
+
try:
|
|
260
|
+
response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
|
|
261
|
+
contents = response.get("Contents", [])
|
|
262
|
+
|
|
263
|
+
# Look for the output JSONL file
|
|
264
|
+
output_key = None
|
|
265
|
+
for obj in contents:
|
|
266
|
+
if obj["Key"].endswith(".jsonl.out"):
|
|
267
|
+
output_key = obj["Key"]
|
|
268
|
+
break
|
|
269
|
+
|
|
270
|
+
if not output_key:
|
|
271
|
+
raise Exception("No output JSONL file found in S3")
|
|
272
|
+
|
|
273
|
+
# Download and parse results
|
|
274
|
+
obj_response = s3_client.get_object(Bucket=bucket_name, Key=output_key)
|
|
275
|
+
content = obj_response["Body"].read().decode("utf-8")
|
|
276
|
+
|
|
277
|
+
results = []
|
|
278
|
+
for line in content.strip().split("\n"):
|
|
279
|
+
if line.strip():
|
|
280
|
+
result_obj = json.loads(line)
|
|
281
|
+
# Extract text from Claude response format
|
|
282
|
+
model_output = result_obj.get("modelOutput", {})
|
|
283
|
+
content_list = model_output.get("content", [])
|
|
284
|
+
if content_list and len(content_list) > 0:
|
|
285
|
+
text = content_list[0].get("text", "")
|
|
286
|
+
results.append(text)
|
|
287
|
+
else:
|
|
288
|
+
results.append("")
|
|
289
|
+
|
|
290
|
+
return results
|
|
291
|
+
|
|
292
|
+
except Exception as e:
|
|
293
|
+
raise Exception(f"Failed to retrieve batch results: {e}")
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def s3_uri_to_bucket_and_prefix(s3_uri: str) -> Tuple[str, str]:
|
|
297
|
+
"""Parse S3 URI to bucket and prefix."""
|
|
298
|
+
# Parse S3 URI
|
|
299
|
+
if s3_uri.startswith("s3://"):
|
|
300
|
+
uri_parts = s3_uri[5:].split("/", 1)
|
|
301
|
+
bucket_name = uri_parts[0]
|
|
302
|
+
prefix = uri_parts[1] if len(uri_parts) > 1 else ""
|
|
303
|
+
else:
|
|
304
|
+
raise ValueError(f"Invalid S3 URI: {s3_uri}")
|
|
305
|
+
return bucket_name, prefix
|
scorebook/inference/openai.py
CHANGED
|
@@ -8,17 +8,19 @@ API communication, request formatting, and response processing.
|
|
|
8
8
|
|
|
9
9
|
import asyncio
|
|
10
10
|
import json
|
|
11
|
+
import logging
|
|
11
12
|
import tempfile
|
|
12
13
|
from typing import Any, List
|
|
13
14
|
|
|
14
|
-
from openai import
|
|
15
|
-
|
|
15
|
+
from openai import AsyncOpenAI
|
|
16
|
+
|
|
17
|
+
logger = logging.getLogger(__name__)
|
|
16
18
|
|
|
17
19
|
|
|
18
20
|
async def responses(
|
|
19
21
|
items: List[Any], model: str = "gpt-4.1-nano", client: Any = None, **hyperparameters: Any
|
|
20
22
|
) -> List[Any]:
|
|
21
|
-
"""Process multiple inference requests using OpenAI's API.
|
|
23
|
+
"""Process multiple inference requests using OpenAI's Async API.
|
|
22
24
|
|
|
23
25
|
This asynchronous function handles multiple inference requests,
|
|
24
26
|
manages the API communication, and processes the responses.
|
|
@@ -35,13 +37,67 @@ async def responses(
|
|
|
35
37
|
Raises:
|
|
36
38
|
NotImplementedError: Currently not implemented.
|
|
37
39
|
"""
|
|
38
|
-
|
|
39
|
-
|
|
40
|
+
logger.debug("OpenAI responses function called with %d items", len(items))
|
|
41
|
+
logger.debug("Using model: %s", model)
|
|
42
|
+
logger.debug("Hyperparameters: %s", hyperparameters)
|
|
40
43
|
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
44
|
+
if client is None:
|
|
45
|
+
logger.debug("Creating new AsyncOpenAI client")
|
|
46
|
+
client = AsyncOpenAI()
|
|
47
|
+
|
|
48
|
+
# Create all tasks concurrently for true parallelism
|
|
49
|
+
tasks = []
|
|
50
|
+
for i, item in enumerate(items):
|
|
51
|
+
logger.debug(
|
|
52
|
+
"Processing item %d: %s",
|
|
53
|
+
i,
|
|
54
|
+
str(item)[:100] + "..." if len(str(item)) > 100 else str(item),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
# Handle string input from preprocessor - convert to proper messages format
|
|
58
|
+
if isinstance(item, str):
|
|
59
|
+
# Convert the string format to proper OpenAI messages array
|
|
60
|
+
messages = [{"role": "user", "content": item}]
|
|
61
|
+
logger.debug(
|
|
62
|
+
"Converted string to messages format: %s",
|
|
63
|
+
(
|
|
64
|
+
messages[0]["content"][:100] + "..."
|
|
65
|
+
if len(messages[0]["content"]) > 100
|
|
66
|
+
else messages[0]["content"]
|
|
67
|
+
),
|
|
68
|
+
)
|
|
69
|
+
elif isinstance(item, list):
|
|
70
|
+
# Already in proper messages format
|
|
71
|
+
messages = item
|
|
72
|
+
logger.debug("Item %d already in messages format", i)
|
|
73
|
+
else:
|
|
74
|
+
# Fallback: treat as user message
|
|
75
|
+
messages = [{"role": "user", "content": str(item)}]
|
|
76
|
+
logger.debug("Item %d converted to fallback format", i)
|
|
77
|
+
|
|
78
|
+
logger.debug("Creating OpenAI task %d with messages: %s", i, messages)
|
|
79
|
+
task = client.chat.completions.create(model=model, messages=messages, **hyperparameters)
|
|
80
|
+
tasks.append(task)
|
|
81
|
+
|
|
82
|
+
logger.debug("Created %d tasks, waiting for OpenAI responses...", len(tasks))
|
|
83
|
+
# Wait for all requests to complete in parallel
|
|
84
|
+
results = await asyncio.gather(*tasks)
|
|
85
|
+
logger.debug("Received %d responses from OpenAI", len(results))
|
|
86
|
+
|
|
87
|
+
for i, result in enumerate(results):
|
|
88
|
+
logger.debug("Response %d type: %s", i, type(result))
|
|
89
|
+
try:
|
|
90
|
+
if hasattr(result, "choices") and result.choices:
|
|
91
|
+
content = result.choices[0].message.content
|
|
92
|
+
logger.debug(
|
|
93
|
+
"Response %d content: %s",
|
|
94
|
+
i,
|
|
95
|
+
content[:100] + "..." if content and len(content) > 100 else content,
|
|
96
|
+
)
|
|
97
|
+
else:
|
|
98
|
+
logger.debug("Response %d has no choices or unexpected format", i)
|
|
99
|
+
except Exception as e:
|
|
100
|
+
logger.error("Error logging response %d: %s", i, e)
|
|
45
101
|
|
|
46
102
|
return results
|
|
47
103
|
|
|
@@ -70,40 +126,23 @@ async def batch(
|
|
|
70
126
|
NotImplementedError: Currently not implemented.
|
|
71
127
|
"""
|
|
72
128
|
if client is None:
|
|
73
|
-
client =
|
|
129
|
+
client = AsyncOpenAI()
|
|
74
130
|
|
|
75
|
-
file_id = _upload_batch(items, client)
|
|
76
|
-
batch_id = _start_batch(file_id, client)
|
|
77
|
-
|
|
78
|
-
# Initialize progress bar
|
|
79
|
-
pbar = tqdm(total=len(items), desc="Batch processing", unit="requests")
|
|
131
|
+
file_id = await _upload_batch(items, client)
|
|
132
|
+
batch_id = await _start_batch(file_id, client)
|
|
80
133
|
|
|
81
134
|
awaiting_batch = True
|
|
82
135
|
while awaiting_batch:
|
|
83
136
|
batch_object = await _get_batch(batch_id, client)
|
|
84
137
|
batch_status = batch_object.status
|
|
85
138
|
|
|
86
|
-
if hasattr(batch_object, "request_counts") and batch_object.request_counts:
|
|
87
|
-
completed = batch_object.request_counts.completed
|
|
88
|
-
total = batch_object.request_counts.total
|
|
89
|
-
pbar.n = completed
|
|
90
|
-
pbar.set_postfix(status=batch_status, completed=f"{completed}/{total}")
|
|
91
|
-
else:
|
|
92
|
-
pbar.set_postfix(status=batch_status)
|
|
93
|
-
|
|
94
|
-
pbar.refresh()
|
|
95
|
-
|
|
96
139
|
if batch_status == "completed":
|
|
97
140
|
awaiting_batch = False
|
|
98
|
-
pbar.n = pbar.total
|
|
99
|
-
pbar.set_postfix(status="completed")
|
|
100
141
|
elif batch_status == "failed":
|
|
101
142
|
raise Exception("Batch processing failed")
|
|
102
143
|
else:
|
|
103
144
|
await asyncio.sleep(60)
|
|
104
145
|
|
|
105
|
-
pbar.close()
|
|
106
|
-
|
|
107
146
|
# Get the final batch object to access output_file_id
|
|
108
147
|
final_batch_object = await _get_batch(batch_id, client)
|
|
109
148
|
output_file_id = final_batch_object.output_file_id
|
|
@@ -112,7 +151,7 @@ async def batch(
|
|
|
112
151
|
return batch_result
|
|
113
152
|
|
|
114
153
|
|
|
115
|
-
def _upload_batch(items: List[Any], client: Any) -> str:
|
|
154
|
+
async def _upload_batch(items: List[Any], client: Any) -> str:
|
|
116
155
|
"""Create a .jsonl file from preprocessed items and upload to OpenAI for batch processing.
|
|
117
156
|
|
|
118
157
|
Args:
|
|
@@ -121,10 +160,9 @@ def _upload_batch(items: List[Any], client: Any) -> str:
|
|
|
121
160
|
Returns:
|
|
122
161
|
The file ID returned by OpenAI after uploading.
|
|
123
162
|
"""
|
|
124
|
-
print("Uploading batch...")
|
|
125
163
|
# Instantiate OpenAI client
|
|
126
164
|
if client is None:
|
|
127
|
-
client =
|
|
165
|
+
client = AsyncOpenAI()
|
|
128
166
|
|
|
129
167
|
# Create temp .jsonl file
|
|
130
168
|
with tempfile.NamedTemporaryFile(mode="w+", suffix=".jsonl", delete=False) as f:
|
|
@@ -141,13 +179,13 @@ def _upload_batch(items: List[Any], client: Any) -> str:
|
|
|
141
179
|
|
|
142
180
|
# Upload file to OpenAI
|
|
143
181
|
with open(file_path, "rb") as upload_file:
|
|
144
|
-
response = client.files.create(file=upload_file, purpose="batch")
|
|
182
|
+
response = await client.files.create(file=upload_file, purpose="batch")
|
|
145
183
|
|
|
146
184
|
return str(response.id)
|
|
147
185
|
|
|
148
186
|
|
|
149
|
-
def _start_batch(file_id: str, client: Any) -> str:
|
|
150
|
-
batch_response = client.batches.create(
|
|
187
|
+
async def _start_batch(file_id: str, client: Any) -> str:
|
|
188
|
+
batch_response = await client.batches.create(
|
|
151
189
|
input_file_id=file_id,
|
|
152
190
|
endpoint="/v1/chat/completions",
|
|
153
191
|
completion_window="24h",
|
|
@@ -156,13 +194,13 @@ def _start_batch(file_id: str, client: Any) -> str:
|
|
|
156
194
|
|
|
157
195
|
|
|
158
196
|
async def _get_batch(batch_id: str, client: Any) -> Any:
|
|
159
|
-
batch_object = client.batches.retrieve(batch_id)
|
|
197
|
+
batch_object = await client.batches.retrieve(batch_id)
|
|
160
198
|
return batch_object
|
|
161
199
|
|
|
162
200
|
|
|
163
201
|
async def _get_results_file(output_file_id: str, client: Any) -> List[str]:
|
|
164
202
|
"""Download and parse the batch results file from OpenAI."""
|
|
165
|
-
response = client.files.content(output_file_id)
|
|
203
|
+
response = await client.files.content(output_file_id)
|
|
166
204
|
|
|
167
205
|
# Parse the JSONL content
|
|
168
206
|
content = response.content.decode("utf-8")
|