ostruct-cli 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ostruct/cli/__init__.py +21 -3
- ostruct/cli/base_errors.py +1 -1
- ostruct/cli/cli.py +66 -1983
- ostruct/cli/click_options.py +460 -28
- ostruct/cli/code_interpreter.py +238 -0
- ostruct/cli/commands/__init__.py +32 -0
- ostruct/cli/commands/list_models.py +128 -0
- ostruct/cli/commands/quick_ref.py +50 -0
- ostruct/cli/commands/run.py +137 -0
- ostruct/cli/commands/update_registry.py +71 -0
- ostruct/cli/config.py +277 -0
- ostruct/cli/cost_estimation.py +134 -0
- ostruct/cli/errors.py +310 -6
- ostruct/cli/exit_codes.py +1 -0
- ostruct/cli/explicit_file_processor.py +548 -0
- ostruct/cli/field_utils.py +69 -0
- ostruct/cli/file_info.py +42 -9
- ostruct/cli/file_list.py +301 -102
- ostruct/cli/file_search.py +455 -0
- ostruct/cli/file_utils.py +47 -13
- ostruct/cli/mcp_integration.py +541 -0
- ostruct/cli/model_creation.py +150 -1
- ostruct/cli/model_validation.py +204 -0
- ostruct/cli/progress_reporting.py +398 -0
- ostruct/cli/registry_updates.py +14 -9
- ostruct/cli/runner.py +1418 -0
- ostruct/cli/schema_utils.py +113 -0
- ostruct/cli/services.py +626 -0
- ostruct/cli/template_debug.py +748 -0
- ostruct/cli/template_debug_help.py +162 -0
- ostruct/cli/template_env.py +15 -6
- ostruct/cli/template_filters.py +55 -3
- ostruct/cli/template_optimizer.py +474 -0
- ostruct/cli/template_processor.py +1080 -0
- ostruct/cli/template_rendering.py +69 -34
- ostruct/cli/token_validation.py +286 -0
- ostruct/cli/types.py +78 -0
- ostruct/cli/unattended_operation.py +269 -0
- ostruct/cli/validators.py +386 -3
- {ostruct_cli-0.7.2.dist-info → ostruct_cli-0.8.0.dist-info}/LICENSE +2 -0
- ostruct_cli-0.8.0.dist-info/METADATA +633 -0
- ostruct_cli-0.8.0.dist-info/RECORD +69 -0
- {ostruct_cli-0.7.2.dist-info → ostruct_cli-0.8.0.dist-info}/WHEEL +1 -1
- ostruct_cli-0.7.2.dist-info/METADATA +0 -370
- ostruct_cli-0.7.2.dist-info/RECORD +0 -45
- {ostruct_cli-0.7.2.dist-info → ostruct_cli-0.8.0.dist-info}/entry_points.txt +0 -0
ostruct/cli/runner.py
ADDED
@@ -0,0 +1,1418 @@
|
|
1
|
+
"""Async execution engine for ostruct CLI operations."""
|
2
|
+
|
3
|
+
import copy
|
4
|
+
import json
|
5
|
+
import logging
|
6
|
+
import os
|
7
|
+
from pathlib import Path
|
8
|
+
from typing import Any, AsyncGenerator, Dict, List, Optional, Type, Union
|
9
|
+
|
10
|
+
from openai import AsyncOpenAI, OpenAIError
|
11
|
+
from openai_model_registry import ModelRegistry
|
12
|
+
from pydantic import BaseModel
|
13
|
+
|
14
|
+
from .code_interpreter import CodeInterpreterManager
|
15
|
+
from .config import OstructConfig
|
16
|
+
from .cost_estimation import calculate_cost_estimate, format_cost_breakdown
|
17
|
+
from .errors import (
|
18
|
+
APIErrorMapper,
|
19
|
+
CLIError,
|
20
|
+
SchemaValidationError,
|
21
|
+
StreamInterruptedError,
|
22
|
+
StreamParseError,
|
23
|
+
)
|
24
|
+
from .exit_codes import ExitCode
|
25
|
+
from .explicit_file_processor import ProcessingResult
|
26
|
+
from .file_search import FileSearchManager
|
27
|
+
from .mcp_integration import MCPConfiguration, MCPServerManager
|
28
|
+
from .progress_reporting import (
|
29
|
+
configure_progress_reporter,
|
30
|
+
get_progress_reporter,
|
31
|
+
report_success,
|
32
|
+
)
|
33
|
+
from .serialization import LogSerializer
|
34
|
+
from .services import ServiceContainer
|
35
|
+
from .types import CLIParams
|
36
|
+
from .unattended_operation import (
|
37
|
+
UnattendedOperationManager,
|
38
|
+
)
|
39
|
+
|
40
|
+
|
41
|
+
# Error classes for streaming operations (duplicated from cli.py for now)
|
42
|
+
class APIResponseError(Exception):
|
43
|
+
pass
|
44
|
+
|
45
|
+
|
46
|
+
class EmptyResponseError(Exception):
|
47
|
+
pass
|
48
|
+
|
49
|
+
|
50
|
+
class InvalidResponseFormatError(Exception):
|
51
|
+
pass
|
52
|
+
|
53
|
+
|
54
|
+
class StreamBufferError(Exception):
|
55
|
+
pass
|
56
|
+
|
57
|
+
|
58
|
+
def make_strict(obj: Any) -> None:
|
59
|
+
"""Transform Pydantic schema for Responses API strict mode.
|
60
|
+
|
61
|
+
This function recursively adds 'additionalProperties: false' to all object types
|
62
|
+
in a JSON schema to make it compatible with OpenAI's strict mode requirement.
|
63
|
+
|
64
|
+
Args:
|
65
|
+
obj: The schema object to transform (modified in-place)
|
66
|
+
"""
|
67
|
+
if isinstance(obj, dict):
|
68
|
+
if obj.get("type") == "object" and "additionalProperties" not in obj:
|
69
|
+
obj["additionalProperties"] = False
|
70
|
+
for value in obj.values():
|
71
|
+
make_strict(value)
|
72
|
+
elif isinstance(obj, list):
|
73
|
+
for item in obj:
|
74
|
+
make_strict(item)
|
75
|
+
|
76
|
+
|
77
|
+
def supports_structured_output(model: str) -> bool:
|
78
|
+
"""Check if model supports structured output."""
|
79
|
+
try:
|
80
|
+
registry = ModelRegistry.get_instance()
|
81
|
+
capabilities = registry.get_capabilities(model)
|
82
|
+
return getattr(capabilities, "supports_structured_output", True)
|
83
|
+
except Exception:
|
84
|
+
# Default to True for backward compatibility
|
85
|
+
return True
|
86
|
+
|
87
|
+
|
88
|
+
logger = logging.getLogger(__name__)
|
89
|
+
|
90
|
+
|
91
|
+
async def process_mcp_configuration(args: CLIParams) -> MCPServerManager:
|
92
|
+
"""Process MCP configuration from CLI arguments.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
args: CLI parameters containing MCP settings
|
96
|
+
|
97
|
+
Returns:
|
98
|
+
MCPServerManager: Configured manager ready for tool integration
|
99
|
+
|
100
|
+
Raises:
|
101
|
+
CLIError: If MCP configuration is invalid
|
102
|
+
"""
|
103
|
+
logger.debug("=== MCP Configuration Processing ===")
|
104
|
+
|
105
|
+
# Parse MCP servers from CLI arguments
|
106
|
+
servers = []
|
107
|
+
for server_spec in args.get("mcp_servers", []):
|
108
|
+
try:
|
109
|
+
# Parse format: [label@]url
|
110
|
+
if "@" in server_spec:
|
111
|
+
label, url = server_spec.rsplit("@", 1)
|
112
|
+
else:
|
113
|
+
url = server_spec
|
114
|
+
label = None
|
115
|
+
|
116
|
+
server_config = {"url": url}
|
117
|
+
if label:
|
118
|
+
server_config["label"] = label
|
119
|
+
|
120
|
+
# Add require_approval setting from CLI
|
121
|
+
server_config["require_approval"] = args.get(
|
122
|
+
"mcp_require_approval", "never"
|
123
|
+
)
|
124
|
+
|
125
|
+
# Parse headers if provided
|
126
|
+
mcp_headers = args.get("mcp_headers")
|
127
|
+
if mcp_headers:
|
128
|
+
try:
|
129
|
+
headers = json.loads(mcp_headers)
|
130
|
+
server_config["headers"] = headers
|
131
|
+
except json.JSONDecodeError as e:
|
132
|
+
raise CLIError(
|
133
|
+
f"Invalid JSON in --mcp-headers: {e}",
|
134
|
+
exit_code=ExitCode.USAGE_ERROR,
|
135
|
+
)
|
136
|
+
|
137
|
+
servers.append(server_config)
|
138
|
+
|
139
|
+
except Exception as e:
|
140
|
+
raise CLIError(
|
141
|
+
f"Failed to parse MCP server spec '{server_spec}': {e}",
|
142
|
+
exit_code=ExitCode.USAGE_ERROR,
|
143
|
+
)
|
144
|
+
|
145
|
+
# Process allowed tools if specified
|
146
|
+
allowed_tools_map = {}
|
147
|
+
mcp_allowed_tools = args.get("mcp_allowed_tools", [])
|
148
|
+
for tools_spec in mcp_allowed_tools:
|
149
|
+
try:
|
150
|
+
if ":" not in tools_spec:
|
151
|
+
raise ValueError("Format should be server_label:tool1,tool2")
|
152
|
+
label, tools_str = tools_spec.split(":", 1)
|
153
|
+
tools_list = [tool.strip() for tool in tools_str.split(",")]
|
154
|
+
allowed_tools_map[label] = tools_list
|
155
|
+
except Exception as e:
|
156
|
+
raise CLIError(
|
157
|
+
f"Failed to parse MCP allowed tools '{tools_spec}': {e}",
|
158
|
+
exit_code=ExitCode.USAGE_ERROR,
|
159
|
+
)
|
160
|
+
|
161
|
+
# Apply allowed tools to server configurations
|
162
|
+
for server in servers:
|
163
|
+
server_label = server.get("label")
|
164
|
+
if server_label and server_label in allowed_tools_map:
|
165
|
+
server["allowed_tools"] = allowed_tools_map[server_label] # type: ignore[assignment]
|
166
|
+
|
167
|
+
# Create configuration and manager
|
168
|
+
MCPConfiguration(servers) # Validate configuration
|
169
|
+
manager = MCPServerManager(servers)
|
170
|
+
|
171
|
+
# Pre-validate servers for CLI compatibility
|
172
|
+
validation_errors = await manager.pre_validate_all_servers()
|
173
|
+
if validation_errors:
|
174
|
+
error_msg = "MCP server validation failed:\n" + "\n".join(
|
175
|
+
f"- {error}" for error in validation_errors
|
176
|
+
)
|
177
|
+
# Map as MCP error
|
178
|
+
mapped_error = APIErrorMapper.map_tool_error(
|
179
|
+
"mcp", Exception(error_msg)
|
180
|
+
)
|
181
|
+
raise mapped_error
|
182
|
+
|
183
|
+
logger.debug(
|
184
|
+
"MCP configuration validated successfully with %d servers",
|
185
|
+
len(servers),
|
186
|
+
)
|
187
|
+
return manager
|
188
|
+
|
189
|
+
|
190
|
+
async def process_code_interpreter_configuration(
|
191
|
+
args: CLIParams, client: AsyncOpenAI
|
192
|
+
) -> Optional[Dict[str, Any]]:
|
193
|
+
"""Process Code Interpreter configuration from CLI arguments.
|
194
|
+
|
195
|
+
Args:
|
196
|
+
args: CLI parameters containing Code Interpreter settings
|
197
|
+
client: AsyncOpenAI client for file uploads
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
Dictionary with Code Interpreter tool config and manager, or None if no files specified
|
201
|
+
|
202
|
+
Raises:
|
203
|
+
CLIError: If Code Interpreter configuration is invalid
|
204
|
+
"""
|
205
|
+
logger.debug("=== Code Interpreter Configuration Processing ===")
|
206
|
+
|
207
|
+
# Collect all files to upload
|
208
|
+
files_to_upload = []
|
209
|
+
|
210
|
+
# Add individual files (extract paths from tuples)
|
211
|
+
code_interpreter_files = args.get("code_interpreter_files", [])
|
212
|
+
for file_entry in code_interpreter_files:
|
213
|
+
if isinstance(file_entry, tuple):
|
214
|
+
# Extract path from (name, path) tuple
|
215
|
+
_, file_path = file_entry
|
216
|
+
files_to_upload.append(str(file_path))
|
217
|
+
else:
|
218
|
+
# Handle legacy string format
|
219
|
+
files_to_upload.append(str(file_entry))
|
220
|
+
|
221
|
+
# Add files from directories
|
222
|
+
for directory in args.get("code_interpreter_dirs", []):
|
223
|
+
try:
|
224
|
+
dir_path = Path(directory)
|
225
|
+
if not dir_path.exists():
|
226
|
+
raise CLIError(
|
227
|
+
f"Directory not found: {directory}",
|
228
|
+
exit_code=ExitCode.USAGE_ERROR,
|
229
|
+
)
|
230
|
+
|
231
|
+
# Get all files from directory (non-recursive for safety)
|
232
|
+
for file_path in dir_path.iterdir():
|
233
|
+
if file_path.is_file():
|
234
|
+
files_to_upload.append(str(file_path))
|
235
|
+
|
236
|
+
except Exception as e:
|
237
|
+
raise CLIError(
|
238
|
+
f"Failed to process directory {directory}: {e}",
|
239
|
+
exit_code=ExitCode.USAGE_ERROR,
|
240
|
+
)
|
241
|
+
|
242
|
+
# If no files specified, return None
|
243
|
+
if not files_to_upload:
|
244
|
+
return None
|
245
|
+
|
246
|
+
# Create Code Interpreter manager
|
247
|
+
manager = CodeInterpreterManager(client)
|
248
|
+
|
249
|
+
# Validate files before upload
|
250
|
+
validation_errors = manager.validate_files_for_upload(files_to_upload)
|
251
|
+
if validation_errors:
|
252
|
+
error_msg = "Code Interpreter file validation failed:\n" + "\n".join(
|
253
|
+
f"- {error}" for error in validation_errors
|
254
|
+
)
|
255
|
+
raise CLIError(error_msg, exit_code=ExitCode.USAGE_ERROR)
|
256
|
+
|
257
|
+
try:
|
258
|
+
# Upload files
|
259
|
+
logger.debug(
|
260
|
+
f"Uploading {len(files_to_upload)} files for Code Interpreter"
|
261
|
+
)
|
262
|
+
file_ids = await manager.upload_files_for_code_interpreter(
|
263
|
+
files_to_upload
|
264
|
+
)
|
265
|
+
|
266
|
+
# Build tool configuration
|
267
|
+
# Cast to concrete CodeInterpreterManager to access build_tool_config
|
268
|
+
concrete_ci_manager = manager
|
269
|
+
if hasattr(concrete_ci_manager, "build_tool_config"):
|
270
|
+
ci_tool_config = concrete_ci_manager.build_tool_config(file_ids)
|
271
|
+
logger.debug(f"Code Interpreter tool config: {ci_tool_config}")
|
272
|
+
return {
|
273
|
+
"tool_config": ci_tool_config,
|
274
|
+
"manager": manager,
|
275
|
+
"file_ids": file_ids,
|
276
|
+
}
|
277
|
+
else:
|
278
|
+
logger.warning(
|
279
|
+
"Code Interpreter manager does not have build_tool_config method"
|
280
|
+
)
|
281
|
+
return None
|
282
|
+
|
283
|
+
except Exception as e:
|
284
|
+
logger.error(f"Failed to configure Code Interpreter: {e}")
|
285
|
+
# Clean up any uploaded files on error
|
286
|
+
await manager.cleanup_uploaded_files()
|
287
|
+
# Map tool-specific errors
|
288
|
+
mapped_error = APIErrorMapper.map_tool_error("code-interpreter", e)
|
289
|
+
raise mapped_error
|
290
|
+
|
291
|
+
|
292
|
+
async def process_file_search_configuration(
|
293
|
+
args: CLIParams, client: AsyncOpenAI
|
294
|
+
) -> Optional[Dict[str, Any]]:
|
295
|
+
"""Process File Search configuration from CLI arguments.
|
296
|
+
|
297
|
+
Args:
|
298
|
+
args: CLI parameters containing File Search settings
|
299
|
+
client: AsyncOpenAI client for vector store operations
|
300
|
+
|
301
|
+
Returns:
|
302
|
+
Dictionary with File Search tool config and manager, or None if no files specified
|
303
|
+
|
304
|
+
Raises:
|
305
|
+
CLIError: If File Search configuration is invalid
|
306
|
+
"""
|
307
|
+
logger.debug("=== File Search Configuration Processing ===")
|
308
|
+
|
309
|
+
# Collect all files to upload
|
310
|
+
files_to_upload = []
|
311
|
+
|
312
|
+
# Add individual files (extract paths from tuples)
|
313
|
+
file_search_files = args.get("file_search_files", [])
|
314
|
+
for file_entry in file_search_files:
|
315
|
+
if isinstance(file_entry, tuple):
|
316
|
+
# Extract path from (name, path) tuple
|
317
|
+
_, file_path = file_entry
|
318
|
+
files_to_upload.append(str(file_path))
|
319
|
+
else:
|
320
|
+
# Handle legacy string format
|
321
|
+
files_to_upload.append(str(file_entry))
|
322
|
+
|
323
|
+
# Add files from directories
|
324
|
+
for directory in args.get("file_search_dirs", []):
|
325
|
+
try:
|
326
|
+
dir_path = Path(directory)
|
327
|
+
if not dir_path.exists():
|
328
|
+
raise CLIError(
|
329
|
+
f"Directory not found: {directory}",
|
330
|
+
exit_code=ExitCode.USAGE_ERROR,
|
331
|
+
)
|
332
|
+
|
333
|
+
# Get all files from directory (non-recursive for safety)
|
334
|
+
for file_path in dir_path.iterdir():
|
335
|
+
if file_path.is_file():
|
336
|
+
files_to_upload.append(str(file_path))
|
337
|
+
|
338
|
+
except Exception as e:
|
339
|
+
raise CLIError(
|
340
|
+
f"Failed to process directory {directory}: {e}",
|
341
|
+
exit_code=ExitCode.USAGE_ERROR,
|
342
|
+
)
|
343
|
+
|
344
|
+
# If no files specified, return None
|
345
|
+
if not files_to_upload:
|
346
|
+
return None
|
347
|
+
|
348
|
+
# Create File Search manager
|
349
|
+
manager = FileSearchManager(client)
|
350
|
+
|
351
|
+
# Validate files before upload
|
352
|
+
validation_errors = manager.validate_files_for_file_search(files_to_upload)
|
353
|
+
if validation_errors:
|
354
|
+
error_msg = "File Search file validation failed:\n" + "\n".join(
|
355
|
+
f"- {error}" for error in validation_errors
|
356
|
+
)
|
357
|
+
raise CLIError(error_msg, exit_code=ExitCode.USAGE_ERROR)
|
358
|
+
|
359
|
+
try:
|
360
|
+
# Get configuration parameters
|
361
|
+
vector_store_name = args.get(
|
362
|
+
"file_search_vector_store_name", "ostruct_search"
|
363
|
+
)
|
364
|
+
retry_count = args.get("file_search_retry_count", 3)
|
365
|
+
timeout = args.get("file_search_timeout", 60.0)
|
366
|
+
|
367
|
+
# Create vector store with retry logic
|
368
|
+
logger.debug(
|
369
|
+
f"Creating vector store '{vector_store_name}' for {len(files_to_upload)} files"
|
370
|
+
)
|
371
|
+
vector_store_id = await manager.create_vector_store_with_retry(
|
372
|
+
name=vector_store_name, max_retries=retry_count
|
373
|
+
)
|
374
|
+
|
375
|
+
# Upload files to vector store
|
376
|
+
logger.debug(
|
377
|
+
f"Uploading {len(files_to_upload)} files to vector store with {retry_count} max retries"
|
378
|
+
)
|
379
|
+
file_ids = await manager.upload_files_to_vector_store(
|
380
|
+
vector_store_id=vector_store_id,
|
381
|
+
files=files_to_upload,
|
382
|
+
max_retries=retry_count,
|
383
|
+
)
|
384
|
+
|
385
|
+
# Wait for vector store to be ready
|
386
|
+
logger.debug(
|
387
|
+
f"Waiting for vector store indexing (timeout: {timeout}s)"
|
388
|
+
)
|
389
|
+
is_ready = await manager.wait_for_vector_store_ready(
|
390
|
+
vector_store_id=vector_store_id, timeout=timeout
|
391
|
+
)
|
392
|
+
|
393
|
+
if not is_ready:
|
394
|
+
logger.warning(
|
395
|
+
f"Vector store may not be fully indexed within {timeout}s timeout"
|
396
|
+
)
|
397
|
+
# Continue anyway as indexing is typically instant
|
398
|
+
|
399
|
+
# Build tool configuration
|
400
|
+
tool_config = manager.build_tool_config(vector_store_id)
|
401
|
+
|
402
|
+
# Get performance info for user awareness
|
403
|
+
perf_info = manager.get_performance_info()
|
404
|
+
logger.debug(f"File Search performance info: {perf_info}")
|
405
|
+
|
406
|
+
return {
|
407
|
+
"tool_config": tool_config,
|
408
|
+
"manager": manager,
|
409
|
+
"vector_store_id": vector_store_id,
|
410
|
+
"file_ids": file_ids,
|
411
|
+
"perf_info": perf_info,
|
412
|
+
}
|
413
|
+
|
414
|
+
except Exception as e:
|
415
|
+
logger.error(f"Failed to configure File Search: {e}")
|
416
|
+
# Clean up any created resources on error
|
417
|
+
await manager.cleanup_resources()
|
418
|
+
# Map tool-specific errors
|
419
|
+
mapped_error = APIErrorMapper.map_tool_error("file-search", e)
|
420
|
+
raise mapped_error
|
421
|
+
|
422
|
+
|
423
|
+
async def stream_structured_output(
|
424
|
+
client: AsyncOpenAI,
|
425
|
+
model: str,
|
426
|
+
system_prompt: str,
|
427
|
+
user_prompt: str,
|
428
|
+
output_schema: Type[BaseModel],
|
429
|
+
output_file: Optional[str] = None,
|
430
|
+
tools: Optional[List[dict]] = None,
|
431
|
+
**kwargs: Any,
|
432
|
+
) -> AsyncGenerator[BaseModel, None]:
|
433
|
+
"""Stream structured output from OpenAI API using Responses API.
|
434
|
+
|
435
|
+
This function uses the OpenAI Responses API with strict mode schema validation
|
436
|
+
to generate structured output that matches the provided Pydantic model.
|
437
|
+
|
438
|
+
Args:
|
439
|
+
client: The OpenAI client to use
|
440
|
+
model: The model to use
|
441
|
+
system_prompt: The system prompt to use
|
442
|
+
user_prompt: The user prompt to use
|
443
|
+
output_schema: The Pydantic model to validate responses against
|
444
|
+
output_file: Optional file to write output to
|
445
|
+
tools: Optional list of tools (e.g., MCP, Code Interpreter) to include
|
446
|
+
**kwargs: Additional parameters to pass to the API
|
447
|
+
|
448
|
+
Returns:
|
449
|
+
An async generator yielding validated model instances
|
450
|
+
|
451
|
+
Raises:
|
452
|
+
ValueError: If the model does not support structured output or parameters are invalid
|
453
|
+
StreamInterruptedError: If the stream is interrupted
|
454
|
+
APIResponseError: If there is an API error
|
455
|
+
"""
|
456
|
+
try:
|
457
|
+
# Check if model supports structured output using our stub function
|
458
|
+
if not supports_structured_output(model):
|
459
|
+
raise ValueError(
|
460
|
+
f"Model {model} does not support structured output with json_schema response format. "
|
461
|
+
"Please use a model that supports structured output."
|
462
|
+
)
|
463
|
+
|
464
|
+
# Extract non-model parameters
|
465
|
+
on_log = kwargs.pop("on_log", None)
|
466
|
+
|
467
|
+
# Handle model-specific parameters
|
468
|
+
stream_kwargs = {}
|
469
|
+
registry = ModelRegistry.get_instance()
|
470
|
+
capabilities = registry.get_capabilities(model)
|
471
|
+
|
472
|
+
# Validate and include supported parameters
|
473
|
+
for param_name, value in kwargs.items():
|
474
|
+
if param_name in capabilities.supported_parameters:
|
475
|
+
# Validate the parameter value
|
476
|
+
capabilities.validate_parameter(param_name, value)
|
477
|
+
stream_kwargs[param_name] = value
|
478
|
+
else:
|
479
|
+
logger.warning(
|
480
|
+
f"Parameter {param_name} is not supported by model {model} and will be ignored"
|
481
|
+
)
|
482
|
+
|
483
|
+
# Prepare schema for strict mode
|
484
|
+
schema = output_schema.model_json_schema()
|
485
|
+
strict_schema = copy.deepcopy(schema)
|
486
|
+
make_strict(strict_schema)
|
487
|
+
|
488
|
+
# Generate schema name from model class name
|
489
|
+
schema_name = output_schema.__name__.lower()
|
490
|
+
|
491
|
+
# Combine system and user prompts into a single input string
|
492
|
+
combined_prompt = f"{system_prompt}\n\n{user_prompt}"
|
493
|
+
|
494
|
+
# Prepare API call parameters
|
495
|
+
api_params = {
|
496
|
+
"model": model,
|
497
|
+
"input": combined_prompt,
|
498
|
+
"text": {
|
499
|
+
"format": {
|
500
|
+
"type": "json_schema",
|
501
|
+
"name": schema_name,
|
502
|
+
"schema": strict_schema,
|
503
|
+
"strict": True,
|
504
|
+
}
|
505
|
+
},
|
506
|
+
"stream": True,
|
507
|
+
**stream_kwargs,
|
508
|
+
}
|
509
|
+
|
510
|
+
# Add tools if provided
|
511
|
+
if tools:
|
512
|
+
api_params["tools"] = tools
|
513
|
+
logger.debug("Tools: %s", json.dumps(tools, indent=2))
|
514
|
+
|
515
|
+
# Log the API request details
|
516
|
+
logger.debug("Making OpenAI Responses API request with:")
|
517
|
+
logger.debug("Model: %s", model)
|
518
|
+
logger.debug("Combined prompt: %s", combined_prompt)
|
519
|
+
logger.debug("Parameters: %s", json.dumps(stream_kwargs, indent=2))
|
520
|
+
logger.debug("Schema: %s", json.dumps(strict_schema, indent=2))
|
521
|
+
logger.debug("Tools being passed to API: %s", tools)
|
522
|
+
logger.debug(
|
523
|
+
"Complete API params: %s",
|
524
|
+
json.dumps(api_params, indent=2, default=str),
|
525
|
+
)
|
526
|
+
|
527
|
+
# Use the Responses API with streaming
|
528
|
+
response = await client.responses.create(**api_params)
|
529
|
+
|
530
|
+
# Process streaming response
|
531
|
+
accumulated_content = ""
|
532
|
+
async for chunk in response:
|
533
|
+
if on_log:
|
534
|
+
on_log(logging.DEBUG, f"Received chunk: {chunk}", {})
|
535
|
+
|
536
|
+
# Check for tool calls (including web search)
|
537
|
+
if hasattr(chunk, "choices") and chunk.choices:
|
538
|
+
choice = chunk.choices[0]
|
539
|
+
# Log tool calls if present
|
540
|
+
if (
|
541
|
+
hasattr(choice, "delta")
|
542
|
+
and hasattr(choice.delta, "tool_calls")
|
543
|
+
and choice.delta.tool_calls
|
544
|
+
):
|
545
|
+
for tool_call in choice.delta.tool_calls:
|
546
|
+
if (
|
547
|
+
hasattr(tool_call, "type")
|
548
|
+
and tool_call.type == "web_search_preview"
|
549
|
+
):
|
550
|
+
tool_id = getattr(tool_call, "id", "unknown")
|
551
|
+
logger.debug(
|
552
|
+
f"Web search tool invoked (id={tool_id})"
|
553
|
+
)
|
554
|
+
elif hasattr(tool_call, "function") and hasattr(
|
555
|
+
tool_call.function, "name"
|
556
|
+
):
|
557
|
+
# Handle other tool types for completeness
|
558
|
+
tool_name = tool_call.function.name
|
559
|
+
tool_id = getattr(tool_call, "id", "unknown")
|
560
|
+
logger.debug(
|
561
|
+
f"Tool '{tool_name}' invoked (id={tool_id})"
|
562
|
+
)
|
563
|
+
|
564
|
+
# Handle different response formats based on the chunk structure
|
565
|
+
content_added = False
|
566
|
+
|
567
|
+
# Try different possible response formats
|
568
|
+
if hasattr(chunk, "choices") and chunk.choices:
|
569
|
+
# Standard chat completion format
|
570
|
+
choice = chunk.choices[0]
|
571
|
+
if (
|
572
|
+
hasattr(choice, "delta")
|
573
|
+
and hasattr(choice.delta, "content")
|
574
|
+
and choice.delta.content
|
575
|
+
):
|
576
|
+
accumulated_content += choice.delta.content
|
577
|
+
content_added = True
|
578
|
+
elif (
|
579
|
+
hasattr(choice, "message")
|
580
|
+
and hasattr(choice.message, "content")
|
581
|
+
and choice.message.content
|
582
|
+
):
|
583
|
+
accumulated_content += choice.message.content
|
584
|
+
content_added = True
|
585
|
+
elif hasattr(chunk, "response") and hasattr(
|
586
|
+
chunk.response, "body"
|
587
|
+
):
|
588
|
+
# Responses API format
|
589
|
+
accumulated_content += chunk.response.body
|
590
|
+
content_added = True
|
591
|
+
elif hasattr(chunk, "content"):
|
592
|
+
# Direct content
|
593
|
+
accumulated_content += chunk.content
|
594
|
+
content_added = True
|
595
|
+
elif hasattr(chunk, "text"):
|
596
|
+
# Text content
|
597
|
+
accumulated_content += chunk.text
|
598
|
+
content_added = True
|
599
|
+
|
600
|
+
if on_log and content_added:
|
601
|
+
on_log(
|
602
|
+
logging.DEBUG,
|
603
|
+
f"Added content, total length: {len(accumulated_content)}",
|
604
|
+
{},
|
605
|
+
)
|
606
|
+
|
607
|
+
# Try to parse and validate accumulated content as complete JSON
|
608
|
+
try:
|
609
|
+
if accumulated_content.strip():
|
610
|
+
# Attempt to parse as complete JSON
|
611
|
+
data = json.loads(accumulated_content.strip())
|
612
|
+
validated = output_schema.model_validate(data)
|
613
|
+
yield validated
|
614
|
+
# Reset for next complete response (if any)
|
615
|
+
accumulated_content = ""
|
616
|
+
except (json.JSONDecodeError, ValueError):
|
617
|
+
# Not yet complete JSON, continue accumulating
|
618
|
+
continue
|
619
|
+
|
620
|
+
# Handle any remaining content
|
621
|
+
if accumulated_content.strip():
|
622
|
+
try:
|
623
|
+
data = json.loads(accumulated_content.strip())
|
624
|
+
validated = output_schema.model_validate(data)
|
625
|
+
yield validated
|
626
|
+
except (json.JSONDecodeError, ValueError) as e:
|
627
|
+
logger.error(f"Failed to parse final accumulated content: {e}")
|
628
|
+
raise StreamParseError(
|
629
|
+
f"Failed to parse response as valid JSON: {e}"
|
630
|
+
)
|
631
|
+
|
632
|
+
except Exception as e:
|
633
|
+
# Map OpenAI errors using the error mapper
|
634
|
+
|
635
|
+
if isinstance(e, OpenAIError):
|
636
|
+
mapped_error = APIErrorMapper.map_openai_error(e)
|
637
|
+
logger.error(f"OpenAI API error mapped: {mapped_error}")
|
638
|
+
raise mapped_error
|
639
|
+
|
640
|
+
# Handle special schema array error with detailed guidance
|
641
|
+
if "Invalid schema for response_format" in str(
|
642
|
+
e
|
643
|
+
) and 'type: "array"' in str(e):
|
644
|
+
error_msg = (
|
645
|
+
"OpenAI API Schema Error: The schema must have a root type of 'object', not 'array'. "
|
646
|
+
"To fix this:\n"
|
647
|
+
"1. Wrap your array in an object property, e.g.:\n"
|
648
|
+
" {\n"
|
649
|
+
' "type": "object",\n'
|
650
|
+
' "properties": {\n'
|
651
|
+
' "items": {\n'
|
652
|
+
' "type": "array",\n'
|
653
|
+
' "items": { ... your array items schema ... }\n'
|
654
|
+
" }\n"
|
655
|
+
" }\n"
|
656
|
+
" }\n"
|
657
|
+
"2. Make sure to update your template to handle the wrapper object."
|
658
|
+
)
|
659
|
+
logger.error(error_msg)
|
660
|
+
raise InvalidResponseFormatError(error_msg)
|
661
|
+
|
662
|
+
# For non-OpenAI errors, create appropriate CLIErrors
|
663
|
+
error_msg = str(e).lower()
|
664
|
+
if (
|
665
|
+
"context_length_exceeded" in error_msg
|
666
|
+
or "maximum context length" in error_msg
|
667
|
+
):
|
668
|
+
raise CLIError(
|
669
|
+
f"Context length exceeded: {str(e)}",
|
670
|
+
exit_code=ExitCode.API_ERROR,
|
671
|
+
)
|
672
|
+
elif "rate_limit" in error_msg or "429" in str(e):
|
673
|
+
raise CLIError(
|
674
|
+
f"Rate limit exceeded: {str(e)}", exit_code=ExitCode.API_ERROR
|
675
|
+
)
|
676
|
+
elif "invalid_api_key" in error_msg:
|
677
|
+
raise CLIError(
|
678
|
+
f"Invalid API key: {str(e)}", exit_code=ExitCode.API_ERROR
|
679
|
+
)
|
680
|
+
else:
|
681
|
+
logger.error(f"Unmapped API error: {e}")
|
682
|
+
raise APIResponseError(str(e))
|
683
|
+
finally:
|
684
|
+
# Note: We don't close the client here as it may be reused
|
685
|
+
# The caller is responsible for client lifecycle management
|
686
|
+
pass
|
687
|
+
|
688
|
+
|
689
|
+
# Note: validation functions are defined in cli.py to avoid circular imports
|
690
|
+
|
691
|
+
|
692
|
+
async def process_templates(
|
693
|
+
args: CLIParams,
|
694
|
+
task_template: str,
|
695
|
+
template_context: Any,
|
696
|
+
env: Any,
|
697
|
+
template_path: str,
|
698
|
+
) -> tuple[str, str]:
|
699
|
+
"""Process templates.
|
700
|
+
|
701
|
+
This function will be moved from cli.py later.
|
702
|
+
For now, we import it from the main cli module to avoid circular imports.
|
703
|
+
"""
|
704
|
+
# Import here to avoid circular dependency during refactoring
|
705
|
+
from .template_processor import process_templates as _process_templates
|
706
|
+
|
707
|
+
return await _process_templates(
|
708
|
+
args, task_template, template_context, env, template_path
|
709
|
+
)
|
710
|
+
|
711
|
+
|
712
|
+
async def execute_model(
|
713
|
+
args: CLIParams,
|
714
|
+
params: Dict[str, Any],
|
715
|
+
output_model: Type[BaseModel],
|
716
|
+
system_prompt: str,
|
717
|
+
user_prompt: str,
|
718
|
+
) -> ExitCode:
|
719
|
+
"""Execute the model and handle the response.
|
720
|
+
|
721
|
+
Args:
|
722
|
+
args: Command line arguments
|
723
|
+
params: Validated model parameters
|
724
|
+
output_model: Generated Pydantic model
|
725
|
+
system_prompt: Processed system prompt
|
726
|
+
user_prompt: Processed user prompt
|
727
|
+
|
728
|
+
Returns:
|
729
|
+
Exit code indicating success or failure
|
730
|
+
|
731
|
+
Raises:
|
732
|
+
CLIError: For execution errors
|
733
|
+
UnattendedOperationTimeoutError: For operation timeouts
|
734
|
+
"""
|
735
|
+
logger.debug("=== Execution Phase ===")
|
736
|
+
|
737
|
+
# Initialize unattended operation manager
|
738
|
+
timeout_seconds = int(args.get("timeout", 3600))
|
739
|
+
operation_manager = UnattendedOperationManager(timeout_seconds)
|
740
|
+
|
741
|
+
# Pre-validate unattended compatibility
|
742
|
+
# Note: MCP validation is handled during MCP configuration processing
|
743
|
+
# mcp_servers = args.get("mcp_servers", [])
|
744
|
+
# if mcp_servers:
|
745
|
+
# validator = UnattendedCompatibilityValidator()
|
746
|
+
# validation_errors = validator.validate_mcp_servers(mcp_servers)
|
747
|
+
# if validation_errors:
|
748
|
+
# error_msg = "Unattended operation compatibility errors:\n" + "\n".join(
|
749
|
+
# f" • {err}" for err in validation_errors
|
750
|
+
# )
|
751
|
+
# logger.error(error_msg)
|
752
|
+
# raise CLIError(error_msg, exit_code=ExitCode.VALIDATION_ERROR)
|
753
|
+
|
754
|
+
api_key = args.get("api_key") or os.getenv("OPENAI_API_KEY")
|
755
|
+
if not api_key:
|
756
|
+
msg = "No API key provided. Set OPENAI_API_KEY environment variable or use --api-key"
|
757
|
+
logger.error(msg)
|
758
|
+
raise CLIError(msg, exit_code=ExitCode.API_ERROR)
|
759
|
+
|
760
|
+
client = AsyncOpenAI(
|
761
|
+
api_key=api_key, timeout=min(args.get("timeout", 60.0), 300.0)
|
762
|
+
) # Cap at 5 min for client timeout
|
763
|
+
|
764
|
+
# Create service container for dependency management
|
765
|
+
services = ServiceContainer(client, args)
|
766
|
+
|
767
|
+
# Initialize variables that will be used in nested functions
|
768
|
+
code_interpreter_info = None
|
769
|
+
file_search_info = None
|
770
|
+
|
771
|
+
# Create detailed log callback
|
772
|
+
def log_callback(level: int, message: str, extra: dict[str, Any]) -> None:
|
773
|
+
if args.get("debug_openai_stream", False):
|
774
|
+
if extra:
|
775
|
+
extra_str = LogSerializer.serialize_log_extra(extra)
|
776
|
+
if extra_str:
|
777
|
+
logger.debug("%s\nExtra:\n%s", message, extra_str)
|
778
|
+
else:
|
779
|
+
logger.debug("%s\nExtra: Failed to serialize", message)
|
780
|
+
else:
|
781
|
+
logger.debug(message)
|
782
|
+
|
783
|
+
async def execute_main_operation() -> ExitCode:
|
784
|
+
"""Main execution operation wrapped for timeout handling."""
|
785
|
+
# Create output buffer
|
786
|
+
output_buffer = []
|
787
|
+
|
788
|
+
# Process tool configurations
|
789
|
+
tools = []
|
790
|
+
nonlocal code_interpreter_info, file_search_info
|
791
|
+
|
792
|
+
# Process MCP configuration if provided
|
793
|
+
if services.is_configured("mcp"):
|
794
|
+
mcp_manager = await services.get_mcp_manager()
|
795
|
+
if mcp_manager:
|
796
|
+
tools.extend(mcp_manager.get_tools_for_responses_api())
|
797
|
+
|
798
|
+
# Get routing result from explicit file processor
|
799
|
+
routing_result = args.get("_routing_result")
|
800
|
+
if routing_result is not None and not isinstance(
|
801
|
+
routing_result, ProcessingResult
|
802
|
+
):
|
803
|
+
routing_result = None # Invalid type, treat as None
|
804
|
+
routing_result_typed: Optional[ProcessingResult] = routing_result
|
805
|
+
|
806
|
+
# Process Code Interpreter configuration if enabled
|
807
|
+
if (
|
808
|
+
routing_result_typed
|
809
|
+
and "code-interpreter" in routing_result_typed.enabled_tools
|
810
|
+
):
|
811
|
+
code_interpreter_files = routing_result_typed.validated_files.get(
|
812
|
+
"code-interpreter", []
|
813
|
+
)
|
814
|
+
if code_interpreter_files:
|
815
|
+
# Override args with routed files for Code Interpreter processing
|
816
|
+
ci_args = dict(args)
|
817
|
+
ci_args["code_interpreter_files"] = code_interpreter_files
|
818
|
+
ci_args["code_interpreter_dirs"] = (
|
819
|
+
[]
|
820
|
+
) # Files already expanded from dirs
|
821
|
+
ci_args["code_interpreter"] = (
|
822
|
+
True # Enable for service container
|
823
|
+
)
|
824
|
+
|
825
|
+
# Create temporary service container with updated args
|
826
|
+
ci_services = ServiceContainer(client, ci_args) # type: ignore[arg-type]
|
827
|
+
code_interpreter_manager = (
|
828
|
+
await ci_services.get_code_interpreter_manager()
|
829
|
+
)
|
830
|
+
if code_interpreter_manager:
|
831
|
+
# Get the uploaded file IDs from the manager
|
832
|
+
if (
|
833
|
+
hasattr(code_interpreter_manager, "uploaded_file_ids")
|
834
|
+
and code_interpreter_manager.uploaded_file_ids
|
835
|
+
):
|
836
|
+
file_ids = code_interpreter_manager.uploaded_file_ids
|
837
|
+
# Cast to concrete CodeInterpreterManager to access build_tool_config
|
838
|
+
concrete_ci_manager = code_interpreter_manager
|
839
|
+
if hasattr(concrete_ci_manager, "build_tool_config"):
|
840
|
+
ci_tool_config = (
|
841
|
+
concrete_ci_manager.build_tool_config(file_ids)
|
842
|
+
)
|
843
|
+
logger.debug(
|
844
|
+
f"Code Interpreter tool config: {ci_tool_config}"
|
845
|
+
)
|
846
|
+
code_interpreter_info = {
|
847
|
+
"manager": code_interpreter_manager,
|
848
|
+
"tool_config": ci_tool_config,
|
849
|
+
}
|
850
|
+
tools.append(ci_tool_config)
|
851
|
+
else:
|
852
|
+
logger.warning(
|
853
|
+
"Code Interpreter manager has no uploaded file IDs"
|
854
|
+
)
|
855
|
+
|
856
|
+
# Process File Search configuration if enabled
|
857
|
+
if (
|
858
|
+
routing_result_typed
|
859
|
+
and "file-search" in routing_result_typed.enabled_tools
|
860
|
+
):
|
861
|
+
file_search_files = routing_result_typed.validated_files.get(
|
862
|
+
"file-search", []
|
863
|
+
)
|
864
|
+
if file_search_files:
|
865
|
+
# Override args with routed files for File Search processing
|
866
|
+
fs_args = dict(args)
|
867
|
+
fs_args["file_search_files"] = file_search_files
|
868
|
+
fs_args["file_search_dirs"] = (
|
869
|
+
[]
|
870
|
+
) # Files already expanded from dirs
|
871
|
+
fs_args["file_search"] = True # Enable for service container
|
872
|
+
|
873
|
+
# Create temporary service container with updated args
|
874
|
+
fs_services = ServiceContainer(client, fs_args) # type: ignore[arg-type]
|
875
|
+
file_search_manager = (
|
876
|
+
await fs_services.get_file_search_manager()
|
877
|
+
)
|
878
|
+
if file_search_manager:
|
879
|
+
# Get the vector store ID from the manager's created vector stores
|
880
|
+
# The most recent one should be the one we need
|
881
|
+
if (
|
882
|
+
hasattr(file_search_manager, "created_vector_stores")
|
883
|
+
and file_search_manager.created_vector_stores
|
884
|
+
):
|
885
|
+
vector_store_id = (
|
886
|
+
file_search_manager.created_vector_stores[-1]
|
887
|
+
)
|
888
|
+
# Cast to concrete FileSearchManager to access build_tool_config
|
889
|
+
concrete_fs_manager = file_search_manager
|
890
|
+
if hasattr(concrete_fs_manager, "build_tool_config"):
|
891
|
+
fs_tool_config = (
|
892
|
+
concrete_fs_manager.build_tool_config(
|
893
|
+
vector_store_id
|
894
|
+
)
|
895
|
+
)
|
896
|
+
logger.debug(
|
897
|
+
f"File Search tool config: {fs_tool_config}"
|
898
|
+
)
|
899
|
+
file_search_info = {
|
900
|
+
"manager": file_search_manager,
|
901
|
+
"tool_config": fs_tool_config,
|
902
|
+
}
|
903
|
+
tools.append(fs_tool_config)
|
904
|
+
else:
|
905
|
+
logger.warning(
|
906
|
+
"File Search manager has no created vector stores"
|
907
|
+
)
|
908
|
+
|
909
|
+
# Process Web Search configuration if enabled
|
910
|
+
# Check CLI flags first, then fall back to config defaults
|
911
|
+
web_search_from_cli = args.get("web_search", False)
|
912
|
+
no_web_search_from_cli = args.get("no_web_search", False)
|
913
|
+
|
914
|
+
# Load configuration to check defaults
|
915
|
+
from typing import cast
|
916
|
+
|
917
|
+
config_path = cast(Union[str, Path, None], args.get("config"))
|
918
|
+
config = OstructConfig.load(config_path)
|
919
|
+
web_search_config = config.get_web_search_config()
|
920
|
+
|
921
|
+
# Determine if web search should be enabled
|
922
|
+
web_search_enabled = False
|
923
|
+
if web_search_from_cli:
|
924
|
+
# Explicit --web-search flag takes precedence
|
925
|
+
web_search_enabled = True
|
926
|
+
elif no_web_search_from_cli:
|
927
|
+
# Explicit --no-web-search flag disables
|
928
|
+
web_search_enabled = False
|
929
|
+
else:
|
930
|
+
# Use config default
|
931
|
+
web_search_enabled = web_search_config.enable_by_default
|
932
|
+
|
933
|
+
if web_search_enabled:
|
934
|
+
# Import validation function
|
935
|
+
from .model_validation import validate_web_search_compatibility
|
936
|
+
|
937
|
+
# Check model compatibility
|
938
|
+
compatibility_warning = validate_web_search_compatibility(
|
939
|
+
args["model"], True
|
940
|
+
)
|
941
|
+
if compatibility_warning:
|
942
|
+
logger.warning(compatibility_warning)
|
943
|
+
# For now, we'll warn but still allow the user to proceed
|
944
|
+
# In the future, this could be made stricter based on user feedback
|
945
|
+
|
946
|
+
# Check for Azure OpenAI endpoint guard-rail
|
947
|
+
api_base = os.getenv("OPENAI_API_BASE", "")
|
948
|
+
if "azure.com" in api_base.lower():
|
949
|
+
logger.warning(
|
950
|
+
"Web search is not currently supported or may be unreliable with Azure OpenAI endpoints and has been disabled."
|
951
|
+
)
|
952
|
+
else:
|
953
|
+
web_tool_config: Dict[str, Any] = {
|
954
|
+
"type": "web_search_preview"
|
955
|
+
}
|
956
|
+
|
957
|
+
# Add user_location if provided via CLI or config
|
958
|
+
user_country = args.get("user_country")
|
959
|
+
user_city = args.get("user_city")
|
960
|
+
user_region = args.get("user_region")
|
961
|
+
|
962
|
+
# Fall back to config if not provided via CLI
|
963
|
+
if (
|
964
|
+
not any([user_country, user_city, user_region])
|
965
|
+
and web_search_config.user_location
|
966
|
+
):
|
967
|
+
user_country = web_search_config.user_location.country
|
968
|
+
user_city = web_search_config.user_location.city
|
969
|
+
user_region = web_search_config.user_location.region
|
970
|
+
|
971
|
+
if user_country or user_city or user_region:
|
972
|
+
user_location: Dict[str, Any] = {"type": "approximate"}
|
973
|
+
if user_country:
|
974
|
+
user_location["country"] = user_country
|
975
|
+
if user_city:
|
976
|
+
user_location["city"] = user_city
|
977
|
+
if user_region:
|
978
|
+
user_location["region"] = user_region
|
979
|
+
|
980
|
+
web_tool_config["user_location"] = user_location
|
981
|
+
|
982
|
+
# Add search_context_size if provided via CLI or config
|
983
|
+
search_context_size = (
|
984
|
+
args.get("search_context_size")
|
985
|
+
or web_search_config.search_context_size
|
986
|
+
)
|
987
|
+
if search_context_size:
|
988
|
+
web_tool_config["search_context_size"] = (
|
989
|
+
search_context_size
|
990
|
+
)
|
991
|
+
|
992
|
+
tools.append(web_tool_config)
|
993
|
+
logger.debug(f"Web Search tool config: {web_tool_config}")
|
994
|
+
|
995
|
+
# Debug log the final tools array
|
996
|
+
logger.debug(f"Final tools array being passed to API: {tools}")
|
997
|
+
|
998
|
+
# Stream the response
|
999
|
+
logger.debug(f"Tools being passed to API: {tools}")
|
1000
|
+
async for response in stream_structured_output(
|
1001
|
+
client=client,
|
1002
|
+
model=args["model"],
|
1003
|
+
system_prompt=system_prompt,
|
1004
|
+
user_prompt=user_prompt,
|
1005
|
+
output_schema=output_model,
|
1006
|
+
output_file=args.get("output_file"),
|
1007
|
+
on_log=log_callback,
|
1008
|
+
tools=tools,
|
1009
|
+
):
|
1010
|
+
output_buffer.append(response)
|
1011
|
+
|
1012
|
+
# Handle final output
|
1013
|
+
output_file = args.get("output_file")
|
1014
|
+
if output_file:
|
1015
|
+
with open(output_file, "w") as f:
|
1016
|
+
if len(output_buffer) == 1:
|
1017
|
+
f.write(output_buffer[0].model_dump_json(indent=2))
|
1018
|
+
else:
|
1019
|
+
# Build complete JSON array as a single string
|
1020
|
+
json_output = "[\n"
|
1021
|
+
for i, response in enumerate(output_buffer):
|
1022
|
+
if i > 0:
|
1023
|
+
json_output += ",\n"
|
1024
|
+
json_output += " " + response.model_dump_json(
|
1025
|
+
indent=2
|
1026
|
+
).replace("\n", "\n ")
|
1027
|
+
json_output += "\n]"
|
1028
|
+
f.write(json_output)
|
1029
|
+
else:
|
1030
|
+
# Write to stdout when no output file is specified
|
1031
|
+
if len(output_buffer) == 1:
|
1032
|
+
print(output_buffer[0].model_dump_json(indent=2))
|
1033
|
+
else:
|
1034
|
+
# Build complete JSON array as a single string
|
1035
|
+
json_output = "[\n"
|
1036
|
+
for i, response in enumerate(output_buffer):
|
1037
|
+
if i > 0:
|
1038
|
+
json_output += ",\n"
|
1039
|
+
json_output += " " + response.model_dump_json(
|
1040
|
+
indent=2
|
1041
|
+
).replace("\n", "\n ")
|
1042
|
+
json_output += "\n]"
|
1043
|
+
print(json_output)
|
1044
|
+
|
1045
|
+
# Handle file downloads from Code Interpreter if any were generated
|
1046
|
+
if (
|
1047
|
+
code_interpreter_info
|
1048
|
+
and hasattr(response, "file_ids")
|
1049
|
+
and response.file_ids
|
1050
|
+
):
|
1051
|
+
try:
|
1052
|
+
download_dir = args.get(
|
1053
|
+
"code_interpreter_download_dir", "./downloads"
|
1054
|
+
)
|
1055
|
+
manager = code_interpreter_info["manager"]
|
1056
|
+
# Type ignore since we know this is a CodeInterpreterManager
|
1057
|
+
downloaded_files = await manager.download_generated_files( # type: ignore[attr-defined]
|
1058
|
+
response.file_ids, download_dir
|
1059
|
+
)
|
1060
|
+
if downloaded_files:
|
1061
|
+
logger.info(
|
1062
|
+
f"Downloaded {len(downloaded_files)} generated files to {download_dir}"
|
1063
|
+
)
|
1064
|
+
for file_path in downloaded_files:
|
1065
|
+
logger.info(f" - {file_path}")
|
1066
|
+
except Exception as e:
|
1067
|
+
logger.warning(f"Failed to download generated files: {e}")
|
1068
|
+
|
1069
|
+
return ExitCode.SUCCESS
|
1070
|
+
|
1071
|
+
# Execute main operation with timeout safeguards
|
1072
|
+
try:
|
1073
|
+
result = await operation_manager.execute_with_safeguards(
|
1074
|
+
execute_main_operation, "model execution"
|
1075
|
+
)
|
1076
|
+
# The result should be an ExitCode from execute_main_operation
|
1077
|
+
return result # type: ignore[no-any-return]
|
1078
|
+
except (
|
1079
|
+
StreamInterruptedError,
|
1080
|
+
StreamBufferError,
|
1081
|
+
StreamParseError,
|
1082
|
+
APIResponseError,
|
1083
|
+
EmptyResponseError,
|
1084
|
+
InvalidResponseFormatError,
|
1085
|
+
) as e:
|
1086
|
+
logger.error("Stream error: %s", str(e))
|
1087
|
+
raise CLIError(str(e), exit_code=ExitCode.API_ERROR)
|
1088
|
+
except Exception as e:
|
1089
|
+
logger.exception("Unexpected error during streaming")
|
1090
|
+
raise CLIError(str(e), exit_code=ExitCode.UNKNOWN_ERROR)
|
1091
|
+
finally:
|
1092
|
+
# Clean up Code Interpreter files if requested
|
1093
|
+
if code_interpreter_info and args.get(
|
1094
|
+
"code_interpreter_cleanup", True
|
1095
|
+
):
|
1096
|
+
try:
|
1097
|
+
manager = code_interpreter_info["manager"]
|
1098
|
+
# Type ignore since we know this is a CodeInterpreterManager
|
1099
|
+
await manager.cleanup_uploaded_files() # type: ignore[attr-defined]
|
1100
|
+
logger.debug("Cleaned up Code Interpreter uploaded files")
|
1101
|
+
except Exception as e:
|
1102
|
+
logger.warning(
|
1103
|
+
f"Failed to clean up Code Interpreter files: {e}"
|
1104
|
+
)
|
1105
|
+
|
1106
|
+
# Clean up File Search resources if requested
|
1107
|
+
if file_search_info and args.get("file_search_cleanup", True):
|
1108
|
+
try:
|
1109
|
+
manager = file_search_info["manager"]
|
1110
|
+
# Type ignore since we know this is a FileSearchManager
|
1111
|
+
await manager.cleanup_resources() # type: ignore[attr-defined]
|
1112
|
+
logger.debug("Cleaned up File Search vector stores and files")
|
1113
|
+
except Exception as e:
|
1114
|
+
logger.warning(
|
1115
|
+
f"Failed to clean up File Search resources: {e}"
|
1116
|
+
)
|
1117
|
+
|
1118
|
+
# Clean up service container
|
1119
|
+
try:
|
1120
|
+
await services.cleanup()
|
1121
|
+
logger.debug("Cleaned up service container")
|
1122
|
+
except Exception as e:
|
1123
|
+
logger.warning(f"Failed to clean up service container: {e}")
|
1124
|
+
|
1125
|
+
await client.close()
|
1126
|
+
|
1127
|
+
|
1128
|
+
async def run_cli_async(args: CLIParams) -> ExitCode:
|
1129
|
+
"""Async wrapper for CLI operations.
|
1130
|
+
|
1131
|
+
Args:
|
1132
|
+
args: CLI parameters.
|
1133
|
+
|
1134
|
+
Returns:
|
1135
|
+
Exit code.
|
1136
|
+
|
1137
|
+
Raises:
|
1138
|
+
CLIError: For errors during CLI operations.
|
1139
|
+
"""
|
1140
|
+
try:
|
1141
|
+
# 0. Configure Debug Logging
|
1142
|
+
from .template_debug import configure_debug_logging
|
1143
|
+
|
1144
|
+
configure_debug_logging(
|
1145
|
+
verbose=bool(args.get("verbose", False)),
|
1146
|
+
debug=bool(args.get("debug", False))
|
1147
|
+
or bool(args.get("debug_templates", False)),
|
1148
|
+
)
|
1149
|
+
|
1150
|
+
# 0a. Handle Debug Help Request
|
1151
|
+
if args.get("help_debug", False):
|
1152
|
+
from .template_debug_help import show_template_debug_help
|
1153
|
+
|
1154
|
+
show_template_debug_help()
|
1155
|
+
return ExitCode.SUCCESS
|
1156
|
+
|
1157
|
+
# 0. Configure Progress Reporting
|
1158
|
+
configure_progress_reporter(
|
1159
|
+
verbose=args.get("verbose", False),
|
1160
|
+
progress_level=args.get("progress_level", "basic"),
|
1161
|
+
)
|
1162
|
+
progress_reporter = get_progress_reporter()
|
1163
|
+
|
1164
|
+
# 0. Model Parameter Validation
|
1165
|
+
progress_reporter.report_phase("Validating configuration", "🔧")
|
1166
|
+
logger.debug("=== Model Parameter Validation ===")
|
1167
|
+
# Import here to avoid circular dependency
|
1168
|
+
from .model_validation import validate_model_params
|
1169
|
+
|
1170
|
+
params = await validate_model_params(args)
|
1171
|
+
|
1172
|
+
# 1. Input Validation Phase (includes schema validation)
|
1173
|
+
progress_reporter.report_phase("Processing input files", "📂")
|
1174
|
+
# Import here to avoid circular dependency
|
1175
|
+
from .validators import validate_inputs
|
1176
|
+
|
1177
|
+
(
|
1178
|
+
security_manager,
|
1179
|
+
task_template,
|
1180
|
+
schema,
|
1181
|
+
template_context,
|
1182
|
+
env,
|
1183
|
+
template_path,
|
1184
|
+
) = await validate_inputs(args)
|
1185
|
+
|
1186
|
+
# Report file routing decisions
|
1187
|
+
routing_result = args.get("_routing_result")
|
1188
|
+
if routing_result is not None and not isinstance(
|
1189
|
+
routing_result, ProcessingResult
|
1190
|
+
):
|
1191
|
+
routing_result = None # Invalid type, treat as None
|
1192
|
+
routing_result_typed: Optional[ProcessingResult] = routing_result
|
1193
|
+
if routing_result_typed:
|
1194
|
+
template_files = routing_result_typed.validated_files.get(
|
1195
|
+
"template", []
|
1196
|
+
)
|
1197
|
+
container_files = routing_result_typed.validated_files.get(
|
1198
|
+
"code-interpreter", []
|
1199
|
+
)
|
1200
|
+
vector_files = routing_result_typed.validated_files.get(
|
1201
|
+
"file-search", []
|
1202
|
+
)
|
1203
|
+
progress_reporter.report_file_routing(
|
1204
|
+
template_files, container_files, vector_files
|
1205
|
+
)
|
1206
|
+
|
1207
|
+
# 2. Template Processing Phase
|
1208
|
+
progress_reporter.report_phase("Rendering template", "📝")
|
1209
|
+
system_prompt, user_prompt = await process_templates(
|
1210
|
+
args, task_template, template_context, env, template_path or ""
|
1211
|
+
)
|
1212
|
+
|
1213
|
+
# 3. Model & Schema Validation Phase
|
1214
|
+
progress_reporter.report_phase("Validating model and schema", "✅")
|
1215
|
+
# Import here to avoid circular dependency
|
1216
|
+
from .model_validation import validate_model_and_schema
|
1217
|
+
|
1218
|
+
(
|
1219
|
+
output_model,
|
1220
|
+
messages,
|
1221
|
+
total_tokens,
|
1222
|
+
registry,
|
1223
|
+
) = await validate_model_and_schema(
|
1224
|
+
args,
|
1225
|
+
schema,
|
1226
|
+
system_prompt,
|
1227
|
+
user_prompt,
|
1228
|
+
template_context,
|
1229
|
+
)
|
1230
|
+
|
1231
|
+
# Report validation results
|
1232
|
+
if registry is not None:
|
1233
|
+
capabilities = registry.get_capabilities(args["model"])
|
1234
|
+
progress_reporter.report_validation_results(
|
1235
|
+
schema_valid=True, # If we got here, schema is valid
|
1236
|
+
template_valid=True, # If we got here, template is valid
|
1237
|
+
token_count=total_tokens,
|
1238
|
+
token_limit=capabilities.context_window,
|
1239
|
+
)
|
1240
|
+
else:
|
1241
|
+
# Fallback for test environments where registry might be None
|
1242
|
+
progress_reporter.report_validation_results(
|
1243
|
+
schema_valid=True, # If we got here, schema is valid
|
1244
|
+
template_valid=True, # If we got here, template is valid
|
1245
|
+
token_count=total_tokens,
|
1246
|
+
token_limit=128000, # Default fallback
|
1247
|
+
)
|
1248
|
+
|
1249
|
+
# 3a. Web Search Compatibility Validation
|
1250
|
+
if args.get("web_search", False) and not args.get(
|
1251
|
+
"no_web_search", False
|
1252
|
+
):
|
1253
|
+
from .model_validation import validate_web_search_compatibility
|
1254
|
+
|
1255
|
+
compatibility_warning = validate_web_search_compatibility(
|
1256
|
+
args["model"], True
|
1257
|
+
)
|
1258
|
+
if compatibility_warning:
|
1259
|
+
logger.warning(compatibility_warning)
|
1260
|
+
# For production usage, consider making this an error instead of warning
|
1261
|
+
# raise CLIError(compatibility_warning, exit_code=ExitCode.VALIDATION_ERROR)
|
1262
|
+
|
1263
|
+
# 4. Dry Run Output Phase - Moved after all validations
|
1264
|
+
if args.get("dry_run", False):
|
1265
|
+
report_success(
|
1266
|
+
"Dry run completed successfully - all validations passed"
|
1267
|
+
)
|
1268
|
+
|
1269
|
+
# Calculate cost estimate
|
1270
|
+
if registry is not None:
|
1271
|
+
capabilities = registry.get_capabilities(args["model"])
|
1272
|
+
estimated_cost = calculate_cost_estimate(
|
1273
|
+
model=args["model"],
|
1274
|
+
input_tokens=total_tokens,
|
1275
|
+
output_tokens=capabilities.max_output_tokens,
|
1276
|
+
registry=registry,
|
1277
|
+
)
|
1278
|
+
|
1279
|
+
# Enhanced dry-run output with cost estimation
|
1280
|
+
cost_breakdown = format_cost_breakdown(
|
1281
|
+
model=args["model"],
|
1282
|
+
input_tokens=total_tokens,
|
1283
|
+
output_tokens=capabilities.max_output_tokens,
|
1284
|
+
total_cost=estimated_cost,
|
1285
|
+
context_window=capabilities.context_window,
|
1286
|
+
)
|
1287
|
+
else:
|
1288
|
+
# Fallback for test environments
|
1289
|
+
cost_breakdown = f"Token Analysis\nModel: {args['model']}\nInput tokens: {total_tokens}\nRegistry not available in test environment"
|
1290
|
+
print(cost_breakdown)
|
1291
|
+
|
1292
|
+
# Show template content based on debug flags
|
1293
|
+
from .template_debug import show_template_content
|
1294
|
+
|
1295
|
+
show_template_content(
|
1296
|
+
system_prompt=system_prompt,
|
1297
|
+
user_prompt=user_prompt,
|
1298
|
+
show_templates=bool(args.get("show_templates", False)),
|
1299
|
+
debug=bool(args.get("debug", False))
|
1300
|
+
or bool(args.get("debug_templates", False)),
|
1301
|
+
)
|
1302
|
+
|
1303
|
+
# Legacy verbose support for backward compatibility
|
1304
|
+
if (
|
1305
|
+
args.get("verbose", False)
|
1306
|
+
and not args.get("debug", False)
|
1307
|
+
and not args.get("show_templates", False)
|
1308
|
+
):
|
1309
|
+
logger.info("\nSystem Prompt:")
|
1310
|
+
logger.info("-" * 40)
|
1311
|
+
logger.info(system_prompt)
|
1312
|
+
logger.info("\nRendered Template:")
|
1313
|
+
logger.info("-" * 40)
|
1314
|
+
logger.info(user_prompt)
|
1315
|
+
|
1316
|
+
# Return success only if we got here (no validation errors)
|
1317
|
+
return ExitCode.SUCCESS
|
1318
|
+
|
1319
|
+
# 5. Execution Phase
|
1320
|
+
progress_reporter.report_phase("Generating response", "🤖")
|
1321
|
+
return await execute_model(
|
1322
|
+
args, params, output_model, system_prompt, user_prompt
|
1323
|
+
)
|
1324
|
+
|
1325
|
+
except KeyboardInterrupt:
|
1326
|
+
logger.info("Operation cancelled by user")
|
1327
|
+
raise
|
1328
|
+
except SchemaValidationError as e:
|
1329
|
+
# Ensure schema validation errors are properly propagated with the correct exit code
|
1330
|
+
logger.error("Schema validation error: %s", str(e))
|
1331
|
+
raise # Re-raise the SchemaValidationError to preserve the error chain
|
1332
|
+
except Exception as e:
|
1333
|
+
if isinstance(e, CLIError):
|
1334
|
+
raise # Let our custom errors propagate
|
1335
|
+
logger.exception("Unexpected error")
|
1336
|
+
raise CLIError(str(e), context={"error_type": type(e).__name__})
|
1337
|
+
|
1338
|
+
|
1339
|
+
class OstructRunner:
|
1340
|
+
"""Clean interface for running ostruct operations.
|
1341
|
+
|
1342
|
+
This class encapsulates the execution logic and provides a clean,
|
1343
|
+
testable interface for running ostruct operations.
|
1344
|
+
"""
|
1345
|
+
|
1346
|
+
def __init__(self, args: CLIParams):
|
1347
|
+
"""Initialize the runner with CLI parameters.
|
1348
|
+
|
1349
|
+
Args:
|
1350
|
+
args: CLI parameters dictionary
|
1351
|
+
"""
|
1352
|
+
self.args = args
|
1353
|
+
|
1354
|
+
async def run(self) -> ExitCode:
|
1355
|
+
"""Main execution entry point.
|
1356
|
+
|
1357
|
+
Returns:
|
1358
|
+
Exit code indicating success or failure
|
1359
|
+
|
1360
|
+
Raises:
|
1361
|
+
CLIError: For errors during CLI operations
|
1362
|
+
"""
|
1363
|
+
return await run_cli_async(self.args)
|
1364
|
+
|
1365
|
+
async def validate_only(self) -> ExitCode:
|
1366
|
+
"""Run validation without executing the model.
|
1367
|
+
|
1368
|
+
This runs all validation phases and returns without making
|
1369
|
+
API calls. Useful for dry runs and validation testing.
|
1370
|
+
|
1371
|
+
Returns:
|
1372
|
+
Exit code indicating validation success or failure
|
1373
|
+
"""
|
1374
|
+
# Create a copy of args with dry_run enabled
|
1375
|
+
validation_args = dict(self.args)
|
1376
|
+
validation_args["dry_run"] = True
|
1377
|
+
return await run_cli_async(validation_args) # type: ignore[arg-type]
|
1378
|
+
|
1379
|
+
async def execute_with_validation(self) -> ExitCode:
|
1380
|
+
"""Run with full validation and execution.
|
1381
|
+
|
1382
|
+
This is the standard execution path that includes all
|
1383
|
+
validation phases followed by model execution.
|
1384
|
+
|
1385
|
+
Returns:
|
1386
|
+
Exit code indicating success or failure
|
1387
|
+
"""
|
1388
|
+
# Ensure dry_run is disabled for full execution
|
1389
|
+
execution_args = dict(self.args)
|
1390
|
+
execution_args["dry_run"] = False
|
1391
|
+
return await run_cli_async(execution_args) # type: ignore[arg-type]
|
1392
|
+
|
1393
|
+
def get_configuration_summary(self) -> Dict[str, Any]:
|
1394
|
+
"""Get a summary of the current configuration.
|
1395
|
+
|
1396
|
+
Returns:
|
1397
|
+
Dictionary containing configuration information
|
1398
|
+
"""
|
1399
|
+
return {
|
1400
|
+
"model": self.args.get("model"),
|
1401
|
+
"dry_run": self.args.get("dry_run", False),
|
1402
|
+
"verbose": self.args.get("verbose", False),
|
1403
|
+
"mcp_servers": len(self.args.get("mcp_servers", [])),
|
1404
|
+
"code_interpreter_enabled": bool(
|
1405
|
+
self.args.get("code_interpreter_files")
|
1406
|
+
or self.args.get("code_interpreter_dirs")
|
1407
|
+
),
|
1408
|
+
"file_search_enabled": bool(
|
1409
|
+
self.args.get("file_search_files")
|
1410
|
+
or self.args.get("file_search_dirs")
|
1411
|
+
),
|
1412
|
+
"template_source": (
|
1413
|
+
"file" if self.args.get("task_file") else "string"
|
1414
|
+
),
|
1415
|
+
"schema_source": (
|
1416
|
+
"file" if self.args.get("schema_file") else "inline"
|
1417
|
+
),
|
1418
|
+
}
|