ostruct-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ostruct/__init__.py +0 -0
- ostruct/cli/__init__.py +19 -0
- ostruct/cli/cache_manager.py +175 -0
- ostruct/cli/cli.py +2033 -0
- ostruct/cli/errors.py +329 -0
- ostruct/cli/file_info.py +316 -0
- ostruct/cli/file_list.py +151 -0
- ostruct/cli/file_utils.py +518 -0
- ostruct/cli/path_utils.py +123 -0
- ostruct/cli/progress.py +105 -0
- ostruct/cli/security.py +311 -0
- ostruct/cli/security_types.py +49 -0
- ostruct/cli/template_env.py +55 -0
- ostruct/cli/template_extensions.py +51 -0
- ostruct/cli/template_filters.py +650 -0
- ostruct/cli/template_io.py +261 -0
- ostruct/cli/template_rendering.py +347 -0
- ostruct/cli/template_schema.py +565 -0
- ostruct/cli/template_utils.py +288 -0
- ostruct/cli/template_validation.py +375 -0
- ostruct/cli/utils.py +31 -0
- ostruct/py.typed +0 -0
- ostruct_cli-0.1.0.dist-info/LICENSE +21 -0
- ostruct_cli-0.1.0.dist-info/METADATA +182 -0
- ostruct_cli-0.1.0.dist-info/RECORD +27 -0
- ostruct_cli-0.1.0.dist-info/WHEEL +4 -0
- ostruct_cli-0.1.0.dist-info/entry_points.txt +3 -0
ostruct/cli/cli.py
ADDED
@@ -0,0 +1,2033 @@
|
|
1
|
+
"""Command-line interface for making structured OpenAI API calls."""
|
2
|
+
|
3
|
+
import argparse
|
4
|
+
import asyncio
|
5
|
+
import json
|
6
|
+
import logging
|
7
|
+
import os
|
8
|
+
import sys
|
9
|
+
from enum import Enum, IntEnum
|
10
|
+
|
11
|
+
if sys.version_info >= (3, 11):
|
12
|
+
from enum import StrEnum
|
13
|
+
|
14
|
+
from datetime import date, datetime, time
|
15
|
+
from importlib.metadata import version
|
16
|
+
from pathlib import Path
|
17
|
+
from typing import (
|
18
|
+
Any,
|
19
|
+
Dict,
|
20
|
+
List,
|
21
|
+
Literal,
|
22
|
+
Optional,
|
23
|
+
Set,
|
24
|
+
Tuple,
|
25
|
+
Type,
|
26
|
+
TypeVar,
|
27
|
+
Union,
|
28
|
+
cast,
|
29
|
+
get_origin,
|
30
|
+
overload,
|
31
|
+
)
|
32
|
+
|
33
|
+
import jinja2
|
34
|
+
import tiktoken
|
35
|
+
import yaml
|
36
|
+
from openai import (
|
37
|
+
APIConnectionError,
|
38
|
+
AsyncOpenAI,
|
39
|
+
AuthenticationError,
|
40
|
+
BadRequestError,
|
41
|
+
InternalServerError,
|
42
|
+
RateLimitError,
|
43
|
+
)
|
44
|
+
from openai_structured.client import (
|
45
|
+
async_openai_structured_stream,
|
46
|
+
supports_structured_output,
|
47
|
+
)
|
48
|
+
from openai_structured.errors import (
|
49
|
+
APIResponseError,
|
50
|
+
EmptyResponseError,
|
51
|
+
InvalidResponseFormatError,
|
52
|
+
ModelNotSupportedError,
|
53
|
+
ModelVersionError,
|
54
|
+
OpenAIClientError,
|
55
|
+
SchemaFileError,
|
56
|
+
SchemaValidationError,
|
57
|
+
StreamBufferError,
|
58
|
+
StreamInterruptedError,
|
59
|
+
StreamParseError,
|
60
|
+
)
|
61
|
+
from pydantic import (
|
62
|
+
AnyUrl,
|
63
|
+
BaseModel,
|
64
|
+
ConfigDict,
|
65
|
+
EmailStr,
|
66
|
+
Field,
|
67
|
+
ValidationError,
|
68
|
+
create_model,
|
69
|
+
)
|
70
|
+
from pydantic.fields import FieldInfo as FieldInfoType
|
71
|
+
from pydantic.functional_validators import BeforeValidator
|
72
|
+
from pydantic.types import constr
|
73
|
+
from typing_extensions import TypeAlias
|
74
|
+
|
75
|
+
from .errors import (
|
76
|
+
DirectoryNotFoundError,
|
77
|
+
FieldDefinitionError,
|
78
|
+
FileNotFoundError,
|
79
|
+
InvalidJSONError,
|
80
|
+
ModelCreationError,
|
81
|
+
ModelValidationError,
|
82
|
+
NestedModelError,
|
83
|
+
PathSecurityError,
|
84
|
+
TaskTemplateSyntaxError,
|
85
|
+
TaskTemplateVariableError,
|
86
|
+
VariableError,
|
87
|
+
VariableNameError,
|
88
|
+
VariableValueError,
|
89
|
+
)
|
90
|
+
from .file_utils import FileInfoList, TemplateValue, collect_files
|
91
|
+
from .path_utils import validate_path_mapping
|
92
|
+
from .progress import ProgressContext
|
93
|
+
from .security import SecurityManager
|
94
|
+
from .template_env import create_jinja_env
|
95
|
+
from .template_utils import SystemPromptError, render_template
|
96
|
+
|
97
|
+
# Set up logging
|
98
|
+
logger = logging.getLogger(__name__)
|
99
|
+
|
100
|
+
# Configure openai_structured logging based on debug flag
|
101
|
+
openai_logger = logging.getLogger("openai_structured")
|
102
|
+
openai_logger.setLevel(logging.DEBUG) # Allow all messages through to handlers
|
103
|
+
openai_logger.propagate = False # Prevent propagation to root logger
|
104
|
+
|
105
|
+
# Remove any existing handlers
|
106
|
+
for handler in openai_logger.handlers:
|
107
|
+
openai_logger.removeHandler(handler)
|
108
|
+
|
109
|
+
# Create a file handler for openai_structured logger that captures all levels
|
110
|
+
log_dir = os.path.expanduser("~/.ostruct/logs")
|
111
|
+
os.makedirs(log_dir, exist_ok=True)
|
112
|
+
openai_file_handler = logging.FileHandler(
|
113
|
+
os.path.join(log_dir, "openai_stream.log")
|
114
|
+
)
|
115
|
+
openai_file_handler.setLevel(logging.DEBUG) # Always capture debug in file
|
116
|
+
openai_file_handler.setFormatter(
|
117
|
+
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
118
|
+
)
|
119
|
+
openai_logger.addHandler(openai_file_handler)
|
120
|
+
|
121
|
+
# Create a file handler for the main logger that captures all levels
|
122
|
+
ostruct_file_handler = logging.FileHandler(
|
123
|
+
os.path.join(log_dir, "ostruct.log")
|
124
|
+
)
|
125
|
+
ostruct_file_handler.setLevel(logging.DEBUG) # Always capture debug in file
|
126
|
+
ostruct_file_handler.setFormatter(
|
127
|
+
logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
128
|
+
)
|
129
|
+
logger.addHandler(ostruct_file_handler)
|
130
|
+
|
131
|
+
# Constants
|
132
|
+
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
|
133
|
+
|
134
|
+
# Get package version
|
135
|
+
try:
|
136
|
+
__version__ = version("openai-structured")
|
137
|
+
except Exception:
|
138
|
+
__version__ = "unknown"
|
139
|
+
|
140
|
+
|
141
|
+
class ExitCode(IntEnum):
|
142
|
+
"""Exit codes for the CLI following standard Unix conventions.
|
143
|
+
|
144
|
+
Categories:
|
145
|
+
- Success (0-1)
|
146
|
+
- User Interruption (2-3)
|
147
|
+
- Input/Validation (64-69)
|
148
|
+
- I/O and File Access (70-79)
|
149
|
+
- API and External Services (80-89)
|
150
|
+
- Internal Errors (90-99)
|
151
|
+
"""
|
152
|
+
|
153
|
+
# Success codes
|
154
|
+
SUCCESS = 0
|
155
|
+
|
156
|
+
# User interruption
|
157
|
+
INTERRUPTED = 2
|
158
|
+
|
159
|
+
# Input/Validation errors (64-69)
|
160
|
+
USAGE_ERROR = 64
|
161
|
+
DATA_ERROR = 65
|
162
|
+
SCHEMA_ERROR = 66
|
163
|
+
VALIDATION_ERROR = 67
|
164
|
+
|
165
|
+
# I/O and File Access errors (70-79)
|
166
|
+
IO_ERROR = 70
|
167
|
+
FILE_NOT_FOUND = 71
|
168
|
+
PERMISSION_ERROR = 72
|
169
|
+
SECURITY_ERROR = 73
|
170
|
+
|
171
|
+
# API and External Service errors (80-89)
|
172
|
+
API_ERROR = 80
|
173
|
+
API_TIMEOUT = 81
|
174
|
+
|
175
|
+
# Internal errors (90-99)
|
176
|
+
INTERNAL_ERROR = 90
|
177
|
+
UNKNOWN_ERROR = 91
|
178
|
+
|
179
|
+
|
180
|
+
# Type aliases
|
181
|
+
FieldType = (
|
182
|
+
Any # Changed from Type[Any] to allow both concrete types and generics
|
183
|
+
)
|
184
|
+
FieldDefinition = Tuple[FieldType, FieldInfoType]
|
185
|
+
ModelType = TypeVar("ModelType", bound=BaseModel)
|
186
|
+
ItemType: TypeAlias = Type[BaseModel]
|
187
|
+
ValueType: TypeAlias = Type[Any]
|
188
|
+
|
189
|
+
|
190
|
+
def is_container_type(tp: Type[Any]) -> bool:
|
191
|
+
"""Check if a type is a container type (list, dict, etc.)."""
|
192
|
+
origin = get_origin(tp)
|
193
|
+
return origin in (list, dict)
|
194
|
+
|
195
|
+
|
196
|
+
def _create_field(**kwargs: Any) -> FieldInfoType:
|
197
|
+
"""Create a Pydantic Field with the given kwargs."""
|
198
|
+
field: FieldInfoType = Field(**kwargs)
|
199
|
+
return field
|
200
|
+
|
201
|
+
|
202
|
+
def _get_type_with_constraints(
|
203
|
+
field_schema: Dict[str, Any], field_name: str, base_name: str
|
204
|
+
) -> FieldDefinition:
|
205
|
+
"""Get type with constraints from field schema.
|
206
|
+
|
207
|
+
Args:
|
208
|
+
field_schema: Field schema dict
|
209
|
+
field_name: Name of the field
|
210
|
+
base_name: Base name for nested models
|
211
|
+
|
212
|
+
Returns:
|
213
|
+
Tuple of (type, field)
|
214
|
+
"""
|
215
|
+
field_type = field_schema.get("type")
|
216
|
+
field_kwargs: Dict[str, Any] = {}
|
217
|
+
|
218
|
+
# Add common field metadata
|
219
|
+
if "title" in field_schema:
|
220
|
+
field_kwargs["title"] = field_schema["title"]
|
221
|
+
if "description" in field_schema:
|
222
|
+
field_kwargs["description"] = field_schema["description"]
|
223
|
+
if "default" in field_schema:
|
224
|
+
field_kwargs["default"] = field_schema["default"]
|
225
|
+
if "readOnly" in field_schema:
|
226
|
+
field_kwargs["frozen"] = field_schema["readOnly"]
|
227
|
+
|
228
|
+
# Handle array type
|
229
|
+
if field_type == "array":
|
230
|
+
items_schema = field_schema.get("items", {})
|
231
|
+
if not items_schema:
|
232
|
+
return (List[Any], Field(**field_kwargs))
|
233
|
+
|
234
|
+
# Create nested model with explicit type annotation
|
235
|
+
array_item_model = create_dynamic_model(
|
236
|
+
items_schema,
|
237
|
+
base_name=f"{base_name}_{field_name}_Item",
|
238
|
+
show_schema=False,
|
239
|
+
debug_validation=False,
|
240
|
+
)
|
241
|
+
array_type: Type[List[Any]] = List[array_item_model] # type: ignore[valid-type]
|
242
|
+
return (array_type, Field(**field_kwargs))
|
243
|
+
|
244
|
+
# Handle object type
|
245
|
+
if field_type == "object":
|
246
|
+
# Create nested model with explicit type annotation
|
247
|
+
object_model = create_dynamic_model(
|
248
|
+
field_schema,
|
249
|
+
base_name=f"{base_name}_{field_name}",
|
250
|
+
show_schema=False,
|
251
|
+
debug_validation=False,
|
252
|
+
)
|
253
|
+
return (object_model, Field(**field_kwargs))
|
254
|
+
|
255
|
+
# Handle additionalProperties
|
256
|
+
if "additionalProperties" in field_schema and isinstance(
|
257
|
+
field_schema["additionalProperties"], dict
|
258
|
+
):
|
259
|
+
# Create nested model with explicit type annotation
|
260
|
+
dict_value_model = create_dynamic_model(
|
261
|
+
field_schema["additionalProperties"],
|
262
|
+
base_name=f"{base_name}_{field_name}_Value",
|
263
|
+
show_schema=False,
|
264
|
+
debug_validation=False,
|
265
|
+
)
|
266
|
+
dict_type: Type[Dict[str, Any]] = Dict[str, dict_value_model] # type: ignore[valid-type]
|
267
|
+
return (dict_type, Field(**field_kwargs))
|
268
|
+
|
269
|
+
# Handle other types
|
270
|
+
if field_type == "string":
|
271
|
+
field_type_cls: Type[Any] = str
|
272
|
+
|
273
|
+
# Add string-specific constraints to field_kwargs
|
274
|
+
if "pattern" in field_schema:
|
275
|
+
field_kwargs["pattern"] = field_schema["pattern"]
|
276
|
+
if "minLength" in field_schema:
|
277
|
+
field_kwargs["min_length"] = field_schema["minLength"]
|
278
|
+
if "maxLength" in field_schema:
|
279
|
+
field_kwargs["max_length"] = field_schema["maxLength"]
|
280
|
+
|
281
|
+
# Handle special string formats
|
282
|
+
if "format" in field_schema:
|
283
|
+
if field_schema["format"] == "date-time":
|
284
|
+
field_type_cls = datetime
|
285
|
+
elif field_schema["format"] == "date":
|
286
|
+
field_type_cls = date
|
287
|
+
elif field_schema["format"] == "time":
|
288
|
+
field_type_cls = time
|
289
|
+
elif field_schema["format"] == "email":
|
290
|
+
field_type_cls = EmailStr
|
291
|
+
elif field_schema["format"] == "uri":
|
292
|
+
field_type_cls = AnyUrl
|
293
|
+
|
294
|
+
return (field_type_cls, Field(**field_kwargs))
|
295
|
+
|
296
|
+
if field_type == "number":
|
297
|
+
field_type_cls = float
|
298
|
+
|
299
|
+
# Add number-specific constraints to field_kwargs
|
300
|
+
if "minimum" in field_schema:
|
301
|
+
field_kwargs["ge"] = field_schema["minimum"]
|
302
|
+
if "maximum" in field_schema:
|
303
|
+
field_kwargs["le"] = field_schema["maximum"]
|
304
|
+
if "exclusiveMinimum" in field_schema:
|
305
|
+
field_kwargs["gt"] = field_schema["exclusiveMinimum"]
|
306
|
+
if "exclusiveMaximum" in field_schema:
|
307
|
+
field_kwargs["lt"] = field_schema["exclusiveMaximum"]
|
308
|
+
if "multipleOf" in field_schema:
|
309
|
+
field_kwargs["multiple_of"] = field_schema["multipleOf"]
|
310
|
+
|
311
|
+
return (field_type_cls, Field(**field_kwargs))
|
312
|
+
|
313
|
+
if field_type == "integer":
|
314
|
+
field_type_cls = int
|
315
|
+
|
316
|
+
# Add integer-specific constraints to field_kwargs
|
317
|
+
if "minimum" in field_schema:
|
318
|
+
field_kwargs["ge"] = field_schema["minimum"]
|
319
|
+
if "maximum" in field_schema:
|
320
|
+
field_kwargs["le"] = field_schema["maximum"]
|
321
|
+
if "exclusiveMinimum" in field_schema:
|
322
|
+
field_kwargs["gt"] = field_schema["exclusiveMinimum"]
|
323
|
+
if "exclusiveMaximum" in field_schema:
|
324
|
+
field_kwargs["lt"] = field_schema["exclusiveMaximum"]
|
325
|
+
if "multipleOf" in field_schema:
|
326
|
+
field_kwargs["multiple_of"] = field_schema["multipleOf"]
|
327
|
+
|
328
|
+
return (field_type_cls, Field(**field_kwargs))
|
329
|
+
|
330
|
+
if field_type == "boolean":
|
331
|
+
return (bool, Field(**field_kwargs))
|
332
|
+
|
333
|
+
if field_type == "null":
|
334
|
+
return (type(None), Field(**field_kwargs))
|
335
|
+
|
336
|
+
# Handle enum
|
337
|
+
if "enum" in field_schema:
|
338
|
+
enum_type = _create_enum_type(field_schema["enum"], field_name)
|
339
|
+
return (cast(Type[Any], enum_type), Field(**field_kwargs))
|
340
|
+
|
341
|
+
# Default to Any for unknown types
|
342
|
+
return (Any, Field(**field_kwargs))
|
343
|
+
|
344
|
+
|
345
|
+
T = TypeVar("T")
|
346
|
+
K = TypeVar("K")
|
347
|
+
V = TypeVar("V")
|
348
|
+
|
349
|
+
|
350
|
+
def estimate_tokens_for_chat(
|
351
|
+
messages: List[Dict[str, str]], model: str
|
352
|
+
) -> int:
|
353
|
+
"""Estimate the number of tokens in a chat completion."""
|
354
|
+
try:
|
355
|
+
encoding = tiktoken.encoding_for_model(model)
|
356
|
+
except KeyError:
|
357
|
+
# Fall back to cl100k_base for unknown models
|
358
|
+
encoding = tiktoken.get_encoding("cl100k_base")
|
359
|
+
|
360
|
+
num_tokens = 0
|
361
|
+
for message in messages:
|
362
|
+
# Add message overhead
|
363
|
+
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
364
|
+
for key, value in message.items():
|
365
|
+
num_tokens += len(encoding.encode(str(value)))
|
366
|
+
if key == "name": # if there's a name, the role is omitted
|
367
|
+
num_tokens += -1 # role is always required and always 1 token
|
368
|
+
num_tokens += 2 # every reply is primed with <im_start>assistant
|
369
|
+
return num_tokens
|
370
|
+
|
371
|
+
|
372
|
+
def get_default_token_limit(model: str) -> int:
|
373
|
+
"""Get the default token limit for a given model.
|
374
|
+
|
375
|
+
Note: These limits are based on current OpenAI model specifications as of 2024 and may
|
376
|
+
need to be updated if OpenAI changes the models' capabilities.
|
377
|
+
|
378
|
+
Args:
|
379
|
+
model: The model name (e.g., 'gpt-4o', 'gpt-4o-mini', 'o1')
|
380
|
+
|
381
|
+
Returns:
|
382
|
+
The default token limit for the model
|
383
|
+
"""
|
384
|
+
if "o1" in model:
|
385
|
+
return 100_000 # o1 supports up to 100K output tokens
|
386
|
+
elif "gpt-4o" in model:
|
387
|
+
return 16_384 # gpt-4o and gpt-4o-mini support up to 16K output tokens
|
388
|
+
else:
|
389
|
+
return 4_096 # default fallback
|
390
|
+
|
391
|
+
|
392
|
+
def get_context_window_limit(model: str) -> int:
|
393
|
+
"""Get the total context window limit for a given model.
|
394
|
+
|
395
|
+
Note: These limits are based on current OpenAI model specifications as of 2024 and may
|
396
|
+
need to be updated if OpenAI changes the models' capabilities.
|
397
|
+
|
398
|
+
Args:
|
399
|
+
model: The model name (e.g., 'gpt-4o', 'gpt-4o-mini', 'o1')
|
400
|
+
|
401
|
+
Returns:
|
402
|
+
The context window limit for the model
|
403
|
+
"""
|
404
|
+
if "o1" in model:
|
405
|
+
return 200_000 # o1 supports 200K total context window
|
406
|
+
elif "gpt-4o" in model:
|
407
|
+
return 128_000 # gpt-4o and gpt-4o-mini support 128K context window
|
408
|
+
else:
|
409
|
+
return 8_192 # default fallback
|
410
|
+
|
411
|
+
|
412
|
+
def validate_token_limits(
|
413
|
+
model: str, total_tokens: int, max_token_limit: Optional[int] = None
|
414
|
+
) -> None:
|
415
|
+
"""Validate token counts against model limits.
|
416
|
+
|
417
|
+
Args:
|
418
|
+
model: The model name
|
419
|
+
total_tokens: Total number of tokens in the prompt
|
420
|
+
max_token_limit: Optional user-specified token limit
|
421
|
+
|
422
|
+
Raises:
|
423
|
+
ValueError: If token limits are exceeded
|
424
|
+
"""
|
425
|
+
context_limit = get_context_window_limit(model)
|
426
|
+
output_limit = (
|
427
|
+
max_token_limit
|
428
|
+
if max_token_limit is not None
|
429
|
+
else get_default_token_limit(model)
|
430
|
+
)
|
431
|
+
|
432
|
+
# Check if total tokens exceed context window
|
433
|
+
if total_tokens >= context_limit:
|
434
|
+
raise ValueError(
|
435
|
+
f"Total tokens ({total_tokens:,}) exceed model's context window limit "
|
436
|
+
f"of {context_limit:,} tokens"
|
437
|
+
)
|
438
|
+
|
439
|
+
# Check if there's enough room for output tokens
|
440
|
+
remaining_tokens = context_limit - total_tokens
|
441
|
+
if remaining_tokens < output_limit:
|
442
|
+
raise ValueError(
|
443
|
+
f"Only {remaining_tokens:,} tokens remaining in context window, but "
|
444
|
+
f"output may require up to {output_limit:,} tokens"
|
445
|
+
)
|
446
|
+
|
447
|
+
|
448
|
+
def process_system_prompt(
|
449
|
+
task_template: str,
|
450
|
+
system_prompt: Optional[str],
|
451
|
+
template_context: Dict[str, Any],
|
452
|
+
env: jinja2.Environment,
|
453
|
+
ignore_task_sysprompt: bool = False,
|
454
|
+
) -> str:
|
455
|
+
"""Process system prompt from various sources.
|
456
|
+
|
457
|
+
Args:
|
458
|
+
task_template: The task template string
|
459
|
+
system_prompt: Optional system prompt string or file path (with @ prefix)
|
460
|
+
template_context: Template context for rendering
|
461
|
+
env: Jinja2 environment
|
462
|
+
ignore_task_sysprompt: Whether to ignore system prompt in task template
|
463
|
+
|
464
|
+
Returns:
|
465
|
+
The final system prompt string
|
466
|
+
|
467
|
+
Raises:
|
468
|
+
SystemPromptError: If the system prompt cannot be loaded or rendered
|
469
|
+
FileNotFoundError: If a prompt file does not exist
|
470
|
+
PathSecurityError: If a prompt file path violates security constraints
|
471
|
+
"""
|
472
|
+
# Default system prompt
|
473
|
+
default_prompt = "You are a helpful assistant."
|
474
|
+
|
475
|
+
# Try to get system prompt from CLI argument first
|
476
|
+
if system_prompt:
|
477
|
+
if system_prompt.startswith("@"):
|
478
|
+
# Load from file
|
479
|
+
path = system_prompt[1:]
|
480
|
+
try:
|
481
|
+
name, path = validate_path_mapping(f"system_prompt={path}")
|
482
|
+
with open(path, "r", encoding="utf-8") as f:
|
483
|
+
system_prompt = f.read().strip()
|
484
|
+
except (FileNotFoundError, PathSecurityError) as e:
|
485
|
+
raise SystemPromptError(f"Invalid system prompt file: {e}")
|
486
|
+
|
487
|
+
# Render system prompt with template context
|
488
|
+
try:
|
489
|
+
template = env.from_string(system_prompt)
|
490
|
+
return cast(str, template.render(**template_context).strip())
|
491
|
+
except jinja2.TemplateError as e:
|
492
|
+
raise SystemPromptError(f"Error rendering system prompt: {e}")
|
493
|
+
|
494
|
+
# If not ignoring task template system prompt, try to extract it
|
495
|
+
if not ignore_task_sysprompt:
|
496
|
+
try:
|
497
|
+
# Extract YAML frontmatter
|
498
|
+
if task_template.startswith("---\n"):
|
499
|
+
end = task_template.find("\n---\n", 4)
|
500
|
+
if end != -1:
|
501
|
+
frontmatter = task_template[4:end]
|
502
|
+
try:
|
503
|
+
metadata = yaml.safe_load(frontmatter)
|
504
|
+
if (
|
505
|
+
isinstance(metadata, dict)
|
506
|
+
and "system_prompt" in metadata
|
507
|
+
):
|
508
|
+
system_prompt = str(metadata["system_prompt"])
|
509
|
+
# Render system prompt with template context
|
510
|
+
try:
|
511
|
+
template = env.from_string(system_prompt)
|
512
|
+
return cast(
|
513
|
+
str,
|
514
|
+
template.render(
|
515
|
+
**template_context
|
516
|
+
).strip(),
|
517
|
+
)
|
518
|
+
except jinja2.TemplateError as e:
|
519
|
+
raise SystemPromptError(
|
520
|
+
f"Error rendering system prompt: {e}"
|
521
|
+
)
|
522
|
+
except yaml.YAMLError as e:
|
523
|
+
raise SystemPromptError(
|
524
|
+
f"Invalid YAML frontmatter: {e}"
|
525
|
+
)
|
526
|
+
|
527
|
+
except Exception as e:
|
528
|
+
raise SystemPromptError(
|
529
|
+
f"Error extracting system prompt from template: {e}"
|
530
|
+
)
|
531
|
+
|
532
|
+
# Fall back to default
|
533
|
+
return default_prompt
|
534
|
+
|
535
|
+
|
536
|
+
def validate_variable_mapping(
|
537
|
+
mapping: str, is_json: bool = False
|
538
|
+
) -> tuple[str, Any]:
|
539
|
+
"""Validate a variable mapping in name=value format."""
|
540
|
+
try:
|
541
|
+
name, value = mapping.split("=", 1)
|
542
|
+
if not name:
|
543
|
+
raise VariableNameError(
|
544
|
+
f"Empty name in {'JSON ' if is_json else ''}variable mapping"
|
545
|
+
)
|
546
|
+
|
547
|
+
if is_json:
|
548
|
+
try:
|
549
|
+
value = json.loads(value)
|
550
|
+
except json.JSONDecodeError as e:
|
551
|
+
raise InvalidJSONError(
|
552
|
+
f"Invalid JSON value for variable {name!r}: {value!r}"
|
553
|
+
) from e
|
554
|
+
|
555
|
+
return name, value
|
556
|
+
|
557
|
+
except ValueError as e:
|
558
|
+
if "not enough values to unpack" in str(e):
|
559
|
+
raise VariableValueError(
|
560
|
+
f"Invalid {'JSON ' if is_json else ''}variable mapping "
|
561
|
+
f"(expected name=value format): {mapping!r}"
|
562
|
+
)
|
563
|
+
raise
|
564
|
+
|
565
|
+
|
566
|
+
@overload
|
567
|
+
def _validate_path_mapping_internal(
|
568
|
+
mapping: str,
|
569
|
+
is_dir: Literal[True],
|
570
|
+
base_dir: Optional[str] = None,
|
571
|
+
security_manager: Optional[SecurityManager] = None,
|
572
|
+
) -> Tuple[str, str]: ...
|
573
|
+
|
574
|
+
|
575
|
+
@overload
|
576
|
+
def _validate_path_mapping_internal(
|
577
|
+
mapping: str,
|
578
|
+
is_dir: Literal[False] = False,
|
579
|
+
base_dir: Optional[str] = None,
|
580
|
+
security_manager: Optional[SecurityManager] = None,
|
581
|
+
) -> Tuple[str, str]: ...
|
582
|
+
|
583
|
+
|
584
|
+
def _validate_path_mapping_internal(
|
585
|
+
mapping: str,
|
586
|
+
is_dir: bool = False,
|
587
|
+
base_dir: Optional[str] = None,
|
588
|
+
security_manager: Optional[SecurityManager] = None,
|
589
|
+
) -> Tuple[str, str]:
|
590
|
+
"""Validate a path mapping in the format "name=path".
|
591
|
+
|
592
|
+
Args:
|
593
|
+
mapping: The path mapping string (e.g., "myvar=/path/to/file").
|
594
|
+
is_dir: Whether the path is expected to be a directory (True) or file (False).
|
595
|
+
base_dir: Optional base directory to resolve relative paths against.
|
596
|
+
security_manager: Optional security manager to validate paths.
|
597
|
+
|
598
|
+
Returns:
|
599
|
+
A (name, path) tuple.
|
600
|
+
|
601
|
+
Raises:
|
602
|
+
VariableNameError: If the variable name portion is empty or invalid.
|
603
|
+
DirectoryNotFoundError: If is_dir=True and the path is not a directory or doesn't exist.
|
604
|
+
FileNotFoundError: If is_dir=False and the path is not a file or doesn't exist.
|
605
|
+
PathSecurityError: If the path is inaccessible or outside the allowed directory.
|
606
|
+
ValueError: If the format is invalid (missing "=").
|
607
|
+
OSError: If there is an underlying OS error (permissions, etc.).
|
608
|
+
"""
|
609
|
+
try:
|
610
|
+
if not mapping or "=" not in mapping:
|
611
|
+
raise ValueError(
|
612
|
+
"Invalid path mapping format. Expected format: name=path"
|
613
|
+
)
|
614
|
+
|
615
|
+
name, path = mapping.split("=", 1)
|
616
|
+
if not name:
|
617
|
+
raise VariableNameError(
|
618
|
+
f"Empty name in {'directory' if is_dir else 'file'} mapping"
|
619
|
+
)
|
620
|
+
|
621
|
+
if not path:
|
622
|
+
raise VariableValueError("Path cannot be empty")
|
623
|
+
|
624
|
+
# Convert to Path object and resolve against base_dir if provided
|
625
|
+
path_obj = Path(path)
|
626
|
+
if base_dir:
|
627
|
+
path_obj = Path(base_dir) / path_obj
|
628
|
+
|
629
|
+
# Resolve the path to catch directory traversal attempts
|
630
|
+
try:
|
631
|
+
resolved_path = path_obj.resolve()
|
632
|
+
except OSError as e:
|
633
|
+
raise OSError(f"Failed to resolve path: {e}")
|
634
|
+
|
635
|
+
# Check for directory traversal
|
636
|
+
try:
|
637
|
+
base_path = (
|
638
|
+
Path.cwd() if base_dir is None else Path(base_dir).resolve()
|
639
|
+
)
|
640
|
+
if not str(resolved_path).startswith(str(base_path)):
|
641
|
+
raise PathSecurityError(
|
642
|
+
f"Path {str(path)!r} resolves to {str(resolved_path)!r} which is outside "
|
643
|
+
f"base directory {str(base_path)!r}"
|
644
|
+
)
|
645
|
+
except OSError as e:
|
646
|
+
raise OSError(f"Failed to resolve base path: {e}")
|
647
|
+
|
648
|
+
# Check if path exists
|
649
|
+
if not resolved_path.exists():
|
650
|
+
if is_dir:
|
651
|
+
raise DirectoryNotFoundError(f"Directory not found: {path!r}")
|
652
|
+
else:
|
653
|
+
raise FileNotFoundError(f"File not found: {path!r}")
|
654
|
+
|
655
|
+
# Check if path is correct type
|
656
|
+
if is_dir and not resolved_path.is_dir():
|
657
|
+
raise DirectoryNotFoundError(f"Path is not a directory: {path!r}")
|
658
|
+
elif not is_dir and not resolved_path.is_file():
|
659
|
+
raise FileNotFoundError(f"Path is not a file: {path!r}")
|
660
|
+
|
661
|
+
# Check if path is accessible
|
662
|
+
try:
|
663
|
+
if is_dir:
|
664
|
+
os.listdir(str(resolved_path))
|
665
|
+
else:
|
666
|
+
with open(str(resolved_path), "r", encoding="utf-8") as f:
|
667
|
+
f.read(1)
|
668
|
+
except OSError as e:
|
669
|
+
if e.errno == 13: # Permission denied
|
670
|
+
raise PathSecurityError(
|
671
|
+
f"Permission denied accessing path: {path!r}",
|
672
|
+
error_logged=True,
|
673
|
+
)
|
674
|
+
raise
|
675
|
+
|
676
|
+
if security_manager:
|
677
|
+
if not security_manager.is_allowed_file(str(resolved_path)):
|
678
|
+
raise PathSecurityError.from_expanded_paths(
|
679
|
+
original_path=str(path),
|
680
|
+
expanded_path=str(resolved_path),
|
681
|
+
base_dir=str(security_manager.base_dir),
|
682
|
+
allowed_dirs=[
|
683
|
+
str(d) for d in security_manager.allowed_dirs
|
684
|
+
],
|
685
|
+
error_logged=True,
|
686
|
+
)
|
687
|
+
|
688
|
+
# Return the original path to maintain relative paths in the output
|
689
|
+
return name, path
|
690
|
+
|
691
|
+
except ValueError as e:
|
692
|
+
if "not enough values to unpack" in str(e):
|
693
|
+
raise VariableValueError(
|
694
|
+
f"Invalid {'directory' if is_dir else 'file'} mapping "
|
695
|
+
f"(expected name=path format): {mapping!r}"
|
696
|
+
)
|
697
|
+
raise
|
698
|
+
|
699
|
+
|
700
|
+
def validate_task_template(task: str) -> str:
|
701
|
+
"""Validate and load a task template.
|
702
|
+
|
703
|
+
Args:
|
704
|
+
task: The task template string or path to task template file (with @ prefix)
|
705
|
+
|
706
|
+
Returns:
|
707
|
+
The task template string
|
708
|
+
|
709
|
+
Raises:
|
710
|
+
TaskTemplateVariableError: If the template file cannot be read or is invalid
|
711
|
+
TaskTemplateSyntaxError: If the template has invalid syntax
|
712
|
+
FileNotFoundError: If the template file does not exist
|
713
|
+
PathSecurityError: If the template file path violates security constraints
|
714
|
+
"""
|
715
|
+
template_content = task
|
716
|
+
|
717
|
+
# Check if task is a file path
|
718
|
+
if task.startswith("@"):
|
719
|
+
path = task[1:]
|
720
|
+
try:
|
721
|
+
name, path = validate_path_mapping(f"task={path}")
|
722
|
+
with open(path, "r", encoding="utf-8") as f:
|
723
|
+
template_content = f.read()
|
724
|
+
except (FileNotFoundError, PathSecurityError) as e:
|
725
|
+
raise TaskTemplateVariableError(f"Invalid task template file: {e}")
|
726
|
+
|
727
|
+
# Validate template syntax
|
728
|
+
try:
|
729
|
+
env = jinja2.Environment(undefined=jinja2.StrictUndefined)
|
730
|
+
env.parse(template_content)
|
731
|
+
return template_content
|
732
|
+
except jinja2.TemplateSyntaxError as e:
|
733
|
+
raise TaskTemplateSyntaxError(
|
734
|
+
f"Invalid task template syntax at line {e.lineno}: {e.message}"
|
735
|
+
)
|
736
|
+
|
737
|
+
|
738
|
+
def validate_schema_file(
|
739
|
+
path: str,
|
740
|
+
verbose: bool = False,
|
741
|
+
) -> Dict[str, Any]:
|
742
|
+
"""Validate a JSON schema file.
|
743
|
+
|
744
|
+
Args:
|
745
|
+
path: Path to the schema file
|
746
|
+
verbose: Whether to enable verbose logging
|
747
|
+
|
748
|
+
Returns:
|
749
|
+
The validated schema
|
750
|
+
|
751
|
+
Raises:
|
752
|
+
SchemaFileError: When file cannot be read
|
753
|
+
InvalidJSONError: When file contains invalid JSON
|
754
|
+
SchemaValidationError: When schema is invalid
|
755
|
+
"""
|
756
|
+
if verbose:
|
757
|
+
logger.info("Validating schema file: %s", path)
|
758
|
+
|
759
|
+
try:
|
760
|
+
with open(path) as f:
|
761
|
+
schema = json.load(f)
|
762
|
+
except FileNotFoundError:
|
763
|
+
raise SchemaFileError(f"Schema file not found: {path}")
|
764
|
+
except json.JSONDecodeError as e:
|
765
|
+
raise InvalidJSONError(f"Invalid JSON in schema file: {e}")
|
766
|
+
except Exception as e:
|
767
|
+
raise SchemaFileError(f"Failed to read schema file: {e}")
|
768
|
+
|
769
|
+
# Pre-validation structure checks
|
770
|
+
if verbose:
|
771
|
+
logger.info("Performing pre-validation structure checks")
|
772
|
+
logger.debug("Loaded schema: %s", json.dumps(schema, indent=2))
|
773
|
+
|
774
|
+
if not isinstance(schema, dict):
|
775
|
+
if verbose:
|
776
|
+
logger.error(
|
777
|
+
"Schema is not a dictionary: %s", type(schema).__name__
|
778
|
+
)
|
779
|
+
raise SchemaValidationError("Schema must be a JSON object")
|
780
|
+
|
781
|
+
# Validate schema structure
|
782
|
+
if "schema" in schema:
|
783
|
+
if verbose:
|
784
|
+
logger.debug("Found schema wrapper, validating inner schema")
|
785
|
+
inner_schema = schema["schema"]
|
786
|
+
if not isinstance(inner_schema, dict):
|
787
|
+
if verbose:
|
788
|
+
logger.error(
|
789
|
+
"Inner schema is not a dictionary: %s",
|
790
|
+
type(inner_schema).__name__,
|
791
|
+
)
|
792
|
+
raise SchemaValidationError("Inner schema must be a JSON object")
|
793
|
+
if verbose:
|
794
|
+
logger.debug("Inner schema validated successfully")
|
795
|
+
else:
|
796
|
+
if verbose:
|
797
|
+
logger.debug("No schema wrapper found, using schema as-is")
|
798
|
+
|
799
|
+
# Return the full schema including wrapper
|
800
|
+
return schema
|
801
|
+
|
802
|
+
|
803
|
+
def collect_template_files(
|
804
|
+
args: argparse.Namespace,
|
805
|
+
security_manager: SecurityManager,
|
806
|
+
) -> Dict[str, TemplateValue]:
|
807
|
+
"""Collect files from command line arguments.
|
808
|
+
|
809
|
+
Args:
|
810
|
+
args: Parsed command line arguments
|
811
|
+
security_manager: Security manager for path validation
|
812
|
+
|
813
|
+
Returns:
|
814
|
+
Dictionary mapping variable names to file info objects
|
815
|
+
|
816
|
+
Raises:
|
817
|
+
PathSecurityError: If any file paths violate security constraints
|
818
|
+
ValueError: If file mappings are invalid or files cannot be accessed
|
819
|
+
"""
|
820
|
+
try:
|
821
|
+
result = collect_files(
|
822
|
+
file_mappings=args.file,
|
823
|
+
pattern_mappings=args.files,
|
824
|
+
dir_mappings=args.dir,
|
825
|
+
dir_recursive=args.dir_recursive,
|
826
|
+
dir_extensions=args.dir_ext.split(",") if args.dir_ext else None,
|
827
|
+
security_manager=security_manager,
|
828
|
+
)
|
829
|
+
return cast(Dict[str, TemplateValue], result)
|
830
|
+
except PathSecurityError:
|
831
|
+
# Let PathSecurityError propagate without wrapping
|
832
|
+
raise
|
833
|
+
except (FileNotFoundError, DirectoryNotFoundError) as e:
|
834
|
+
# Wrap file-related errors
|
835
|
+
raise ValueError(f"File access error: {e}")
|
836
|
+
except Exception as e:
|
837
|
+
# Check if this is a wrapped security error
|
838
|
+
if isinstance(e.__cause__, PathSecurityError):
|
839
|
+
raise e.__cause__
|
840
|
+
# Wrap unexpected errors
|
841
|
+
raise ValueError(f"Error collecting files: {e}")
|
842
|
+
|
843
|
+
|
844
|
+
def collect_simple_variables(args: argparse.Namespace) -> Dict[str, str]:
|
845
|
+
"""Collect simple string variables from --var arguments.
|
846
|
+
|
847
|
+
Args:
|
848
|
+
args: Parsed command line arguments
|
849
|
+
|
850
|
+
Returns:
|
851
|
+
Dictionary mapping variable names to string values
|
852
|
+
|
853
|
+
Raises:
|
854
|
+
VariableNameError: If a variable name is invalid or duplicate
|
855
|
+
"""
|
856
|
+
variables: Dict[str, str] = {}
|
857
|
+
all_names: Set[str] = set()
|
858
|
+
|
859
|
+
if args.var:
|
860
|
+
for mapping in args.var:
|
861
|
+
try:
|
862
|
+
name, value = mapping.split("=", 1)
|
863
|
+
if not name.isidentifier():
|
864
|
+
raise VariableNameError(f"Invalid variable name: {name}")
|
865
|
+
if name in all_names:
|
866
|
+
raise VariableNameError(f"Duplicate variable name: {name}")
|
867
|
+
variables[name] = value
|
868
|
+
all_names.add(name)
|
869
|
+
except ValueError:
|
870
|
+
raise VariableNameError(
|
871
|
+
f"Invalid variable mapping (expected name=value format): {mapping!r}"
|
872
|
+
)
|
873
|
+
|
874
|
+
return variables
|
875
|
+
|
876
|
+
|
877
|
+
def collect_json_variables(args: argparse.Namespace) -> Dict[str, Any]:
|
878
|
+
"""Collect JSON variables from --json-var arguments.
|
879
|
+
|
880
|
+
Args:
|
881
|
+
args: Parsed command line arguments
|
882
|
+
|
883
|
+
Returns:
|
884
|
+
Dictionary mapping variable names to parsed JSON values
|
885
|
+
|
886
|
+
Raises:
|
887
|
+
VariableNameError: If a variable name is invalid or duplicate
|
888
|
+
InvalidJSONError: If a JSON value is invalid
|
889
|
+
"""
|
890
|
+
variables: Dict[str, Any] = {}
|
891
|
+
all_names: Set[str] = set()
|
892
|
+
|
893
|
+
if args.json_var:
|
894
|
+
for mapping in args.json_var:
|
895
|
+
try:
|
896
|
+
name, json_str = mapping.split("=", 1)
|
897
|
+
if not name.isidentifier():
|
898
|
+
raise VariableNameError(f"Invalid variable name: {name}")
|
899
|
+
if name in all_names:
|
900
|
+
raise VariableNameError(f"Duplicate variable name: {name}")
|
901
|
+
try:
|
902
|
+
value = json.loads(json_str)
|
903
|
+
variables[name] = value
|
904
|
+
all_names.add(name)
|
905
|
+
except json.JSONDecodeError as e:
|
906
|
+
raise InvalidJSONError(
|
907
|
+
f"Invalid JSON value for {name}: {str(e)}"
|
908
|
+
)
|
909
|
+
except ValueError:
|
910
|
+
raise VariableNameError(
|
911
|
+
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
912
|
+
)
|
913
|
+
|
914
|
+
return variables
|
915
|
+
|
916
|
+
|
917
|
+
def create_template_context(
|
918
|
+
files: Optional[Dict[str, FileInfoList]] = None,
|
919
|
+
variables: Optional[Dict[str, str]] = None,
|
920
|
+
json_variables: Optional[Dict[str, Any]] = None,
|
921
|
+
security_manager: Optional[SecurityManager] = None,
|
922
|
+
stdin_content: Optional[str] = None,
|
923
|
+
) -> Dict[str, Any]:
|
924
|
+
"""Create template context from direct inputs.
|
925
|
+
|
926
|
+
Args:
|
927
|
+
files: Optional dictionary mapping names to FileInfoList objects
|
928
|
+
variables: Optional dictionary of simple string variables
|
929
|
+
json_variables: Optional dictionary of JSON variables
|
930
|
+
security_manager: Optional security manager for path validation
|
931
|
+
stdin_content: Optional content to use for stdin
|
932
|
+
|
933
|
+
Returns:
|
934
|
+
Template context dictionary
|
935
|
+
|
936
|
+
Raises:
|
937
|
+
PathSecurityError: If any file paths violate security constraints
|
938
|
+
VariableError: If variable mappings are invalid
|
939
|
+
"""
|
940
|
+
context: Dict[str, Any] = {}
|
941
|
+
|
942
|
+
# Add file variables
|
943
|
+
if files:
|
944
|
+
for name, file_list in files.items():
|
945
|
+
# For single files, extract the first FileInfo object
|
946
|
+
if len(file_list) == 1:
|
947
|
+
context[name] = file_list[0]
|
948
|
+
else:
|
949
|
+
context[name] = file_list
|
950
|
+
|
951
|
+
# Add simple variables
|
952
|
+
if variables:
|
953
|
+
context.update(variables)
|
954
|
+
|
955
|
+
# Add JSON variables
|
956
|
+
if json_variables:
|
957
|
+
context.update(json_variables)
|
958
|
+
|
959
|
+
# Add stdin if provided
|
960
|
+
if stdin_content is not None:
|
961
|
+
context["stdin"] = stdin_content
|
962
|
+
|
963
|
+
return context
|
964
|
+
|
965
|
+
|
966
|
+
def create_template_context_from_args(
|
967
|
+
args: argparse.Namespace,
|
968
|
+
security_manager: SecurityManager,
|
969
|
+
) -> Dict[str, Any]:
|
970
|
+
"""Create template context from command line arguments.
|
971
|
+
|
972
|
+
Args:
|
973
|
+
args: Parsed command line arguments
|
974
|
+
security_manager: Security manager for path validation
|
975
|
+
|
976
|
+
Returns:
|
977
|
+
Template context dictionary
|
978
|
+
|
979
|
+
Raises:
|
980
|
+
PathSecurityError: If any file paths violate security constraints
|
981
|
+
VariableError: If variable mappings are invalid
|
982
|
+
ValueError: If file mappings are invalid or files cannot be accessed
|
983
|
+
"""
|
984
|
+
try:
|
985
|
+
# Collect files from arguments
|
986
|
+
files = None
|
987
|
+
if any([args.file, args.files, args.dir]):
|
988
|
+
files = collect_files(
|
989
|
+
file_mappings=args.file,
|
990
|
+
pattern_mappings=args.files,
|
991
|
+
dir_mappings=args.dir,
|
992
|
+
dir_recursive=args.dir_recursive,
|
993
|
+
dir_extensions=(
|
994
|
+
args.dir_ext.split(",") if args.dir_ext else None
|
995
|
+
),
|
996
|
+
security_manager=security_manager,
|
997
|
+
)
|
998
|
+
|
999
|
+
# Collect simple variables
|
1000
|
+
try:
|
1001
|
+
variables = collect_simple_variables(args)
|
1002
|
+
except VariableNameError as e:
|
1003
|
+
raise VariableError(str(e))
|
1004
|
+
|
1005
|
+
# Collect JSON variables
|
1006
|
+
json_variables = {}
|
1007
|
+
if args.json_var:
|
1008
|
+
for mapping in args.json_var:
|
1009
|
+
try:
|
1010
|
+
name, value = mapping.split("=", 1)
|
1011
|
+
if not name.isidentifier():
|
1012
|
+
raise VariableNameError(
|
1013
|
+
f"Invalid variable name: {name}"
|
1014
|
+
)
|
1015
|
+
try:
|
1016
|
+
json_value = json.loads(value)
|
1017
|
+
except json.JSONDecodeError as e:
|
1018
|
+
raise InvalidJSONError(
|
1019
|
+
f"Invalid JSON value for {name} ({value!r}): {str(e)}"
|
1020
|
+
)
|
1021
|
+
if name in json_variables:
|
1022
|
+
raise VariableNameError(
|
1023
|
+
f"Duplicate variable name: {name}"
|
1024
|
+
)
|
1025
|
+
json_variables[name] = json_value
|
1026
|
+
except ValueError:
|
1027
|
+
raise VariableNameError(
|
1028
|
+
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
1029
|
+
)
|
1030
|
+
|
1031
|
+
# Get stdin content if available
|
1032
|
+
stdin_content = None
|
1033
|
+
try:
|
1034
|
+
if not sys.stdin.isatty():
|
1035
|
+
stdin_content = sys.stdin.read()
|
1036
|
+
except (OSError, IOError):
|
1037
|
+
# Skip stdin if it can't be read
|
1038
|
+
pass
|
1039
|
+
|
1040
|
+
return create_template_context(
|
1041
|
+
files=files,
|
1042
|
+
variables=variables,
|
1043
|
+
json_variables=json_variables,
|
1044
|
+
security_manager=security_manager,
|
1045
|
+
stdin_content=stdin_content,
|
1046
|
+
)
|
1047
|
+
|
1048
|
+
except PathSecurityError:
|
1049
|
+
# Let PathSecurityError propagate without wrapping
|
1050
|
+
raise
|
1051
|
+
except (FileNotFoundError, DirectoryNotFoundError) as e:
|
1052
|
+
# Wrap file-related errors
|
1053
|
+
raise ValueError(f"File access error: {e}")
|
1054
|
+
except Exception as e:
|
1055
|
+
# Check if this is a wrapped security error
|
1056
|
+
if isinstance(e.__cause__, PathSecurityError):
|
1057
|
+
raise e.__cause__
|
1058
|
+
# Wrap unexpected errors
|
1059
|
+
raise ValueError(f"Error collecting files: {e}")
|
1060
|
+
|
1061
|
+
|
1062
|
+
def validate_security_manager(
|
1063
|
+
base_dir: Optional[str] = None,
|
1064
|
+
allowed_dirs: Optional[List[str]] = None,
|
1065
|
+
allowed_dirs_file: Optional[str] = None,
|
1066
|
+
) -> SecurityManager:
|
1067
|
+
"""Create and validate a security manager.
|
1068
|
+
|
1069
|
+
Args:
|
1070
|
+
base_dir: Optional base directory to resolve paths against
|
1071
|
+
allowed_dirs: Optional list of allowed directory paths
|
1072
|
+
allowed_dirs_file: Optional path to file containing allowed directories
|
1073
|
+
|
1074
|
+
Returns:
|
1075
|
+
Configured SecurityManager instance
|
1076
|
+
|
1077
|
+
Raises:
|
1078
|
+
FileNotFoundError: If allowed_dirs_file does not exist
|
1079
|
+
PathSecurityError: If any paths are outside base directory
|
1080
|
+
"""
|
1081
|
+
# Convert base_dir to string if it's a Path
|
1082
|
+
base_dir_str = str(base_dir) if base_dir else None
|
1083
|
+
security_manager = SecurityManager(base_dir_str)
|
1084
|
+
|
1085
|
+
if allowed_dirs_file:
|
1086
|
+
security_manager.add_allowed_dirs_from_file(str(allowed_dirs_file))
|
1087
|
+
|
1088
|
+
if allowed_dirs:
|
1089
|
+
for allowed_dir in allowed_dirs:
|
1090
|
+
security_manager.add_allowed_dir(str(allowed_dir))
|
1091
|
+
|
1092
|
+
return security_manager
|
1093
|
+
|
1094
|
+
|
1095
|
+
def parse_var(var_str: str) -> Tuple[str, str]:
|
1096
|
+
"""Parse a simple variable string in the format 'name=value'.
|
1097
|
+
|
1098
|
+
Args:
|
1099
|
+
var_str: Variable string in format 'name=value'
|
1100
|
+
|
1101
|
+
Returns:
|
1102
|
+
Tuple of (name, value)
|
1103
|
+
|
1104
|
+
Raises:
|
1105
|
+
VariableNameError: If variable name is empty or invalid
|
1106
|
+
VariableValueError: If variable format is invalid
|
1107
|
+
"""
|
1108
|
+
try:
|
1109
|
+
name, value = var_str.split("=", 1)
|
1110
|
+
if not name:
|
1111
|
+
raise VariableNameError("Empty name in variable mapping")
|
1112
|
+
if not name.isidentifier():
|
1113
|
+
raise VariableNameError(
|
1114
|
+
f"Invalid variable name: {name}. Must be a valid Python identifier"
|
1115
|
+
)
|
1116
|
+
return name, value
|
1117
|
+
except ValueError as e:
|
1118
|
+
if "not enough values to unpack" in str(e):
|
1119
|
+
raise VariableValueError(
|
1120
|
+
f"Invalid variable mapping (expected name=value format): {var_str!r}"
|
1121
|
+
)
|
1122
|
+
raise
|
1123
|
+
|
1124
|
+
|
1125
|
+
def parse_json_var(var_str: str) -> Tuple[str, Any]:
|
1126
|
+
"""Parse a JSON variable string in the format 'name=json_value'.
|
1127
|
+
|
1128
|
+
Args:
|
1129
|
+
var_str: Variable string in format 'name=json_value'
|
1130
|
+
|
1131
|
+
Returns:
|
1132
|
+
Tuple of (name, parsed_value)
|
1133
|
+
|
1134
|
+
Raises:
|
1135
|
+
VariableNameError: If variable name is empty or invalid
|
1136
|
+
VariableValueError: If variable format is invalid
|
1137
|
+
InvalidJSONError: If JSON value is invalid
|
1138
|
+
"""
|
1139
|
+
try:
|
1140
|
+
name, json_str = var_str.split("=", 1)
|
1141
|
+
if not name:
|
1142
|
+
raise VariableNameError("Empty name in JSON variable mapping")
|
1143
|
+
if not name.isidentifier():
|
1144
|
+
raise VariableNameError(
|
1145
|
+
f"Invalid variable name: {name}. Must be a valid Python identifier"
|
1146
|
+
)
|
1147
|
+
|
1148
|
+
try:
|
1149
|
+
value = json.loads(json_str)
|
1150
|
+
except json.JSONDecodeError as e:
|
1151
|
+
raise InvalidJSONError(
|
1152
|
+
f"Invalid JSON value for variable {name!r}: {json_str!r}"
|
1153
|
+
) from e
|
1154
|
+
|
1155
|
+
return name, value
|
1156
|
+
|
1157
|
+
except ValueError as e:
|
1158
|
+
if "not enough values to unpack" in str(e):
|
1159
|
+
raise VariableValueError(
|
1160
|
+
f"Invalid JSON variable mapping (expected name=json format): {var_str!r}"
|
1161
|
+
)
|
1162
|
+
raise
|
1163
|
+
|
1164
|
+
|
1165
|
+
def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
|
1166
|
+
"""Create an enum type from a list of values.
|
1167
|
+
|
1168
|
+
Args:
|
1169
|
+
values: List of enum values
|
1170
|
+
field_name: Name of the field for enum type name
|
1171
|
+
|
1172
|
+
Returns:
|
1173
|
+
Created enum type
|
1174
|
+
"""
|
1175
|
+
# Determine the value type
|
1176
|
+
value_types = {type(v) for v in values}
|
1177
|
+
|
1178
|
+
if len(value_types) > 1:
|
1179
|
+
# Mixed types, use string representation
|
1180
|
+
enum_dict = {f"VALUE_{i}": str(v) for i, v in enumerate(values)}
|
1181
|
+
return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
|
1182
|
+
elif value_types == {int}:
|
1183
|
+
# All integer values
|
1184
|
+
enum_dict = {f"VALUE_{v}": v for v in values}
|
1185
|
+
return type(f"{field_name.title()}Enum", (IntEnum,), enum_dict)
|
1186
|
+
elif value_types == {str}:
|
1187
|
+
# All string values
|
1188
|
+
enum_dict = {v.upper().replace(" ", "_"): v for v in values}
|
1189
|
+
if sys.version_info >= (3, 11):
|
1190
|
+
return type(f"{field_name.title()}Enum", (StrEnum,), enum_dict)
|
1191
|
+
else:
|
1192
|
+
# Other types, use string representation
|
1193
|
+
return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
|
1194
|
+
|
1195
|
+
# Default case: treat as string enum
|
1196
|
+
enum_dict = {f"VALUE_{i}": str(v) for i, v in enumerate(values)}
|
1197
|
+
return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
|
1198
|
+
|
1199
|
+
|
1200
|
+
def create_argument_parser() -> argparse.ArgumentParser:
|
1201
|
+
"""Create argument parser for CLI."""
|
1202
|
+
parser = argparse.ArgumentParser(
|
1203
|
+
description="Make structured OpenAI API calls.",
|
1204
|
+
formatter_class=argparse.RawDescriptionHelpFormatter,
|
1205
|
+
)
|
1206
|
+
|
1207
|
+
# Debug output options
|
1208
|
+
debug_group = parser.add_argument_group("Debug Output Options")
|
1209
|
+
debug_group.add_argument(
|
1210
|
+
"--show-model-schema",
|
1211
|
+
action="store_true",
|
1212
|
+
help="Display the generated Pydantic model schema",
|
1213
|
+
)
|
1214
|
+
debug_group.add_argument(
|
1215
|
+
"--debug-validation",
|
1216
|
+
action="store_true",
|
1217
|
+
help="Show detailed schema validation debugging information",
|
1218
|
+
)
|
1219
|
+
debug_group.add_argument(
|
1220
|
+
"--verbose-schema",
|
1221
|
+
action="store_true",
|
1222
|
+
help="Enable verbose schema debugging output",
|
1223
|
+
)
|
1224
|
+
debug_group.add_argument(
|
1225
|
+
"--progress-level",
|
1226
|
+
choices=["none", "basic", "detailed"],
|
1227
|
+
default="basic",
|
1228
|
+
help="Set the level of progress reporting (default: basic)",
|
1229
|
+
)
|
1230
|
+
|
1231
|
+
# Required arguments
|
1232
|
+
parser.add_argument(
|
1233
|
+
"--task",
|
1234
|
+
required=True,
|
1235
|
+
help="Task template string or @file",
|
1236
|
+
)
|
1237
|
+
|
1238
|
+
# File access arguments
|
1239
|
+
parser.add_argument(
|
1240
|
+
"--file",
|
1241
|
+
action="append",
|
1242
|
+
default=[],
|
1243
|
+
help="Map file to variable (name=path)",
|
1244
|
+
metavar="NAME=PATH",
|
1245
|
+
)
|
1246
|
+
parser.add_argument(
|
1247
|
+
"--files",
|
1248
|
+
action="append",
|
1249
|
+
default=[],
|
1250
|
+
help="Map file pattern to variable (name=pattern)",
|
1251
|
+
metavar="NAME=PATTERN",
|
1252
|
+
)
|
1253
|
+
parser.add_argument(
|
1254
|
+
"--dir",
|
1255
|
+
action="append",
|
1256
|
+
default=[],
|
1257
|
+
help="Map directory to variable (name=path)",
|
1258
|
+
metavar="NAME=PATH",
|
1259
|
+
)
|
1260
|
+
parser.add_argument(
|
1261
|
+
"--allowed-dir",
|
1262
|
+
action="append",
|
1263
|
+
default=[],
|
1264
|
+
help="Additional allowed directory or @file",
|
1265
|
+
metavar="PATH",
|
1266
|
+
)
|
1267
|
+
parser.add_argument(
|
1268
|
+
"--base-dir",
|
1269
|
+
help="Base directory for file access (defaults to current directory)",
|
1270
|
+
default=os.getcwd(),
|
1271
|
+
)
|
1272
|
+
parser.add_argument(
|
1273
|
+
"--allowed-dirs-file",
|
1274
|
+
help="File containing list of allowed directories",
|
1275
|
+
)
|
1276
|
+
parser.add_argument(
|
1277
|
+
"--dir-recursive",
|
1278
|
+
action="store_true",
|
1279
|
+
help="Process directories recursively",
|
1280
|
+
)
|
1281
|
+
parser.add_argument(
|
1282
|
+
"--dir-ext",
|
1283
|
+
help="Comma-separated list of file extensions to include in directory processing",
|
1284
|
+
)
|
1285
|
+
|
1286
|
+
# Variable arguments
|
1287
|
+
parser.add_argument(
|
1288
|
+
"--var",
|
1289
|
+
action="append",
|
1290
|
+
default=[],
|
1291
|
+
help="Pass simple variables (name=value)",
|
1292
|
+
metavar="NAME=VALUE",
|
1293
|
+
)
|
1294
|
+
parser.add_argument(
|
1295
|
+
"--json-var",
|
1296
|
+
action="append",
|
1297
|
+
default=[],
|
1298
|
+
help="Pass JSON variables (name=json)",
|
1299
|
+
metavar="NAME=JSON",
|
1300
|
+
)
|
1301
|
+
|
1302
|
+
# System prompt options
|
1303
|
+
parser.add_argument(
|
1304
|
+
"--system-prompt",
|
1305
|
+
help=(
|
1306
|
+
"System prompt for the model (use @file to load from file, "
|
1307
|
+
"can also be specified in task template YAML frontmatter)"
|
1308
|
+
),
|
1309
|
+
default=DEFAULT_SYSTEM_PROMPT,
|
1310
|
+
)
|
1311
|
+
parser.add_argument(
|
1312
|
+
"--ignore-task-sysprompt",
|
1313
|
+
action="store_true",
|
1314
|
+
help="Ignore system prompt from task template YAML frontmatter",
|
1315
|
+
)
|
1316
|
+
|
1317
|
+
# Schema validation
|
1318
|
+
parser.add_argument(
|
1319
|
+
"--schema",
|
1320
|
+
dest="schema_file",
|
1321
|
+
required=True,
|
1322
|
+
help="JSON schema file for response validation",
|
1323
|
+
)
|
1324
|
+
parser.add_argument(
|
1325
|
+
"--validate-schema",
|
1326
|
+
action="store_true",
|
1327
|
+
help="Validate schema and response",
|
1328
|
+
)
|
1329
|
+
|
1330
|
+
# Model configuration
|
1331
|
+
parser.add_argument(
|
1332
|
+
"--model",
|
1333
|
+
default="gpt-4o-2024-08-06",
|
1334
|
+
help="Model to use",
|
1335
|
+
)
|
1336
|
+
parser.add_argument(
|
1337
|
+
"--temperature",
|
1338
|
+
type=float,
|
1339
|
+
default=0.0,
|
1340
|
+
help="Temperature (0.0-2.0)",
|
1341
|
+
)
|
1342
|
+
parser.add_argument(
|
1343
|
+
"--max-tokens",
|
1344
|
+
type=int,
|
1345
|
+
help="Maximum tokens to generate",
|
1346
|
+
)
|
1347
|
+
parser.add_argument(
|
1348
|
+
"--top-p",
|
1349
|
+
type=float,
|
1350
|
+
default=1.0,
|
1351
|
+
help="Top-p sampling (0.0-1.0)",
|
1352
|
+
)
|
1353
|
+
parser.add_argument(
|
1354
|
+
"--frequency-penalty",
|
1355
|
+
type=float,
|
1356
|
+
default=0.0,
|
1357
|
+
help="Frequency penalty (-2.0-2.0)",
|
1358
|
+
)
|
1359
|
+
parser.add_argument(
|
1360
|
+
"--presence-penalty",
|
1361
|
+
type=float,
|
1362
|
+
default=0.0,
|
1363
|
+
help="Presence penalty (-2.0-2.0)",
|
1364
|
+
)
|
1365
|
+
parser.add_argument(
|
1366
|
+
"--timeout",
|
1367
|
+
type=float,
|
1368
|
+
default=60.0,
|
1369
|
+
help="API timeout in seconds",
|
1370
|
+
)
|
1371
|
+
|
1372
|
+
# Output options
|
1373
|
+
parser.add_argument(
|
1374
|
+
"--output-file",
|
1375
|
+
help="Write JSON output to file",
|
1376
|
+
)
|
1377
|
+
parser.add_argument(
|
1378
|
+
"--dry-run",
|
1379
|
+
action="store_true",
|
1380
|
+
help="Simulate API call without making request",
|
1381
|
+
)
|
1382
|
+
parser.add_argument(
|
1383
|
+
"--no-progress",
|
1384
|
+
action="store_true",
|
1385
|
+
help="Disable progress indicators",
|
1386
|
+
)
|
1387
|
+
|
1388
|
+
# Other options
|
1389
|
+
parser.add_argument(
|
1390
|
+
"--api-key",
|
1391
|
+
help="OpenAI API key (overrides env var)",
|
1392
|
+
)
|
1393
|
+
parser.add_argument(
|
1394
|
+
"--verbose",
|
1395
|
+
action="store_true",
|
1396
|
+
help="Enable verbose output",
|
1397
|
+
)
|
1398
|
+
parser.add_argument(
|
1399
|
+
"--debug-openai-stream",
|
1400
|
+
action="store_true",
|
1401
|
+
help="Enable low-level debug output for OpenAI streaming (very verbose)",
|
1402
|
+
)
|
1403
|
+
parser.add_argument(
|
1404
|
+
"--version",
|
1405
|
+
action="version",
|
1406
|
+
version=f"%(prog)s {__version__}",
|
1407
|
+
)
|
1408
|
+
|
1409
|
+
return parser
|
1410
|
+
|
1411
|
+
|
1412
|
+
async def _main() -> ExitCode:
|
1413
|
+
"""Main CLI function.
|
1414
|
+
|
1415
|
+
Returns:
|
1416
|
+
ExitCode: Exit code indicating success or failure
|
1417
|
+
"""
|
1418
|
+
try:
|
1419
|
+
parser = create_argument_parser()
|
1420
|
+
args = parser.parse_args()
|
1421
|
+
|
1422
|
+
# Configure logging
|
1423
|
+
log_level = logging.DEBUG if args.verbose else logging.INFO
|
1424
|
+
logger.setLevel(log_level)
|
1425
|
+
|
1426
|
+
# Create security manager
|
1427
|
+
security_manager = validate_security_manager(
|
1428
|
+
base_dir=args.base_dir,
|
1429
|
+
allowed_dirs=args.allowed_dir,
|
1430
|
+
allowed_dirs_file=args.allowed_dirs_file,
|
1431
|
+
)
|
1432
|
+
|
1433
|
+
# Validate task template
|
1434
|
+
task_template = validate_task_template(args.task)
|
1435
|
+
|
1436
|
+
# Validate schema file
|
1437
|
+
schema = validate_schema_file(args.schema_file, args.verbose)
|
1438
|
+
|
1439
|
+
# Create template context
|
1440
|
+
template_context = create_template_context_from_args(
|
1441
|
+
args, security_manager
|
1442
|
+
)
|
1443
|
+
|
1444
|
+
# Create Jinja environment
|
1445
|
+
env = create_jinja_env()
|
1446
|
+
|
1447
|
+
# Process system prompt
|
1448
|
+
args.system_prompt = process_system_prompt(
|
1449
|
+
task_template,
|
1450
|
+
args.system_prompt,
|
1451
|
+
template_context,
|
1452
|
+
env,
|
1453
|
+
args.ignore_task_sysprompt,
|
1454
|
+
)
|
1455
|
+
|
1456
|
+
# Render task template
|
1457
|
+
rendered_task = render_template(task_template, template_context, env)
|
1458
|
+
logger.info(rendered_task) # Log the rendered template
|
1459
|
+
|
1460
|
+
# If dry run, exit here
|
1461
|
+
if args.dry_run:
|
1462
|
+
logger.info("DRY RUN MODE")
|
1463
|
+
return ExitCode.SUCCESS
|
1464
|
+
|
1465
|
+
# Load and validate schema
|
1466
|
+
try:
|
1467
|
+
logger.debug("[_main] Loading schema from %s", args.schema_file)
|
1468
|
+
schema = validate_schema_file(
|
1469
|
+
args.schema_file, verbose=args.verbose_schema
|
1470
|
+
)
|
1471
|
+
logger.debug("[_main] Creating output model")
|
1472
|
+
output_model = create_dynamic_model(
|
1473
|
+
schema,
|
1474
|
+
base_name="OutputModel",
|
1475
|
+
show_schema=args.show_model_schema,
|
1476
|
+
debug_validation=args.debug_validation,
|
1477
|
+
)
|
1478
|
+
logger.debug("[_main] Successfully created output model")
|
1479
|
+
except (SchemaFileError, InvalidJSONError, SchemaValidationError) as e:
|
1480
|
+
logger.error(str(e))
|
1481
|
+
return ExitCode.SCHEMA_ERROR
|
1482
|
+
except ModelCreationError as e:
|
1483
|
+
logger.error(f"Model creation error: {e}")
|
1484
|
+
return ExitCode.SCHEMA_ERROR
|
1485
|
+
except Exception as e:
|
1486
|
+
logger.error(f"Unexpected error creating model: {e}")
|
1487
|
+
return ExitCode.SCHEMA_ERROR
|
1488
|
+
|
1489
|
+
# Validate model support
|
1490
|
+
try:
|
1491
|
+
supports_structured_output(args.model)
|
1492
|
+
except ModelNotSupportedError as e:
|
1493
|
+
logger.error(str(e))
|
1494
|
+
return ExitCode.DATA_ERROR
|
1495
|
+
except ModelVersionError as e:
|
1496
|
+
logger.error(str(e))
|
1497
|
+
return ExitCode.DATA_ERROR
|
1498
|
+
|
1499
|
+
# Estimate token usage
|
1500
|
+
messages = [
|
1501
|
+
{"role": "system", "content": args.system_prompt},
|
1502
|
+
{"role": "user", "content": rendered_task},
|
1503
|
+
]
|
1504
|
+
total_tokens = estimate_tokens_for_chat(messages, args.model)
|
1505
|
+
context_limit = get_context_window_limit(args.model)
|
1506
|
+
|
1507
|
+
if total_tokens > context_limit:
|
1508
|
+
logger.error(
|
1509
|
+
f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
|
1510
|
+
)
|
1511
|
+
return ExitCode.DATA_ERROR
|
1512
|
+
|
1513
|
+
# Get API key
|
1514
|
+
api_key = args.api_key or os.getenv("OPENAI_API_KEY")
|
1515
|
+
if not api_key:
|
1516
|
+
logger.error(
|
1517
|
+
"No OpenAI API key provided (--api-key or OPENAI_API_KEY env var)"
|
1518
|
+
)
|
1519
|
+
return ExitCode.USAGE_ERROR
|
1520
|
+
|
1521
|
+
# Create OpenAI client
|
1522
|
+
client = AsyncOpenAI(api_key=api_key, timeout=args.timeout)
|
1523
|
+
|
1524
|
+
# Create log callback that matches expected signature
|
1525
|
+
def log_callback(
|
1526
|
+
level: int, message: str, extra: dict[str, Any]
|
1527
|
+
) -> None:
|
1528
|
+
# Only log if debug_openai_stream is enabled
|
1529
|
+
if args.debug_openai_stream:
|
1530
|
+
# Include extra dictionary in the message for both DEBUG and ERROR
|
1531
|
+
if extra: # Only add if there's actually extra data
|
1532
|
+
extra_str = json.dumps(extra, indent=2)
|
1533
|
+
message = f"{message}\nDetails:\n{extra_str}"
|
1534
|
+
openai_logger.log(level, message, extra=extra)
|
1535
|
+
|
1536
|
+
# Make API request
|
1537
|
+
try:
|
1538
|
+
logger.debug("Creating ProgressContext for API response handling")
|
1539
|
+
with ProgressContext(
|
1540
|
+
description="Processing API response",
|
1541
|
+
level=args.progress_level,
|
1542
|
+
) as progress:
|
1543
|
+
logger.debug("Starting API response stream processing")
|
1544
|
+
logger.debug("Debug flag status: %s", args.debug_openai_stream)
|
1545
|
+
logger.debug("OpenAI logger level: %s", openai_logger.level)
|
1546
|
+
for handler in openai_logger.handlers:
|
1547
|
+
logger.debug("Handler level: %s", handler.level)
|
1548
|
+
async for chunk in async_openai_structured_stream(
|
1549
|
+
client=client,
|
1550
|
+
model=args.model,
|
1551
|
+
temperature=args.temperature,
|
1552
|
+
max_tokens=args.max_tokens,
|
1553
|
+
top_p=args.top_p,
|
1554
|
+
frequency_penalty=args.frequency_penalty,
|
1555
|
+
presence_penalty=args.presence_penalty,
|
1556
|
+
system_prompt=args.system_prompt,
|
1557
|
+
user_prompt=rendered_task,
|
1558
|
+
output_schema=output_model,
|
1559
|
+
timeout=args.timeout,
|
1560
|
+
on_log=log_callback,
|
1561
|
+
):
|
1562
|
+
logger.debug("Received API response chunk")
|
1563
|
+
if not chunk:
|
1564
|
+
logger.debug("Empty chunk received, skipping")
|
1565
|
+
continue
|
1566
|
+
|
1567
|
+
# Write output
|
1568
|
+
try:
|
1569
|
+
logger.debug("Starting to process output chunk")
|
1570
|
+
dumped = chunk.model_dump(mode="json")
|
1571
|
+
logger.debug("Successfully dumped chunk to JSON")
|
1572
|
+
logger.debug("Dumped chunk: %s", dumped)
|
1573
|
+
logger.debug(
|
1574
|
+
"Chunk type: %s, length: %d",
|
1575
|
+
type(dumped),
|
1576
|
+
len(json.dumps(dumped)),
|
1577
|
+
)
|
1578
|
+
|
1579
|
+
if args.output_file:
|
1580
|
+
logger.debug(
|
1581
|
+
"Writing to output file: %s", args.output_file
|
1582
|
+
)
|
1583
|
+
try:
|
1584
|
+
with open(
|
1585
|
+
args.output_file, "a", encoding="utf-8"
|
1586
|
+
) as f:
|
1587
|
+
json_str = json.dumps(dumped, indent=2)
|
1588
|
+
logger.debug(
|
1589
|
+
"Writing JSON string of length %d",
|
1590
|
+
len(json_str),
|
1591
|
+
)
|
1592
|
+
f.write(json_str)
|
1593
|
+
f.write("\n")
|
1594
|
+
logger.debug("Successfully wrote to file")
|
1595
|
+
except Exception as e:
|
1596
|
+
logger.error(
|
1597
|
+
"Failed to write to output file: %s", e
|
1598
|
+
)
|
1599
|
+
else:
|
1600
|
+
logger.debug(
|
1601
|
+
"About to call progress.print_output with JSON string"
|
1602
|
+
)
|
1603
|
+
json_str = json.dumps(dumped, indent=2)
|
1604
|
+
logger.debug(
|
1605
|
+
"JSON string length before print_output: %d",
|
1606
|
+
len(json_str),
|
1607
|
+
)
|
1608
|
+
logger.debug(
|
1609
|
+
"First 100 chars of JSON string: %s",
|
1610
|
+
json_str[:100] if json_str else "",
|
1611
|
+
)
|
1612
|
+
progress.print_output(json_str)
|
1613
|
+
logger.debug(
|
1614
|
+
"Completed print_output call for JSON string"
|
1615
|
+
)
|
1616
|
+
|
1617
|
+
logger.debug("Starting progress update")
|
1618
|
+
progress.update()
|
1619
|
+
logger.debug("Completed progress update")
|
1620
|
+
except Exception as e:
|
1621
|
+
logger.error("Failed to process chunk: %s", e)
|
1622
|
+
logger.error("Chunk: %s", chunk)
|
1623
|
+
continue
|
1624
|
+
|
1625
|
+
logger.debug("Finished processing API response stream")
|
1626
|
+
|
1627
|
+
except StreamInterruptedError as e:
|
1628
|
+
logger.error(f"Stream interrupted: {e}")
|
1629
|
+
return ExitCode.API_ERROR
|
1630
|
+
except StreamBufferError as e:
|
1631
|
+
logger.error(f"Stream buffer error: {e}")
|
1632
|
+
return ExitCode.API_ERROR
|
1633
|
+
except StreamParseError as e:
|
1634
|
+
logger.error(f"Stream parse error: {e}")
|
1635
|
+
return ExitCode.API_ERROR
|
1636
|
+
except APIResponseError as e:
|
1637
|
+
logger.error(f"API response error: {e}")
|
1638
|
+
return ExitCode.API_ERROR
|
1639
|
+
except EmptyResponseError as e:
|
1640
|
+
logger.error(f"Empty response error: {e}")
|
1641
|
+
return ExitCode.API_ERROR
|
1642
|
+
except InvalidResponseFormatError as e:
|
1643
|
+
logger.error(f"Invalid response format: {e}")
|
1644
|
+
return ExitCode.API_ERROR
|
1645
|
+
except (APIConnectionError, InternalServerError) as e:
|
1646
|
+
logger.error(f"API connection error: {e}")
|
1647
|
+
return ExitCode.API_ERROR
|
1648
|
+
except RateLimitError as e:
|
1649
|
+
logger.error(f"Rate limit exceeded: {e}")
|
1650
|
+
return ExitCode.API_ERROR
|
1651
|
+
except BadRequestError as e:
|
1652
|
+
logger.error(f"Bad request: {e}")
|
1653
|
+
return ExitCode.API_ERROR
|
1654
|
+
except AuthenticationError as e:
|
1655
|
+
logger.error(f"Authentication failed: {e}")
|
1656
|
+
return ExitCode.API_ERROR
|
1657
|
+
except OpenAIClientError as e:
|
1658
|
+
logger.error(f"OpenAI client error: {e}")
|
1659
|
+
return ExitCode.API_ERROR
|
1660
|
+
except Exception as e:
|
1661
|
+
logger.error(f"Unexpected error: {e}")
|
1662
|
+
return ExitCode.INTERNAL_ERROR
|
1663
|
+
|
1664
|
+
return ExitCode.SUCCESS
|
1665
|
+
|
1666
|
+
except KeyboardInterrupt:
|
1667
|
+
logger.error("Operation cancelled by user")
|
1668
|
+
return ExitCode.INTERRUPTED
|
1669
|
+
except PathSecurityError as e:
|
1670
|
+
# Only log security errors if they haven't been logged already
|
1671
|
+
logger.debug(
|
1672
|
+
"[_main] Caught PathSecurityError: %s (logged=%s)",
|
1673
|
+
str(e),
|
1674
|
+
getattr(e, "has_been_logged", False),
|
1675
|
+
)
|
1676
|
+
if not getattr(e, "has_been_logged", False):
|
1677
|
+
logger.error(str(e))
|
1678
|
+
return ExitCode.SECURITY_ERROR
|
1679
|
+
except ValueError as e:
|
1680
|
+
# Get the original cause of the error
|
1681
|
+
cause = e.__cause__ or e.__context__
|
1682
|
+
if isinstance(cause, PathSecurityError):
|
1683
|
+
logger.debug(
|
1684
|
+
"[_main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
|
1685
|
+
str(cause),
|
1686
|
+
getattr(cause, "has_been_logged", False),
|
1687
|
+
)
|
1688
|
+
# Only log security errors if they haven't been logged already
|
1689
|
+
if not getattr(cause, "has_been_logged", False):
|
1690
|
+
logger.error(str(cause))
|
1691
|
+
return ExitCode.SECURITY_ERROR
|
1692
|
+
else:
|
1693
|
+
logger.debug("[_main] Caught ValueError: %s", str(e))
|
1694
|
+
logger.error(f"Invalid input: {e}")
|
1695
|
+
return ExitCode.DATA_ERROR
|
1696
|
+
except Exception as e:
|
1697
|
+
# Check if this is a wrapped security error
|
1698
|
+
if isinstance(e.__cause__, PathSecurityError):
|
1699
|
+
logger.debug(
|
1700
|
+
"[_main] Caught wrapped PathSecurityError in Exception: %s (logged=%s)",
|
1701
|
+
str(e.__cause__),
|
1702
|
+
getattr(e.__cause__, "has_been_logged", False),
|
1703
|
+
)
|
1704
|
+
# Only log security errors if they haven't been logged already
|
1705
|
+
if not getattr(e.__cause__, "has_been_logged", False):
|
1706
|
+
logger.error(str(e.__cause__))
|
1707
|
+
return ExitCode.SECURITY_ERROR
|
1708
|
+
logger.debug("[_main] Caught unexpected error: %s", str(e))
|
1709
|
+
logger.error(f"Unexpected error: {e}")
|
1710
|
+
return ExitCode.INTERNAL_ERROR
|
1711
|
+
|
1712
|
+
|
1713
|
+
def main() -> None:
|
1714
|
+
"""CLI entry point that handles all errors."""
|
1715
|
+
try:
|
1716
|
+
logger.debug("[main] Starting main execution")
|
1717
|
+
exit_code = asyncio.run(_main())
|
1718
|
+
sys.exit(exit_code.value)
|
1719
|
+
except KeyboardInterrupt:
|
1720
|
+
logger.error("Operation cancelled by user")
|
1721
|
+
sys.exit(ExitCode.INTERRUPTED.value)
|
1722
|
+
except PathSecurityError as e:
|
1723
|
+
# Only log security errors if they haven't been logged already
|
1724
|
+
logger.debug(
|
1725
|
+
"[main] Caught PathSecurityError: %s (logged=%s)",
|
1726
|
+
str(e),
|
1727
|
+
getattr(e, "has_been_logged", False),
|
1728
|
+
)
|
1729
|
+
if not getattr(e, "has_been_logged", False):
|
1730
|
+
logger.error(str(e))
|
1731
|
+
sys.exit(ExitCode.SECURITY_ERROR.value)
|
1732
|
+
except ValueError as e:
|
1733
|
+
# Get the original cause of the error
|
1734
|
+
cause = e.__cause__ or e.__context__
|
1735
|
+
if isinstance(cause, PathSecurityError):
|
1736
|
+
logger.debug(
|
1737
|
+
"[main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
|
1738
|
+
str(cause),
|
1739
|
+
getattr(cause, "has_been_logged", False),
|
1740
|
+
)
|
1741
|
+
# Only log security errors if they haven't been logged already
|
1742
|
+
if not getattr(cause, "has_been_logged", False):
|
1743
|
+
logger.error(str(cause))
|
1744
|
+
sys.exit(ExitCode.SECURITY_ERROR.value)
|
1745
|
+
else:
|
1746
|
+
logger.debug("[main] Caught ValueError: %s", str(e))
|
1747
|
+
logger.error(f"Invalid input: {e}")
|
1748
|
+
sys.exit(ExitCode.DATA_ERROR.value)
|
1749
|
+
except Exception as e:
|
1750
|
+
# Check if this is a wrapped security error
|
1751
|
+
if isinstance(e.__cause__, PathSecurityError):
|
1752
|
+
logger.debug(
|
1753
|
+
"[main] Caught wrapped PathSecurityError in Exception: %s (logged=%s)",
|
1754
|
+
str(e.__cause__),
|
1755
|
+
getattr(e.__cause__, "has_been_logged", False),
|
1756
|
+
)
|
1757
|
+
# Only log security errors if they haven't been logged already
|
1758
|
+
if not getattr(e.__cause__, "has_been_logged", False):
|
1759
|
+
logger.error(str(e.__cause__))
|
1760
|
+
sys.exit(ExitCode.SECURITY_ERROR.value)
|
1761
|
+
logger.debug("[main] Caught unexpected error: %s", str(e))
|
1762
|
+
logger.error(f"Unexpected error: {e}")
|
1763
|
+
sys.exit(ExitCode.INTERNAL_ERROR.value)
|
1764
|
+
|
1765
|
+
|
1766
|
+
# Export public API
|
1767
|
+
__all__ = [
|
1768
|
+
"ExitCode",
|
1769
|
+
"estimate_tokens_for_chat",
|
1770
|
+
"get_context_window_limit",
|
1771
|
+
"get_default_token_limit",
|
1772
|
+
"parse_json_var",
|
1773
|
+
"create_dynamic_model",
|
1774
|
+
"validate_path_mapping",
|
1775
|
+
"create_argument_parser",
|
1776
|
+
"main",
|
1777
|
+
]
|
1778
|
+
|
1779
|
+
|
1780
|
+
def create_dynamic_model(
|
1781
|
+
schema: Dict[str, Any],
|
1782
|
+
base_name: str = "DynamicModel",
|
1783
|
+
show_schema: bool = False,
|
1784
|
+
debug_validation: bool = False,
|
1785
|
+
) -> Type[BaseModel]:
|
1786
|
+
"""Create a Pydantic model from a JSON schema.
|
1787
|
+
|
1788
|
+
Args:
|
1789
|
+
schema: JSON schema dict, can be wrapped in {"schema": ...} format
|
1790
|
+
base_name: Base name for the model
|
1791
|
+
show_schema: Whether to show the generated schema
|
1792
|
+
debug_validation: Whether to enable validation debugging
|
1793
|
+
|
1794
|
+
Returns:
|
1795
|
+
Generated Pydantic model class
|
1796
|
+
|
1797
|
+
Raises:
|
1798
|
+
ModelCreationError: When model creation fails
|
1799
|
+
SchemaValidationError: When schema is invalid
|
1800
|
+
"""
|
1801
|
+
if debug_validation:
|
1802
|
+
logger.info("Creating dynamic model from schema:")
|
1803
|
+
logger.info(json.dumps(schema, indent=2))
|
1804
|
+
|
1805
|
+
try:
|
1806
|
+
# Extract required fields
|
1807
|
+
required: Set[str] = set(schema.get("required", []))
|
1808
|
+
|
1809
|
+
# Handle our wrapper format if present
|
1810
|
+
if "schema" in schema:
|
1811
|
+
if debug_validation:
|
1812
|
+
logger.info("Found schema wrapper, extracting inner schema")
|
1813
|
+
logger.info(
|
1814
|
+
"Original schema: %s", json.dumps(schema, indent=2)
|
1815
|
+
)
|
1816
|
+
inner_schema = schema["schema"]
|
1817
|
+
if not isinstance(inner_schema, dict):
|
1818
|
+
if debug_validation:
|
1819
|
+
logger.info(
|
1820
|
+
"Inner schema must be a dictionary, got %s",
|
1821
|
+
type(inner_schema),
|
1822
|
+
)
|
1823
|
+
raise SchemaValidationError(
|
1824
|
+
"Inner schema must be a dictionary"
|
1825
|
+
)
|
1826
|
+
if debug_validation:
|
1827
|
+
logger.info("Using inner schema:")
|
1828
|
+
logger.info(json.dumps(inner_schema, indent=2))
|
1829
|
+
schema = inner_schema
|
1830
|
+
|
1831
|
+
# Ensure schema has type field
|
1832
|
+
if "type" not in schema:
|
1833
|
+
if debug_validation:
|
1834
|
+
logger.info("Schema missing type field, assuming object type")
|
1835
|
+
schema["type"] = "object"
|
1836
|
+
|
1837
|
+
# Validate root schema is object type
|
1838
|
+
if schema["type"] != "object":
|
1839
|
+
if debug_validation:
|
1840
|
+
logger.error(
|
1841
|
+
"Schema type must be 'object', got %s", schema["type"]
|
1842
|
+
)
|
1843
|
+
raise SchemaValidationError("Root schema must be of type 'object'")
|
1844
|
+
|
1845
|
+
# Create model configuration
|
1846
|
+
config = ConfigDict(
|
1847
|
+
title=schema.get("title", base_name),
|
1848
|
+
extra=(
|
1849
|
+
"forbid"
|
1850
|
+
if schema.get("additionalProperties") is False
|
1851
|
+
else "allow"
|
1852
|
+
),
|
1853
|
+
validate_default=True,
|
1854
|
+
use_enum_values=True,
|
1855
|
+
arbitrary_types_allowed=True,
|
1856
|
+
json_schema_extra={
|
1857
|
+
k: v
|
1858
|
+
for k, v in schema.items()
|
1859
|
+
if k
|
1860
|
+
not in {
|
1861
|
+
"type",
|
1862
|
+
"properties",
|
1863
|
+
"required",
|
1864
|
+
"title",
|
1865
|
+
"description",
|
1866
|
+
"additionalProperties",
|
1867
|
+
"readOnly",
|
1868
|
+
}
|
1869
|
+
},
|
1870
|
+
)
|
1871
|
+
|
1872
|
+
if debug_validation:
|
1873
|
+
logger.info("Created model configuration:")
|
1874
|
+
logger.info(" Title: %s", config.get("title"))
|
1875
|
+
logger.info(" Extra: %s", config.get("extra"))
|
1876
|
+
logger.info(
|
1877
|
+
" Validate Default: %s", config.get("validate_default")
|
1878
|
+
)
|
1879
|
+
logger.info(" Use Enum Values: %s", config.get("use_enum_values"))
|
1880
|
+
logger.info(
|
1881
|
+
" Arbitrary Types: %s", config.get("arbitrary_types_allowed")
|
1882
|
+
)
|
1883
|
+
logger.info(
|
1884
|
+
" JSON Schema Extra: %s", config.get("json_schema_extra")
|
1885
|
+
)
|
1886
|
+
|
1887
|
+
# Create field definitions
|
1888
|
+
field_definitions: Dict[str, FieldDefinition] = {}
|
1889
|
+
properties = schema.get("properties", {})
|
1890
|
+
|
1891
|
+
for field_name, field_schema in properties.items():
|
1892
|
+
try:
|
1893
|
+
if debug_validation:
|
1894
|
+
logger.info("Processing field %s:", field_name)
|
1895
|
+
logger.info(
|
1896
|
+
" Schema: %s", json.dumps(field_schema, indent=2)
|
1897
|
+
)
|
1898
|
+
|
1899
|
+
python_type, field = _get_type_with_constraints(
|
1900
|
+
field_schema, field_name, base_name
|
1901
|
+
)
|
1902
|
+
|
1903
|
+
# Handle optional fields
|
1904
|
+
if field_name not in required:
|
1905
|
+
if debug_validation:
|
1906
|
+
logger.info(
|
1907
|
+
"Field %s is optional, wrapping in Optional",
|
1908
|
+
field_name,
|
1909
|
+
)
|
1910
|
+
field_type = cast(Type[Any], Optional[python_type])
|
1911
|
+
else:
|
1912
|
+
field_type = python_type
|
1913
|
+
if debug_validation:
|
1914
|
+
logger.info("Field %s is required", field_name)
|
1915
|
+
|
1916
|
+
# Create field definition
|
1917
|
+
field_definitions[field_name] = (field_type, field)
|
1918
|
+
|
1919
|
+
if debug_validation:
|
1920
|
+
logger.info("Successfully created field definition:")
|
1921
|
+
logger.info(" Name: %s", field_name)
|
1922
|
+
logger.info(" Type: %s", str(field_type))
|
1923
|
+
logger.info(" Required: %s", field_name in required)
|
1924
|
+
|
1925
|
+
except (FieldDefinitionError, NestedModelError) as e:
|
1926
|
+
if debug_validation:
|
1927
|
+
logger.error("Error creating field %s:", field_name)
|
1928
|
+
logger.error(" Error type: %s", type(e).__name__)
|
1929
|
+
logger.error(" Error message: %s", str(e))
|
1930
|
+
raise ModelValidationError(base_name, [str(e)])
|
1931
|
+
|
1932
|
+
# Create the model with the fields
|
1933
|
+
model = create_model(
|
1934
|
+
base_name,
|
1935
|
+
__config__=config,
|
1936
|
+
**{
|
1937
|
+
name: (
|
1938
|
+
(
|
1939
|
+
cast(Type[Any], field_type)
|
1940
|
+
if is_container_type(field_type)
|
1941
|
+
else field_type
|
1942
|
+
),
|
1943
|
+
field,
|
1944
|
+
)
|
1945
|
+
for name, (field_type, field) in field_definitions.items()
|
1946
|
+
},
|
1947
|
+
)
|
1948
|
+
|
1949
|
+
if debug_validation:
|
1950
|
+
logger.info("Successfully created model: %s", model.__name__)
|
1951
|
+
logger.info("Model config: %s", dict(model.model_config))
|
1952
|
+
logger.info(
|
1953
|
+
"Model schema: %s",
|
1954
|
+
json.dumps(model.model_json_schema(), indent=2),
|
1955
|
+
)
|
1956
|
+
|
1957
|
+
# Validate the model's JSON schema
|
1958
|
+
try:
|
1959
|
+
model.model_json_schema()
|
1960
|
+
except ValidationError as e:
|
1961
|
+
if debug_validation:
|
1962
|
+
logger.error("Schema validation failed:")
|
1963
|
+
logger.error(" Error type: %s", type(e).__name__)
|
1964
|
+
logger.error(" Error message: %s", str(e))
|
1965
|
+
if hasattr(e, "errors"):
|
1966
|
+
logger.error(" Validation errors:")
|
1967
|
+
for error in e.errors():
|
1968
|
+
logger.error(" - %s", error)
|
1969
|
+
validation_errors = (
|
1970
|
+
[str(err) for err in e.errors()]
|
1971
|
+
if hasattr(e, "errors")
|
1972
|
+
else [str(e)]
|
1973
|
+
)
|
1974
|
+
raise ModelValidationError(base_name, validation_errors)
|
1975
|
+
|
1976
|
+
return cast(Type[BaseModel], model)
|
1977
|
+
|
1978
|
+
except Exception as e:
|
1979
|
+
if debug_validation:
|
1980
|
+
logger.error("Failed to create model:")
|
1981
|
+
logger.error(" Error type: %s", type(e).__name__)
|
1982
|
+
logger.error(" Error message: %s", str(e))
|
1983
|
+
if hasattr(e, "__cause__"):
|
1984
|
+
logger.error(" Caused by: %s", str(e.__cause__))
|
1985
|
+
if hasattr(e, "__context__"):
|
1986
|
+
logger.error(" Context: %s", str(e.__context__))
|
1987
|
+
if hasattr(e, "__traceback__"):
|
1988
|
+
import traceback
|
1989
|
+
|
1990
|
+
logger.error(
|
1991
|
+
" Traceback:\n%s",
|
1992
|
+
"".join(traceback.format_tb(e.__traceback__)),
|
1993
|
+
)
|
1994
|
+
raise ModelCreationError(
|
1995
|
+
f"Failed to create model '{base_name}': {str(e)}"
|
1996
|
+
)
|
1997
|
+
|
1998
|
+
|
1999
|
+
# Validation functions
|
2000
|
+
def pattern(regex: str) -> Any:
|
2001
|
+
return constr(pattern=regex)
|
2002
|
+
|
2003
|
+
|
2004
|
+
def min_length(length: int) -> Any:
|
2005
|
+
return BeforeValidator(lambda v: v if len(str(v)) >= length else None)
|
2006
|
+
|
2007
|
+
|
2008
|
+
def max_length(length: int) -> Any:
|
2009
|
+
return BeforeValidator(lambda v: v if len(str(v)) <= length else None)
|
2010
|
+
|
2011
|
+
|
2012
|
+
def ge(value: Union[int, float]) -> Any:
|
2013
|
+
return BeforeValidator(lambda v: v if float(v) >= value else None)
|
2014
|
+
|
2015
|
+
|
2016
|
+
def le(value: Union[int, float]) -> Any:
|
2017
|
+
return BeforeValidator(lambda v: v if float(v) <= value else None)
|
2018
|
+
|
2019
|
+
|
2020
|
+
def gt(value: Union[int, float]) -> Any:
|
2021
|
+
return BeforeValidator(lambda v: v if float(v) > value else None)
|
2022
|
+
|
2023
|
+
|
2024
|
+
def lt(value: Union[int, float]) -> Any:
|
2025
|
+
return BeforeValidator(lambda v: v if float(v) < value else None)
|
2026
|
+
|
2027
|
+
|
2028
|
+
def multiple_of(value: Union[int, float]) -> Any:
|
2029
|
+
return BeforeValidator(lambda v: v if float(v) % value == 0 else None)
|
2030
|
+
|
2031
|
+
|
2032
|
+
if __name__ == "__main__":
|
2033
|
+
main()
|