ostruct-cli 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ostruct/cli/base_errors.py +183 -0
- ostruct/cli/cli.py +822 -543
- ostruct/cli/click_options.py +320 -202
- ostruct/cli/errors.py +222 -128
- ostruct/cli/exit_codes.py +18 -0
- ostruct/cli/file_info.py +30 -14
- ostruct/cli/file_list.py +4 -10
- ostruct/cli/file_utils.py +43 -35
- ostruct/cli/path_utils.py +32 -4
- ostruct/cli/security/allowed_checker.py +8 -0
- ostruct/cli/security/base.py +46 -0
- ostruct/cli/security/errors.py +83 -103
- ostruct/cli/security/security_manager.py +22 -9
- ostruct/cli/serialization.py +25 -0
- ostruct/cli/template_filters.py +5 -3
- ostruct/cli/template_rendering.py +46 -22
- ostruct/cli/template_utils.py +12 -4
- ostruct/cli/template_validation.py +26 -8
- ostruct/cli/token_utils.py +43 -0
- ostruct/cli/validators.py +109 -0
- {ostruct_cli-0.4.0.dist-info → ostruct_cli-0.5.0.dist-info}/METADATA +60 -21
- ostruct_cli-0.5.0.dist-info/RECORD +42 -0
- {ostruct_cli-0.4.0.dist-info → ostruct_cli-0.5.0.dist-info}/WHEEL +1 -1
- ostruct_cli-0.4.0.dist-info/RECORD +0 -36
- {ostruct_cli-0.4.0.dist-info → ostruct_cli-0.5.0.dist-info}/LICENSE +0 -0
- {ostruct_cli-0.4.0.dist-info → ostruct_cli-0.5.0.dist-info}/entry_points.txt +0 -0
ostruct/cli/cli.py
CHANGED
@@ -5,10 +5,10 @@ import json
|
|
5
5
|
import logging
|
6
6
|
import os
|
7
7
|
import sys
|
8
|
-
from dataclasses import dataclass
|
9
8
|
from enum import Enum, IntEnum
|
10
9
|
from typing import (
|
11
10
|
Any,
|
11
|
+
AsyncGenerator,
|
12
12
|
Dict,
|
13
13
|
List,
|
14
14
|
Literal,
|
@@ -16,6 +16,7 @@ from typing import (
|
|
16
16
|
Set,
|
17
17
|
Tuple,
|
18
18
|
Type,
|
19
|
+
TypedDict,
|
19
20
|
TypeVar,
|
20
21
|
Union,
|
21
22
|
cast,
|
@@ -31,20 +32,10 @@ from pathlib import Path
|
|
31
32
|
|
32
33
|
import click
|
33
34
|
import jinja2
|
34
|
-
import tiktoken
|
35
35
|
import yaml
|
36
|
-
from openai import
|
37
|
-
APIConnectionError,
|
38
|
-
AsyncOpenAI,
|
39
|
-
AuthenticationError,
|
40
|
-
BadRequestError,
|
41
|
-
InternalServerError,
|
42
|
-
RateLimitError,
|
43
|
-
)
|
36
|
+
from openai import AsyncOpenAI
|
44
37
|
from openai_structured.client import (
|
45
38
|
async_openai_structured_stream,
|
46
|
-
get_context_window_limit,
|
47
|
-
get_default_token_limit,
|
48
39
|
supports_structured_output,
|
49
40
|
)
|
50
41
|
from openai_structured.errors import (
|
@@ -54,12 +45,9 @@ from openai_structured.errors import (
|
|
54
45
|
ModelNotSupportedError,
|
55
46
|
ModelVersionError,
|
56
47
|
OpenAIClientError,
|
57
|
-
SchemaFileError,
|
58
|
-
SchemaValidationError,
|
59
48
|
StreamBufferError,
|
60
|
-
StreamInterruptedError,
|
61
|
-
StreamParseError,
|
62
49
|
)
|
50
|
+
from openai_structured.model_registry import ModelRegistry
|
63
51
|
from pydantic import (
|
64
52
|
AnyUrl,
|
65
53
|
BaseModel,
|
@@ -74,61 +62,60 @@ from pydantic.functional_validators import BeforeValidator
|
|
74
62
|
from pydantic.types import constr
|
75
63
|
from typing_extensions import TypeAlias
|
76
64
|
|
77
|
-
from ostruct.cli.click_options import
|
65
|
+
from ostruct.cli.click_options import all_options
|
66
|
+
from ostruct.cli.exit_codes import ExitCode
|
78
67
|
|
79
68
|
from .. import __version__ # noqa: F401 - Used in package metadata
|
80
69
|
from .errors import (
|
81
70
|
CLIError,
|
82
71
|
DirectoryNotFoundError,
|
83
72
|
FieldDefinitionError,
|
84
|
-
FileNotFoundError,
|
85
73
|
InvalidJSONError,
|
86
74
|
ModelCreationError,
|
87
75
|
ModelValidationError,
|
88
76
|
NestedModelError,
|
77
|
+
OstructFileNotFoundError,
|
89
78
|
PathSecurityError,
|
79
|
+
SchemaFileError,
|
80
|
+
SchemaValidationError,
|
81
|
+
StreamInterruptedError,
|
82
|
+
StreamParseError,
|
90
83
|
TaskTemplateSyntaxError,
|
91
84
|
TaskTemplateVariableError,
|
92
|
-
VariableError,
|
93
85
|
VariableNameError,
|
94
86
|
VariableValueError,
|
95
87
|
)
|
96
|
-
from .file_utils import FileInfoList,
|
88
|
+
from .file_utils import FileInfoList, collect_files
|
97
89
|
from .path_utils import validate_path_mapping
|
98
90
|
from .security import SecurityManager
|
91
|
+
from .serialization import LogSerializer
|
99
92
|
from .template_env import create_jinja_env
|
100
93
|
from .template_utils import SystemPromptError, render_template
|
94
|
+
from .token_utils import estimate_tokens_with_encoding
|
101
95
|
|
102
96
|
# Constants
|
103
97
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
|
104
98
|
|
105
99
|
|
106
|
-
|
107
|
-
|
108
|
-
"""Compatibility class to mimic argparse.Namespace for existing code."""
|
100
|
+
class CLIParams(TypedDict, total=False):
|
101
|
+
"""Type-safe CLI parameters."""
|
109
102
|
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
103
|
+
files: List[
|
104
|
+
Tuple[str, str]
|
105
|
+
] # List of (name, path) tuples from Click's nargs=2
|
106
|
+
dir: List[
|
107
|
+
Tuple[str, str]
|
108
|
+
] # List of (name, dir) tuples from Click's nargs=2
|
109
|
+
allowed_dirs: List[str]
|
116
110
|
base_dir: str
|
117
111
|
allowed_dir_file: Optional[str]
|
118
|
-
|
119
|
-
dir_ext: Optional[str]
|
112
|
+
recursive: bool
|
120
113
|
var: List[str]
|
121
114
|
json_var: List[str]
|
122
115
|
system_prompt: Optional[str]
|
123
116
|
system_prompt_file: Optional[str]
|
124
117
|
ignore_task_sysprompt: bool
|
125
|
-
schema_file: str
|
126
118
|
model: str
|
127
|
-
temperature: float
|
128
|
-
max_tokens: Optional[int]
|
129
|
-
top_p: float
|
130
|
-
frequency_penalty: float
|
131
|
-
presence_penalty: float
|
132
119
|
timeout: float
|
133
120
|
output_file: Optional[str]
|
134
121
|
dry_run: bool
|
@@ -138,7 +125,16 @@ class Namespace:
|
|
138
125
|
debug_openai_stream: bool
|
139
126
|
show_model_schema: bool
|
140
127
|
debug_validation: bool
|
141
|
-
|
128
|
+
temperature: Optional[float]
|
129
|
+
max_output_tokens: Optional[int]
|
130
|
+
top_p: Optional[float]
|
131
|
+
frequency_penalty: Optional[float]
|
132
|
+
presence_penalty: Optional[float]
|
133
|
+
reasoning_effort: Optional[str]
|
134
|
+
progress_level: str
|
135
|
+
task_file: Optional[str]
|
136
|
+
task: Optional[str]
|
137
|
+
schema_file: str
|
142
138
|
|
143
139
|
|
144
140
|
# Set up logging
|
@@ -176,45 +172,6 @@ ostruct_file_handler.setFormatter(
|
|
176
172
|
logger.addHandler(ostruct_file_handler)
|
177
173
|
|
178
174
|
|
179
|
-
class ExitCode(IntEnum):
|
180
|
-
"""Exit codes for the CLI following standard Unix conventions.
|
181
|
-
|
182
|
-
Categories:
|
183
|
-
- Success (0-1)
|
184
|
-
- User Interruption (2-3)
|
185
|
-
- Input/Validation (64-69)
|
186
|
-
- I/O and File Access (70-79)
|
187
|
-
- API and External Services (80-89)
|
188
|
-
- Internal Errors (90-99)
|
189
|
-
"""
|
190
|
-
|
191
|
-
# Success codes
|
192
|
-
SUCCESS = 0
|
193
|
-
|
194
|
-
# User interruption
|
195
|
-
INTERRUPTED = 2
|
196
|
-
|
197
|
-
# Input/Validation errors (64-69)
|
198
|
-
USAGE_ERROR = 64
|
199
|
-
DATA_ERROR = 65
|
200
|
-
SCHEMA_ERROR = 66
|
201
|
-
VALIDATION_ERROR = 67
|
202
|
-
|
203
|
-
# I/O and File Access errors (70-79)
|
204
|
-
IO_ERROR = 70
|
205
|
-
FILE_NOT_FOUND = 71
|
206
|
-
PERMISSION_ERROR = 72
|
207
|
-
SECURITY_ERROR = 73
|
208
|
-
|
209
|
-
# API and External Service errors (80-89)
|
210
|
-
API_ERROR = 80
|
211
|
-
API_TIMEOUT = 81
|
212
|
-
|
213
|
-
# Internal errors (90-99)
|
214
|
-
INTERNAL_ERROR = 90
|
215
|
-
UNKNOWN_ERROR = 91
|
216
|
-
|
217
|
-
|
218
175
|
# Type aliases
|
219
176
|
FieldType = (
|
220
177
|
Any # Changed from Type[Any] to allow both concrete types and generics
|
@@ -281,7 +238,7 @@ def _get_type_with_constraints(
|
|
281
238
|
show_schema=False,
|
282
239
|
debug_validation=False,
|
283
240
|
)
|
284
|
-
array_type: Type[List[Any]] = List[array_item_model] # type: ignore
|
241
|
+
array_type: Type[List[Any]] = List[array_item_model] # type: ignore
|
285
242
|
return (array_type, Field(**field_kwargs))
|
286
243
|
|
287
244
|
# For non-object items, use the type directly
|
@@ -403,64 +360,17 @@ K = TypeVar("K")
|
|
403
360
|
V = TypeVar("V")
|
404
361
|
|
405
362
|
|
406
|
-
def estimate_tokens_for_chat(
|
407
|
-
messages: List[Dict[str, str]],
|
408
|
-
model: str,
|
409
|
-
encoder: Any = None,
|
410
|
-
) -> int:
|
411
|
-
"""Estimate the number of tokens in a chat completion.
|
412
|
-
|
413
|
-
Args:
|
414
|
-
messages: List of chat messages
|
415
|
-
model: Model name
|
416
|
-
encoder: Optional tiktoken encoder for testing. If provided, only uses encoder.encode() results.
|
417
|
-
"""
|
418
|
-
if encoder is None:
|
419
|
-
try:
|
420
|
-
# Try to get the encoding for the specific model
|
421
|
-
encoder = tiktoken.get_encoding("o200k_base")
|
422
|
-
except KeyError:
|
423
|
-
# Fall back to cl100k_base for unknown models
|
424
|
-
encoder = tiktoken.get_encoding("cl100k_base")
|
425
|
-
|
426
|
-
# Use standard token counting logic for real tiktoken encoders
|
427
|
-
num_tokens = 0
|
428
|
-
for message in messages:
|
429
|
-
# Add message overhead
|
430
|
-
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
431
|
-
for key, value in message.items():
|
432
|
-
num_tokens += len(encoder.encode(str(value)))
|
433
|
-
if key == "name": # if there's a name, the role is omitted
|
434
|
-
num_tokens -= 1 # role is omitted
|
435
|
-
num_tokens += 2 # every reply is primed with <im_start>assistant
|
436
|
-
return num_tokens
|
437
|
-
else:
|
438
|
-
# For mock encoders in tests, just return the length of encoded content
|
439
|
-
num_tokens = 0
|
440
|
-
for message in messages:
|
441
|
-
for value in message.values():
|
442
|
-
num_tokens += len(encoder.encode(str(value)))
|
443
|
-
return num_tokens
|
444
|
-
|
445
|
-
|
446
363
|
def validate_token_limits(
|
447
364
|
model: str, total_tokens: int, max_token_limit: Optional[int] = None
|
448
365
|
) -> None:
|
449
|
-
"""Validate token counts against model limits.
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
total_tokens: Total number of tokens in the prompt
|
454
|
-
max_token_limit: Optional user-specified token limit
|
455
|
-
|
456
|
-
Raises:
|
457
|
-
ValueError: If token limits are exceeded
|
458
|
-
"""
|
459
|
-
context_limit = get_context_window_limit(model)
|
366
|
+
"""Validate token counts against model limits."""
|
367
|
+
registry = ModelRegistry()
|
368
|
+
capabilities = registry.get_capabilities(model)
|
369
|
+
context_limit = capabilities.context_window
|
460
370
|
output_limit = (
|
461
371
|
max_token_limit
|
462
372
|
if max_token_limit is not None
|
463
|
-
else
|
373
|
+
else capabilities.max_output_tokens
|
464
374
|
)
|
465
375
|
|
466
376
|
# Check if total tokens exceed context window
|
@@ -522,8 +432,12 @@ def process_system_prompt(
|
|
522
432
|
)
|
523
433
|
with open(path, "r", encoding="utf-8") as f:
|
524
434
|
system_prompt = f.read().strip()
|
525
|
-
except
|
526
|
-
raise SystemPromptError(
|
435
|
+
except OstructFileNotFoundError as e:
|
436
|
+
raise SystemPromptError(
|
437
|
+
f"Failed to load system prompt file: {e}"
|
438
|
+
) from e
|
439
|
+
except PathSecurityError as e:
|
440
|
+
raise SystemPromptError(f"Invalid system prompt file: {e}") from e
|
527
441
|
|
528
442
|
if system_prompt is not None:
|
529
443
|
# Render system prompt with template context
|
@@ -591,7 +505,8 @@ def validate_variable_mapping(
|
|
591
505
|
value = json.loads(value)
|
592
506
|
except json.JSONDecodeError as e:
|
593
507
|
raise InvalidJSONError(
|
594
|
-
f"Invalid JSON value for variable {name!r}: {value!r}"
|
508
|
+
f"Invalid JSON value for variable {name!r}: {value!r}",
|
509
|
+
context={"variable_name": name},
|
595
510
|
) from e
|
596
511
|
|
597
512
|
return name, value
|
@@ -787,11 +702,20 @@ def validate_task_template(
|
|
787
702
|
template_content: str
|
788
703
|
if task_file is not None:
|
789
704
|
try:
|
790
|
-
|
791
|
-
with open(path, "r", encoding="utf-8") as f:
|
705
|
+
with open(task_file, "r", encoding="utf-8") as f:
|
792
706
|
template_content = f.read()
|
793
|
-
except
|
794
|
-
raise TaskTemplateVariableError(
|
707
|
+
except FileNotFoundError:
|
708
|
+
raise TaskTemplateVariableError(
|
709
|
+
f"Task template file not found: {task_file}"
|
710
|
+
)
|
711
|
+
except PermissionError:
|
712
|
+
raise TaskTemplateVariableError(
|
713
|
+
f"Permission denied reading task template file: {task_file}"
|
714
|
+
)
|
715
|
+
except Exception as e:
|
716
|
+
raise TaskTemplateVariableError(
|
717
|
+
f"Error reading task template file: {e}"
|
718
|
+
)
|
795
719
|
else:
|
796
720
|
template_content = task # type: ignore # We know task is str here due to the checks above
|
797
721
|
|
@@ -809,10 +733,10 @@ def validate_schema_file(
|
|
809
733
|
path: str,
|
810
734
|
verbose: bool = False,
|
811
735
|
) -> Dict[str, Any]:
|
812
|
-
"""Validate a JSON schema file.
|
736
|
+
"""Validate and load a JSON schema file.
|
813
737
|
|
814
738
|
Args:
|
815
|
-
path: Path to
|
739
|
+
path: Path to schema file
|
816
740
|
verbose: Whether to enable verbose logging
|
817
741
|
|
818
742
|
Returns:
|
@@ -827,14 +751,42 @@ def validate_schema_file(
|
|
827
751
|
logger.info("Validating schema file: %s", path)
|
828
752
|
|
829
753
|
try:
|
830
|
-
|
831
|
-
|
754
|
+
logger.debug("Opening schema file: %s", path)
|
755
|
+
with open(path, "r", encoding="utf-8") as f:
|
756
|
+
logger.debug("Loading JSON from schema file")
|
757
|
+
try:
|
758
|
+
schema = json.load(f)
|
759
|
+
logger.debug(
|
760
|
+
"Successfully loaded JSON: %s",
|
761
|
+
json.dumps(schema, indent=2),
|
762
|
+
)
|
763
|
+
except json.JSONDecodeError as e:
|
764
|
+
logger.error("JSON decode error in %s: %s", path, str(e))
|
765
|
+
logger.debug(
|
766
|
+
"Error details - line: %d, col: %d, msg: %s",
|
767
|
+
e.lineno,
|
768
|
+
e.colno,
|
769
|
+
e.msg,
|
770
|
+
)
|
771
|
+
raise InvalidJSONError(
|
772
|
+
f"Invalid JSON in schema file {path}: {e}",
|
773
|
+
context={"schema_path": path},
|
774
|
+
) from e
|
832
775
|
except FileNotFoundError:
|
833
|
-
|
834
|
-
|
835
|
-
raise
|
776
|
+
msg = f"Schema file not found: {path}"
|
777
|
+
logger.error(msg)
|
778
|
+
raise SchemaFileError(msg, schema_path=path)
|
779
|
+
except PermissionError:
|
780
|
+
msg = f"Permission denied reading schema file: {path}"
|
781
|
+
logger.error(msg)
|
782
|
+
raise SchemaFileError(msg, schema_path=path)
|
836
783
|
except Exception as e:
|
837
|
-
|
784
|
+
if isinstance(e, InvalidJSONError):
|
785
|
+
raise
|
786
|
+
msg = f"Failed to read schema file {path}: {e}"
|
787
|
+
logger.error(msg)
|
788
|
+
logger.debug("Unexpected error details: %s", str(e))
|
789
|
+
raise SchemaFileError(msg, schema_path=path) from e
|
838
790
|
|
839
791
|
# Pre-validation structure checks
|
840
792
|
if verbose:
|
@@ -842,11 +794,9 @@ def validate_schema_file(
|
|
842
794
|
logger.debug("Loaded schema: %s", json.dumps(schema, indent=2))
|
843
795
|
|
844
796
|
if not isinstance(schema, dict):
|
845
|
-
|
846
|
-
|
847
|
-
|
848
|
-
)
|
849
|
-
raise SchemaValidationError("Schema must be a JSON object")
|
797
|
+
msg = f"Schema in {path} must be a JSON object"
|
798
|
+
logger.error(msg)
|
799
|
+
raise SchemaValidationError(msg, schema_path=path)
|
850
800
|
|
851
801
|
# Validate schema structure
|
852
802
|
if "schema" in schema:
|
@@ -854,30 +804,37 @@ def validate_schema_file(
|
|
854
804
|
logger.debug("Found schema wrapper, validating inner schema")
|
855
805
|
inner_schema = schema["schema"]
|
856
806
|
if not isinstance(inner_schema, dict):
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
type(inner_schema).__name__,
|
861
|
-
)
|
862
|
-
raise SchemaValidationError("Inner schema must be a JSON object")
|
807
|
+
msg = f"Inner schema in {path} must be a JSON object"
|
808
|
+
logger.error(msg)
|
809
|
+
raise SchemaValidationError(msg, schema_path=path)
|
863
810
|
if verbose:
|
864
811
|
logger.debug("Inner schema validated successfully")
|
812
|
+
logger.debug(
|
813
|
+
"Inner schema: %s", json.dumps(inner_schema, indent=2)
|
814
|
+
)
|
865
815
|
else:
|
866
816
|
if verbose:
|
867
817
|
logger.debug("No schema wrapper found, using schema as-is")
|
818
|
+
logger.debug("Schema: %s", json.dumps(schema, indent=2))
|
819
|
+
|
820
|
+
# Additional schema validation
|
821
|
+
if "type" not in schema.get("schema", schema):
|
822
|
+
msg = f"Schema in {path} must specify a type"
|
823
|
+
logger.error(msg)
|
824
|
+
raise SchemaValidationError(msg, schema_path=path)
|
868
825
|
|
869
826
|
# Return the full schema including wrapper
|
870
827
|
return schema
|
871
828
|
|
872
829
|
|
873
830
|
def collect_template_files(
|
874
|
-
args:
|
831
|
+
args: CLIParams,
|
875
832
|
security_manager: SecurityManager,
|
876
|
-
) -> Dict[str,
|
833
|
+
) -> Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]:
|
877
834
|
"""Collect files from command line arguments.
|
878
835
|
|
879
836
|
Args:
|
880
|
-
args:
|
837
|
+
args: Command line arguments
|
881
838
|
security_manager: Security manager for path validation
|
882
839
|
|
883
840
|
Returns:
|
@@ -888,15 +845,29 @@ def collect_template_files(
|
|
888
845
|
ValueError: If file mappings are invalid or files cannot be accessed
|
889
846
|
"""
|
890
847
|
try:
|
891
|
-
|
892
|
-
|
893
|
-
|
894
|
-
|
895
|
-
|
896
|
-
|
848
|
+
# Get files and directories from args - they are already tuples from Click's nargs=2
|
849
|
+
files = list(
|
850
|
+
args.get("files", [])
|
851
|
+
) # List of (name, path) tuples from Click
|
852
|
+
dirs = args.get("dir", []) # List of (name, dir) tuples from Click
|
853
|
+
|
854
|
+
# Collect files from directories
|
855
|
+
dir_files = collect_files(
|
856
|
+
file_mappings=cast(
|
857
|
+
List[Tuple[str, Union[str, Path]]], files
|
858
|
+
), # Cast to correct type
|
859
|
+
dir_mappings=cast(
|
860
|
+
List[Tuple[str, Union[str, Path]]], dirs
|
861
|
+
), # Cast to correct type
|
862
|
+
dir_recursive=args.get("recursive", False),
|
897
863
|
security_manager=security_manager,
|
898
864
|
)
|
899
|
-
|
865
|
+
|
866
|
+
# Combine results
|
867
|
+
return cast(
|
868
|
+
Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]],
|
869
|
+
dir_files,
|
870
|
+
)
|
900
871
|
except PathSecurityError:
|
901
872
|
# Let PathSecurityError propagate without wrapping
|
902
873
|
raise
|
@@ -914,11 +885,11 @@ def collect_template_files(
|
|
914
885
|
raise ValueError(f"Error collecting files: {e}")
|
915
886
|
|
916
887
|
|
917
|
-
def collect_simple_variables(args:
|
888
|
+
def collect_simple_variables(args: CLIParams) -> Dict[str, str]:
|
918
889
|
"""Collect simple string variables from --var arguments.
|
919
890
|
|
920
891
|
Args:
|
921
|
-
args:
|
892
|
+
args: Command line arguments
|
922
893
|
|
923
894
|
Returns:
|
924
895
|
Dictionary mapping variable names to string values
|
@@ -929,10 +900,15 @@ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
|
|
929
900
|
variables: Dict[str, str] = {}
|
930
901
|
all_names: Set[str] = set()
|
931
902
|
|
932
|
-
if args.var:
|
933
|
-
for mapping in args
|
903
|
+
if args.get("var"):
|
904
|
+
for mapping in args["var"]:
|
934
905
|
try:
|
935
|
-
|
906
|
+
# Handle both tuple format and string format
|
907
|
+
if isinstance(mapping, tuple):
|
908
|
+
name, value = mapping
|
909
|
+
else:
|
910
|
+
name, value = mapping.split("=", 1)
|
911
|
+
|
936
912
|
if not name.isidentifier():
|
937
913
|
raise VariableNameError(f"Invalid variable name: {name}")
|
938
914
|
if name in all_names:
|
@@ -947,11 +923,11 @@ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
|
|
947
923
|
return variables
|
948
924
|
|
949
925
|
|
950
|
-
def collect_json_variables(args:
|
926
|
+
def collect_json_variables(args: CLIParams) -> Dict[str, Any]:
|
951
927
|
"""Collect JSON variables from --json-var arguments.
|
952
928
|
|
953
929
|
Args:
|
954
|
-
args:
|
930
|
+
args: Command line arguments
|
955
931
|
|
956
932
|
Returns:
|
957
933
|
Dictionary mapping variable names to parsed JSON values
|
@@ -963,32 +939,46 @@ def collect_json_variables(args: Namespace) -> Dict[str, Any]:
|
|
963
939
|
variables: Dict[str, Any] = {}
|
964
940
|
all_names: Set[str] = set()
|
965
941
|
|
966
|
-
if args.json_var:
|
967
|
-
for mapping in args
|
942
|
+
if args.get("json_var"):
|
943
|
+
for mapping in args["json_var"]:
|
968
944
|
try:
|
969
|
-
|
945
|
+
# Handle both tuple format and string format
|
946
|
+
if isinstance(mapping, tuple):
|
947
|
+
name, value = (
|
948
|
+
mapping # Value is already parsed by Click validator
|
949
|
+
)
|
950
|
+
else:
|
951
|
+
try:
|
952
|
+
name, json_str = mapping.split("=", 1)
|
953
|
+
except ValueError:
|
954
|
+
raise VariableNameError(
|
955
|
+
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
956
|
+
)
|
957
|
+
try:
|
958
|
+
value = json.loads(json_str)
|
959
|
+
except json.JSONDecodeError as e:
|
960
|
+
raise InvalidJSONError(
|
961
|
+
f"Invalid JSON value for variable '{name}': {json_str}",
|
962
|
+
context={"variable_name": name},
|
963
|
+
) from e
|
964
|
+
|
970
965
|
if not name.isidentifier():
|
971
966
|
raise VariableNameError(f"Invalid variable name: {name}")
|
972
967
|
if name in all_names:
|
973
968
|
raise VariableNameError(f"Duplicate variable name: {name}")
|
974
|
-
|
975
|
-
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
raise InvalidJSONError(
|
980
|
-
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
|
981
|
-
)
|
982
|
-
except ValueError:
|
983
|
-
raise VariableNameError(
|
984
|
-
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
985
|
-
)
|
969
|
+
|
970
|
+
variables[name] = value
|
971
|
+
all_names.add(name)
|
972
|
+
except (VariableNameError, InvalidJSONError):
|
973
|
+
raise
|
986
974
|
|
987
975
|
return variables
|
988
976
|
|
989
977
|
|
990
978
|
def create_template_context(
|
991
|
-
files: Optional[
|
979
|
+
files: Optional[
|
980
|
+
Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]
|
981
|
+
] = None,
|
992
982
|
variables: Optional[Dict[str, str]] = None,
|
993
983
|
json_variables: Optional[Dict[str, Any]] = None,
|
994
984
|
security_manager: Optional[SecurityManager] = None,
|
@@ -1032,14 +1022,14 @@ def create_template_context(
|
|
1032
1022
|
return context
|
1033
1023
|
|
1034
1024
|
|
1035
|
-
def create_template_context_from_args(
|
1036
|
-
args:
|
1025
|
+
async def create_template_context_from_args(
|
1026
|
+
args: CLIParams,
|
1037
1027
|
security_manager: SecurityManager,
|
1038
1028
|
) -> Dict[str, Any]:
|
1039
1029
|
"""Create template context from command line arguments.
|
1040
1030
|
|
1041
1031
|
Args:
|
1042
|
-
args:
|
1032
|
+
args: Command line arguments
|
1043
1033
|
security_manager: Security manager for path validation
|
1044
1034
|
|
1045
1035
|
Returns:
|
@@ -1052,50 +1042,13 @@ def create_template_context_from_args(
|
|
1052
1042
|
"""
|
1053
1043
|
try:
|
1054
1044
|
# Collect files from arguments
|
1055
|
-
files =
|
1056
|
-
if any([args.file, args.files, args.dir]):
|
1057
|
-
files = collect_files(
|
1058
|
-
file_mappings=args.file,
|
1059
|
-
pattern_mappings=args.files,
|
1060
|
-
dir_mappings=args.dir,
|
1061
|
-
dir_recursive=args.dir_recursive,
|
1062
|
-
dir_extensions=(
|
1063
|
-
args.dir_ext.split(",") if args.dir_ext else None
|
1064
|
-
),
|
1065
|
-
security_manager=security_manager,
|
1066
|
-
)
|
1045
|
+
files = collect_template_files(args, security_manager)
|
1067
1046
|
|
1068
1047
|
# Collect simple variables
|
1069
|
-
|
1070
|
-
variables = collect_simple_variables(args)
|
1071
|
-
except VariableNameError as e:
|
1072
|
-
raise VariableError(str(e))
|
1048
|
+
variables = collect_simple_variables(args)
|
1073
1049
|
|
1074
1050
|
# Collect JSON variables
|
1075
|
-
json_variables =
|
1076
|
-
if args.json_var:
|
1077
|
-
for mapping in args.json_var:
|
1078
|
-
try:
|
1079
|
-
name, value = mapping.split("=", 1)
|
1080
|
-
if not name.isidentifier():
|
1081
|
-
raise VariableNameError(
|
1082
|
-
f"Invalid variable name: {name}"
|
1083
|
-
)
|
1084
|
-
try:
|
1085
|
-
json_value = json.loads(value)
|
1086
|
-
except json.JSONDecodeError as e:
|
1087
|
-
raise InvalidJSONError(
|
1088
|
-
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {value}"
|
1089
|
-
)
|
1090
|
-
if name in json_variables:
|
1091
|
-
raise VariableNameError(
|
1092
|
-
f"Duplicate variable name: {name}"
|
1093
|
-
)
|
1094
|
-
json_variables[name] = json_value
|
1095
|
-
except ValueError:
|
1096
|
-
raise VariableNameError(
|
1097
|
-
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
1098
|
-
)
|
1051
|
+
json_variables = collect_json_variables(args)
|
1099
1052
|
|
1100
1053
|
# Get stdin content if available
|
1101
1054
|
stdin_content = None
|
@@ -1106,7 +1059,7 @@ def create_template_context_from_args(
|
|
1106
1059
|
# Skip stdin if it can't be read
|
1107
1060
|
pass
|
1108
1061
|
|
1109
|
-
|
1062
|
+
context = create_template_context(
|
1110
1063
|
files=files,
|
1111
1064
|
variables=variables,
|
1112
1065
|
json_variables=json_variables,
|
@@ -1114,6 +1067,11 @@ def create_template_context_from_args(
|
|
1114
1067
|
stdin_content=stdin_content,
|
1115
1068
|
)
|
1116
1069
|
|
1070
|
+
# Add current model to context
|
1071
|
+
context["current_model"] = args["model"]
|
1072
|
+
|
1073
|
+
return context
|
1074
|
+
|
1117
1075
|
except PathSecurityError:
|
1118
1076
|
# Let PathSecurityError propagate without wrapping
|
1119
1077
|
raise
|
@@ -1235,7 +1193,8 @@ def parse_json_var(var_str: str) -> Tuple[str, Any]:
|
|
1235
1193
|
value = json.loads(json_str)
|
1236
1194
|
except json.JSONDecodeError as e:
|
1237
1195
|
raise InvalidJSONError(
|
1238
|
-
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
|
1196
|
+
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}",
|
1197
|
+
context={"variable_name": name},
|
1239
1198
|
)
|
1240
1199
|
|
1241
1200
|
return name, value
|
@@ -1284,41 +1243,96 @@ def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
|
|
1284
1243
|
|
1285
1244
|
|
1286
1245
|
def handle_error(e: Exception) -> None:
|
1287
|
-
"""Handle errors
|
1246
|
+
"""Handle CLI errors and display appropriate messages.
|
1247
|
+
|
1248
|
+
Maintains specific error type handling while reducing duplication.
|
1249
|
+
Provides enhanced debug logging for CLI errors.
|
1250
|
+
"""
|
1251
|
+
# 1. Determine error type and message
|
1288
1252
|
if isinstance(e, click.UsageError):
|
1289
|
-
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
|
1294
|
-
|
1295
|
-
|
1296
|
-
|
1297
|
-
|
1298
|
-
elif isinstance(e,
|
1299
|
-
|
1300
|
-
|
1301
|
-
click.secho(msg, fg="red", err=True)
|
1302
|
-
sys.exit(ExitCode.DATA_ERROR)
|
1303
|
-
elif isinstance(e, FileNotFoundError):
|
1304
|
-
# Use the original error message if available
|
1305
|
-
msg = str(e) if str(e) != "None" else "File not found"
|
1306
|
-
click.secho(msg, fg="red", err=True)
|
1307
|
-
sys.exit(ExitCode.SCHEMA_ERROR)
|
1308
|
-
elif isinstance(e, TaskTemplateSyntaxError):
|
1309
|
-
# Use the original error message if available
|
1310
|
-
msg = str(e) if str(e) != "None" else "Template syntax error"
|
1311
|
-
click.secho(msg, fg="red", err=True)
|
1312
|
-
sys.exit(ExitCode.INTERNAL_ERROR)
|
1253
|
+
msg = f"Usage error: {str(e)}"
|
1254
|
+
exit_code = ExitCode.USAGE_ERROR
|
1255
|
+
elif isinstance(e, SchemaFileError):
|
1256
|
+
# Preserve specific schema error handling
|
1257
|
+
msg = str(e) # Use existing __str__ formatting
|
1258
|
+
exit_code = ExitCode.SCHEMA_ERROR
|
1259
|
+
elif isinstance(e, (InvalidJSONError, json.JSONDecodeError)):
|
1260
|
+
msg = f"Invalid JSON error: {str(e)}"
|
1261
|
+
exit_code = ExitCode.DATA_ERROR
|
1262
|
+
elif isinstance(e, SchemaValidationError):
|
1263
|
+
msg = f"Schema validation error: {str(e)}"
|
1264
|
+
exit_code = ExitCode.VALIDATION_ERROR
|
1313
1265
|
elif isinstance(e, CLIError):
|
1314
|
-
# Use
|
1315
|
-
e.
|
1316
|
-
|
1317
|
-
|
1266
|
+
msg = str(e) # Use existing __str__ formatting
|
1267
|
+
exit_code = ExitCode(e.exit_code) # Convert int to ExitCode
|
1268
|
+
else:
|
1269
|
+
msg = f"Unexpected error: {str(e)}"
|
1270
|
+
exit_code = ExitCode.INTERNAL_ERROR
|
1271
|
+
|
1272
|
+
# 2. Debug logging
|
1273
|
+
if isinstance(e, CLIError) and logger.isEnabledFor(logging.DEBUG):
|
1274
|
+
# Format context fields with lowercase keys and simple values
|
1275
|
+
context_str = ""
|
1276
|
+
if hasattr(e, "context"):
|
1277
|
+
for key, value in sorted(e.context.items()):
|
1278
|
+
if key not in {
|
1279
|
+
"timestamp",
|
1280
|
+
"host",
|
1281
|
+
"version",
|
1282
|
+
"python_version",
|
1283
|
+
}:
|
1284
|
+
context_str += f"{key.lower()}: {value}\n"
|
1285
|
+
|
1286
|
+
logger.debug(
|
1287
|
+
"Error details:\n"
|
1288
|
+
f"Type: {type(e).__name__}\n"
|
1289
|
+
f"{context_str.rstrip()}"
|
1318
1290
|
)
|
1291
|
+
elif not isinstance(e, click.UsageError):
|
1292
|
+
logger.error(msg, exc_info=True)
|
1319
1293
|
else:
|
1320
|
-
|
1321
|
-
|
1294
|
+
logger.error(msg)
|
1295
|
+
|
1296
|
+
# 3. User output
|
1297
|
+
click.secho(msg, fg="red", err=True)
|
1298
|
+
sys.exit(exit_code)
|
1299
|
+
|
1300
|
+
|
1301
|
+
def validate_model_parameters(model: str, params: Dict[str, Any]) -> None:
|
1302
|
+
"""Validate model parameters against model capabilities.
|
1303
|
+
|
1304
|
+
Args:
|
1305
|
+
model: The model name to validate parameters for
|
1306
|
+
params: Dictionary of parameter names and values to validate
|
1307
|
+
|
1308
|
+
Raises:
|
1309
|
+
CLIError: If any parameters are not supported by the model
|
1310
|
+
"""
|
1311
|
+
try:
|
1312
|
+
capabilities = ModelRegistry().get_capabilities(model)
|
1313
|
+
for param_name, value in params.items():
|
1314
|
+
try:
|
1315
|
+
capabilities.validate_parameter(param_name, value)
|
1316
|
+
except OpenAIClientError as e:
|
1317
|
+
logger.error(
|
1318
|
+
"Validation failed for model %s: %s", model, str(e)
|
1319
|
+
)
|
1320
|
+
raise CLIError(
|
1321
|
+
str(e),
|
1322
|
+
exit_code=ExitCode.VALIDATION_ERROR,
|
1323
|
+
context={
|
1324
|
+
"model": model,
|
1325
|
+
"param": param_name,
|
1326
|
+
"value": value,
|
1327
|
+
},
|
1328
|
+
)
|
1329
|
+
except (ModelNotSupportedError, ModelVersionError) as e:
|
1330
|
+
logger.error("Model validation failed: %s", str(e))
|
1331
|
+
raise CLIError(
|
1332
|
+
str(e),
|
1333
|
+
exit_code=ExitCode.VALIDATION_ERROR,
|
1334
|
+
context={"model": model},
|
1335
|
+
)
|
1322
1336
|
|
1323
1337
|
|
1324
1338
|
async def stream_structured_output(
|
@@ -1329,85 +1343,75 @@ async def stream_structured_output(
|
|
1329
1343
|
output_schema: Type[BaseModel],
|
1330
1344
|
output_file: Optional[str] = None,
|
1331
1345
|
**kwargs: Any,
|
1332
|
-
) -> None:
|
1346
|
+
) -> AsyncGenerator[BaseModel, None]:
|
1333
1347
|
"""Stream structured output from OpenAI API.
|
1334
1348
|
|
1335
1349
|
This function follows the guide's recommendation for a focused async streaming function.
|
1336
1350
|
It handles the core streaming logic and resource cleanup.
|
1337
|
-
"""
|
1338
|
-
try:
|
1339
|
-
# Base models that don't support streaming
|
1340
|
-
non_streaming_models = {"o1", "o3"}
|
1341
1351
|
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
|
1352
|
+
Args:
|
1353
|
+
client: The OpenAI client to use
|
1354
|
+
model: The model to use
|
1355
|
+
system_prompt: The system prompt to use
|
1356
|
+
user_prompt: The user prompt to use
|
1357
|
+
output_schema: The Pydantic model to validate responses against
|
1358
|
+
output_file: Optional file to write output to
|
1359
|
+
**kwargs: Additional parameters to pass to the API
|
1347
1360
|
|
1348
|
-
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1352
|
-
|
1353
|
-
|
1354
|
-
|
1355
|
-
|
1356
|
-
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1360
|
-
|
1361
|
-
|
1362
|
-
if not chunk:
|
1363
|
-
continue
|
1364
|
-
|
1365
|
-
# Process and output the chunk
|
1366
|
-
dumped = chunk.model_dump(mode="json")
|
1367
|
-
json_str = json.dumps(dumped, indent=2)
|
1368
|
-
|
1369
|
-
if output_file:
|
1370
|
-
with open(output_file, "a", encoding="utf-8") as f:
|
1371
|
-
f.write(json_str)
|
1372
|
-
f.write("\n")
|
1373
|
-
f.flush() # Ensure immediate flush to file
|
1374
|
-
else:
|
1375
|
-
# Print directly to stdout with immediate flush
|
1376
|
-
print(json_str, flush=True)
|
1377
|
-
else:
|
1378
|
-
# For non-streaming models, use regular completion
|
1379
|
-
response = await client.chat.completions.create(
|
1380
|
-
model=model,
|
1381
|
-
messages=[
|
1382
|
-
{"role": "system", "content": system_prompt},
|
1383
|
-
{"role": "user", "content": user_prompt},
|
1384
|
-
],
|
1385
|
-
stream=False,
|
1386
|
-
**stream_kwargs,
|
1361
|
+
Returns:
|
1362
|
+
An async generator yielding validated model instances
|
1363
|
+
|
1364
|
+
Raises:
|
1365
|
+
ValueError: If the model does not support structured output or parameters are invalid
|
1366
|
+
StreamInterruptedError: If the stream is interrupted
|
1367
|
+
APIResponseError: If there is an API error
|
1368
|
+
"""
|
1369
|
+
try:
|
1370
|
+
# Check if model supports structured output using openai_structured's function
|
1371
|
+
if not supports_structured_output(model):
|
1372
|
+
raise ValueError(
|
1373
|
+
f"Model {model} does not support structured output with json_schema response format. "
|
1374
|
+
"Please use a model that supports structured output."
|
1387
1375
|
)
|
1388
1376
|
|
1389
|
-
|
1390
|
-
|
1391
|
-
if content:
|
1392
|
-
try:
|
1393
|
-
# Parse and validate against schema
|
1394
|
-
result = output_schema.model_validate_json(content)
|
1395
|
-
json_str = json.dumps(
|
1396
|
-
result.model_dump(mode="json"), indent=2
|
1397
|
-
)
|
1377
|
+
# Extract non-model parameters
|
1378
|
+
on_log = kwargs.pop("on_log", None)
|
1398
1379
|
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
|
1406
|
-
|
1407
|
-
|
1408
|
-
|
1380
|
+
# Handle model-specific parameters
|
1381
|
+
stream_kwargs = {}
|
1382
|
+
registry = ModelRegistry()
|
1383
|
+
capabilities = registry.get_capabilities(model)
|
1384
|
+
|
1385
|
+
# Validate and include supported parameters
|
1386
|
+
for param_name, value in kwargs.items():
|
1387
|
+
if param_name in capabilities.supported_parameters:
|
1388
|
+
# Validate the parameter value
|
1389
|
+
capabilities.validate_parameter(param_name, value)
|
1390
|
+
stream_kwargs[param_name] = value
|
1409
1391
|
else:
|
1410
|
-
|
1392
|
+
logger.warning(
|
1393
|
+
f"Parameter {param_name} is not supported by model {model} and will be ignored"
|
1394
|
+
)
|
1395
|
+
|
1396
|
+
# Log the API request details
|
1397
|
+
logger.debug("Making OpenAI API request with:")
|
1398
|
+
logger.debug("Model: %s", model)
|
1399
|
+
logger.debug("System prompt: %s", system_prompt)
|
1400
|
+
logger.debug("User prompt: %s", user_prompt)
|
1401
|
+
logger.debug("Parameters: %s", json.dumps(stream_kwargs, indent=2))
|
1402
|
+
logger.debug("Schema: %s", output_schema.model_json_schema())
|
1403
|
+
|
1404
|
+
# Use the async generator from openai_structured directly
|
1405
|
+
async for chunk in async_openai_structured_stream(
|
1406
|
+
client=client,
|
1407
|
+
model=model,
|
1408
|
+
system_prompt=system_prompt,
|
1409
|
+
user_prompt=user_prompt,
|
1410
|
+
output_schema=output_schema,
|
1411
|
+
on_log=on_log, # Pass non-model parameters directly to the function
|
1412
|
+
**stream_kwargs, # Pass only validated model parameters
|
1413
|
+
):
|
1414
|
+
yield chunk
|
1411
1415
|
|
1412
1416
|
except (
|
1413
1417
|
StreamInterruptedError,
|
@@ -1424,149 +1428,458 @@ async def stream_structured_output(
|
|
1424
1428
|
await client.close()
|
1425
1429
|
|
1426
1430
|
|
1427
|
-
|
1428
|
-
|
1431
|
+
@click.group()
|
1432
|
+
@click.version_option(version=__version__)
|
1433
|
+
def cli() -> None:
|
1434
|
+
"""ostruct CLI - Make structured OpenAI API calls.
|
1435
|
+
|
1436
|
+
ostruct allows you to invoke OpenAI Structured Output to produce structured JSON
|
1437
|
+
output using templates and JSON schemas. It provides support for file handling, variable
|
1438
|
+
substitution, and output validation.
|
1439
|
+
|
1440
|
+
For detailed documentation, visit: https://ostruct.readthedocs.io
|
1441
|
+
|
1442
|
+
Examples:
|
1429
1443
|
|
1430
|
-
|
1431
|
-
|
1444
|
+
# Basic usage with a template and schema
|
1445
|
+
|
1446
|
+
ostruct run task.j2 schema.json -V name=value
|
1447
|
+
|
1448
|
+
# Process files with recursive directory scanning
|
1449
|
+
|
1450
|
+
ostruct run template.j2 schema.json -f code main.py -d src ./src -R
|
1451
|
+
|
1452
|
+
# Use JSON variables and custom model parameters
|
1453
|
+
|
1454
|
+
ostruct run task.j2 schema.json -J config='{"env":"prod"}' -m o3-mini
|
1455
|
+
"""
|
1456
|
+
pass
|
1457
|
+
|
1458
|
+
|
1459
|
+
@cli.command()
|
1460
|
+
@click.argument("task_template", type=click.Path(exists=True))
|
1461
|
+
@click.argument("schema_file", type=click.Path(exists=True))
|
1462
|
+
@all_options
|
1463
|
+
@click.pass_context
|
1464
|
+
def run(
|
1465
|
+
ctx: click.Context,
|
1466
|
+
task_template: str,
|
1467
|
+
schema_file: str,
|
1468
|
+
**kwargs: Any,
|
1469
|
+
) -> None:
|
1470
|
+
"""Run a structured task with template and schema.
|
1471
|
+
|
1472
|
+
TASK_TEMPLATE is the path to your Jinja2 template file that defines the task.
|
1473
|
+
SCHEMA_FILE is the path to your JSON schema file that defines the expected output structure.
|
1474
|
+
|
1475
|
+
The command supports various options for file handling, variable definition,
|
1476
|
+
model configuration, and output control. Use --help to see all available options.
|
1477
|
+
|
1478
|
+
Examples:
|
1479
|
+
# Basic usage
|
1480
|
+
ostruct run task.j2 schema.json
|
1481
|
+
|
1482
|
+
# Process multiple files
|
1483
|
+
ostruct run task.j2 schema.json -f code main.py -f test tests/test_main.py
|
1484
|
+
|
1485
|
+
# Scan directories recursively
|
1486
|
+
ostruct run task.j2 schema.json -d src ./src -R
|
1487
|
+
|
1488
|
+
# Define variables
|
1489
|
+
ostruct run task.j2 schema.json -V debug=true -J config='{"env":"prod"}'
|
1490
|
+
|
1491
|
+
# Configure model
|
1492
|
+
ostruct run task.j2 schema.json -m gpt-4 --temperature 0.7 --max-output-tokens 1000
|
1493
|
+
|
1494
|
+
# Control output
|
1495
|
+
ostruct run task.j2 schema.json --output-file result.json --verbose
|
1432
1496
|
"""
|
1433
1497
|
try:
|
1434
|
-
#
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1498
|
+
# Convert Click parameters to typed dict
|
1499
|
+
params: CLIParams = {
|
1500
|
+
"task_file": task_template,
|
1501
|
+
"task": None,
|
1502
|
+
"schema_file": schema_file,
|
1503
|
+
}
|
1504
|
+
# Add only valid keys from kwargs
|
1505
|
+
valid_keys = set(CLIParams.__annotations__.keys())
|
1506
|
+
for k, v in kwargs.items():
|
1507
|
+
if k in valid_keys:
|
1508
|
+
params[k] = v # type: ignore[literal-required]
|
1509
|
+
|
1510
|
+
# Run the async function synchronously
|
1511
|
+
loop = asyncio.new_event_loop()
|
1512
|
+
asyncio.set_event_loop(loop)
|
1513
|
+
try:
|
1514
|
+
exit_code = loop.run_until_complete(run_cli_async(params))
|
1515
|
+
sys.exit(int(exit_code))
|
1516
|
+
finally:
|
1517
|
+
loop.close()
|
1518
|
+
|
1519
|
+
except (
|
1520
|
+
CLIError,
|
1521
|
+
InvalidJSONError,
|
1522
|
+
SchemaFileError,
|
1523
|
+
SchemaValidationError,
|
1524
|
+
) as e:
|
1525
|
+
handle_error(e)
|
1526
|
+
sys.exit(
|
1527
|
+
e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
|
1439
1528
|
)
|
1529
|
+
except click.UsageError as e:
|
1530
|
+
handle_error(e)
|
1531
|
+
sys.exit(ExitCode.USAGE_ERROR)
|
1532
|
+
except Exception as e:
|
1533
|
+
handle_error(e)
|
1534
|
+
sys.exit(ExitCode.INTERNAL_ERROR)
|
1535
|
+
|
1536
|
+
|
1537
|
+
# Remove the old @create_click_command() decorator and cli function definition
|
1538
|
+
# Keep all the other functions and code below this point
|
1440
1539
|
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1540
|
+
|
1541
|
+
async def validate_model_params(args: CLIParams) -> Dict[str, Any]:
|
1542
|
+
"""Validate model parameters and return a dictionary of valid parameters.
|
1543
|
+
|
1544
|
+
Args:
|
1545
|
+
args: Command line arguments
|
1546
|
+
|
1547
|
+
Returns:
|
1548
|
+
Dictionary of validated model parameters
|
1549
|
+
|
1550
|
+
Raises:
|
1551
|
+
CLIError: If model parameters are invalid
|
1552
|
+
"""
|
1553
|
+
params = {
|
1554
|
+
"temperature": args.get("temperature"),
|
1555
|
+
"max_output_tokens": args.get("max_output_tokens"),
|
1556
|
+
"top_p": args.get("top_p"),
|
1557
|
+
"frequency_penalty": args.get("frequency_penalty"),
|
1558
|
+
"presence_penalty": args.get("presence_penalty"),
|
1559
|
+
"reasoning_effort": args.get("reasoning_effort"),
|
1560
|
+
}
|
1561
|
+
# Remove None values
|
1562
|
+
params = {k: v for k, v in params.items() if v is not None}
|
1563
|
+
validate_model_parameters(args["model"], params)
|
1564
|
+
return params
|
1565
|
+
|
1566
|
+
|
1567
|
+
async def validate_inputs(
|
1568
|
+
args: CLIParams,
|
1569
|
+
) -> Tuple[
|
1570
|
+
SecurityManager, str, Dict[str, Any], Dict[str, Any], jinja2.Environment
|
1571
|
+
]:
|
1572
|
+
"""Validate all input parameters and return validated components.
|
1573
|
+
|
1574
|
+
Args:
|
1575
|
+
args: Command line arguments
|
1576
|
+
|
1577
|
+
Returns:
|
1578
|
+
Tuple containing:
|
1579
|
+
- SecurityManager instance
|
1580
|
+
- Task template string
|
1581
|
+
- Schema dictionary
|
1582
|
+
- Template context dictionary
|
1583
|
+
- Jinja2 environment
|
1584
|
+
|
1585
|
+
Raises:
|
1586
|
+
CLIError: For various validation errors
|
1587
|
+
"""
|
1588
|
+
logger.debug("=== Input Validation Phase ===")
|
1589
|
+
security_manager = validate_security_manager(
|
1590
|
+
base_dir=args.get("base_dir"),
|
1591
|
+
allowed_dirs=args.get("allowed_dirs"),
|
1592
|
+
allowed_dir_file=args.get("allowed_dir_file"),
|
1593
|
+
)
|
1594
|
+
|
1595
|
+
task_template = validate_task_template(
|
1596
|
+
args.get("task"), args.get("task_file")
|
1597
|
+
)
|
1598
|
+
logger.debug("Validating schema from %s", args["schema_file"])
|
1599
|
+
schema = validate_schema_file(
|
1600
|
+
args["schema_file"], args.get("verbose", False)
|
1601
|
+
)
|
1602
|
+
template_context = await create_template_context_from_args(
|
1603
|
+
args, security_manager
|
1604
|
+
)
|
1605
|
+
env = create_jinja_env()
|
1606
|
+
|
1607
|
+
return security_manager, task_template, schema, template_context, env
|
1608
|
+
|
1609
|
+
|
1610
|
+
async def process_templates(
|
1611
|
+
args: CLIParams,
|
1612
|
+
task_template: str,
|
1613
|
+
template_context: Dict[str, Any],
|
1614
|
+
env: jinja2.Environment,
|
1615
|
+
) -> Tuple[str, str]:
|
1616
|
+
"""Process system prompt and user prompt templates.
|
1617
|
+
|
1618
|
+
Args:
|
1619
|
+
args: Command line arguments
|
1620
|
+
task_template: Validated task template
|
1621
|
+
template_context: Template context dictionary
|
1622
|
+
env: Jinja2 environment
|
1623
|
+
|
1624
|
+
Returns:
|
1625
|
+
Tuple of (system_prompt, user_prompt)
|
1626
|
+
|
1627
|
+
Raises:
|
1628
|
+
CLIError: For template processing errors
|
1629
|
+
"""
|
1630
|
+
logger.debug("=== Template Processing Phase ===")
|
1631
|
+
system_prompt = process_system_prompt(
|
1632
|
+
task_template,
|
1633
|
+
args.get("system_prompt"),
|
1634
|
+
args.get("system_prompt_file"),
|
1635
|
+
template_context,
|
1636
|
+
env,
|
1637
|
+
args.get("ignore_task_sysprompt", False),
|
1638
|
+
)
|
1639
|
+
user_prompt = render_template(task_template, template_context, env)
|
1640
|
+
return system_prompt, user_prompt
|
1641
|
+
|
1642
|
+
|
1643
|
+
async def validate_model_and_schema(
|
1644
|
+
args: CLIParams,
|
1645
|
+
schema: Dict[str, Any],
|
1646
|
+
system_prompt: str,
|
1647
|
+
user_prompt: str,
|
1648
|
+
) -> Tuple[Type[BaseModel], List[Dict[str, str]], int, ModelRegistry]:
|
1649
|
+
"""Validate model compatibility and schema, and check token limits.
|
1650
|
+
|
1651
|
+
Args:
|
1652
|
+
args: Command line arguments
|
1653
|
+
schema: Schema dictionary
|
1654
|
+
system_prompt: Processed system prompt
|
1655
|
+
user_prompt: Processed user prompt
|
1656
|
+
|
1657
|
+
Returns:
|
1658
|
+
Tuple of (output_model, messages, total_tokens, registry)
|
1659
|
+
|
1660
|
+
Raises:
|
1661
|
+
CLIError: For validation errors
|
1662
|
+
ModelCreationError: When model creation fails
|
1663
|
+
SchemaValidationError: When schema is invalid
|
1664
|
+
"""
|
1665
|
+
logger.debug("=== Model & Schema Validation Phase ===")
|
1666
|
+
try:
|
1667
|
+
output_model = create_dynamic_model(
|
1668
|
+
schema,
|
1669
|
+
show_schema=args.get("show_model_schema", False),
|
1670
|
+
debug_validation=args.get("debug_validation", False),
|
1446
1671
|
)
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
1451
|
-
|
1452
|
-
|
1453
|
-
|
1454
|
-
|
1455
|
-
|
1456
|
-
|
1672
|
+
logger.debug("Successfully created output model")
|
1673
|
+
except (
|
1674
|
+
SchemaFileError,
|
1675
|
+
InvalidJSONError,
|
1676
|
+
SchemaValidationError,
|
1677
|
+
ModelCreationError,
|
1678
|
+
) as e:
|
1679
|
+
logger.error("Schema error: %s", str(e))
|
1680
|
+
raise
|
1681
|
+
|
1682
|
+
if not supports_structured_output(args["model"]):
|
1683
|
+
msg = f"Model {args['model']} does not support structured output"
|
1684
|
+
logger.error(msg)
|
1685
|
+
raise ModelNotSupportedError(msg)
|
1686
|
+
|
1687
|
+
messages = [
|
1688
|
+
{"role": "system", "content": system_prompt},
|
1689
|
+
{"role": "user", "content": user_prompt},
|
1690
|
+
]
|
1691
|
+
|
1692
|
+
total_tokens = estimate_tokens_with_encoding(messages, args["model"])
|
1693
|
+
registry = ModelRegistry()
|
1694
|
+
capabilities = registry.get_capabilities(args["model"])
|
1695
|
+
context_limit = capabilities.context_window
|
1696
|
+
|
1697
|
+
if total_tokens > context_limit:
|
1698
|
+
msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
|
1699
|
+
logger.error(msg)
|
1700
|
+
raise CLIError(
|
1701
|
+
msg,
|
1702
|
+
context={
|
1703
|
+
"total_tokens": total_tokens,
|
1704
|
+
"context_limit": context_limit,
|
1705
|
+
},
|
1457
1706
|
)
|
1458
|
-
rendered_task = render_template(task_template, template_context, env)
|
1459
|
-
logger.info("Rendered task template: %s", rendered_task)
|
1460
1707
|
|
1461
|
-
|
1462
|
-
logger.info("DRY RUN MODE")
|
1463
|
-
return ExitCode.SUCCESS
|
1708
|
+
return output_model, messages, total_tokens, registry
|
1464
1709
|
|
1465
|
-
# Create output model
|
1466
|
-
logger.debug("Creating output model")
|
1467
|
-
try:
|
1468
|
-
output_model = create_dynamic_model(
|
1469
|
-
schema,
|
1470
|
-
base_name="OutputModel",
|
1471
|
-
show_schema=args.show_model_schema,
|
1472
|
-
debug_validation=args.debug_validation,
|
1473
|
-
)
|
1474
|
-
logger.debug("Successfully created output model")
|
1475
|
-
except (
|
1476
|
-
SchemaFileError,
|
1477
|
-
InvalidJSONError,
|
1478
|
-
SchemaValidationError,
|
1479
|
-
ModelCreationError,
|
1480
|
-
) as e:
|
1481
|
-
logger.error("Schema error: %s", str(e))
|
1482
|
-
raise # Let the error propagate with its context
|
1483
|
-
|
1484
|
-
# Validate model support and token usage
|
1485
|
-
try:
|
1486
|
-
supports_structured_output(args.model)
|
1487
|
-
except (ModelNotSupportedError, ModelVersionError) as e:
|
1488
|
-
logger.error("Model validation error: %s", str(e))
|
1489
|
-
raise # Let the error propagate with its context
|
1490
|
-
|
1491
|
-
messages = [
|
1492
|
-
{"role": "system", "content": system_prompt},
|
1493
|
-
{"role": "user", "content": rendered_task},
|
1494
|
-
]
|
1495
|
-
total_tokens = estimate_tokens_for_chat(messages, args.model)
|
1496
|
-
context_limit = get_context_window_limit(args.model)
|
1497
|
-
if total_tokens > context_limit:
|
1498
|
-
msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
|
1499
|
-
logger.error(msg)
|
1500
|
-
raise CLIError(
|
1501
|
-
msg,
|
1502
|
-
context={
|
1503
|
-
"total_tokens": total_tokens,
|
1504
|
-
"context_limit": context_limit,
|
1505
|
-
},
|
1506
|
-
)
|
1507
1710
|
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1711
|
+
async def execute_model(
|
1712
|
+
args: CLIParams,
|
1713
|
+
params: Dict[str, Any],
|
1714
|
+
output_model: Type[BaseModel],
|
1715
|
+
system_prompt: str,
|
1716
|
+
user_prompt: str,
|
1717
|
+
) -> ExitCode:
|
1718
|
+
"""Execute the model and handle the response.
|
1514
1719
|
|
1515
|
-
|
1720
|
+
Args:
|
1721
|
+
args: Command line arguments
|
1722
|
+
params: Validated model parameters
|
1723
|
+
output_model: Generated Pydantic model
|
1724
|
+
system_prompt: Processed system prompt
|
1725
|
+
user_prompt: Processed user prompt
|
1516
1726
|
|
1517
|
-
|
1518
|
-
|
1519
|
-
level: int, message: str, extra: dict[str, Any]
|
1520
|
-
) -> None:
|
1521
|
-
if args.debug_openai_stream:
|
1522
|
-
if extra:
|
1523
|
-
extra_str = json.dumps(extra, indent=2)
|
1524
|
-
message = f"{message}\nDetails:\n{extra_str}"
|
1525
|
-
logger.log(level, message, extra=extra)
|
1727
|
+
Returns:
|
1728
|
+
Exit code indicating success or failure
|
1526
1729
|
|
1527
|
-
|
1528
|
-
|
1529
|
-
|
1530
|
-
|
1531
|
-
|
1532
|
-
|
1533
|
-
|
1534
|
-
|
1535
|
-
|
1536
|
-
|
1537
|
-
|
1538
|
-
|
1539
|
-
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1730
|
+
Raises:
|
1731
|
+
CLIError: For execution errors
|
1732
|
+
"""
|
1733
|
+
logger.debug("=== Execution Phase ===")
|
1734
|
+
api_key = args.get("api_key") or os.getenv("OPENAI_API_KEY")
|
1735
|
+
if not api_key:
|
1736
|
+
msg = "No API key provided. Set OPENAI_API_KEY environment variable or use --api-key"
|
1737
|
+
logger.error(msg)
|
1738
|
+
raise CLIError(msg, exit_code=ExitCode.API_ERROR)
|
1739
|
+
|
1740
|
+
client = AsyncOpenAI(api_key=api_key, timeout=args.get("timeout", 60.0))
|
1741
|
+
|
1742
|
+
# Create detailed log callback
|
1743
|
+
def log_callback(level: int, message: str, extra: dict[str, Any]) -> None:
|
1744
|
+
if args.get("debug_openai_stream", False):
|
1745
|
+
if extra:
|
1746
|
+
extra_str = LogSerializer.serialize_log_extra(extra)
|
1747
|
+
if extra_str:
|
1748
|
+
logger.debug("%s\nExtra:\n%s", message, extra_str)
|
1749
|
+
else:
|
1750
|
+
logger.debug("%s\nExtra: Failed to serialize", message)
|
1751
|
+
else:
|
1752
|
+
logger.debug(message)
|
1753
|
+
|
1754
|
+
try:
|
1755
|
+
# Create output buffer
|
1756
|
+
output_buffer = []
|
1757
|
+
|
1758
|
+
# Stream the response
|
1759
|
+
async for response in stream_structured_output(
|
1760
|
+
client=client,
|
1761
|
+
model=args["model"],
|
1762
|
+
system_prompt=system_prompt,
|
1763
|
+
user_prompt=user_prompt,
|
1764
|
+
output_schema=output_model,
|
1765
|
+
output_file=args.get("output_file"),
|
1766
|
+
**params, # Only pass validated model parameters
|
1767
|
+
on_log=log_callback, # Pass logging callback separately
|
1768
|
+
):
|
1769
|
+
output_buffer.append(response)
|
1770
|
+
|
1771
|
+
# Handle final output
|
1772
|
+
output_file = args.get("output_file")
|
1773
|
+
if output_file:
|
1774
|
+
with open(output_file, "w") as f:
|
1775
|
+
if len(output_buffer) == 1:
|
1776
|
+
f.write(output_buffer[0].model_dump_json(indent=2))
|
1777
|
+
else:
|
1778
|
+
# Build complete JSON array as a single string
|
1779
|
+
json_output = "[\n"
|
1780
|
+
for i, response in enumerate(output_buffer):
|
1781
|
+
if i > 0:
|
1782
|
+
json_output += ",\n"
|
1783
|
+
json_output += " " + response.model_dump_json(
|
1784
|
+
indent=2
|
1785
|
+
).replace("\n", "\n ")
|
1786
|
+
json_output += "\n]"
|
1787
|
+
f.write(json_output)
|
1788
|
+
else:
|
1789
|
+
# Write to stdout when no output file is specified
|
1790
|
+
if len(output_buffer) == 1:
|
1791
|
+
print(output_buffer[0].model_dump_json(indent=2))
|
1792
|
+
else:
|
1793
|
+
# Build complete JSON array as a single string
|
1794
|
+
json_output = "[\n"
|
1795
|
+
for i, response in enumerate(output_buffer):
|
1796
|
+
if i > 0:
|
1797
|
+
json_output += ",\n"
|
1798
|
+
json_output += " " + response.model_dump_json(
|
1799
|
+
indent=2
|
1800
|
+
).replace("\n", "\n ")
|
1801
|
+
json_output += "\n]"
|
1802
|
+
print(json_output)
|
1803
|
+
|
1804
|
+
return ExitCode.SUCCESS
|
1805
|
+
|
1806
|
+
except (
|
1807
|
+
StreamInterruptedError,
|
1808
|
+
StreamBufferError,
|
1809
|
+
StreamParseError,
|
1810
|
+
APIResponseError,
|
1811
|
+
EmptyResponseError,
|
1812
|
+
InvalidResponseFormatError,
|
1813
|
+
) as e:
|
1814
|
+
logger.error("Stream error: %s", str(e))
|
1815
|
+
raise CLIError(str(e), exit_code=ExitCode.API_ERROR)
|
1816
|
+
except Exception as e:
|
1817
|
+
logger.exception("Unexpected error during streaming")
|
1818
|
+
raise CLIError(str(e), exit_code=ExitCode.UNKNOWN_ERROR)
|
1819
|
+
finally:
|
1820
|
+
await client.close()
|
1821
|
+
|
1822
|
+
|
1823
|
+
async def run_cli_async(args: CLIParams) -> ExitCode:
|
1824
|
+
"""Async wrapper for CLI operations.
|
1825
|
+
|
1826
|
+
Returns:
|
1827
|
+
Exit code to return from the CLI
|
1828
|
+
|
1829
|
+
Raises:
|
1830
|
+
CLIError: For various error conditions
|
1831
|
+
KeyboardInterrupt: When operation is cancelled by user
|
1832
|
+
"""
|
1833
|
+
try:
|
1834
|
+
# 0. Model Parameter Validation
|
1835
|
+
logger.debug("=== Model Parameter Validation ===")
|
1836
|
+
params = await validate_model_params(args)
|
1837
|
+
|
1838
|
+
# 1. Input Validation Phase
|
1839
|
+
security_manager, task_template, schema, template_context, env = (
|
1840
|
+
await validate_inputs(args)
|
1841
|
+
)
|
1842
|
+
|
1843
|
+
# 2. Template Processing Phase
|
1844
|
+
system_prompt, user_prompt = await process_templates(
|
1845
|
+
args, task_template, template_context, env
|
1846
|
+
)
|
1847
|
+
|
1848
|
+
# 3. Model & Schema Validation Phase
|
1849
|
+
output_model, messages, total_tokens, registry = (
|
1850
|
+
await validate_model_and_schema(
|
1851
|
+
args, schema, system_prompt, user_prompt
|
1852
|
+
)
|
1853
|
+
)
|
1854
|
+
|
1855
|
+
# 4. Dry Run Output Phase
|
1856
|
+
if args.get("dry_run", False):
|
1857
|
+
logger.info("\n=== Dry Run Summary ===")
|
1858
|
+
logger.info("✓ Template rendered successfully")
|
1859
|
+
logger.info("✓ Schema validation passed")
|
1860
|
+
logger.info("✓ Model compatibility validated")
|
1861
|
+
logger.info(
|
1862
|
+
f"✓ Token count: {total_tokens}/{registry.get_capabilities(args['model']).context_window}"
|
1543
1863
|
)
|
1864
|
+
|
1865
|
+
if args.get("verbose", False):
|
1866
|
+
logger.info("\nSystem Prompt:")
|
1867
|
+
logger.info("-" * 40)
|
1868
|
+
logger.info(system_prompt)
|
1869
|
+
logger.info("\nRendered Template:")
|
1870
|
+
logger.info("-" * 40)
|
1871
|
+
logger.info(user_prompt)
|
1872
|
+
|
1544
1873
|
return ExitCode.SUCCESS
|
1545
|
-
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
EmptyResponseError,
|
1551
|
-
InvalidResponseFormatError,
|
1552
|
-
) as e:
|
1553
|
-
logger.error("Stream error: %s", str(e))
|
1554
|
-
raise # Let stream errors propagate
|
1555
|
-
except (APIConnectionError, InternalServerError) as e:
|
1556
|
-
logger.error("API connection error: %s", str(e))
|
1557
|
-
raise APIResponseError(str(e)) # Convert to our error type
|
1558
|
-
except RateLimitError as e:
|
1559
|
-
logger.error("Rate limit exceeded: %s", str(e))
|
1560
|
-
raise APIResponseError(str(e)) # Convert to our error type
|
1561
|
-
except (BadRequestError, AuthenticationError, OpenAIClientError) as e:
|
1562
|
-
logger.error("API client error: %s", str(e))
|
1563
|
-
raise APIResponseError(str(e)) # Convert to our error type
|
1564
|
-
finally:
|
1565
|
-
await client.close()
|
1874
|
+
|
1875
|
+
# 5. Execution Phase
|
1876
|
+
return await execute_model(
|
1877
|
+
args, params, output_model, system_prompt, user_prompt
|
1878
|
+
)
|
1566
1879
|
|
1567
1880
|
except KeyboardInterrupt:
|
1568
1881
|
logger.info("Operation cancelled by user")
|
1569
|
-
|
1882
|
+
raise
|
1570
1883
|
except Exception as e:
|
1571
1884
|
if isinstance(e, CLIError):
|
1572
1885
|
raise # Let our custom errors propagate
|
@@ -1580,65 +1893,35 @@ def create_cli() -> click.Command:
|
|
1580
1893
|
Returns:
|
1581
1894
|
click.Command: The CLI command object
|
1582
1895
|
"""
|
1583
|
-
|
1584
|
-
@create_click_command()
|
1585
|
-
def cli(**kwargs: Any) -> None:
|
1586
|
-
"""CLI entry point for structured OpenAI API calls."""
|
1587
|
-
try:
|
1588
|
-
args = Namespace(**kwargs)
|
1589
|
-
|
1590
|
-
# Validate required arguments first
|
1591
|
-
if not args.task and not args.task_file:
|
1592
|
-
raise click.UsageError(
|
1593
|
-
"Must specify either --task or --task-file"
|
1594
|
-
)
|
1595
|
-
if not args.schema_file:
|
1596
|
-
raise click.UsageError("Missing option '--schema-file'")
|
1597
|
-
if args.task and args.task_file:
|
1598
|
-
raise click.UsageError(
|
1599
|
-
"Cannot specify both --task and --task-file"
|
1600
|
-
)
|
1601
|
-
if args.system_prompt and args.system_prompt_file:
|
1602
|
-
raise click.UsageError(
|
1603
|
-
"Cannot specify both --system-prompt and --system-prompt-file"
|
1604
|
-
)
|
1605
|
-
|
1606
|
-
# Run the async function synchronously
|
1607
|
-
exit_code = asyncio.run(run_cli_async(args))
|
1608
|
-
|
1609
|
-
if exit_code != ExitCode.SUCCESS:
|
1610
|
-
error_msg = f"Command failed with exit code {exit_code}"
|
1611
|
-
if hasattr(ExitCode, exit_code.name):
|
1612
|
-
error_msg = f"{error_msg} ({exit_code.name})"
|
1613
|
-
raise CLIError(error_msg, context={"exit_code": exit_code})
|
1614
|
-
|
1615
|
-
except click.UsageError:
|
1616
|
-
# Let Click handle usage errors directly
|
1617
|
-
raise
|
1618
|
-
except InvalidJSONError:
|
1619
|
-
# Let InvalidJSONError propagate directly
|
1620
|
-
raise
|
1621
|
-
except CLIError:
|
1622
|
-
# Let our custom errors propagate with their context
|
1623
|
-
raise
|
1624
|
-
except Exception as e:
|
1625
|
-
# Convert other exceptions to CLIError
|
1626
|
-
logger.exception("Unexpected error")
|
1627
|
-
raise CLIError(str(e), context={"error_type": type(e).__name__})
|
1628
|
-
|
1629
|
-
return cli
|
1896
|
+
return cli # The decorator already returns a Command
|
1630
1897
|
|
1631
1898
|
|
1632
1899
|
def main() -> None:
|
1633
1900
|
"""Main entry point for the CLI."""
|
1634
|
-
|
1635
|
-
|
1901
|
+
try:
|
1902
|
+
cli(standalone_mode=False)
|
1903
|
+
except (
|
1904
|
+
CLIError,
|
1905
|
+
InvalidJSONError,
|
1906
|
+
SchemaFileError,
|
1907
|
+
SchemaValidationError,
|
1908
|
+
) as e:
|
1909
|
+
handle_error(e)
|
1910
|
+
sys.exit(
|
1911
|
+
e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
|
1912
|
+
)
|
1913
|
+
except click.UsageError as e:
|
1914
|
+
handle_error(e)
|
1915
|
+
sys.exit(ExitCode.USAGE_ERROR)
|
1916
|
+
except Exception as e:
|
1917
|
+
handle_error(e)
|
1918
|
+
sys.exit(ExitCode.INTERNAL_ERROR)
|
1636
1919
|
|
1637
1920
|
|
1638
1921
|
# Export public API
|
1639
1922
|
__all__ = [
|
1640
1923
|
"ExitCode",
|
1641
|
-
"
|
1924
|
+
"estimate_tokens_with_encoding",
|
1642
1925
|
"parse_json_var",
|
1643
1926
|
"create_dynamic_model",
|
1644
1927
|
"validate_path_mapping",
|
@@ -1656,17 +1939,16 @@ def create_dynamic_model(
|
|
1656
1939
|
"""Create a Pydantic model from a JSON schema.
|
1657
1940
|
|
1658
1941
|
Args:
|
1659
|
-
schema: JSON schema
|
1660
|
-
base_name:
|
1661
|
-
show_schema: Whether to show the generated schema
|
1662
|
-
debug_validation: Whether to
|
1942
|
+
schema: JSON schema to create model from
|
1943
|
+
base_name: Name for the model class
|
1944
|
+
show_schema: Whether to show the generated model schema
|
1945
|
+
debug_validation: Whether to show detailed validation errors
|
1663
1946
|
|
1664
1947
|
Returns:
|
1665
|
-
|
1948
|
+
Type[BaseModel]: The generated Pydantic model class
|
1666
1949
|
|
1667
1950
|
Raises:
|
1668
|
-
|
1669
|
-
SchemaValidationError: When schema is invalid
|
1951
|
+
ModelValidationError: If the schema is invalid
|
1670
1952
|
"""
|
1671
1953
|
if debug_validation:
|
1672
1954
|
logger.info("Creating dynamic model from schema:")
|
@@ -1758,18 +2040,17 @@ def create_dynamic_model(
|
|
1758
2040
|
" JSON Schema Extra: %s", config.get("json_schema_extra")
|
1759
2041
|
)
|
1760
2042
|
|
1761
|
-
#
|
1762
|
-
field_definitions: Dict[str, FieldDefinition] = {}
|
2043
|
+
# Process schema properties into fields
|
1763
2044
|
properties = schema.get("properties", {})
|
2045
|
+
required = schema.get("required", [])
|
1764
2046
|
|
2047
|
+
field_definitions: Dict[str, Tuple[Type[Any], FieldInfoType]] = {}
|
1765
2048
|
for field_name, field_schema in properties.items():
|
1766
|
-
|
1767
|
-
|
1768
|
-
|
1769
|
-
logger.info(
|
1770
|
-
" Schema: %s", json.dumps(field_schema, indent=2)
|
1771
|
-
)
|
2049
|
+
if debug_validation:
|
2050
|
+
logger.info("Processing field %s:", field_name)
|
2051
|
+
logger.info(" Schema: %s", json.dumps(field_schema, indent=2))
|
1772
2052
|
|
2053
|
+
try:
|
1773
2054
|
python_type, field = _get_type_with_constraints(
|
1774
2055
|
field_schema, field_name, base_name
|
1775
2056
|
)
|
@@ -1804,22 +2085,24 @@ def create_dynamic_model(
|
|
1804
2085
|
raise ModelValidationError(base_name, [str(e)])
|
1805
2086
|
|
1806
2087
|
# Create the model with the fields
|
1807
|
-
|
1808
|
-
|
1809
|
-
|
1810
|
-
|
1811
|
-
|
1812
|
-
|
1813
|
-
|
1814
|
-
|
1815
|
-
|
1816
|
-
|
1817
|
-
|
1818
|
-
|
1819
|
-
|
1820
|
-
},
|
2088
|
+
field_defs: Dict[str, Any] = {
|
2089
|
+
name: (
|
2090
|
+
(
|
2091
|
+
cast(Type[Any], field_type)
|
2092
|
+
if is_container_type(field_type)
|
2093
|
+
else field_type
|
2094
|
+
),
|
2095
|
+
field,
|
2096
|
+
)
|
2097
|
+
for name, (field_type, field) in field_definitions.items()
|
2098
|
+
}
|
2099
|
+
model: Type[BaseModel] = create_model(
|
2100
|
+
base_name, __config__=config, **field_defs
|
1821
2101
|
)
|
1822
2102
|
|
2103
|
+
# Set the model config after creation
|
2104
|
+
model.model_config = config
|
2105
|
+
|
1823
2106
|
if debug_validation:
|
1824
2107
|
logger.info("Successfully created model: %s", model.__name__)
|
1825
2108
|
logger.info("Model config: %s", dict(model.model_config))
|
@@ -1836,10 +2119,6 @@ def create_dynamic_model(
|
|
1836
2119
|
logger.error("Schema validation failed:")
|
1837
2120
|
logger.error(" Error type: %s", type(e).__name__)
|
1838
2121
|
logger.error(" Error message: %s", str(e))
|
1839
|
-
if hasattr(e, "errors"):
|
1840
|
-
logger.error(" Validation errors:")
|
1841
|
-
for error in e.errors():
|
1842
|
-
logger.error(" - %s", error)
|
1843
2122
|
validation_errors = (
|
1844
2123
|
[str(err) for err in e.errors()]
|
1845
2124
|
if hasattr(e, "errors")
|
@@ -1847,7 +2126,7 @@ def create_dynamic_model(
|
|
1847
2126
|
)
|
1848
2127
|
raise ModelValidationError(base_name, validation_errors)
|
1849
2128
|
|
1850
|
-
return
|
2129
|
+
return model
|
1851
2130
|
|
1852
2131
|
except Exception as e:
|
1853
2132
|
if debug_validation:
|