ostruct-cli 0.4.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ostruct/cli/base_errors.py +183 -0
- ostruct/cli/cli.py +879 -592
- ostruct/cli/click_options.py +320 -202
- ostruct/cli/errors.py +273 -134
- ostruct/cli/exit_codes.py +18 -0
- ostruct/cli/file_info.py +30 -14
- ostruct/cli/file_list.py +4 -10
- ostruct/cli/file_utils.py +43 -35
- ostruct/cli/path_utils.py +32 -4
- ostruct/cli/schema_validation.py +213 -0
- ostruct/cli/security/allowed_checker.py +8 -0
- ostruct/cli/security/base.py +46 -0
- ostruct/cli/security/errors.py +83 -103
- ostruct/cli/security/security_manager.py +22 -9
- ostruct/cli/serialization.py +25 -0
- ostruct/cli/template_filters.py +5 -3
- ostruct/cli/template_rendering.py +46 -22
- ostruct/cli/template_utils.py +12 -4
- ostruct/cli/template_validation.py +26 -8
- ostruct/cli/token_utils.py +43 -0
- ostruct/cli/validators.py +109 -0
- ostruct_cli-0.6.0.dist-info/METADATA +404 -0
- ostruct_cli-0.6.0.dist-info/RECORD +43 -0
- {ostruct_cli-0.4.0.dist-info → ostruct_cli-0.6.0.dist-info}/WHEEL +1 -1
- ostruct_cli-0.4.0.dist-info/METADATA +0 -186
- ostruct_cli-0.4.0.dist-info/RECORD +0 -36
- {ostruct_cli-0.4.0.dist-info → ostruct_cli-0.6.0.dist-info}/LICENSE +0 -0
- {ostruct_cli-0.4.0.dist-info → ostruct_cli-0.6.0.dist-info}/entry_points.txt +0 -0
ostruct/cli/cli.py
CHANGED
@@ -5,10 +5,10 @@ import json
|
|
5
5
|
import logging
|
6
6
|
import os
|
7
7
|
import sys
|
8
|
-
from dataclasses import dataclass
|
9
8
|
from enum import Enum, IntEnum
|
10
9
|
from typing import (
|
11
10
|
Any,
|
11
|
+
AsyncGenerator,
|
12
12
|
Dict,
|
13
13
|
List,
|
14
14
|
Literal,
|
@@ -16,6 +16,7 @@ from typing import (
|
|
16
16
|
Set,
|
17
17
|
Tuple,
|
18
18
|
Type,
|
19
|
+
TypedDict,
|
19
20
|
TypeVar,
|
20
21
|
Union,
|
21
22
|
cast,
|
@@ -31,20 +32,10 @@ from pathlib import Path
|
|
31
32
|
|
32
33
|
import click
|
33
34
|
import jinja2
|
34
|
-
import tiktoken
|
35
35
|
import yaml
|
36
|
-
from openai import
|
37
|
-
APIConnectionError,
|
38
|
-
AsyncOpenAI,
|
39
|
-
AuthenticationError,
|
40
|
-
BadRequestError,
|
41
|
-
InternalServerError,
|
42
|
-
RateLimitError,
|
43
|
-
)
|
36
|
+
from openai import AsyncOpenAI
|
44
37
|
from openai_structured.client import (
|
45
38
|
async_openai_structured_stream,
|
46
|
-
get_context_window_limit,
|
47
|
-
get_default_token_limit,
|
48
39
|
supports_structured_output,
|
49
40
|
)
|
50
41
|
from openai_structured.errors import (
|
@@ -54,12 +45,9 @@ from openai_structured.errors import (
|
|
54
45
|
ModelNotSupportedError,
|
55
46
|
ModelVersionError,
|
56
47
|
OpenAIClientError,
|
57
|
-
SchemaFileError,
|
58
|
-
SchemaValidationError,
|
59
48
|
StreamBufferError,
|
60
|
-
StreamInterruptedError,
|
61
|
-
StreamParseError,
|
62
49
|
)
|
50
|
+
from openai_structured.model_registry import ModelRegistry
|
63
51
|
from pydantic import (
|
64
52
|
AnyUrl,
|
65
53
|
BaseModel,
|
@@ -74,61 +62,63 @@ from pydantic.functional_validators import BeforeValidator
|
|
74
62
|
from pydantic.types import constr
|
75
63
|
from typing_extensions import TypeAlias
|
76
64
|
|
77
|
-
from ostruct.cli.click_options import
|
65
|
+
from ostruct.cli.click_options import all_options
|
66
|
+
from ostruct.cli.exit_codes import ExitCode
|
78
67
|
|
79
68
|
from .. import __version__ # noqa: F401 - Used in package metadata
|
80
69
|
from .errors import (
|
81
70
|
CLIError,
|
82
71
|
DirectoryNotFoundError,
|
83
72
|
FieldDefinitionError,
|
84
|
-
FileNotFoundError,
|
85
73
|
InvalidJSONError,
|
86
74
|
ModelCreationError,
|
87
75
|
ModelValidationError,
|
88
76
|
NestedModelError,
|
77
|
+
OstructFileNotFoundError,
|
89
78
|
PathSecurityError,
|
79
|
+
SchemaFileError,
|
80
|
+
SchemaValidationError,
|
81
|
+
StreamInterruptedError,
|
82
|
+
StreamParseError,
|
90
83
|
TaskTemplateSyntaxError,
|
91
84
|
TaskTemplateVariableError,
|
92
|
-
VariableError,
|
93
85
|
VariableNameError,
|
94
86
|
VariableValueError,
|
95
87
|
)
|
96
|
-
from .file_utils import FileInfoList,
|
88
|
+
from .file_utils import FileInfoList, collect_files
|
97
89
|
from .path_utils import validate_path_mapping
|
98
90
|
from .security import SecurityManager
|
91
|
+
from .serialization import LogSerializer
|
99
92
|
from .template_env import create_jinja_env
|
100
93
|
from .template_utils import SystemPromptError, render_template
|
94
|
+
from .token_utils import estimate_tokens_with_encoding
|
101
95
|
|
102
96
|
# Constants
|
103
97
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
|
104
98
|
|
105
99
|
|
106
|
-
|
107
|
-
|
108
|
-
"""Compatibility class to mimic argparse.Namespace for existing code."""
|
100
|
+
class CLIParams(TypedDict, total=False):
|
101
|
+
"""Type-safe CLI parameters."""
|
109
102
|
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
103
|
+
files: List[
|
104
|
+
Tuple[str, str]
|
105
|
+
] # List of (name, path) tuples from Click's nargs=2
|
106
|
+
dir: List[
|
107
|
+
Tuple[str, str]
|
108
|
+
] # List of (name, dir) tuples from Click's nargs=2
|
109
|
+
patterns: List[
|
110
|
+
Tuple[str, str]
|
111
|
+
] # List of (name, pattern) tuples from Click's nargs=2
|
112
|
+
allowed_dirs: List[str]
|
116
113
|
base_dir: str
|
117
114
|
allowed_dir_file: Optional[str]
|
118
|
-
|
119
|
-
dir_ext: Optional[str]
|
115
|
+
recursive: bool
|
120
116
|
var: List[str]
|
121
117
|
json_var: List[str]
|
122
118
|
system_prompt: Optional[str]
|
123
119
|
system_prompt_file: Optional[str]
|
124
120
|
ignore_task_sysprompt: bool
|
125
|
-
schema_file: str
|
126
121
|
model: str
|
127
|
-
temperature: float
|
128
|
-
max_tokens: Optional[int]
|
129
|
-
top_p: float
|
130
|
-
frequency_penalty: float
|
131
|
-
presence_penalty: float
|
132
122
|
timeout: float
|
133
123
|
output_file: Optional[str]
|
134
124
|
dry_run: bool
|
@@ -138,7 +128,16 @@ class Namespace:
|
|
138
128
|
debug_openai_stream: bool
|
139
129
|
show_model_schema: bool
|
140
130
|
debug_validation: bool
|
141
|
-
|
131
|
+
temperature: Optional[float]
|
132
|
+
max_output_tokens: Optional[int]
|
133
|
+
top_p: Optional[float]
|
134
|
+
frequency_penalty: Optional[float]
|
135
|
+
presence_penalty: Optional[float]
|
136
|
+
reasoning_effort: Optional[str]
|
137
|
+
progress_level: str
|
138
|
+
task_file: Optional[str]
|
139
|
+
task: Optional[str]
|
140
|
+
schema_file: str
|
142
141
|
|
143
142
|
|
144
143
|
# Set up logging
|
@@ -176,45 +175,6 @@ ostruct_file_handler.setFormatter(
|
|
176
175
|
logger.addHandler(ostruct_file_handler)
|
177
176
|
|
178
177
|
|
179
|
-
class ExitCode(IntEnum):
|
180
|
-
"""Exit codes for the CLI following standard Unix conventions.
|
181
|
-
|
182
|
-
Categories:
|
183
|
-
- Success (0-1)
|
184
|
-
- User Interruption (2-3)
|
185
|
-
- Input/Validation (64-69)
|
186
|
-
- I/O and File Access (70-79)
|
187
|
-
- API and External Services (80-89)
|
188
|
-
- Internal Errors (90-99)
|
189
|
-
"""
|
190
|
-
|
191
|
-
# Success codes
|
192
|
-
SUCCESS = 0
|
193
|
-
|
194
|
-
# User interruption
|
195
|
-
INTERRUPTED = 2
|
196
|
-
|
197
|
-
# Input/Validation errors (64-69)
|
198
|
-
USAGE_ERROR = 64
|
199
|
-
DATA_ERROR = 65
|
200
|
-
SCHEMA_ERROR = 66
|
201
|
-
VALIDATION_ERROR = 67
|
202
|
-
|
203
|
-
# I/O and File Access errors (70-79)
|
204
|
-
IO_ERROR = 70
|
205
|
-
FILE_NOT_FOUND = 71
|
206
|
-
PERMISSION_ERROR = 72
|
207
|
-
SECURITY_ERROR = 73
|
208
|
-
|
209
|
-
# API and External Service errors (80-89)
|
210
|
-
API_ERROR = 80
|
211
|
-
API_TIMEOUT = 81
|
212
|
-
|
213
|
-
# Internal errors (90-99)
|
214
|
-
INTERNAL_ERROR = 90
|
215
|
-
UNKNOWN_ERROR = 91
|
216
|
-
|
217
|
-
|
218
178
|
# Type aliases
|
219
179
|
FieldType = (
|
220
180
|
Any # Changed from Type[Any] to allow both concrete types and generics
|
@@ -281,7 +241,7 @@ def _get_type_with_constraints(
|
|
281
241
|
show_schema=False,
|
282
242
|
debug_validation=False,
|
283
243
|
)
|
284
|
-
array_type: Type[List[Any]] = List[array_item_model] # type: ignore
|
244
|
+
array_type: Type[List[Any]] = List[array_item_model] # type: ignore
|
285
245
|
return (array_type, Field(**field_kwargs))
|
286
246
|
|
287
247
|
# For non-object items, use the type directly
|
@@ -403,64 +363,17 @@ K = TypeVar("K")
|
|
403
363
|
V = TypeVar("V")
|
404
364
|
|
405
365
|
|
406
|
-
def estimate_tokens_for_chat(
|
407
|
-
messages: List[Dict[str, str]],
|
408
|
-
model: str,
|
409
|
-
encoder: Any = None,
|
410
|
-
) -> int:
|
411
|
-
"""Estimate the number of tokens in a chat completion.
|
412
|
-
|
413
|
-
Args:
|
414
|
-
messages: List of chat messages
|
415
|
-
model: Model name
|
416
|
-
encoder: Optional tiktoken encoder for testing. If provided, only uses encoder.encode() results.
|
417
|
-
"""
|
418
|
-
if encoder is None:
|
419
|
-
try:
|
420
|
-
# Try to get the encoding for the specific model
|
421
|
-
encoder = tiktoken.get_encoding("o200k_base")
|
422
|
-
except KeyError:
|
423
|
-
# Fall back to cl100k_base for unknown models
|
424
|
-
encoder = tiktoken.get_encoding("cl100k_base")
|
425
|
-
|
426
|
-
# Use standard token counting logic for real tiktoken encoders
|
427
|
-
num_tokens = 0
|
428
|
-
for message in messages:
|
429
|
-
# Add message overhead
|
430
|
-
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
431
|
-
for key, value in message.items():
|
432
|
-
num_tokens += len(encoder.encode(str(value)))
|
433
|
-
if key == "name": # if there's a name, the role is omitted
|
434
|
-
num_tokens -= 1 # role is omitted
|
435
|
-
num_tokens += 2 # every reply is primed with <im_start>assistant
|
436
|
-
return num_tokens
|
437
|
-
else:
|
438
|
-
# For mock encoders in tests, just return the length of encoded content
|
439
|
-
num_tokens = 0
|
440
|
-
for message in messages:
|
441
|
-
for value in message.values():
|
442
|
-
num_tokens += len(encoder.encode(str(value)))
|
443
|
-
return num_tokens
|
444
|
-
|
445
|
-
|
446
366
|
def validate_token_limits(
|
447
367
|
model: str, total_tokens: int, max_token_limit: Optional[int] = None
|
448
368
|
) -> None:
|
449
|
-
"""Validate token counts against model limits.
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
total_tokens: Total number of tokens in the prompt
|
454
|
-
max_token_limit: Optional user-specified token limit
|
455
|
-
|
456
|
-
Raises:
|
457
|
-
ValueError: If token limits are exceeded
|
458
|
-
"""
|
459
|
-
context_limit = get_context_window_limit(model)
|
369
|
+
"""Validate token counts against model limits."""
|
370
|
+
registry = ModelRegistry()
|
371
|
+
capabilities = registry.get_capabilities(model)
|
372
|
+
context_limit = capabilities.context_window
|
460
373
|
output_limit = (
|
461
374
|
max_token_limit
|
462
375
|
if max_token_limit is not None
|
463
|
-
else
|
376
|
+
else capabilities.max_output_tokens
|
464
377
|
)
|
465
378
|
|
466
379
|
# Check if total tokens exceed context window
|
@@ -522,8 +435,12 @@ def process_system_prompt(
|
|
522
435
|
)
|
523
436
|
with open(path, "r", encoding="utf-8") as f:
|
524
437
|
system_prompt = f.read().strip()
|
525
|
-
except
|
526
|
-
raise SystemPromptError(
|
438
|
+
except OstructFileNotFoundError as e:
|
439
|
+
raise SystemPromptError(
|
440
|
+
f"Failed to load system prompt file: {e}"
|
441
|
+
) from e
|
442
|
+
except PathSecurityError as e:
|
443
|
+
raise SystemPromptError(f"Invalid system prompt file: {e}") from e
|
527
444
|
|
528
445
|
if system_prompt is not None:
|
529
446
|
# Render system prompt with template context
|
@@ -591,7 +508,8 @@ def validate_variable_mapping(
|
|
591
508
|
value = json.loads(value)
|
592
509
|
except json.JSONDecodeError as e:
|
593
510
|
raise InvalidJSONError(
|
594
|
-
f"Invalid JSON value for variable {name!r}: {value!r}"
|
511
|
+
f"Invalid JSON value for variable {name!r}: {value!r}",
|
512
|
+
context={"variable_name": name},
|
595
513
|
) from e
|
596
514
|
|
597
515
|
return name, value
|
@@ -787,11 +705,20 @@ def validate_task_template(
|
|
787
705
|
template_content: str
|
788
706
|
if task_file is not None:
|
789
707
|
try:
|
790
|
-
|
791
|
-
with open(path, "r", encoding="utf-8") as f:
|
708
|
+
with open(task_file, "r", encoding="utf-8") as f:
|
792
709
|
template_content = f.read()
|
793
|
-
except
|
794
|
-
raise TaskTemplateVariableError(
|
710
|
+
except FileNotFoundError:
|
711
|
+
raise TaskTemplateVariableError(
|
712
|
+
f"Task template file not found: {task_file}"
|
713
|
+
)
|
714
|
+
except PermissionError:
|
715
|
+
raise TaskTemplateVariableError(
|
716
|
+
f"Permission denied reading task template file: {task_file}"
|
717
|
+
)
|
718
|
+
except Exception as e:
|
719
|
+
raise TaskTemplateVariableError(
|
720
|
+
f"Error reading task template file: {e}"
|
721
|
+
)
|
795
722
|
else:
|
796
723
|
template_content = task # type: ignore # We know task is str here due to the checks above
|
797
724
|
|
@@ -809,10 +736,10 @@ def validate_schema_file(
|
|
809
736
|
path: str,
|
810
737
|
verbose: bool = False,
|
811
738
|
) -> Dict[str, Any]:
|
812
|
-
"""Validate a JSON schema file.
|
739
|
+
"""Validate and load a JSON schema file.
|
813
740
|
|
814
741
|
Args:
|
815
|
-
path: Path to
|
742
|
+
path: Path to schema file
|
816
743
|
verbose: Whether to enable verbose logging
|
817
744
|
|
818
745
|
Returns:
|
@@ -827,14 +754,42 @@ def validate_schema_file(
|
|
827
754
|
logger.info("Validating schema file: %s", path)
|
828
755
|
|
829
756
|
try:
|
830
|
-
|
831
|
-
|
757
|
+
logger.debug("Opening schema file: %s", path)
|
758
|
+
with open(path, "r", encoding="utf-8") as f:
|
759
|
+
logger.debug("Loading JSON from schema file")
|
760
|
+
try:
|
761
|
+
schema = json.load(f)
|
762
|
+
logger.debug(
|
763
|
+
"Successfully loaded JSON: %s",
|
764
|
+
json.dumps(schema, indent=2),
|
765
|
+
)
|
766
|
+
except json.JSONDecodeError as e:
|
767
|
+
logger.error("JSON decode error in %s: %s", path, str(e))
|
768
|
+
logger.debug(
|
769
|
+
"Error details - line: %d, col: %d, msg: %s",
|
770
|
+
e.lineno,
|
771
|
+
e.colno,
|
772
|
+
e.msg,
|
773
|
+
)
|
774
|
+
raise InvalidJSONError(
|
775
|
+
f"Invalid JSON in schema file {path}: {e}",
|
776
|
+
context={"schema_path": path},
|
777
|
+
) from e
|
832
778
|
except FileNotFoundError:
|
833
|
-
|
834
|
-
|
835
|
-
raise
|
779
|
+
msg = f"Schema file not found: {path}"
|
780
|
+
logger.error(msg)
|
781
|
+
raise SchemaFileError(msg, schema_path=path)
|
782
|
+
except PermissionError:
|
783
|
+
msg = f"Permission denied reading schema file: {path}"
|
784
|
+
logger.error(msg)
|
785
|
+
raise SchemaFileError(msg, schema_path=path)
|
836
786
|
except Exception as e:
|
837
|
-
|
787
|
+
if isinstance(e, InvalidJSONError):
|
788
|
+
raise
|
789
|
+
msg = f"Failed to read schema file {path}: {e}"
|
790
|
+
logger.error(msg)
|
791
|
+
logger.debug("Unexpected error details: %s", str(e))
|
792
|
+
raise SchemaFileError(msg, schema_path=path) from e
|
838
793
|
|
839
794
|
# Pre-validation structure checks
|
840
795
|
if verbose:
|
@@ -842,11 +797,9 @@ def validate_schema_file(
|
|
842
797
|
logger.debug("Loaded schema: %s", json.dumps(schema, indent=2))
|
843
798
|
|
844
799
|
if not isinstance(schema, dict):
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
)
|
849
|
-
raise SchemaValidationError("Schema must be a JSON object")
|
800
|
+
msg = f"Schema in {path} must be a JSON object"
|
801
|
+
logger.error(msg)
|
802
|
+
raise SchemaValidationError(msg, context={"path": path})
|
850
803
|
|
851
804
|
# Validate schema structure
|
852
805
|
if "schema" in schema:
|
@@ -854,30 +807,37 @@ def validate_schema_file(
|
|
854
807
|
logger.debug("Found schema wrapper, validating inner schema")
|
855
808
|
inner_schema = schema["schema"]
|
856
809
|
if not isinstance(inner_schema, dict):
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
type(inner_schema).__name__,
|
861
|
-
)
|
862
|
-
raise SchemaValidationError("Inner schema must be a JSON object")
|
810
|
+
msg = f"Inner schema in {path} must be a JSON object"
|
811
|
+
logger.error(msg)
|
812
|
+
raise SchemaValidationError(msg, context={"path": path})
|
863
813
|
if verbose:
|
864
814
|
logger.debug("Inner schema validated successfully")
|
815
|
+
logger.debug(
|
816
|
+
"Inner schema: %s", json.dumps(inner_schema, indent=2)
|
817
|
+
)
|
865
818
|
else:
|
866
819
|
if verbose:
|
867
820
|
logger.debug("No schema wrapper found, using schema as-is")
|
821
|
+
logger.debug("Schema: %s", json.dumps(schema, indent=2))
|
822
|
+
|
823
|
+
# Additional schema validation
|
824
|
+
if "type" not in schema.get("schema", schema):
|
825
|
+
msg = f"Schema in {path} must specify a type"
|
826
|
+
logger.error(msg)
|
827
|
+
raise SchemaValidationError(msg, context={"path": path})
|
868
828
|
|
869
829
|
# Return the full schema including wrapper
|
870
830
|
return schema
|
871
831
|
|
872
832
|
|
873
833
|
def collect_template_files(
|
874
|
-
args:
|
834
|
+
args: CLIParams,
|
875
835
|
security_manager: SecurityManager,
|
876
|
-
) -> Dict[str,
|
836
|
+
) -> Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]:
|
877
837
|
"""Collect files from command line arguments.
|
878
838
|
|
879
839
|
Args:
|
880
|
-
args:
|
840
|
+
args: Command line arguments
|
881
841
|
security_manager: Security manager for path validation
|
882
842
|
|
883
843
|
Returns:
|
@@ -888,15 +848,31 @@ def collect_template_files(
|
|
888
848
|
ValueError: If file mappings are invalid or files cannot be accessed
|
889
849
|
"""
|
890
850
|
try:
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
851
|
+
# Get files, directories, and patterns from args - they are already tuples from Click's nargs=2
|
852
|
+
files = list(
|
853
|
+
args.get("files", [])
|
854
|
+
) # List of (name, path) tuples from Click
|
855
|
+
dirs = args.get("dir", []) # List of (name, dir) tuples from Click
|
856
|
+
patterns = args.get(
|
857
|
+
"patterns", []
|
858
|
+
) # List of (name, pattern) tuples from Click
|
859
|
+
|
860
|
+
# Collect files from directories and patterns
|
861
|
+
dir_files = collect_files(
|
862
|
+
file_mappings=cast(List[Tuple[str, Union[str, Path]]], files),
|
863
|
+
dir_mappings=cast(List[Tuple[str, Union[str, Path]]], dirs),
|
864
|
+
pattern_mappings=cast(
|
865
|
+
List[Tuple[str, Union[str, Path]]], patterns
|
866
|
+
),
|
867
|
+
dir_recursive=args.get("recursive", False),
|
897
868
|
security_manager=security_manager,
|
898
869
|
)
|
899
|
-
|
870
|
+
|
871
|
+
# Combine results
|
872
|
+
return cast(
|
873
|
+
Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]],
|
874
|
+
dir_files,
|
875
|
+
)
|
900
876
|
except PathSecurityError:
|
901
877
|
# Let PathSecurityError propagate without wrapping
|
902
878
|
raise
|
@@ -914,11 +890,11 @@ def collect_template_files(
|
|
914
890
|
raise ValueError(f"Error collecting files: {e}")
|
915
891
|
|
916
892
|
|
917
|
-
def collect_simple_variables(args:
|
893
|
+
def collect_simple_variables(args: CLIParams) -> Dict[str, str]:
|
918
894
|
"""Collect simple string variables from --var arguments.
|
919
895
|
|
920
896
|
Args:
|
921
|
-
args:
|
897
|
+
args: Command line arguments
|
922
898
|
|
923
899
|
Returns:
|
924
900
|
Dictionary mapping variable names to string values
|
@@ -929,10 +905,15 @@ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
|
|
929
905
|
variables: Dict[str, str] = {}
|
930
906
|
all_names: Set[str] = set()
|
931
907
|
|
932
|
-
if args.var:
|
933
|
-
for mapping in args
|
908
|
+
if args.get("var"):
|
909
|
+
for mapping in args["var"]:
|
934
910
|
try:
|
935
|
-
|
911
|
+
# Handle both tuple format and string format
|
912
|
+
if isinstance(mapping, tuple):
|
913
|
+
name, value = mapping
|
914
|
+
else:
|
915
|
+
name, value = mapping.split("=", 1)
|
916
|
+
|
936
917
|
if not name.isidentifier():
|
937
918
|
raise VariableNameError(f"Invalid variable name: {name}")
|
938
919
|
if name in all_names:
|
@@ -947,11 +928,11 @@ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
|
|
947
928
|
return variables
|
948
929
|
|
949
930
|
|
950
|
-
def collect_json_variables(args:
|
931
|
+
def collect_json_variables(args: CLIParams) -> Dict[str, Any]:
|
951
932
|
"""Collect JSON variables from --json-var arguments.
|
952
933
|
|
953
934
|
Args:
|
954
|
-
args:
|
935
|
+
args: Command line arguments
|
955
936
|
|
956
937
|
Returns:
|
957
938
|
Dictionary mapping variable names to parsed JSON values
|
@@ -963,53 +944,52 @@ def collect_json_variables(args: Namespace) -> Dict[str, Any]:
|
|
963
944
|
variables: Dict[str, Any] = {}
|
964
945
|
all_names: Set[str] = set()
|
965
946
|
|
966
|
-
if args.json_var:
|
967
|
-
for mapping in args
|
947
|
+
if args.get("json_var"):
|
948
|
+
for mapping in args["json_var"]:
|
968
949
|
try:
|
969
|
-
|
950
|
+
# Handle both tuple format and string format
|
951
|
+
if isinstance(mapping, tuple):
|
952
|
+
name, value = (
|
953
|
+
mapping # Value is already parsed by Click validator
|
954
|
+
)
|
955
|
+
else:
|
956
|
+
try:
|
957
|
+
name, json_str = mapping.split("=", 1)
|
958
|
+
except ValueError:
|
959
|
+
raise VariableNameError(
|
960
|
+
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
961
|
+
)
|
962
|
+
try:
|
963
|
+
value = json.loads(json_str)
|
964
|
+
except json.JSONDecodeError as e:
|
965
|
+
raise InvalidJSONError(
|
966
|
+
f"Invalid JSON value for variable '{name}': {json_str}",
|
967
|
+
context={"variable_name": name},
|
968
|
+
) from e
|
969
|
+
|
970
970
|
if not name.isidentifier():
|
971
971
|
raise VariableNameError(f"Invalid variable name: {name}")
|
972
972
|
if name in all_names:
|
973
973
|
raise VariableNameError(f"Duplicate variable name: {name}")
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
raise InvalidJSONError(
|
980
|
-
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
|
981
|
-
)
|
982
|
-
except ValueError:
|
983
|
-
raise VariableNameError(
|
984
|
-
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
985
|
-
)
|
974
|
+
|
975
|
+
variables[name] = value
|
976
|
+
all_names.add(name)
|
977
|
+
except (VariableNameError, InvalidJSONError):
|
978
|
+
raise
|
986
979
|
|
987
980
|
return variables
|
988
981
|
|
989
982
|
|
990
983
|
def create_template_context(
|
991
|
-
files: Optional[
|
984
|
+
files: Optional[
|
985
|
+
Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]
|
986
|
+
] = None,
|
992
987
|
variables: Optional[Dict[str, str]] = None,
|
993
988
|
json_variables: Optional[Dict[str, Any]] = None,
|
994
989
|
security_manager: Optional[SecurityManager] = None,
|
995
990
|
stdin_content: Optional[str] = None,
|
996
991
|
) -> Dict[str, Any]:
|
997
|
-
"""Create template context from
|
998
|
-
|
999
|
-
Args:
|
1000
|
-
files: Optional dictionary mapping names to FileInfoList objects
|
1001
|
-
variables: Optional dictionary of simple string variables
|
1002
|
-
json_variables: Optional dictionary of JSON variables
|
1003
|
-
security_manager: Optional security manager for path validation
|
1004
|
-
stdin_content: Optional content to use for stdin
|
1005
|
-
|
1006
|
-
Returns:
|
1007
|
-
Template context dictionary
|
1008
|
-
|
1009
|
-
Raises:
|
1010
|
-
PathSecurityError: If any file paths violate security constraints
|
1011
|
-
VariableError: If variable mappings are invalid
|
1012
|
-
"""
|
992
|
+
"""Create template context from files and variables."""
|
1013
993
|
context: Dict[str, Any] = {}
|
1014
994
|
|
1015
995
|
# Add file variables
|
@@ -1032,14 +1012,14 @@ def create_template_context(
|
|
1032
1012
|
return context
|
1033
1013
|
|
1034
1014
|
|
1035
|
-
def create_template_context_from_args(
|
1036
|
-
args:
|
1015
|
+
async def create_template_context_from_args(
|
1016
|
+
args: CLIParams,
|
1037
1017
|
security_manager: SecurityManager,
|
1038
1018
|
) -> Dict[str, Any]:
|
1039
1019
|
"""Create template context from command line arguments.
|
1040
1020
|
|
1041
1021
|
Args:
|
1042
|
-
args:
|
1022
|
+
args: Command line arguments
|
1043
1023
|
security_manager: Security manager for path validation
|
1044
1024
|
|
1045
1025
|
Returns:
|
@@ -1052,50 +1032,13 @@ def create_template_context_from_args(
|
|
1052
1032
|
"""
|
1053
1033
|
try:
|
1054
1034
|
# Collect files from arguments
|
1055
|
-
files =
|
1056
|
-
if any([args.file, args.files, args.dir]):
|
1057
|
-
files = collect_files(
|
1058
|
-
file_mappings=args.file,
|
1059
|
-
pattern_mappings=args.files,
|
1060
|
-
dir_mappings=args.dir,
|
1061
|
-
dir_recursive=args.dir_recursive,
|
1062
|
-
dir_extensions=(
|
1063
|
-
args.dir_ext.split(",") if args.dir_ext else None
|
1064
|
-
),
|
1065
|
-
security_manager=security_manager,
|
1066
|
-
)
|
1035
|
+
files = collect_template_files(args, security_manager)
|
1067
1036
|
|
1068
1037
|
# Collect simple variables
|
1069
|
-
|
1070
|
-
variables = collect_simple_variables(args)
|
1071
|
-
except VariableNameError as e:
|
1072
|
-
raise VariableError(str(e))
|
1038
|
+
variables = collect_simple_variables(args)
|
1073
1039
|
|
1074
1040
|
# Collect JSON variables
|
1075
|
-
json_variables =
|
1076
|
-
if args.json_var:
|
1077
|
-
for mapping in args.json_var:
|
1078
|
-
try:
|
1079
|
-
name, value = mapping.split("=", 1)
|
1080
|
-
if not name.isidentifier():
|
1081
|
-
raise VariableNameError(
|
1082
|
-
f"Invalid variable name: {name}"
|
1083
|
-
)
|
1084
|
-
try:
|
1085
|
-
json_value = json.loads(value)
|
1086
|
-
except json.JSONDecodeError as e:
|
1087
|
-
raise InvalidJSONError(
|
1088
|
-
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {value}"
|
1089
|
-
)
|
1090
|
-
if name in json_variables:
|
1091
|
-
raise VariableNameError(
|
1092
|
-
f"Duplicate variable name: {name}"
|
1093
|
-
)
|
1094
|
-
json_variables[name] = json_value
|
1095
|
-
except ValueError:
|
1096
|
-
raise VariableNameError(
|
1097
|
-
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
1098
|
-
)
|
1041
|
+
json_variables = collect_json_variables(args)
|
1099
1042
|
|
1100
1043
|
# Get stdin content if available
|
1101
1044
|
stdin_content = None
|
@@ -1106,7 +1049,7 @@ def create_template_context_from_args(
|
|
1106
1049
|
# Skip stdin if it can't be read
|
1107
1050
|
pass
|
1108
1051
|
|
1109
|
-
|
1052
|
+
context = create_template_context(
|
1110
1053
|
files=files,
|
1111
1054
|
variables=variables,
|
1112
1055
|
json_variables=json_variables,
|
@@ -1114,6 +1057,11 @@ def create_template_context_from_args(
|
|
1114
1057
|
stdin_content=stdin_content,
|
1115
1058
|
)
|
1116
1059
|
|
1060
|
+
# Add current model to context
|
1061
|
+
context["current_model"] = args["model"]
|
1062
|
+
|
1063
|
+
return context
|
1064
|
+
|
1117
1065
|
except PathSecurityError:
|
1118
1066
|
# Let PathSecurityError propagate without wrapping
|
1119
1067
|
raise
|
@@ -1235,7 +1183,8 @@ def parse_json_var(var_str: str) -> Tuple[str, Any]:
|
|
1235
1183
|
value = json.loads(json_str)
|
1236
1184
|
except json.JSONDecodeError as e:
|
1237
1185
|
raise InvalidJSONError(
|
1238
|
-
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
|
1186
|
+
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}",
|
1187
|
+
context={"variable_name": name},
|
1239
1188
|
)
|
1240
1189
|
|
1241
1190
|
return name, value
|
@@ -1284,41 +1233,96 @@ def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
|
|
1284
1233
|
|
1285
1234
|
|
1286
1235
|
def handle_error(e: Exception) -> None:
|
1287
|
-
"""Handle errors
|
1236
|
+
"""Handle CLI errors and display appropriate messages.
|
1237
|
+
|
1238
|
+
Maintains specific error type handling while reducing duplication.
|
1239
|
+
Provides enhanced debug logging for CLI errors.
|
1240
|
+
"""
|
1241
|
+
# 1. Determine error type and message
|
1288
1242
|
if isinstance(e, click.UsageError):
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
elif isinstance(e,
|
1299
|
-
|
1300
|
-
|
1301
|
-
click.secho(msg, fg="red", err=True)
|
1302
|
-
sys.exit(ExitCode.DATA_ERROR)
|
1303
|
-
elif isinstance(e, FileNotFoundError):
|
1304
|
-
# Use the original error message if available
|
1305
|
-
msg = str(e) if str(e) != "None" else "File not found"
|
1306
|
-
click.secho(msg, fg="red", err=True)
|
1307
|
-
sys.exit(ExitCode.SCHEMA_ERROR)
|
1308
|
-
elif isinstance(e, TaskTemplateSyntaxError):
|
1309
|
-
# Use the original error message if available
|
1310
|
-
msg = str(e) if str(e) != "None" else "Template syntax error"
|
1311
|
-
click.secho(msg, fg="red", err=True)
|
1312
|
-
sys.exit(ExitCode.INTERNAL_ERROR)
|
1243
|
+
msg = f"Usage error: {str(e)}"
|
1244
|
+
exit_code = ExitCode.USAGE_ERROR
|
1245
|
+
elif isinstance(e, SchemaFileError):
|
1246
|
+
# Preserve specific schema error handling
|
1247
|
+
msg = str(e) # Use existing __str__ formatting
|
1248
|
+
exit_code = ExitCode.SCHEMA_ERROR
|
1249
|
+
elif isinstance(e, (InvalidJSONError, json.JSONDecodeError)):
|
1250
|
+
msg = f"Invalid JSON error: {str(e)}"
|
1251
|
+
exit_code = ExitCode.DATA_ERROR
|
1252
|
+
elif isinstance(e, SchemaValidationError):
|
1253
|
+
msg = f"Schema validation error: {str(e)}"
|
1254
|
+
exit_code = ExitCode.VALIDATION_ERROR
|
1313
1255
|
elif isinstance(e, CLIError):
|
1314
|
-
# Use
|
1315
|
-
e.
|
1316
|
-
|
1317
|
-
|
1256
|
+
msg = str(e) # Use existing __str__ formatting
|
1257
|
+
exit_code = ExitCode(e.exit_code) # Convert int to ExitCode
|
1258
|
+
else:
|
1259
|
+
msg = f"Unexpected error: {str(e)}"
|
1260
|
+
exit_code = ExitCode.INTERNAL_ERROR
|
1261
|
+
|
1262
|
+
# 2. Debug logging
|
1263
|
+
if isinstance(e, CLIError) and logger.isEnabledFor(logging.DEBUG):
|
1264
|
+
# Format context fields with lowercase keys and simple values
|
1265
|
+
context_str = ""
|
1266
|
+
if hasattr(e, "context"):
|
1267
|
+
for key, value in sorted(e.context.items()):
|
1268
|
+
if key not in {
|
1269
|
+
"timestamp",
|
1270
|
+
"host",
|
1271
|
+
"version",
|
1272
|
+
"python_version",
|
1273
|
+
}:
|
1274
|
+
context_str += f"{key.lower()}: {value}\n"
|
1275
|
+
|
1276
|
+
logger.debug(
|
1277
|
+
"Error details:\n"
|
1278
|
+
f"Type: {type(e).__name__}\n"
|
1279
|
+
f"{context_str.rstrip()}"
|
1318
1280
|
)
|
1281
|
+
elif not isinstance(e, click.UsageError):
|
1282
|
+
logger.error(msg, exc_info=True)
|
1319
1283
|
else:
|
1320
|
-
|
1321
|
-
|
1284
|
+
logger.error(msg)
|
1285
|
+
|
1286
|
+
# 3. User output
|
1287
|
+
click.secho(msg, fg="red", err=True)
|
1288
|
+
sys.exit(exit_code)
|
1289
|
+
|
1290
|
+
|
1291
|
+
def validate_model_parameters(model: str, params: Dict[str, Any]) -> None:
|
1292
|
+
"""Validate model parameters against model capabilities.
|
1293
|
+
|
1294
|
+
Args:
|
1295
|
+
model: The model name to validate parameters for
|
1296
|
+
params: Dictionary of parameter names and values to validate
|
1297
|
+
|
1298
|
+
Raises:
|
1299
|
+
CLIError: If any parameters are not supported by the model
|
1300
|
+
"""
|
1301
|
+
try:
|
1302
|
+
capabilities = ModelRegistry().get_capabilities(model)
|
1303
|
+
for param_name, value in params.items():
|
1304
|
+
try:
|
1305
|
+
capabilities.validate_parameter(param_name, value)
|
1306
|
+
except OpenAIClientError as e:
|
1307
|
+
logger.error(
|
1308
|
+
"Validation failed for model %s: %s", model, str(e)
|
1309
|
+
)
|
1310
|
+
raise CLIError(
|
1311
|
+
str(e),
|
1312
|
+
exit_code=ExitCode.VALIDATION_ERROR,
|
1313
|
+
context={
|
1314
|
+
"model": model,
|
1315
|
+
"param": param_name,
|
1316
|
+
"value": value,
|
1317
|
+
},
|
1318
|
+
)
|
1319
|
+
except (ModelNotSupportedError, ModelVersionError) as e:
|
1320
|
+
logger.error("Model validation failed: %s", str(e))
|
1321
|
+
raise CLIError(
|
1322
|
+
str(e),
|
1323
|
+
exit_code=ExitCode.VALIDATION_ERROR,
|
1324
|
+
context={"model": model},
|
1325
|
+
)
|
1322
1326
|
|
1323
1327
|
|
1324
1328
|
async def stream_structured_output(
|
@@ -1329,91 +1333,103 @@ async def stream_structured_output(
|
|
1329
1333
|
output_schema: Type[BaseModel],
|
1330
1334
|
output_file: Optional[str] = None,
|
1331
1335
|
**kwargs: Any,
|
1332
|
-
) -> None:
|
1336
|
+
) -> AsyncGenerator[BaseModel, None]:
|
1333
1337
|
"""Stream structured output from OpenAI API.
|
1334
1338
|
|
1335
1339
|
This function follows the guide's recommendation for a focused async streaming function.
|
1336
1340
|
It handles the core streaming logic and resource cleanup.
|
1337
|
-
"""
|
1338
|
-
try:
|
1339
|
-
# Base models that don't support streaming
|
1340
|
-
non_streaming_models = {"o1", "o3"}
|
1341
1341
|
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1342
|
+
Args:
|
1343
|
+
client: The OpenAI client to use
|
1344
|
+
model: The model to use
|
1345
|
+
system_prompt: The system prompt to use
|
1346
|
+
user_prompt: The user prompt to use
|
1347
|
+
output_schema: The Pydantic model to validate responses against
|
1348
|
+
output_file: Optional file to write output to
|
1349
|
+
**kwargs: Additional parameters to pass to the API
|
1347
1350
|
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
if not chunk:
|
1363
|
-
continue
|
1364
|
-
|
1365
|
-
# Process and output the chunk
|
1366
|
-
dumped = chunk.model_dump(mode="json")
|
1367
|
-
json_str = json.dumps(dumped, indent=2)
|
1368
|
-
|
1369
|
-
if output_file:
|
1370
|
-
with open(output_file, "a", encoding="utf-8") as f:
|
1371
|
-
f.write(json_str)
|
1372
|
-
f.write("\n")
|
1373
|
-
f.flush() # Ensure immediate flush to file
|
1374
|
-
else:
|
1375
|
-
# Print directly to stdout with immediate flush
|
1376
|
-
print(json_str, flush=True)
|
1377
|
-
else:
|
1378
|
-
# For non-streaming models, use regular completion
|
1379
|
-
response = await client.chat.completions.create(
|
1380
|
-
model=model,
|
1381
|
-
messages=[
|
1382
|
-
{"role": "system", "content": system_prompt},
|
1383
|
-
{"role": "user", "content": user_prompt},
|
1384
|
-
],
|
1385
|
-
stream=False,
|
1386
|
-
**stream_kwargs,
|
1351
|
+
Returns:
|
1352
|
+
An async generator yielding validated model instances
|
1353
|
+
|
1354
|
+
Raises:
|
1355
|
+
ValueError: If the model does not support structured output or parameters are invalid
|
1356
|
+
StreamInterruptedError: If the stream is interrupted
|
1357
|
+
APIResponseError: If there is an API error
|
1358
|
+
"""
|
1359
|
+
try:
|
1360
|
+
# Check if model supports structured output using openai_structured's function
|
1361
|
+
if not supports_structured_output(model):
|
1362
|
+
raise ValueError(
|
1363
|
+
f"Model {model} does not support structured output with json_schema response format. "
|
1364
|
+
"Please use a model that supports structured output."
|
1387
1365
|
)
|
1388
1366
|
|
1389
|
-
|
1390
|
-
|
1391
|
-
if content:
|
1392
|
-
try:
|
1393
|
-
# Parse and validate against schema
|
1394
|
-
result = output_schema.model_validate_json(content)
|
1395
|
-
json_str = json.dumps(
|
1396
|
-
result.model_dump(mode="json"), indent=2
|
1397
|
-
)
|
1367
|
+
# Extract non-model parameters
|
1368
|
+
on_log = kwargs.pop("on_log", None)
|
1398
1369
|
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1370
|
+
# Handle model-specific parameters
|
1371
|
+
stream_kwargs = {}
|
1372
|
+
registry = ModelRegistry()
|
1373
|
+
capabilities = registry.get_capabilities(model)
|
1374
|
+
|
1375
|
+
# Validate and include supported parameters
|
1376
|
+
for param_name, value in kwargs.items():
|
1377
|
+
if param_name in capabilities.supported_parameters:
|
1378
|
+
# Validate the parameter value
|
1379
|
+
capabilities.validate_parameter(param_name, value)
|
1380
|
+
stream_kwargs[param_name] = value
|
1409
1381
|
else:
|
1410
|
-
|
1382
|
+
logger.warning(
|
1383
|
+
f"Parameter {param_name} is not supported by model {model} and will be ignored"
|
1384
|
+
)
|
1411
1385
|
|
1386
|
+
# Log the API request details
|
1387
|
+
logger.debug("Making OpenAI API request with:")
|
1388
|
+
logger.debug("Model: %s", model)
|
1389
|
+
logger.debug("System prompt: %s", system_prompt)
|
1390
|
+
logger.debug("User prompt: %s", user_prompt)
|
1391
|
+
logger.debug("Parameters: %s", json.dumps(stream_kwargs, indent=2))
|
1392
|
+
logger.debug("Schema: %s", output_schema.model_json_schema())
|
1393
|
+
|
1394
|
+
# Use the async generator from openai_structured directly
|
1395
|
+
async for chunk in async_openai_structured_stream(
|
1396
|
+
client=client,
|
1397
|
+
model=model,
|
1398
|
+
system_prompt=system_prompt,
|
1399
|
+
user_prompt=user_prompt,
|
1400
|
+
output_schema=output_schema,
|
1401
|
+
on_log=on_log, # Pass non-model parameters directly to the function
|
1402
|
+
**stream_kwargs, # Pass only validated model parameters
|
1403
|
+
):
|
1404
|
+
yield chunk
|
1405
|
+
|
1406
|
+
except APIResponseError as e:
|
1407
|
+
if "Invalid schema for response_format" in str(
|
1408
|
+
e
|
1409
|
+
) and 'type: "array"' in str(e):
|
1410
|
+
error_msg = (
|
1411
|
+
"OpenAI API Schema Error: The schema must have a root type of 'object', not 'array'. "
|
1412
|
+
"To fix this:\n"
|
1413
|
+
"1. Wrap your array in an object property, e.g.:\n"
|
1414
|
+
" {\n"
|
1415
|
+
' "type": "object",\n'
|
1416
|
+
' "properties": {\n'
|
1417
|
+
' "items": {\n'
|
1418
|
+
' "type": "array",\n'
|
1419
|
+
' "items": { ... your array items schema ... }\n'
|
1420
|
+
" }\n"
|
1421
|
+
" }\n"
|
1422
|
+
" }\n"
|
1423
|
+
"2. Make sure to update your template to handle the wrapper object."
|
1424
|
+
)
|
1425
|
+
logger.error(error_msg)
|
1426
|
+
raise InvalidResponseFormatError(error_msg)
|
1427
|
+
logger.error(f"API error: {e}")
|
1428
|
+
raise
|
1412
1429
|
except (
|
1413
1430
|
StreamInterruptedError,
|
1414
1431
|
StreamBufferError,
|
1415
1432
|
StreamParseError,
|
1416
|
-
APIResponseError,
|
1417
1433
|
EmptyResponseError,
|
1418
1434
|
InvalidResponseFormatError,
|
1419
1435
|
) as e:
|
@@ -1424,149 +1440,457 @@ async def stream_structured_output(
|
|
1424
1440
|
await client.close()
|
1425
1441
|
|
1426
1442
|
|
1427
|
-
|
1428
|
-
|
1443
|
+
@click.group()
|
1444
|
+
@click.version_option(version=__version__)
|
1445
|
+
def cli() -> None:
|
1446
|
+
"""ostruct CLI - Make structured OpenAI API calls.
|
1447
|
+
|
1448
|
+
ostruct allows you to invoke OpenAI Structured Output to produce structured JSON
|
1449
|
+
output using templates and JSON schemas. It provides support for file handling, variable
|
1450
|
+
substitution, and output validation.
|
1451
|
+
|
1452
|
+
For detailed documentation, visit: https://ostruct.readthedocs.io
|
1453
|
+
|
1454
|
+
Examples:
|
1455
|
+
|
1456
|
+
# Basic usage with a template and schema
|
1457
|
+
|
1458
|
+
ostruct run task.j2 schema.json -V name=value
|
1429
1459
|
|
1430
|
-
|
1431
|
-
|
1460
|
+
# Process files with recursive directory scanning
|
1461
|
+
|
1462
|
+
ostruct run template.j2 schema.json -f code main.py -d src ./src -R
|
1463
|
+
|
1464
|
+
# Use JSON variables and custom model parameters
|
1465
|
+
|
1466
|
+
ostruct run task.j2 schema.json -J config='{"env":"prod"}' -m o3-mini
|
1467
|
+
"""
|
1468
|
+
pass
|
1469
|
+
|
1470
|
+
|
1471
|
+
@cli.command()
|
1472
|
+
@click.argument("task_template", type=click.Path(exists=True))
|
1473
|
+
@click.argument("schema_file", type=click.Path(exists=True))
|
1474
|
+
@all_options
|
1475
|
+
@click.pass_context
|
1476
|
+
def run(
|
1477
|
+
ctx: click.Context,
|
1478
|
+
task_template: str,
|
1479
|
+
schema_file: str,
|
1480
|
+
**kwargs: Any,
|
1481
|
+
) -> None:
|
1482
|
+
"""Run a structured task with template and schema.
|
1483
|
+
|
1484
|
+
TASK_TEMPLATE is the path to your Jinja2 template file that defines the task.
|
1485
|
+
SCHEMA_FILE is the path to your JSON schema file that defines the expected output structure.
|
1486
|
+
|
1487
|
+
The command supports various options for file handling, variable definition,
|
1488
|
+
model configuration, and output control. Use --help to see all available options.
|
1489
|
+
|
1490
|
+
Examples:
|
1491
|
+
# Basic usage
|
1492
|
+
ostruct run task.j2 schema.json
|
1493
|
+
|
1494
|
+
# Process multiple files
|
1495
|
+
ostruct run task.j2 schema.json -f code main.py -f test tests/test_main.py
|
1496
|
+
|
1497
|
+
# Scan directories recursively
|
1498
|
+
ostruct run task.j2 schema.json -d src ./src -R
|
1499
|
+
|
1500
|
+
# Define variables
|
1501
|
+
ostruct run task.j2 schema.json -V debug=true -J config='{"env":"prod"}'
|
1502
|
+
|
1503
|
+
# Configure model
|
1504
|
+
ostruct run task.j2 schema.json -m gpt-4 --temperature 0.7 --max-output-tokens 1000
|
1505
|
+
|
1506
|
+
# Control output
|
1507
|
+
ostruct run task.j2 schema.json --output-file result.json --verbose
|
1432
1508
|
"""
|
1433
1509
|
try:
|
1434
|
-
#
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1510
|
+
# Convert Click parameters to typed dict
|
1511
|
+
params: CLIParams = {
|
1512
|
+
"task_file": task_template,
|
1513
|
+
"task": None,
|
1514
|
+
"schema_file": schema_file,
|
1515
|
+
}
|
1516
|
+
# Add only valid keys from kwargs
|
1517
|
+
valid_keys = set(CLIParams.__annotations__.keys())
|
1518
|
+
for k, v in kwargs.items():
|
1519
|
+
if k in valid_keys:
|
1520
|
+
params[k] = v # type: ignore[literal-required]
|
1521
|
+
|
1522
|
+
# Run the async function synchronously
|
1523
|
+
loop = asyncio.new_event_loop()
|
1524
|
+
asyncio.set_event_loop(loop)
|
1525
|
+
try:
|
1526
|
+
exit_code = loop.run_until_complete(run_cli_async(params))
|
1527
|
+
sys.exit(int(exit_code))
|
1528
|
+
finally:
|
1529
|
+
loop.close()
|
1530
|
+
|
1531
|
+
except (
|
1532
|
+
CLIError,
|
1533
|
+
InvalidJSONError,
|
1534
|
+
SchemaFileError,
|
1535
|
+
SchemaValidationError,
|
1536
|
+
) as e:
|
1537
|
+
handle_error(e)
|
1538
|
+
sys.exit(
|
1539
|
+
e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
|
1439
1540
|
)
|
1541
|
+
except click.UsageError as e:
|
1542
|
+
handle_error(e)
|
1543
|
+
sys.exit(ExitCode.USAGE_ERROR)
|
1544
|
+
except Exception as e:
|
1545
|
+
handle_error(e)
|
1546
|
+
sys.exit(ExitCode.INTERNAL_ERROR)
|
1547
|
+
|
1548
|
+
|
1549
|
+
# Remove the old @create_click_command() decorator and cli function definition
|
1550
|
+
# Keep all the other functions and code below this point
|
1551
|
+
|
1552
|
+
|
1553
|
+
async def validate_model_params(args: CLIParams) -> Dict[str, Any]:
|
1554
|
+
"""Validate model parameters and return a dictionary of valid parameters.
|
1555
|
+
|
1556
|
+
Args:
|
1557
|
+
args: Command line arguments
|
1558
|
+
|
1559
|
+
Returns:
|
1560
|
+
Dictionary of validated model parameters
|
1561
|
+
|
1562
|
+
Raises:
|
1563
|
+
CLIError: If model parameters are invalid
|
1564
|
+
"""
|
1565
|
+
params = {
|
1566
|
+
"temperature": args.get("temperature"),
|
1567
|
+
"max_output_tokens": args.get("max_output_tokens"),
|
1568
|
+
"top_p": args.get("top_p"),
|
1569
|
+
"frequency_penalty": args.get("frequency_penalty"),
|
1570
|
+
"presence_penalty": args.get("presence_penalty"),
|
1571
|
+
"reasoning_effort": args.get("reasoning_effort"),
|
1572
|
+
}
|
1573
|
+
# Remove None values
|
1574
|
+
params = {k: v for k, v in params.items() if v is not None}
|
1575
|
+
validate_model_parameters(args["model"], params)
|
1576
|
+
return params
|
1577
|
+
|
1578
|
+
|
1579
|
+
async def validate_inputs(
|
1580
|
+
args: CLIParams,
|
1581
|
+
) -> Tuple[
|
1582
|
+
SecurityManager, str, Dict[str, Any], Dict[str, Any], jinja2.Environment
|
1583
|
+
]:
|
1584
|
+
"""Validate all input parameters and return validated components.
|
1585
|
+
|
1586
|
+
Args:
|
1587
|
+
args: Command line arguments
|
1588
|
+
|
1589
|
+
Returns:
|
1590
|
+
Tuple containing:
|
1591
|
+
- SecurityManager instance
|
1592
|
+
- Task template string
|
1593
|
+
- Schema dictionary
|
1594
|
+
- Template context dictionary
|
1595
|
+
- Jinja2 environment
|
1596
|
+
|
1597
|
+
Raises:
|
1598
|
+
CLIError: For various validation errors
|
1599
|
+
"""
|
1600
|
+
logger.debug("=== Input Validation Phase ===")
|
1601
|
+
security_manager = validate_security_manager(
|
1602
|
+
base_dir=args.get("base_dir"),
|
1603
|
+
allowed_dirs=args.get("allowed_dirs"),
|
1604
|
+
allowed_dir_file=args.get("allowed_dir_file"),
|
1605
|
+
)
|
1606
|
+
|
1607
|
+
task_template = validate_task_template(
|
1608
|
+
args.get("task"), args.get("task_file")
|
1609
|
+
)
|
1610
|
+
logger.debug("Validating schema from %s", args["schema_file"])
|
1611
|
+
schema = validate_schema_file(
|
1612
|
+
args["schema_file"], args.get("verbose", False)
|
1613
|
+
)
|
1614
|
+
template_context = await create_template_context_from_args(
|
1615
|
+
args, security_manager
|
1616
|
+
)
|
1617
|
+
env = create_jinja_env()
|
1618
|
+
|
1619
|
+
return security_manager, task_template, schema, template_context, env
|
1620
|
+
|
1621
|
+
|
1622
|
+
async def process_templates(
|
1623
|
+
args: CLIParams,
|
1624
|
+
task_template: str,
|
1625
|
+
template_context: Dict[str, Any],
|
1626
|
+
env: jinja2.Environment,
|
1627
|
+
) -> Tuple[str, str]:
|
1628
|
+
"""Process system prompt and user prompt templates.
|
1629
|
+
|
1630
|
+
Args:
|
1631
|
+
args: Command line arguments
|
1632
|
+
task_template: Validated task template
|
1633
|
+
template_context: Template context dictionary
|
1634
|
+
env: Jinja2 environment
|
1635
|
+
|
1636
|
+
Returns:
|
1637
|
+
Tuple of (system_prompt, user_prompt)
|
1638
|
+
|
1639
|
+
Raises:
|
1640
|
+
CLIError: For template processing errors
|
1641
|
+
"""
|
1642
|
+
logger.debug("=== Template Processing Phase ===")
|
1643
|
+
system_prompt = process_system_prompt(
|
1644
|
+
task_template,
|
1645
|
+
args.get("system_prompt"),
|
1646
|
+
args.get("system_prompt_file"),
|
1647
|
+
template_context,
|
1648
|
+
env,
|
1649
|
+
args.get("ignore_task_sysprompt", False),
|
1650
|
+
)
|
1651
|
+
user_prompt = render_template(task_template, template_context, env)
|
1652
|
+
return system_prompt, user_prompt
|
1653
|
+
|
1440
1654
|
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1655
|
+
async def validate_model_and_schema(
|
1656
|
+
args: CLIParams,
|
1657
|
+
schema: Dict[str, Any],
|
1658
|
+
system_prompt: str,
|
1659
|
+
user_prompt: str,
|
1660
|
+
) -> Tuple[Type[BaseModel], List[Dict[str, str]], int, ModelRegistry]:
|
1661
|
+
"""Validate model compatibility and schema, and check token limits.
|
1662
|
+
|
1663
|
+
Args:
|
1664
|
+
args: Command line arguments
|
1665
|
+
schema: Schema dictionary
|
1666
|
+
system_prompt: Processed system prompt
|
1667
|
+
user_prompt: Processed user prompt
|
1668
|
+
|
1669
|
+
Returns:
|
1670
|
+
Tuple of (output_model, messages, total_tokens, registry)
|
1671
|
+
|
1672
|
+
Raises:
|
1673
|
+
CLIError: For validation errors
|
1674
|
+
ModelCreationError: When model creation fails
|
1675
|
+
SchemaValidationError: When schema is invalid
|
1676
|
+
"""
|
1677
|
+
logger.debug("=== Model & Schema Validation Phase ===")
|
1678
|
+
try:
|
1679
|
+
output_model = create_dynamic_model(
|
1680
|
+
schema,
|
1681
|
+
show_schema=args.get("show_model_schema", False),
|
1682
|
+
debug_validation=args.get("debug_validation", False),
|
1446
1683
|
)
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
1451
|
-
|
1452
|
-
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1684
|
+
logger.debug("Successfully created output model")
|
1685
|
+
except (
|
1686
|
+
SchemaFileError,
|
1687
|
+
InvalidJSONError,
|
1688
|
+
SchemaValidationError,
|
1689
|
+
ModelCreationError,
|
1690
|
+
) as e:
|
1691
|
+
logger.error("Schema error: %s", str(e))
|
1692
|
+
raise
|
1693
|
+
|
1694
|
+
if not supports_structured_output(args["model"]):
|
1695
|
+
msg = f"Model {args['model']} does not support structured output"
|
1696
|
+
logger.error(msg)
|
1697
|
+
raise ModelNotSupportedError(msg)
|
1698
|
+
|
1699
|
+
messages = [
|
1700
|
+
{"role": "system", "content": system_prompt},
|
1701
|
+
{"role": "user", "content": user_prompt},
|
1702
|
+
]
|
1703
|
+
|
1704
|
+
total_tokens = estimate_tokens_with_encoding(messages, args["model"])
|
1705
|
+
registry = ModelRegistry()
|
1706
|
+
capabilities = registry.get_capabilities(args["model"])
|
1707
|
+
context_limit = capabilities.context_window
|
1708
|
+
|
1709
|
+
if total_tokens > context_limit:
|
1710
|
+
msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
|
1711
|
+
logger.error(msg)
|
1712
|
+
raise CLIError(
|
1713
|
+
msg,
|
1714
|
+
context={
|
1715
|
+
"total_tokens": total_tokens,
|
1716
|
+
"context_limit": context_limit,
|
1717
|
+
},
|
1457
1718
|
)
|
1458
|
-
rendered_task = render_template(task_template, template_context, env)
|
1459
|
-
logger.info("Rendered task template: %s", rendered_task)
|
1460
1719
|
|
1461
|
-
|
1462
|
-
logger.info("DRY RUN MODE")
|
1463
|
-
return ExitCode.SUCCESS
|
1720
|
+
return output_model, messages, total_tokens, registry
|
1464
1721
|
|
1465
|
-
# Create output model
|
1466
|
-
logger.debug("Creating output model")
|
1467
|
-
try:
|
1468
|
-
output_model = create_dynamic_model(
|
1469
|
-
schema,
|
1470
|
-
base_name="OutputModel",
|
1471
|
-
show_schema=args.show_model_schema,
|
1472
|
-
debug_validation=args.debug_validation,
|
1473
|
-
)
|
1474
|
-
logger.debug("Successfully created output model")
|
1475
|
-
except (
|
1476
|
-
SchemaFileError,
|
1477
|
-
InvalidJSONError,
|
1478
|
-
SchemaValidationError,
|
1479
|
-
ModelCreationError,
|
1480
|
-
) as e:
|
1481
|
-
logger.error("Schema error: %s", str(e))
|
1482
|
-
raise # Let the error propagate with its context
|
1483
|
-
|
1484
|
-
# Validate model support and token usage
|
1485
|
-
try:
|
1486
|
-
supports_structured_output(args.model)
|
1487
|
-
except (ModelNotSupportedError, ModelVersionError) as e:
|
1488
|
-
logger.error("Model validation error: %s", str(e))
|
1489
|
-
raise # Let the error propagate with its context
|
1490
|
-
|
1491
|
-
messages = [
|
1492
|
-
{"role": "system", "content": system_prompt},
|
1493
|
-
{"role": "user", "content": rendered_task},
|
1494
|
-
]
|
1495
|
-
total_tokens = estimate_tokens_for_chat(messages, args.model)
|
1496
|
-
context_limit = get_context_window_limit(args.model)
|
1497
|
-
if total_tokens > context_limit:
|
1498
|
-
msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
|
1499
|
-
logger.error(msg)
|
1500
|
-
raise CLIError(
|
1501
|
-
msg,
|
1502
|
-
context={
|
1503
|
-
"total_tokens": total_tokens,
|
1504
|
-
"context_limit": context_limit,
|
1505
|
-
},
|
1506
|
-
)
|
1507
1722
|
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1723
|
+
async def execute_model(
|
1724
|
+
args: CLIParams,
|
1725
|
+
params: Dict[str, Any],
|
1726
|
+
output_model: Type[BaseModel],
|
1727
|
+
system_prompt: str,
|
1728
|
+
user_prompt: str,
|
1729
|
+
) -> ExitCode:
|
1730
|
+
"""Execute the model and handle the response.
|
1514
1731
|
|
1515
|
-
|
1732
|
+
Args:
|
1733
|
+
args: Command line arguments
|
1734
|
+
params: Validated model parameters
|
1735
|
+
output_model: Generated Pydantic model
|
1736
|
+
system_prompt: Processed system prompt
|
1737
|
+
user_prompt: Processed user prompt
|
1516
1738
|
|
1517
|
-
|
1518
|
-
|
1519
|
-
level: int, message: str, extra: dict[str, Any]
|
1520
|
-
) -> None:
|
1521
|
-
if args.debug_openai_stream:
|
1522
|
-
if extra:
|
1523
|
-
extra_str = json.dumps(extra, indent=2)
|
1524
|
-
message = f"{message}\nDetails:\n{extra_str}"
|
1525
|
-
logger.log(level, message, extra=extra)
|
1739
|
+
Returns:
|
1740
|
+
Exit code indicating success or failure
|
1526
1741
|
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1530
|
-
|
1531
|
-
|
1532
|
-
|
1533
|
-
|
1534
|
-
|
1535
|
-
|
1536
|
-
|
1537
|
-
|
1538
|
-
|
1539
|
-
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1742
|
+
Raises:
|
1743
|
+
CLIError: For execution errors
|
1744
|
+
"""
|
1745
|
+
logger.debug("=== Execution Phase ===")
|
1746
|
+
api_key = args.get("api_key") or os.getenv("OPENAI_API_KEY")
|
1747
|
+
if not api_key:
|
1748
|
+
msg = "No API key provided. Set OPENAI_API_KEY environment variable or use --api-key"
|
1749
|
+
logger.error(msg)
|
1750
|
+
raise CLIError(msg, exit_code=ExitCode.API_ERROR)
|
1751
|
+
|
1752
|
+
client = AsyncOpenAI(api_key=api_key, timeout=args.get("timeout", 60.0))
|
1753
|
+
|
1754
|
+
# Create detailed log callback
|
1755
|
+
def log_callback(level: int, message: str, extra: dict[str, Any]) -> None:
|
1756
|
+
if args.get("debug_openai_stream", False):
|
1757
|
+
if extra:
|
1758
|
+
extra_str = LogSerializer.serialize_log_extra(extra)
|
1759
|
+
if extra_str:
|
1760
|
+
logger.debug("%s\nExtra:\n%s", message, extra_str)
|
1761
|
+
else:
|
1762
|
+
logger.debug("%s\nExtra: Failed to serialize", message)
|
1763
|
+
else:
|
1764
|
+
logger.debug(message)
|
1765
|
+
|
1766
|
+
try:
|
1767
|
+
# Create output buffer
|
1768
|
+
output_buffer = []
|
1769
|
+
|
1770
|
+
# Stream the response
|
1771
|
+
async for response in stream_structured_output(
|
1772
|
+
client=client,
|
1773
|
+
model=args["model"],
|
1774
|
+
system_prompt=system_prompt,
|
1775
|
+
user_prompt=user_prompt,
|
1776
|
+
output_schema=output_model,
|
1777
|
+
output_file=args.get("output_file"),
|
1778
|
+
on_log=log_callback,
|
1779
|
+
):
|
1780
|
+
output_buffer.append(response)
|
1781
|
+
|
1782
|
+
# Handle final output
|
1783
|
+
output_file = args.get("output_file")
|
1784
|
+
if output_file:
|
1785
|
+
with open(output_file, "w") as f:
|
1786
|
+
if len(output_buffer) == 1:
|
1787
|
+
f.write(output_buffer[0].model_dump_json(indent=2))
|
1788
|
+
else:
|
1789
|
+
# Build complete JSON array as a single string
|
1790
|
+
json_output = "[\n"
|
1791
|
+
for i, response in enumerate(output_buffer):
|
1792
|
+
if i > 0:
|
1793
|
+
json_output += ",\n"
|
1794
|
+
json_output += " " + response.model_dump_json(
|
1795
|
+
indent=2
|
1796
|
+
).replace("\n", "\n ")
|
1797
|
+
json_output += "\n]"
|
1798
|
+
f.write(json_output)
|
1799
|
+
else:
|
1800
|
+
# Write to stdout when no output file is specified
|
1801
|
+
if len(output_buffer) == 1:
|
1802
|
+
print(output_buffer[0].model_dump_json(indent=2))
|
1803
|
+
else:
|
1804
|
+
# Build complete JSON array as a single string
|
1805
|
+
json_output = "[\n"
|
1806
|
+
for i, response in enumerate(output_buffer):
|
1807
|
+
if i > 0:
|
1808
|
+
json_output += ",\n"
|
1809
|
+
json_output += " " + response.model_dump_json(
|
1810
|
+
indent=2
|
1811
|
+
).replace("\n", "\n ")
|
1812
|
+
json_output += "\n]"
|
1813
|
+
print(json_output)
|
1814
|
+
|
1815
|
+
return ExitCode.SUCCESS
|
1816
|
+
|
1817
|
+
except (
|
1818
|
+
StreamInterruptedError,
|
1819
|
+
StreamBufferError,
|
1820
|
+
StreamParseError,
|
1821
|
+
APIResponseError,
|
1822
|
+
EmptyResponseError,
|
1823
|
+
InvalidResponseFormatError,
|
1824
|
+
) as e:
|
1825
|
+
logger.error("Stream error: %s", str(e))
|
1826
|
+
raise CLIError(str(e), exit_code=ExitCode.API_ERROR)
|
1827
|
+
except Exception as e:
|
1828
|
+
logger.exception("Unexpected error during streaming")
|
1829
|
+
raise CLIError(str(e), exit_code=ExitCode.UNKNOWN_ERROR)
|
1830
|
+
finally:
|
1831
|
+
await client.close()
|
1832
|
+
|
1833
|
+
|
1834
|
+
async def run_cli_async(args: CLIParams) -> ExitCode:
|
1835
|
+
"""Async wrapper for CLI operations.
|
1836
|
+
|
1837
|
+
Returns:
|
1838
|
+
Exit code to return from the CLI
|
1839
|
+
|
1840
|
+
Raises:
|
1841
|
+
CLIError: For various error conditions
|
1842
|
+
KeyboardInterrupt: When operation is cancelled by user
|
1843
|
+
"""
|
1844
|
+
try:
|
1845
|
+
# 0. Model Parameter Validation
|
1846
|
+
logger.debug("=== Model Parameter Validation ===")
|
1847
|
+
params = await validate_model_params(args)
|
1848
|
+
|
1849
|
+
# 1. Input Validation Phase
|
1850
|
+
security_manager, task_template, schema, template_context, env = (
|
1851
|
+
await validate_inputs(args)
|
1852
|
+
)
|
1853
|
+
|
1854
|
+
# 2. Template Processing Phase
|
1855
|
+
system_prompt, user_prompt = await process_templates(
|
1856
|
+
args, task_template, template_context, env
|
1857
|
+
)
|
1858
|
+
|
1859
|
+
# 3. Model & Schema Validation Phase
|
1860
|
+
output_model, messages, total_tokens, registry = (
|
1861
|
+
await validate_model_and_schema(
|
1862
|
+
args, schema, system_prompt, user_prompt
|
1543
1863
|
)
|
1864
|
+
)
|
1865
|
+
|
1866
|
+
# 4. Dry Run Output Phase
|
1867
|
+
if args.get("dry_run", False):
|
1868
|
+
logger.info("\n=== Dry Run Summary ===")
|
1869
|
+
logger.info("✓ Template rendered successfully")
|
1870
|
+
logger.info("✓ Schema validation passed")
|
1871
|
+
logger.info("✓ Model compatibility validated")
|
1872
|
+
logger.info(
|
1873
|
+
f"✓ Token count: {total_tokens}/{registry.get_capabilities(args['model']).context_window}"
|
1874
|
+
)
|
1875
|
+
|
1876
|
+
if args.get("verbose", False):
|
1877
|
+
logger.info("\nSystem Prompt:")
|
1878
|
+
logger.info("-" * 40)
|
1879
|
+
logger.info(system_prompt)
|
1880
|
+
logger.info("\nRendered Template:")
|
1881
|
+
logger.info("-" * 40)
|
1882
|
+
logger.info(user_prompt)
|
1883
|
+
|
1544
1884
|
return ExitCode.SUCCESS
|
1545
|
-
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
EmptyResponseError,
|
1551
|
-
InvalidResponseFormatError,
|
1552
|
-
) as e:
|
1553
|
-
logger.error("Stream error: %s", str(e))
|
1554
|
-
raise # Let stream errors propagate
|
1555
|
-
except (APIConnectionError, InternalServerError) as e:
|
1556
|
-
logger.error("API connection error: %s", str(e))
|
1557
|
-
raise APIResponseError(str(e)) # Convert to our error type
|
1558
|
-
except RateLimitError as e:
|
1559
|
-
logger.error("Rate limit exceeded: %s", str(e))
|
1560
|
-
raise APIResponseError(str(e)) # Convert to our error type
|
1561
|
-
except (BadRequestError, AuthenticationError, OpenAIClientError) as e:
|
1562
|
-
logger.error("API client error: %s", str(e))
|
1563
|
-
raise APIResponseError(str(e)) # Convert to our error type
|
1564
|
-
finally:
|
1565
|
-
await client.close()
|
1885
|
+
|
1886
|
+
# 5. Execution Phase
|
1887
|
+
return await execute_model(
|
1888
|
+
args, params, output_model, system_prompt, user_prompt
|
1889
|
+
)
|
1566
1890
|
|
1567
1891
|
except KeyboardInterrupt:
|
1568
1892
|
logger.info("Operation cancelled by user")
|
1569
|
-
|
1893
|
+
raise
|
1570
1894
|
except Exception as e:
|
1571
1895
|
if isinstance(e, CLIError):
|
1572
1896
|
raise # Let our custom errors propagate
|
@@ -1580,65 +1904,35 @@ def create_cli() -> click.Command:
|
|
1580
1904
|
Returns:
|
1581
1905
|
click.Command: The CLI command object
|
1582
1906
|
"""
|
1583
|
-
|
1584
|
-
@create_click_command()
|
1585
|
-
def cli(**kwargs: Any) -> None:
|
1586
|
-
"""CLI entry point for structured OpenAI API calls."""
|
1587
|
-
try:
|
1588
|
-
args = Namespace(**kwargs)
|
1589
|
-
|
1590
|
-
# Validate required arguments first
|
1591
|
-
if not args.task and not args.task_file:
|
1592
|
-
raise click.UsageError(
|
1593
|
-
"Must specify either --task or --task-file"
|
1594
|
-
)
|
1595
|
-
if not args.schema_file:
|
1596
|
-
raise click.UsageError("Missing option '--schema-file'")
|
1597
|
-
if args.task and args.task_file:
|
1598
|
-
raise click.UsageError(
|
1599
|
-
"Cannot specify both --task and --task-file"
|
1600
|
-
)
|
1601
|
-
if args.system_prompt and args.system_prompt_file:
|
1602
|
-
raise click.UsageError(
|
1603
|
-
"Cannot specify both --system-prompt and --system-prompt-file"
|
1604
|
-
)
|
1605
|
-
|
1606
|
-
# Run the async function synchronously
|
1607
|
-
exit_code = asyncio.run(run_cli_async(args))
|
1608
|
-
|
1609
|
-
if exit_code != ExitCode.SUCCESS:
|
1610
|
-
error_msg = f"Command failed with exit code {exit_code}"
|
1611
|
-
if hasattr(ExitCode, exit_code.name):
|
1612
|
-
error_msg = f"{error_msg} ({exit_code.name})"
|
1613
|
-
raise CLIError(error_msg, context={"exit_code": exit_code})
|
1614
|
-
|
1615
|
-
except click.UsageError:
|
1616
|
-
# Let Click handle usage errors directly
|
1617
|
-
raise
|
1618
|
-
except InvalidJSONError:
|
1619
|
-
# Let InvalidJSONError propagate directly
|
1620
|
-
raise
|
1621
|
-
except CLIError:
|
1622
|
-
# Let our custom errors propagate with their context
|
1623
|
-
raise
|
1624
|
-
except Exception as e:
|
1625
|
-
# Convert other exceptions to CLIError
|
1626
|
-
logger.exception("Unexpected error")
|
1627
|
-
raise CLIError(str(e), context={"error_type": type(e).__name__})
|
1628
|
-
|
1629
|
-
return cli
|
1907
|
+
return cli # The decorator already returns a Command
|
1630
1908
|
|
1631
1909
|
|
1632
1910
|
def main() -> None:
|
1633
1911
|
"""Main entry point for the CLI."""
|
1634
|
-
|
1635
|
-
|
1912
|
+
try:
|
1913
|
+
cli(standalone_mode=False)
|
1914
|
+
except (
|
1915
|
+
CLIError,
|
1916
|
+
InvalidJSONError,
|
1917
|
+
SchemaFileError,
|
1918
|
+
SchemaValidationError,
|
1919
|
+
) as e:
|
1920
|
+
handle_error(e)
|
1921
|
+
sys.exit(
|
1922
|
+
e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
|
1923
|
+
)
|
1924
|
+
except click.UsageError as e:
|
1925
|
+
handle_error(e)
|
1926
|
+
sys.exit(ExitCode.USAGE_ERROR)
|
1927
|
+
except Exception as e:
|
1928
|
+
handle_error(e)
|
1929
|
+
sys.exit(ExitCode.INTERNAL_ERROR)
|
1636
1930
|
|
1637
1931
|
|
1638
1932
|
# Export public API
|
1639
1933
|
__all__ = [
|
1640
1934
|
"ExitCode",
|
1641
|
-
"
|
1935
|
+
"estimate_tokens_with_encoding",
|
1642
1936
|
"parse_json_var",
|
1643
1937
|
"create_dynamic_model",
|
1644
1938
|
"validate_path_mapping",
|
@@ -1656,26 +1950,23 @@ def create_dynamic_model(
|
|
1656
1950
|
"""Create a Pydantic model from a JSON schema.
|
1657
1951
|
|
1658
1952
|
Args:
|
1659
|
-
schema: JSON schema
|
1660
|
-
base_name:
|
1661
|
-
show_schema: Whether to show the generated schema
|
1662
|
-
debug_validation: Whether to
|
1953
|
+
schema: JSON schema to create model from
|
1954
|
+
base_name: Name for the model class
|
1955
|
+
show_schema: Whether to show the generated model schema
|
1956
|
+
debug_validation: Whether to show detailed validation errors
|
1663
1957
|
|
1664
1958
|
Returns:
|
1665
|
-
|
1959
|
+
Type[BaseModel]: The generated Pydantic model class
|
1666
1960
|
|
1667
1961
|
Raises:
|
1668
|
-
|
1669
|
-
SchemaValidationError:
|
1962
|
+
ModelValidationError: If the schema is invalid
|
1963
|
+
SchemaValidationError: If the schema violates OpenAI requirements
|
1670
1964
|
"""
|
1671
1965
|
if debug_validation:
|
1672
1966
|
logger.info("Creating dynamic model from schema:")
|
1673
1967
|
logger.info(json.dumps(schema, indent=2))
|
1674
1968
|
|
1675
1969
|
try:
|
1676
|
-
# Extract required fields
|
1677
|
-
required: Set[str] = set(schema.get("required", []))
|
1678
|
-
|
1679
1970
|
# Handle our wrapper format if present
|
1680
1971
|
if "schema" in schema:
|
1681
1972
|
if debug_validation:
|
@@ -1698,32 +1989,15 @@ def create_dynamic_model(
|
|
1698
1989
|
logger.info(json.dumps(inner_schema, indent=2))
|
1699
1990
|
schema = inner_schema
|
1700
1991
|
|
1701
|
-
#
|
1702
|
-
|
1703
|
-
if debug_validation:
|
1704
|
-
logger.info("Schema missing type field, assuming object type")
|
1705
|
-
schema["type"] = "object"
|
1992
|
+
# Validate against OpenAI requirements
|
1993
|
+
from .schema_validation import validate_openai_schema
|
1706
1994
|
|
1707
|
-
|
1708
|
-
if schema["type"] != "object":
|
1709
|
-
if debug_validation:
|
1710
|
-
logger.info(
|
1711
|
-
"Converting non-object root schema to object wrapper"
|
1712
|
-
)
|
1713
|
-
schema = {
|
1714
|
-
"type": "object",
|
1715
|
-
"properties": {"value": schema},
|
1716
|
-
"required": ["value"],
|
1717
|
-
}
|
1995
|
+
validate_openai_schema(schema)
|
1718
1996
|
|
1719
1997
|
# Create model configuration
|
1720
1998
|
config = ConfigDict(
|
1721
1999
|
title=schema.get("title", base_name),
|
1722
|
-
extra=
|
1723
|
-
"forbid"
|
1724
|
-
if schema.get("additionalProperties") is False
|
1725
|
-
else "allow"
|
1726
|
-
),
|
2000
|
+
extra="forbid", # OpenAI requires additionalProperties: false
|
1727
2001
|
validate_default=True,
|
1728
2002
|
use_enum_values=True,
|
1729
2003
|
arbitrary_types_allowed=True,
|
@@ -1758,18 +2032,17 @@ def create_dynamic_model(
|
|
1758
2032
|
" JSON Schema Extra: %s", config.get("json_schema_extra")
|
1759
2033
|
)
|
1760
2034
|
|
1761
|
-
#
|
1762
|
-
field_definitions: Dict[str, FieldDefinition] = {}
|
2035
|
+
# Process schema properties into fields
|
1763
2036
|
properties = schema.get("properties", {})
|
2037
|
+
required = schema.get("required", [])
|
1764
2038
|
|
2039
|
+
field_definitions: Dict[str, Tuple[Type[Any], FieldInfoType]] = {}
|
1765
2040
|
for field_name, field_schema in properties.items():
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
logger.info(
|
1770
|
-
" Schema: %s", json.dumps(field_schema, indent=2)
|
1771
|
-
)
|
2041
|
+
if debug_validation:
|
2042
|
+
logger.info("Processing field %s:", field_name)
|
2043
|
+
logger.info(" Schema: %s", json.dumps(field_schema, indent=2))
|
1772
2044
|
|
2045
|
+
try:
|
1773
2046
|
python_type, field = _get_type_with_constraints(
|
1774
2047
|
field_schema, field_name, base_name
|
1775
2048
|
)
|
@@ -1804,22 +2077,24 @@ def create_dynamic_model(
|
|
1804
2077
|
raise ModelValidationError(base_name, [str(e)])
|
1805
2078
|
|
1806
2079
|
# Create the model with the fields
|
1807
|
-
|
1808
|
-
|
1809
|
-
|
1810
|
-
|
1811
|
-
|
1812
|
-
|
1813
|
-
|
1814
|
-
|
1815
|
-
|
1816
|
-
|
1817
|
-
|
1818
|
-
|
1819
|
-
|
1820
|
-
},
|
2080
|
+
field_defs: Dict[str, Any] = {
|
2081
|
+
name: (
|
2082
|
+
(
|
2083
|
+
cast(Type[Any], field_type)
|
2084
|
+
if is_container_type(field_type)
|
2085
|
+
else field_type
|
2086
|
+
),
|
2087
|
+
field,
|
2088
|
+
)
|
2089
|
+
for name, (field_type, field) in field_definitions.items()
|
2090
|
+
}
|
2091
|
+
model: Type[BaseModel] = create_model(
|
2092
|
+
base_name, __config__=config, **field_defs
|
1821
2093
|
)
|
1822
2094
|
|
2095
|
+
# Set the model config after creation
|
2096
|
+
model.model_config = config
|
2097
|
+
|
1823
2098
|
if debug_validation:
|
1824
2099
|
logger.info("Successfully created model: %s", model.__name__)
|
1825
2100
|
logger.info("Model config: %s", dict(model.model_config))
|
@@ -1832,28 +2107,38 @@ def create_dynamic_model(
|
|
1832
2107
|
try:
|
1833
2108
|
model.model_json_schema()
|
1834
2109
|
except ValidationError as e:
|
1835
|
-
if debug_validation:
|
1836
|
-
logger.error("Schema validation failed:")
|
1837
|
-
logger.error(" Error type: %s", type(e).__name__)
|
1838
|
-
logger.error(" Error message: %s", str(e))
|
1839
|
-
if hasattr(e, "errors"):
|
1840
|
-
logger.error(" Validation errors:")
|
1841
|
-
for error in e.errors():
|
1842
|
-
logger.error(" - %s", error)
|
1843
2110
|
validation_errors = (
|
1844
2111
|
[str(err) for err in e.errors()]
|
1845
2112
|
if hasattr(e, "errors")
|
1846
2113
|
else [str(e)]
|
1847
2114
|
)
|
2115
|
+
if debug_validation:
|
2116
|
+
logger.error("Schema validation failed:")
|
2117
|
+
logger.error(" Error type: %s", type(e).__name__)
|
2118
|
+
logger.error(" Error message: %s", str(e))
|
1848
2119
|
raise ModelValidationError(base_name, validation_errors)
|
1849
2120
|
|
1850
|
-
return
|
2121
|
+
return model
|
2122
|
+
|
2123
|
+
except SchemaValidationError as e:
|
2124
|
+
# Always log basic error info
|
2125
|
+
logger.error("Schema validation error: %s", str(e))
|
2126
|
+
|
2127
|
+
# Log additional debug info if requested
|
2128
|
+
if debug_validation:
|
2129
|
+
logger.error(" Error type: %s", type(e).__name__)
|
2130
|
+
logger.error(" Error details: %s", str(e))
|
2131
|
+
# Always raise schema validation errors directly
|
2132
|
+
raise
|
1851
2133
|
|
1852
2134
|
except Exception as e:
|
2135
|
+
# Always log basic error info
|
2136
|
+
logger.error("Model creation error: %s", str(e))
|
2137
|
+
|
2138
|
+
# Log additional debug info if requested
|
1853
2139
|
if debug_validation:
|
1854
|
-
logger.error("Failed to create model:")
|
1855
2140
|
logger.error(" Error type: %s", type(e).__name__)
|
1856
|
-
logger.error(" Error
|
2141
|
+
logger.error(" Error details: %s", str(e))
|
1857
2142
|
if hasattr(e, "__cause__"):
|
1858
2143
|
logger.error(" Caused by: %s", str(e.__cause__))
|
1859
2144
|
if hasattr(e, "__context__"):
|
@@ -1865,9 +2150,11 @@ def create_dynamic_model(
|
|
1865
2150
|
" Traceback:\n%s",
|
1866
2151
|
"".join(traceback.format_tb(e.__traceback__)),
|
1867
2152
|
)
|
2153
|
+
# Always wrap other errors as ModelCreationError
|
1868
2154
|
raise ModelCreationError(
|
1869
|
-
f"Failed to create model
|
1870
|
-
|
2155
|
+
f"Failed to create model {base_name}",
|
2156
|
+
context={"error": str(e)},
|
2157
|
+
) from e
|
1871
2158
|
|
1872
2159
|
|
1873
2160
|
# Validation functions
|