scorebook 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,305 @@
1
+ """
2
+ AWS Bedrock batch inference implementation for Scorebook.
3
+
4
+ This module provides utilities for running batch inference using AWS Bedrock's
5
+ Model Invocation Jobs, supporting large-scale asynchronous processing. It handles
6
+ API communication, request formatting, response processing, and S3 operations.
7
+ """
8
+
9
+ import asyncio
10
+ import json
11
+ import os
12
+ import tempfile
13
+ import uuid
14
+ from datetime import datetime
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import boto3
18
+ from botocore.config import Config
19
+ from botocore.exceptions import ClientError
20
+ from tqdm.asyncio import tqdm
21
+
22
+
23
+ async def batch(
24
+ items: List[Any],
25
+ model: Optional[str] = None,
26
+ aws_region: Optional[str] = None,
27
+ aws_profile: Optional[str] = None,
28
+ bucket: Optional[str] = None,
29
+ input_prefix: Optional[str] = None,
30
+ output_prefix: Optional[str] = None,
31
+ role_arn: Optional[str] = None,
32
+ **hyperparameters: Any,
33
+ ) -> List[Any]:
34
+ """Process multiple inference requests in batch using AWS Bedrock.
35
+
36
+ This asynchronous function handles batch processing of inference requests,
37
+ optimizing for cost and throughput using AWS Bedrock's Model Invocation Jobs.
38
+
39
+ Args:
40
+ items: List of preprocessed items to process.
41
+ model: Bedrock model ID (e.g., 'us.anthropic.claude-3-5-sonnet-20241022-v2:0').
42
+ aws_region: AWS region for Bedrock and S3.
43
+ aws_profile: AWS profile name for authentication.
44
+ bucket: S3 bucket name for input/output data.
45
+ input_prefix: S3 prefix for input data.
46
+ output_prefix: S3 prefix for output data.
47
+ role_arn: IAM role ARN for Bedrock execution.
48
+ hyperparameters: Additional parameters for the batch requests.
49
+
50
+ Returns:
51
+ A list of raw model responses.
52
+ """
53
+ # Set up AWS session and clients
54
+ session_kwargs = {}
55
+ if aws_profile:
56
+ session_kwargs["profile_name"] = aws_profile
57
+ if aws_region:
58
+ session_kwargs["region_name"] = aws_region
59
+
60
+ session = boto3.Session(**session_kwargs)
61
+
62
+ boto_config = Config(region_name=aws_region, retries={"max_attempts": 10, "mode": "adaptive"})
63
+
64
+ s3_client = session.client("s3", config=boto_config)
65
+ bedrock_client = session.client("bedrock", config=boto_config)
66
+
67
+ # Upload batch data
68
+ input_uri = await _upload_batch(
69
+ items, s3_client, bucket, input_prefix, model, **hyperparameters
70
+ )
71
+
72
+ # Start batch job
73
+ job_arn = await _start_batch_job(
74
+ bedrock_client, model, input_uri, bucket, output_prefix, role_arn
75
+ )
76
+
77
+ # Wait for completion with progress tracking
78
+ await _wait_for_completion(bedrock_client, job_arn, len(items))
79
+
80
+ # Retrieve results
81
+ results = await _get_batch_results(s3_client, bedrock_client, job_arn)
82
+
83
+ return results
84
+
85
+
86
+ async def _upload_batch(
87
+ items: List[Any],
88
+ s3_client: Any,
89
+ bucket: Optional[str],
90
+ input_prefix: Optional[str],
91
+ model: Optional[str],
92
+ **hyperparameters: Any,
93
+ ) -> str:
94
+ """Create a JSONL file from preprocessed items and upload to S3 for batch processing."""
95
+
96
+ # Generate unique run ID and key
97
+ run_id = datetime.utcnow().strftime("%Y%m%dT%H%M%S") + "-" + uuid.uuid4().hex[:8]
98
+
99
+ if input_prefix:
100
+ input_key = f"{input_prefix.rstrip('/')}/inputs-{run_id}.jsonl"
101
+ else:
102
+ input_key = f"inputs-{run_id}.jsonl"
103
+
104
+ # Create temp JSONL file
105
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False) as f:
106
+ for i, item in enumerate(items):
107
+ # Construct batch request in Bedrock format
108
+ record = {
109
+ "recordId": f"rec-{i:04d}",
110
+ "modelInput": _build_claude_messages_payload(item, **hyperparameters),
111
+ }
112
+ f.write(json.dumps(record, separators=(",", ":")) + "\n")
113
+ file_path = f.name
114
+
115
+ # Upload to S3
116
+ try:
117
+ body = open(file_path, "rb").read()
118
+ s3_client.put_object(
119
+ Bucket=bucket,
120
+ Key=input_key,
121
+ Body=body,
122
+ StorageClass="INTELLIGENT_TIERING",
123
+ ContentType="application/json",
124
+ )
125
+ input_uri = f"s3://{bucket}/{input_key}"
126
+ except Exception as e:
127
+ raise Exception(f"Failed to upload file to S3: {e}")
128
+ finally:
129
+ # Clean up temp file
130
+ os.unlink(file_path)
131
+
132
+ return input_uri
133
+
134
+
135
+ def _build_claude_messages_payload(item: Any, **hyperparameters: Any) -> Dict[str, Any]:
136
+ """Build Claude messages payload for Bedrock batch processing."""
137
+
138
+ # item is a list of messages from our preprocessor
139
+ messages = item
140
+
141
+ # Convert to Bedrock format and extract system message
142
+ bedrock_messages = []
143
+ system_content = None
144
+
145
+ for msg in messages:
146
+ if msg["role"] == "system":
147
+ system_content = msg["content"]
148
+ else:
149
+ bedrock_messages.append(
150
+ {"role": msg["role"], "content": [{"type": "text", "text": msg["content"]}]}
151
+ )
152
+
153
+ payload = {
154
+ "anthropic_version": "bedrock-2023-05-31",
155
+ "max_tokens": 256,
156
+ "messages": bedrock_messages,
157
+ }
158
+
159
+ if system_content:
160
+ payload["system"] = system_content
161
+
162
+ payload.update(hyperparameters)
163
+ return payload
164
+
165
+
166
+ async def _start_batch_job(
167
+ bedrock_client: Any,
168
+ model: Optional[str],
169
+ input_uri: str,
170
+ bucket: Optional[str],
171
+ output_prefix: Optional[str],
172
+ role_arn: Optional[str],
173
+ ) -> str:
174
+ """Start a Bedrock Model Invocation Job."""
175
+
176
+ # Generate unique job name and output URI
177
+ run_id = datetime.utcnow().strftime("%Y%m%dT%H%M%S") + "-" + uuid.uuid4().hex[:8]
178
+ job_name = f"bedrock-batch-{run_id}"
179
+
180
+ if output_prefix:
181
+ output_uri = f"s3://{bucket}/{output_prefix.rstrip('/')}/job-{run_id}/"
182
+ else:
183
+ output_uri = f"s3://{bucket}/job-{run_id}/"
184
+
185
+ try:
186
+ response = bedrock_client.create_model_invocation_job(
187
+ jobName=job_name,
188
+ modelId=model,
189
+ roleArn=role_arn,
190
+ inputDataConfig={"s3InputDataConfig": {"s3Uri": input_uri}},
191
+ outputDataConfig={"s3OutputDataConfig": {"s3Uri": output_uri}},
192
+ tags=[{"key": "project", "value": "scorebook-batch"}],
193
+ )
194
+ job_arn: str = response["jobArn"]
195
+ return job_arn
196
+ except ClientError as e:
197
+ error_info = e.response.get("Error", {})
198
+ raise Exception(f"Failed to create batch job: {error_info}")
199
+
200
+
201
+ async def _wait_for_completion(bedrock_client: Any, job_arn: str, total_items: int) -> None:
202
+ """Wait for batch job completion with progress tracking."""
203
+
204
+ # Initialize progress bar
205
+ pbar = tqdm(total=total_items, desc="Batch processing", unit="requests")
206
+
207
+ terminal_states = {"Completed", "Failed", "Stopped"}
208
+ sleep_time = 15
209
+
210
+ while True:
211
+ try:
212
+ desc = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)
213
+ status = desc["status"]
214
+
215
+ # Get progress if available
216
+ job_state = desc.get("jobState", {})
217
+ progress = job_state.get("percentComplete")
218
+
219
+ # Update progress bar
220
+ if progress is not None:
221
+ completed = int((progress / 100) * total_items)
222
+ pbar.n = completed
223
+ pbar.set_postfix(status=status, progress=f"{progress}%")
224
+ else:
225
+ pbar.set_postfix(status=status)
226
+
227
+ pbar.refresh()
228
+
229
+ if status in terminal_states:
230
+ if status == "Completed":
231
+ pbar.n = pbar.total
232
+ pbar.set_postfix(status="COMPLETED")
233
+ else:
234
+ pbar.close()
235
+ error_msg = desc.get("failureMessage", f"Job ended with status {status}")
236
+ raise Exception(f"Batch job failed: {error_msg}")
237
+ break
238
+
239
+ # Wait before checking again
240
+ await asyncio.sleep(sleep_time)
241
+
242
+ except Exception as e:
243
+ pbar.close()
244
+ raise e
245
+
246
+ pbar.close()
247
+
248
+
249
+ async def _get_batch_results(s3_client: Any, bedrock_client: Any, job_arn: str) -> List[str]:
250
+ """Download and parse batch results from S3."""
251
+
252
+ # Get job details to find output location
253
+ desc = bedrock_client.get_model_invocation_job(jobIdentifier=job_arn)
254
+ output_uri = desc["outputDataConfig"]["s3OutputDataConfig"]["s3Uri"]
255
+
256
+ bucket_name, prefix = s3_uri_to_bucket_and_prefix(output_uri)
257
+
258
+ # Find the output JSONL file
259
+ try:
260
+ response = s3_client.list_objects_v2(Bucket=bucket_name, Prefix=prefix)
261
+ contents = response.get("Contents", [])
262
+
263
+ # Look for the output JSONL file
264
+ output_key = None
265
+ for obj in contents:
266
+ if obj["Key"].endswith(".jsonl.out"):
267
+ output_key = obj["Key"]
268
+ break
269
+
270
+ if not output_key:
271
+ raise Exception("No output JSONL file found in S3")
272
+
273
+ # Download and parse results
274
+ obj_response = s3_client.get_object(Bucket=bucket_name, Key=output_key)
275
+ content = obj_response["Body"].read().decode("utf-8")
276
+
277
+ results = []
278
+ for line in content.strip().split("\n"):
279
+ if line.strip():
280
+ result_obj = json.loads(line)
281
+ # Extract text from Claude response format
282
+ model_output = result_obj.get("modelOutput", {})
283
+ content_list = model_output.get("content", [])
284
+ if content_list and len(content_list) > 0:
285
+ text = content_list[0].get("text", "")
286
+ results.append(text)
287
+ else:
288
+ results.append("")
289
+
290
+ return results
291
+
292
+ except Exception as e:
293
+ raise Exception(f"Failed to retrieve batch results: {e}")
294
+
295
+
296
+ def s3_uri_to_bucket_and_prefix(s3_uri: str) -> Tuple[str, str]:
297
+ """Parse S3 URI to bucket and prefix."""
298
+ # Parse S3 URI
299
+ if s3_uri.startswith("s3://"):
300
+ uri_parts = s3_uri[5:].split("/", 1)
301
+ bucket_name = uri_parts[0]
302
+ prefix = uri_parts[1] if len(uri_parts) > 1 else ""
303
+ else:
304
+ raise ValueError(f"Invalid S3 URI: {s3_uri}")
305
+ return bucket_name, prefix
@@ -8,17 +8,19 @@ API communication, request formatting, and response processing.
8
8
 
9
9
  import asyncio
10
10
  import json
11
+ import logging
11
12
  import tempfile
12
13
  from typing import Any, List
13
14
 
14
- from openai import OpenAI
15
- from tqdm.asyncio import tqdm
15
+ from openai import AsyncOpenAI
16
+
17
+ logger = logging.getLogger(__name__)
16
18
 
17
19
 
18
20
  async def responses(
19
21
  items: List[Any], model: str = "gpt-4.1-nano", client: Any = None, **hyperparameters: Any
20
22
  ) -> List[Any]:
21
- """Process multiple inference requests using OpenAI's API.
23
+ """Process multiple inference requests using OpenAI's Async API.
22
24
 
23
25
  This asynchronous function handles multiple inference requests,
24
26
  manages the API communication, and processes the responses.
@@ -35,13 +37,67 @@ async def responses(
35
37
  Raises:
36
38
  NotImplementedError: Currently not implemented.
37
39
  """
38
- if client is None:
39
- client = OpenAI()
40
+ logger.debug("OpenAI responses function called with %d items", len(items))
41
+ logger.debug("Using model: %s", model)
42
+ logger.debug("Hyperparameters: %s", hyperparameters)
40
43
 
41
- results = []
42
- for item in items:
43
- response = client.responses.create(model=model, input=item)
44
- results.append(response)
44
+ if client is None:
45
+ logger.debug("Creating new AsyncOpenAI client")
46
+ client = AsyncOpenAI()
47
+
48
+ # Create all tasks concurrently for true parallelism
49
+ tasks = []
50
+ for i, item in enumerate(items):
51
+ logger.debug(
52
+ "Processing item %d: %s",
53
+ i,
54
+ str(item)[:100] + "..." if len(str(item)) > 100 else str(item),
55
+ )
56
+
57
+ # Handle string input from preprocessor - convert to proper messages format
58
+ if isinstance(item, str):
59
+ # Convert the string format to proper OpenAI messages array
60
+ messages = [{"role": "user", "content": item}]
61
+ logger.debug(
62
+ "Converted string to messages format: %s",
63
+ (
64
+ messages[0]["content"][:100] + "..."
65
+ if len(messages[0]["content"]) > 100
66
+ else messages[0]["content"]
67
+ ),
68
+ )
69
+ elif isinstance(item, list):
70
+ # Already in proper messages format
71
+ messages = item
72
+ logger.debug("Item %d already in messages format", i)
73
+ else:
74
+ # Fallback: treat as user message
75
+ messages = [{"role": "user", "content": str(item)}]
76
+ logger.debug("Item %d converted to fallback format", i)
77
+
78
+ logger.debug("Creating OpenAI task %d with messages: %s", i, messages)
79
+ task = client.chat.completions.create(model=model, messages=messages, **hyperparameters)
80
+ tasks.append(task)
81
+
82
+ logger.debug("Created %d tasks, waiting for OpenAI responses...", len(tasks))
83
+ # Wait for all requests to complete in parallel
84
+ results = await asyncio.gather(*tasks)
85
+ logger.debug("Received %d responses from OpenAI", len(results))
86
+
87
+ for i, result in enumerate(results):
88
+ logger.debug("Response %d type: %s", i, type(result))
89
+ try:
90
+ if hasattr(result, "choices") and result.choices:
91
+ content = result.choices[0].message.content
92
+ logger.debug(
93
+ "Response %d content: %s",
94
+ i,
95
+ content[:100] + "..." if content and len(content) > 100 else content,
96
+ )
97
+ else:
98
+ logger.debug("Response %d has no choices or unexpected format", i)
99
+ except Exception as e:
100
+ logger.error("Error logging response %d: %s", i, e)
45
101
 
46
102
  return results
47
103
 
@@ -70,40 +126,23 @@ async def batch(
70
126
  NotImplementedError: Currently not implemented.
71
127
  """
72
128
  if client is None:
73
- client = OpenAI()
129
+ client = AsyncOpenAI()
74
130
 
75
- file_id = _upload_batch(items, client)
76
- batch_id = _start_batch(file_id, client)
77
-
78
- # Initialize progress bar
79
- pbar = tqdm(total=len(items), desc="Batch processing", unit="requests")
131
+ file_id = await _upload_batch(items, client)
132
+ batch_id = await _start_batch(file_id, client)
80
133
 
81
134
  awaiting_batch = True
82
135
  while awaiting_batch:
83
136
  batch_object = await _get_batch(batch_id, client)
84
137
  batch_status = batch_object.status
85
138
 
86
- if hasattr(batch_object, "request_counts") and batch_object.request_counts:
87
- completed = batch_object.request_counts.completed
88
- total = batch_object.request_counts.total
89
- pbar.n = completed
90
- pbar.set_postfix(status=batch_status, completed=f"{completed}/{total}")
91
- else:
92
- pbar.set_postfix(status=batch_status)
93
-
94
- pbar.refresh()
95
-
96
139
  if batch_status == "completed":
97
140
  awaiting_batch = False
98
- pbar.n = pbar.total
99
- pbar.set_postfix(status="completed")
100
141
  elif batch_status == "failed":
101
142
  raise Exception("Batch processing failed")
102
143
  else:
103
144
  await asyncio.sleep(60)
104
145
 
105
- pbar.close()
106
-
107
146
  # Get the final batch object to access output_file_id
108
147
  final_batch_object = await _get_batch(batch_id, client)
109
148
  output_file_id = final_batch_object.output_file_id
@@ -112,7 +151,7 @@ async def batch(
112
151
  return batch_result
113
152
 
114
153
 
115
- def _upload_batch(items: List[Any], client: Any) -> str:
154
+ async def _upload_batch(items: List[Any], client: Any) -> str:
116
155
  """Create a .jsonl file from preprocessed items and upload to OpenAI for batch processing.
117
156
 
118
157
  Args:
@@ -121,10 +160,9 @@ def _upload_batch(items: List[Any], client: Any) -> str:
121
160
  Returns:
122
161
  The file ID returned by OpenAI after uploading.
123
162
  """
124
- print("Uploading batch...")
125
163
  # Instantiate OpenAI client
126
164
  if client is None:
127
- client = OpenAI()
165
+ client = AsyncOpenAI()
128
166
 
129
167
  # Create temp .jsonl file
130
168
  with tempfile.NamedTemporaryFile(mode="w+", suffix=".jsonl", delete=False) as f:
@@ -141,13 +179,13 @@ def _upload_batch(items: List[Any], client: Any) -> str:
141
179
 
142
180
  # Upload file to OpenAI
143
181
  with open(file_path, "rb") as upload_file:
144
- response = client.files.create(file=upload_file, purpose="batch")
182
+ response = await client.files.create(file=upload_file, purpose="batch")
145
183
 
146
184
  return str(response.id)
147
185
 
148
186
 
149
- def _start_batch(file_id: str, client: Any) -> str:
150
- batch_response = client.batches.create(
187
+ async def _start_batch(file_id: str, client: Any) -> str:
188
+ batch_response = await client.batches.create(
151
189
  input_file_id=file_id,
152
190
  endpoint="/v1/chat/completions",
153
191
  completion_window="24h",
@@ -156,13 +194,13 @@ def _start_batch(file_id: str, client: Any) -> str:
156
194
 
157
195
 
158
196
  async def _get_batch(batch_id: str, client: Any) -> Any:
159
- batch_object = client.batches.retrieve(batch_id)
197
+ batch_object = await client.batches.retrieve(batch_id)
160
198
  return batch_object
161
199
 
162
200
 
163
201
  async def _get_results_file(output_file_id: str, client: Any) -> List[str]:
164
202
  """Download and parse the batch results file from OpenAI."""
165
- response = client.files.content(output_file_id)
203
+ response = await client.files.content(output_file_id)
166
204
 
167
205
  # Parse the JSONL content
168
206
  content = response.content.decode("utf-8")