ostruct-cli 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ostruct/cli/base_errors.py +183 -0
- ostruct/cli/cli.py +830 -585
- ostruct/cli/click_options.py +338 -211
- ostruct/cli/errors.py +214 -227
- ostruct/cli/exit_codes.py +18 -0
- ostruct/cli/file_info.py +126 -69
- ostruct/cli/file_list.py +191 -72
- ostruct/cli/file_utils.py +132 -97
- ostruct/cli/path_utils.py +86 -77
- ostruct/cli/security/__init__.py +32 -0
- ostruct/cli/security/allowed_checker.py +55 -0
- ostruct/cli/security/base.py +46 -0
- ostruct/cli/security/case_manager.py +75 -0
- ostruct/cli/security/errors.py +164 -0
- ostruct/cli/security/normalization.py +161 -0
- ostruct/cli/security/safe_joiner.py +211 -0
- ostruct/cli/security/security_manager.py +366 -0
- ostruct/cli/security/symlink_resolver.py +483 -0
- ostruct/cli/security/types.py +108 -0
- ostruct/cli/security/windows_paths.py +404 -0
- ostruct/cli/serialization.py +25 -0
- ostruct/cli/template_filters.py +13 -8
- ostruct/cli/template_rendering.py +46 -22
- ostruct/cli/template_utils.py +12 -4
- ostruct/cli/template_validation.py +26 -8
- ostruct/cli/token_utils.py +43 -0
- ostruct/cli/validators.py +109 -0
- {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/METADATA +64 -24
- ostruct_cli-0.5.0.dist-info/RECORD +42 -0
- {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/WHEEL +1 -1
- ostruct/cli/security.py +0 -964
- ostruct/cli/security_types.py +0 -46
- ostruct_cli-0.3.0.dist-info/RECORD +0 -28
- {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/LICENSE +0 -0
- {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/entry_points.txt +0 -0
ostruct/cli/cli.py
CHANGED
@@ -5,10 +5,10 @@ import json
|
|
5
5
|
import logging
|
6
6
|
import os
|
7
7
|
import sys
|
8
|
-
from dataclasses import dataclass
|
9
8
|
from enum import Enum, IntEnum
|
10
9
|
from typing import (
|
11
10
|
Any,
|
11
|
+
AsyncGenerator,
|
12
12
|
Dict,
|
13
13
|
List,
|
14
14
|
Literal,
|
@@ -16,6 +16,7 @@ from typing import (
|
|
16
16
|
Set,
|
17
17
|
Tuple,
|
18
18
|
Type,
|
19
|
+
TypedDict,
|
19
20
|
TypeVar,
|
20
21
|
Union,
|
21
22
|
cast,
|
@@ -31,16 +32,8 @@ from pathlib import Path
|
|
31
32
|
|
32
33
|
import click
|
33
34
|
import jinja2
|
34
|
-
import tiktoken
|
35
35
|
import yaml
|
36
|
-
from openai import
|
37
|
-
APIConnectionError,
|
38
|
-
AsyncOpenAI,
|
39
|
-
AuthenticationError,
|
40
|
-
BadRequestError,
|
41
|
-
InternalServerError,
|
42
|
-
RateLimitError,
|
43
|
-
)
|
36
|
+
from openai import AsyncOpenAI
|
44
37
|
from openai_structured.client import (
|
45
38
|
async_openai_structured_stream,
|
46
39
|
supports_structured_output,
|
@@ -52,12 +45,9 @@ from openai_structured.errors import (
|
|
52
45
|
ModelNotSupportedError,
|
53
46
|
ModelVersionError,
|
54
47
|
OpenAIClientError,
|
55
|
-
SchemaFileError,
|
56
|
-
SchemaValidationError,
|
57
48
|
StreamBufferError,
|
58
|
-
StreamInterruptedError,
|
59
|
-
StreamParseError,
|
60
49
|
)
|
50
|
+
from openai_structured.model_registry import ModelRegistry
|
61
51
|
from pydantic import (
|
62
52
|
AnyUrl,
|
63
53
|
BaseModel,
|
@@ -72,61 +62,60 @@ from pydantic.functional_validators import BeforeValidator
|
|
72
62
|
from pydantic.types import constr
|
73
63
|
from typing_extensions import TypeAlias
|
74
64
|
|
75
|
-
from ostruct.cli.click_options import
|
65
|
+
from ostruct.cli.click_options import all_options
|
66
|
+
from ostruct.cli.exit_codes import ExitCode
|
76
67
|
|
77
68
|
from .. import __version__ # noqa: F401 - Used in package metadata
|
78
69
|
from .errors import (
|
79
70
|
CLIError,
|
80
71
|
DirectoryNotFoundError,
|
81
72
|
FieldDefinitionError,
|
82
|
-
FileNotFoundError,
|
83
73
|
InvalidJSONError,
|
84
74
|
ModelCreationError,
|
85
75
|
ModelValidationError,
|
86
76
|
NestedModelError,
|
77
|
+
OstructFileNotFoundError,
|
87
78
|
PathSecurityError,
|
79
|
+
SchemaFileError,
|
80
|
+
SchemaValidationError,
|
81
|
+
StreamInterruptedError,
|
82
|
+
StreamParseError,
|
88
83
|
TaskTemplateSyntaxError,
|
89
84
|
TaskTemplateVariableError,
|
90
|
-
VariableError,
|
91
85
|
VariableNameError,
|
92
86
|
VariableValueError,
|
93
87
|
)
|
94
|
-
from .file_utils import FileInfoList,
|
88
|
+
from .file_utils import FileInfoList, collect_files
|
95
89
|
from .path_utils import validate_path_mapping
|
96
90
|
from .security import SecurityManager
|
91
|
+
from .serialization import LogSerializer
|
97
92
|
from .template_env import create_jinja_env
|
98
93
|
from .template_utils import SystemPromptError, render_template
|
94
|
+
from .token_utils import estimate_tokens_with_encoding
|
99
95
|
|
100
96
|
# Constants
|
101
97
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
|
102
98
|
|
103
99
|
|
104
|
-
|
105
|
-
|
106
|
-
"""Compatibility class to mimic argparse.Namespace for existing code."""
|
100
|
+
class CLIParams(TypedDict, total=False):
|
101
|
+
"""Type-safe CLI parameters."""
|
107
102
|
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
103
|
+
files: List[
|
104
|
+
Tuple[str, str]
|
105
|
+
] # List of (name, path) tuples from Click's nargs=2
|
106
|
+
dir: List[
|
107
|
+
Tuple[str, str]
|
108
|
+
] # List of (name, dir) tuples from Click's nargs=2
|
109
|
+
allowed_dirs: List[str]
|
114
110
|
base_dir: str
|
115
111
|
allowed_dir_file: Optional[str]
|
116
|
-
|
117
|
-
dir_ext: Optional[str]
|
112
|
+
recursive: bool
|
118
113
|
var: List[str]
|
119
114
|
json_var: List[str]
|
120
115
|
system_prompt: Optional[str]
|
121
116
|
system_prompt_file: Optional[str]
|
122
117
|
ignore_task_sysprompt: bool
|
123
|
-
schema_file: str
|
124
118
|
model: str
|
125
|
-
temperature: float
|
126
|
-
max_tokens: Optional[int]
|
127
|
-
top_p: float
|
128
|
-
frequency_penalty: float
|
129
|
-
presence_penalty: float
|
130
119
|
timeout: float
|
131
120
|
output_file: Optional[str]
|
132
121
|
dry_run: bool
|
@@ -136,7 +125,16 @@ class Namespace:
|
|
136
125
|
debug_openai_stream: bool
|
137
126
|
show_model_schema: bool
|
138
127
|
debug_validation: bool
|
139
|
-
|
128
|
+
temperature: Optional[float]
|
129
|
+
max_output_tokens: Optional[int]
|
130
|
+
top_p: Optional[float]
|
131
|
+
frequency_penalty: Optional[float]
|
132
|
+
presence_penalty: Optional[float]
|
133
|
+
reasoning_effort: Optional[str]
|
134
|
+
progress_level: str
|
135
|
+
task_file: Optional[str]
|
136
|
+
task: Optional[str]
|
137
|
+
schema_file: str
|
140
138
|
|
141
139
|
|
142
140
|
# Set up logging
|
@@ -174,45 +172,6 @@ ostruct_file_handler.setFormatter(
|
|
174
172
|
logger.addHandler(ostruct_file_handler)
|
175
173
|
|
176
174
|
|
177
|
-
class ExitCode(IntEnum):
|
178
|
-
"""Exit codes for the CLI following standard Unix conventions.
|
179
|
-
|
180
|
-
Categories:
|
181
|
-
- Success (0-1)
|
182
|
-
- User Interruption (2-3)
|
183
|
-
- Input/Validation (64-69)
|
184
|
-
- I/O and File Access (70-79)
|
185
|
-
- API and External Services (80-89)
|
186
|
-
- Internal Errors (90-99)
|
187
|
-
"""
|
188
|
-
|
189
|
-
# Success codes
|
190
|
-
SUCCESS = 0
|
191
|
-
|
192
|
-
# User interruption
|
193
|
-
INTERRUPTED = 2
|
194
|
-
|
195
|
-
# Input/Validation errors (64-69)
|
196
|
-
USAGE_ERROR = 64
|
197
|
-
DATA_ERROR = 65
|
198
|
-
SCHEMA_ERROR = 66
|
199
|
-
VALIDATION_ERROR = 67
|
200
|
-
|
201
|
-
# I/O and File Access errors (70-79)
|
202
|
-
IO_ERROR = 70
|
203
|
-
FILE_NOT_FOUND = 71
|
204
|
-
PERMISSION_ERROR = 72
|
205
|
-
SECURITY_ERROR = 73
|
206
|
-
|
207
|
-
# API and External Service errors (80-89)
|
208
|
-
API_ERROR = 80
|
209
|
-
API_TIMEOUT = 81
|
210
|
-
|
211
|
-
# Internal errors (90-99)
|
212
|
-
INTERNAL_ERROR = 90
|
213
|
-
UNKNOWN_ERROR = 91
|
214
|
-
|
215
|
-
|
216
175
|
# Type aliases
|
217
176
|
FieldType = (
|
218
177
|
Any # Changed from Type[Any] to allow both concrete types and generics
|
@@ -279,7 +238,7 @@ def _get_type_with_constraints(
|
|
279
238
|
show_schema=False,
|
280
239
|
debug_validation=False,
|
281
240
|
)
|
282
|
-
array_type: Type[List[Any]] = List[array_item_model] # type: ignore
|
241
|
+
array_type: Type[List[Any]] = List[array_item_model] # type: ignore
|
283
242
|
return (array_type, Field(**field_kwargs))
|
284
243
|
|
285
244
|
# For non-object items, use the type directly
|
@@ -401,106 +360,17 @@ K = TypeVar("K")
|
|
401
360
|
V = TypeVar("V")
|
402
361
|
|
403
362
|
|
404
|
-
def estimate_tokens_for_chat(
|
405
|
-
messages: List[Dict[str, str]],
|
406
|
-
model: str,
|
407
|
-
encoder: Any = None,
|
408
|
-
) -> int:
|
409
|
-
"""Estimate the number of tokens in a chat completion.
|
410
|
-
|
411
|
-
Args:
|
412
|
-
messages: List of chat messages
|
413
|
-
model: Model name
|
414
|
-
encoder: Optional tiktoken encoder for testing. If provided, only uses encoder.encode() results.
|
415
|
-
"""
|
416
|
-
if encoder is None:
|
417
|
-
try:
|
418
|
-
# Try to get the encoding for the specific model
|
419
|
-
encoder = tiktoken.get_encoding("o200k_base")
|
420
|
-
except KeyError:
|
421
|
-
# Fall back to cl100k_base for unknown models
|
422
|
-
encoder = tiktoken.get_encoding("cl100k_base")
|
423
|
-
|
424
|
-
# Use standard token counting logic for real tiktoken encoders
|
425
|
-
num_tokens = 0
|
426
|
-
for message in messages:
|
427
|
-
# Add message overhead
|
428
|
-
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
429
|
-
for key, value in message.items():
|
430
|
-
num_tokens += len(encoder.encode(str(value)))
|
431
|
-
if key == "name": # if there's a name, the role is omitted
|
432
|
-
num_tokens -= 1 # role is omitted
|
433
|
-
num_tokens += 2 # every reply is primed with <im_start>assistant
|
434
|
-
return num_tokens
|
435
|
-
else:
|
436
|
-
# For mock encoders in tests, just return the length of encoded content
|
437
|
-
num_tokens = 0
|
438
|
-
for message in messages:
|
439
|
-
for value in message.values():
|
440
|
-
num_tokens += len(encoder.encode(str(value)))
|
441
|
-
return num_tokens
|
442
|
-
|
443
|
-
|
444
|
-
def get_default_token_limit(model: str) -> int:
|
445
|
-
"""Get the default token limit for a given model.
|
446
|
-
|
447
|
-
Note: These limits are based on current OpenAI model specifications as of 2024 and may
|
448
|
-
need to be updated if OpenAI changes the models' capabilities.
|
449
|
-
|
450
|
-
Args:
|
451
|
-
model: The model name (e.g., 'gpt-4o', 'o1-mini', 'o3-mini')
|
452
|
-
|
453
|
-
Returns:
|
454
|
-
The default token limit for the model
|
455
|
-
"""
|
456
|
-
if "o1-" in model:
|
457
|
-
return 100_000 # o1-mini supports up to 100K output tokens
|
458
|
-
elif "gpt-4o" in model:
|
459
|
-
return 16_384 # gpt-4o supports up to 16K output tokens
|
460
|
-
elif "o3-" in model:
|
461
|
-
return 16_384 # o3-mini supports up to 16K output tokens
|
462
|
-
else:
|
463
|
-
return 4_096 # default fallback
|
464
|
-
|
465
|
-
|
466
|
-
def get_context_window_limit(model: str) -> int:
|
467
|
-
"""Get the total context window limit for a given model.
|
468
|
-
|
469
|
-
Note: These limits are based on current OpenAI model specifications as of 2024 and may
|
470
|
-
need to be updated if OpenAI changes the models' capabilities.
|
471
|
-
|
472
|
-
Args:
|
473
|
-
model: The model name (e.g., 'gpt-4o', 'o1-mini', 'o3-mini')
|
474
|
-
|
475
|
-
Returns:
|
476
|
-
The context window limit for the model
|
477
|
-
"""
|
478
|
-
if "o1-" in model:
|
479
|
-
return 200_000 # o1-mini supports 200K total context window
|
480
|
-
elif "gpt-4o" in model or "o3-" in model:
|
481
|
-
return 128_000 # gpt-4o and o3-mini support 128K context window
|
482
|
-
else:
|
483
|
-
return 8_192 # default fallback
|
484
|
-
|
485
|
-
|
486
363
|
def validate_token_limits(
|
487
364
|
model: str, total_tokens: int, max_token_limit: Optional[int] = None
|
488
365
|
) -> None:
|
489
|
-
"""Validate token counts against model limits.
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
total_tokens: Total number of tokens in the prompt
|
494
|
-
max_token_limit: Optional user-specified token limit
|
495
|
-
|
496
|
-
Raises:
|
497
|
-
ValueError: If token limits are exceeded
|
498
|
-
"""
|
499
|
-
context_limit = get_context_window_limit(model)
|
366
|
+
"""Validate token counts against model limits."""
|
367
|
+
registry = ModelRegistry()
|
368
|
+
capabilities = registry.get_capabilities(model)
|
369
|
+
context_limit = capabilities.context_window
|
500
370
|
output_limit = (
|
501
371
|
max_token_limit
|
502
372
|
if max_token_limit is not None
|
503
|
-
else
|
373
|
+
else capabilities.max_output_tokens
|
504
374
|
)
|
505
375
|
|
506
376
|
# Check if total tokens exceed context window
|
@@ -562,8 +432,12 @@ def process_system_prompt(
|
|
562
432
|
)
|
563
433
|
with open(path, "r", encoding="utf-8") as f:
|
564
434
|
system_prompt = f.read().strip()
|
565
|
-
except
|
566
|
-
raise SystemPromptError(
|
435
|
+
except OstructFileNotFoundError as e:
|
436
|
+
raise SystemPromptError(
|
437
|
+
f"Failed to load system prompt file: {e}"
|
438
|
+
) from e
|
439
|
+
except PathSecurityError as e:
|
440
|
+
raise SystemPromptError(f"Invalid system prompt file: {e}") from e
|
567
441
|
|
568
442
|
if system_prompt is not None:
|
569
443
|
# Render system prompt with template context
|
@@ -631,7 +505,8 @@ def validate_variable_mapping(
|
|
631
505
|
value = json.loads(value)
|
632
506
|
except json.JSONDecodeError as e:
|
633
507
|
raise InvalidJSONError(
|
634
|
-
f"Invalid JSON value for variable {name!r}: {value!r}"
|
508
|
+
f"Invalid JSON value for variable {name!r}: {value!r}",
|
509
|
+
context={"variable_name": name},
|
635
510
|
) from e
|
636
511
|
|
637
512
|
return name, value
|
@@ -771,7 +646,9 @@ def _validate_path_mapping_internal(
|
|
771
646
|
raise
|
772
647
|
|
773
648
|
if security_manager:
|
774
|
-
|
649
|
+
try:
|
650
|
+
security_manager.validate_path(str(resolved_path))
|
651
|
+
except PathSecurityError:
|
775
652
|
raise PathSecurityError.from_expanded_paths(
|
776
653
|
original_path=str(path),
|
777
654
|
expanded_path=str(resolved_path),
|
@@ -825,11 +702,20 @@ def validate_task_template(
|
|
825
702
|
template_content: str
|
826
703
|
if task_file is not None:
|
827
704
|
try:
|
828
|
-
|
829
|
-
with open(path, "r", encoding="utf-8") as f:
|
705
|
+
with open(task_file, "r", encoding="utf-8") as f:
|
830
706
|
template_content = f.read()
|
831
|
-
except
|
832
|
-
raise TaskTemplateVariableError(
|
707
|
+
except FileNotFoundError:
|
708
|
+
raise TaskTemplateVariableError(
|
709
|
+
f"Task template file not found: {task_file}"
|
710
|
+
)
|
711
|
+
except PermissionError:
|
712
|
+
raise TaskTemplateVariableError(
|
713
|
+
f"Permission denied reading task template file: {task_file}"
|
714
|
+
)
|
715
|
+
except Exception as e:
|
716
|
+
raise TaskTemplateVariableError(
|
717
|
+
f"Error reading task template file: {e}"
|
718
|
+
)
|
833
719
|
else:
|
834
720
|
template_content = task # type: ignore # We know task is str here due to the checks above
|
835
721
|
|
@@ -847,10 +733,10 @@ def validate_schema_file(
|
|
847
733
|
path: str,
|
848
734
|
verbose: bool = False,
|
849
735
|
) -> Dict[str, Any]:
|
850
|
-
"""Validate a JSON schema file.
|
736
|
+
"""Validate and load a JSON schema file.
|
851
737
|
|
852
738
|
Args:
|
853
|
-
path: Path to
|
739
|
+
path: Path to schema file
|
854
740
|
verbose: Whether to enable verbose logging
|
855
741
|
|
856
742
|
Returns:
|
@@ -865,14 +751,42 @@ def validate_schema_file(
|
|
865
751
|
logger.info("Validating schema file: %s", path)
|
866
752
|
|
867
753
|
try:
|
868
|
-
|
869
|
-
|
754
|
+
logger.debug("Opening schema file: %s", path)
|
755
|
+
with open(path, "r", encoding="utf-8") as f:
|
756
|
+
logger.debug("Loading JSON from schema file")
|
757
|
+
try:
|
758
|
+
schema = json.load(f)
|
759
|
+
logger.debug(
|
760
|
+
"Successfully loaded JSON: %s",
|
761
|
+
json.dumps(schema, indent=2),
|
762
|
+
)
|
763
|
+
except json.JSONDecodeError as e:
|
764
|
+
logger.error("JSON decode error in %s: %s", path, str(e))
|
765
|
+
logger.debug(
|
766
|
+
"Error details - line: %d, col: %d, msg: %s",
|
767
|
+
e.lineno,
|
768
|
+
e.colno,
|
769
|
+
e.msg,
|
770
|
+
)
|
771
|
+
raise InvalidJSONError(
|
772
|
+
f"Invalid JSON in schema file {path}: {e}",
|
773
|
+
context={"schema_path": path},
|
774
|
+
) from e
|
870
775
|
except FileNotFoundError:
|
871
|
-
|
872
|
-
|
873
|
-
raise
|
776
|
+
msg = f"Schema file not found: {path}"
|
777
|
+
logger.error(msg)
|
778
|
+
raise SchemaFileError(msg, schema_path=path)
|
779
|
+
except PermissionError:
|
780
|
+
msg = f"Permission denied reading schema file: {path}"
|
781
|
+
logger.error(msg)
|
782
|
+
raise SchemaFileError(msg, schema_path=path)
|
874
783
|
except Exception as e:
|
875
|
-
|
784
|
+
if isinstance(e, InvalidJSONError):
|
785
|
+
raise
|
786
|
+
msg = f"Failed to read schema file {path}: {e}"
|
787
|
+
logger.error(msg)
|
788
|
+
logger.debug("Unexpected error details: %s", str(e))
|
789
|
+
raise SchemaFileError(msg, schema_path=path) from e
|
876
790
|
|
877
791
|
# Pre-validation structure checks
|
878
792
|
if verbose:
|
@@ -880,11 +794,9 @@ def validate_schema_file(
|
|
880
794
|
logger.debug("Loaded schema: %s", json.dumps(schema, indent=2))
|
881
795
|
|
882
796
|
if not isinstance(schema, dict):
|
883
|
-
|
884
|
-
|
885
|
-
|
886
|
-
)
|
887
|
-
raise SchemaValidationError("Schema must be a JSON object")
|
797
|
+
msg = f"Schema in {path} must be a JSON object"
|
798
|
+
logger.error(msg)
|
799
|
+
raise SchemaValidationError(msg, schema_path=path)
|
888
800
|
|
889
801
|
# Validate schema structure
|
890
802
|
if "schema" in schema:
|
@@ -892,30 +804,37 @@ def validate_schema_file(
|
|
892
804
|
logger.debug("Found schema wrapper, validating inner schema")
|
893
805
|
inner_schema = schema["schema"]
|
894
806
|
if not isinstance(inner_schema, dict):
|
895
|
-
|
896
|
-
|
897
|
-
|
898
|
-
type(inner_schema).__name__,
|
899
|
-
)
|
900
|
-
raise SchemaValidationError("Inner schema must be a JSON object")
|
807
|
+
msg = f"Inner schema in {path} must be a JSON object"
|
808
|
+
logger.error(msg)
|
809
|
+
raise SchemaValidationError(msg, schema_path=path)
|
901
810
|
if verbose:
|
902
811
|
logger.debug("Inner schema validated successfully")
|
812
|
+
logger.debug(
|
813
|
+
"Inner schema: %s", json.dumps(inner_schema, indent=2)
|
814
|
+
)
|
903
815
|
else:
|
904
816
|
if verbose:
|
905
817
|
logger.debug("No schema wrapper found, using schema as-is")
|
818
|
+
logger.debug("Schema: %s", json.dumps(schema, indent=2))
|
819
|
+
|
820
|
+
# Additional schema validation
|
821
|
+
if "type" not in schema.get("schema", schema):
|
822
|
+
msg = f"Schema in {path} must specify a type"
|
823
|
+
logger.error(msg)
|
824
|
+
raise SchemaValidationError(msg, schema_path=path)
|
906
825
|
|
907
826
|
# Return the full schema including wrapper
|
908
827
|
return schema
|
909
828
|
|
910
829
|
|
911
830
|
def collect_template_files(
|
912
|
-
args:
|
831
|
+
args: CLIParams,
|
913
832
|
security_manager: SecurityManager,
|
914
|
-
) -> Dict[str,
|
833
|
+
) -> Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]:
|
915
834
|
"""Collect files from command line arguments.
|
916
835
|
|
917
836
|
Args:
|
918
|
-
args:
|
837
|
+
args: Command line arguments
|
919
838
|
security_manager: Security manager for path validation
|
920
839
|
|
921
840
|
Returns:
|
@@ -926,15 +845,29 @@ def collect_template_files(
|
|
926
845
|
ValueError: If file mappings are invalid or files cannot be accessed
|
927
846
|
"""
|
928
847
|
try:
|
929
|
-
|
930
|
-
|
931
|
-
|
932
|
-
|
933
|
-
|
934
|
-
|
848
|
+
# Get files and directories from args - they are already tuples from Click's nargs=2
|
849
|
+
files = list(
|
850
|
+
args.get("files", [])
|
851
|
+
) # List of (name, path) tuples from Click
|
852
|
+
dirs = args.get("dir", []) # List of (name, dir) tuples from Click
|
853
|
+
|
854
|
+
# Collect files from directories
|
855
|
+
dir_files = collect_files(
|
856
|
+
file_mappings=cast(
|
857
|
+
List[Tuple[str, Union[str, Path]]], files
|
858
|
+
), # Cast to correct type
|
859
|
+
dir_mappings=cast(
|
860
|
+
List[Tuple[str, Union[str, Path]]], dirs
|
861
|
+
), # Cast to correct type
|
862
|
+
dir_recursive=args.get("recursive", False),
|
935
863
|
security_manager=security_manager,
|
936
864
|
)
|
937
|
-
|
865
|
+
|
866
|
+
# Combine results
|
867
|
+
return cast(
|
868
|
+
Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]],
|
869
|
+
dir_files,
|
870
|
+
)
|
938
871
|
except PathSecurityError:
|
939
872
|
# Let PathSecurityError propagate without wrapping
|
940
873
|
raise
|
@@ -952,11 +885,11 @@ def collect_template_files(
|
|
952
885
|
raise ValueError(f"Error collecting files: {e}")
|
953
886
|
|
954
887
|
|
955
|
-
def collect_simple_variables(args:
|
888
|
+
def collect_simple_variables(args: CLIParams) -> Dict[str, str]:
|
956
889
|
"""Collect simple string variables from --var arguments.
|
957
890
|
|
958
891
|
Args:
|
959
|
-
args:
|
892
|
+
args: Command line arguments
|
960
893
|
|
961
894
|
Returns:
|
962
895
|
Dictionary mapping variable names to string values
|
@@ -967,10 +900,15 @@ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
|
|
967
900
|
variables: Dict[str, str] = {}
|
968
901
|
all_names: Set[str] = set()
|
969
902
|
|
970
|
-
if args.var:
|
971
|
-
for mapping in args
|
903
|
+
if args.get("var"):
|
904
|
+
for mapping in args["var"]:
|
972
905
|
try:
|
973
|
-
|
906
|
+
# Handle both tuple format and string format
|
907
|
+
if isinstance(mapping, tuple):
|
908
|
+
name, value = mapping
|
909
|
+
else:
|
910
|
+
name, value = mapping.split("=", 1)
|
911
|
+
|
974
912
|
if not name.isidentifier():
|
975
913
|
raise VariableNameError(f"Invalid variable name: {name}")
|
976
914
|
if name in all_names:
|
@@ -985,11 +923,11 @@ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
|
|
985
923
|
return variables
|
986
924
|
|
987
925
|
|
988
|
-
def collect_json_variables(args:
|
926
|
+
def collect_json_variables(args: CLIParams) -> Dict[str, Any]:
|
989
927
|
"""Collect JSON variables from --json-var arguments.
|
990
928
|
|
991
929
|
Args:
|
992
|
-
args:
|
930
|
+
args: Command line arguments
|
993
931
|
|
994
932
|
Returns:
|
995
933
|
Dictionary mapping variable names to parsed JSON values
|
@@ -1001,32 +939,46 @@ def collect_json_variables(args: Namespace) -> Dict[str, Any]:
|
|
1001
939
|
variables: Dict[str, Any] = {}
|
1002
940
|
all_names: Set[str] = set()
|
1003
941
|
|
1004
|
-
if args.json_var:
|
1005
|
-
for mapping in args
|
942
|
+
if args.get("json_var"):
|
943
|
+
for mapping in args["json_var"]:
|
1006
944
|
try:
|
1007
|
-
|
945
|
+
# Handle both tuple format and string format
|
946
|
+
if isinstance(mapping, tuple):
|
947
|
+
name, value = (
|
948
|
+
mapping # Value is already parsed by Click validator
|
949
|
+
)
|
950
|
+
else:
|
951
|
+
try:
|
952
|
+
name, json_str = mapping.split("=", 1)
|
953
|
+
except ValueError:
|
954
|
+
raise VariableNameError(
|
955
|
+
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
956
|
+
)
|
957
|
+
try:
|
958
|
+
value = json.loads(json_str)
|
959
|
+
except json.JSONDecodeError as e:
|
960
|
+
raise InvalidJSONError(
|
961
|
+
f"Invalid JSON value for variable '{name}': {json_str}",
|
962
|
+
context={"variable_name": name},
|
963
|
+
) from e
|
964
|
+
|
1008
965
|
if not name.isidentifier():
|
1009
966
|
raise VariableNameError(f"Invalid variable name: {name}")
|
1010
967
|
if name in all_names:
|
1011
968
|
raise VariableNameError(f"Duplicate variable name: {name}")
|
1012
|
-
|
1013
|
-
|
1014
|
-
|
1015
|
-
|
1016
|
-
|
1017
|
-
raise InvalidJSONError(
|
1018
|
-
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
|
1019
|
-
)
|
1020
|
-
except ValueError:
|
1021
|
-
raise VariableNameError(
|
1022
|
-
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
1023
|
-
)
|
969
|
+
|
970
|
+
variables[name] = value
|
971
|
+
all_names.add(name)
|
972
|
+
except (VariableNameError, InvalidJSONError):
|
973
|
+
raise
|
1024
974
|
|
1025
975
|
return variables
|
1026
976
|
|
1027
977
|
|
1028
978
|
def create_template_context(
|
1029
|
-
files: Optional[
|
979
|
+
files: Optional[
|
980
|
+
Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]
|
981
|
+
] = None,
|
1030
982
|
variables: Optional[Dict[str, str]] = None,
|
1031
983
|
json_variables: Optional[Dict[str, Any]] = None,
|
1032
984
|
security_manager: Optional[SecurityManager] = None,
|
@@ -1070,14 +1022,14 @@ def create_template_context(
|
|
1070
1022
|
return context
|
1071
1023
|
|
1072
1024
|
|
1073
|
-
def create_template_context_from_args(
|
1074
|
-
args:
|
1025
|
+
async def create_template_context_from_args(
|
1026
|
+
args: CLIParams,
|
1075
1027
|
security_manager: SecurityManager,
|
1076
1028
|
) -> Dict[str, Any]:
|
1077
1029
|
"""Create template context from command line arguments.
|
1078
1030
|
|
1079
1031
|
Args:
|
1080
|
-
args:
|
1032
|
+
args: Command line arguments
|
1081
1033
|
security_manager: Security manager for path validation
|
1082
1034
|
|
1083
1035
|
Returns:
|
@@ -1090,50 +1042,13 @@ def create_template_context_from_args(
|
|
1090
1042
|
"""
|
1091
1043
|
try:
|
1092
1044
|
# Collect files from arguments
|
1093
|
-
files =
|
1094
|
-
if any([args.file, args.files, args.dir]):
|
1095
|
-
files = collect_files(
|
1096
|
-
file_mappings=args.file,
|
1097
|
-
pattern_mappings=args.files,
|
1098
|
-
dir_mappings=args.dir,
|
1099
|
-
dir_recursive=args.dir_recursive,
|
1100
|
-
dir_extensions=(
|
1101
|
-
args.dir_ext.split(",") if args.dir_ext else None
|
1102
|
-
),
|
1103
|
-
security_manager=security_manager,
|
1104
|
-
)
|
1045
|
+
files = collect_template_files(args, security_manager)
|
1105
1046
|
|
1106
1047
|
# Collect simple variables
|
1107
|
-
|
1108
|
-
variables = collect_simple_variables(args)
|
1109
|
-
except VariableNameError as e:
|
1110
|
-
raise VariableError(str(e))
|
1048
|
+
variables = collect_simple_variables(args)
|
1111
1049
|
|
1112
1050
|
# Collect JSON variables
|
1113
|
-
json_variables =
|
1114
|
-
if args.json_var:
|
1115
|
-
for mapping in args.json_var:
|
1116
|
-
try:
|
1117
|
-
name, value = mapping.split("=", 1)
|
1118
|
-
if not name.isidentifier():
|
1119
|
-
raise VariableNameError(
|
1120
|
-
f"Invalid variable name: {name}"
|
1121
|
-
)
|
1122
|
-
try:
|
1123
|
-
json_value = json.loads(value)
|
1124
|
-
except json.JSONDecodeError as e:
|
1125
|
-
raise InvalidJSONError(
|
1126
|
-
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {value}"
|
1127
|
-
)
|
1128
|
-
if name in json_variables:
|
1129
|
-
raise VariableNameError(
|
1130
|
-
f"Duplicate variable name: {name}"
|
1131
|
-
)
|
1132
|
-
json_variables[name] = json_value
|
1133
|
-
except ValueError:
|
1134
|
-
raise VariableNameError(
|
1135
|
-
f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
|
1136
|
-
)
|
1051
|
+
json_variables = collect_json_variables(args)
|
1137
1052
|
|
1138
1053
|
# Get stdin content if available
|
1139
1054
|
stdin_content = None
|
@@ -1144,7 +1059,7 @@ def create_template_context_from_args(
|
|
1144
1059
|
# Skip stdin if it can't be read
|
1145
1060
|
pass
|
1146
1061
|
|
1147
|
-
|
1062
|
+
context = create_template_context(
|
1148
1063
|
files=files,
|
1149
1064
|
variables=variables,
|
1150
1065
|
json_variables=json_variables,
|
@@ -1152,6 +1067,11 @@ def create_template_context_from_args(
|
|
1152
1067
|
stdin_content=stdin_content,
|
1153
1068
|
)
|
1154
1069
|
|
1070
|
+
# Add current model to context
|
1071
|
+
context["current_model"] = args["model"]
|
1072
|
+
|
1073
|
+
return context
|
1074
|
+
|
1155
1075
|
except PathSecurityError:
|
1156
1076
|
# Let PathSecurityError propagate without wrapping
|
1157
1077
|
raise
|
@@ -1192,40 +1112,13 @@ def validate_security_manager(
|
|
1192
1112
|
if base_dir is None:
|
1193
1113
|
base_dir = os.getcwd()
|
1194
1114
|
|
1195
|
-
#
|
1196
|
-
|
1197
|
-
allowed_dirs = []
|
1198
|
-
|
1199
|
-
# Add base directory if it exists
|
1200
|
-
try:
|
1201
|
-
base_dir_path = Path(base_dir).resolve()
|
1202
|
-
if not base_dir_path.exists():
|
1203
|
-
raise DirectoryNotFoundError(
|
1204
|
-
f"Base directory not found: {base_dir}"
|
1205
|
-
)
|
1206
|
-
if not base_dir_path.is_dir():
|
1207
|
-
raise DirectoryNotFoundError(
|
1208
|
-
f"Base directory is not a directory: {base_dir}"
|
1209
|
-
)
|
1210
|
-
all_allowed_dirs = [str(base_dir_path)]
|
1211
|
-
except OSError as e:
|
1212
|
-
raise DirectoryNotFoundError(f"Invalid base directory: {e}")
|
1115
|
+
# Create security manager with base directory
|
1116
|
+
security_manager = SecurityManager(base_dir)
|
1213
1117
|
|
1214
1118
|
# Add explicitly allowed directories
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
if not resolved_path.exists():
|
1219
|
-
raise DirectoryNotFoundError(
|
1220
|
-
f"Directory not found: {dir_path}"
|
1221
|
-
)
|
1222
|
-
if not resolved_path.is_dir():
|
1223
|
-
raise DirectoryNotFoundError(
|
1224
|
-
f"Path is not a directory: {dir_path}"
|
1225
|
-
)
|
1226
|
-
all_allowed_dirs.append(str(resolved_path))
|
1227
|
-
except OSError as e:
|
1228
|
-
raise DirectoryNotFoundError(f"Invalid directory path: {e}")
|
1119
|
+
if allowed_dirs:
|
1120
|
+
for dir_path in allowed_dirs:
|
1121
|
+
security_manager.add_allowed_directory(dir_path)
|
1229
1122
|
|
1230
1123
|
# Add directories from file if specified
|
1231
1124
|
if allowed_dir_file:
|
@@ -1234,28 +1127,13 @@ def validate_security_manager(
|
|
1234
1127
|
for line in f:
|
1235
1128
|
line = line.strip()
|
1236
1129
|
if line and not line.startswith("#"):
|
1237
|
-
|
1238
|
-
resolved_path = Path(line).resolve()
|
1239
|
-
if not resolved_path.exists():
|
1240
|
-
raise DirectoryNotFoundError(
|
1241
|
-
f"Directory not found: {line}"
|
1242
|
-
)
|
1243
|
-
if not resolved_path.is_dir():
|
1244
|
-
raise DirectoryNotFoundError(
|
1245
|
-
f"Path is not a directory: {line}"
|
1246
|
-
)
|
1247
|
-
all_allowed_dirs.append(str(resolved_path))
|
1248
|
-
except OSError as e:
|
1249
|
-
raise DirectoryNotFoundError(
|
1250
|
-
f"Invalid directory path in {allowed_dir_file}: {e}"
|
1251
|
-
)
|
1130
|
+
security_manager.add_allowed_directory(line)
|
1252
1131
|
except OSError as e:
|
1253
1132
|
raise DirectoryNotFoundError(
|
1254
1133
|
f"Failed to read allowed directories file: {e}"
|
1255
1134
|
)
|
1256
1135
|
|
1257
|
-
|
1258
|
-
return SecurityManager(base_dir=base_dir, allowed_dirs=all_allowed_dirs)
|
1136
|
+
return security_manager
|
1259
1137
|
|
1260
1138
|
|
1261
1139
|
def parse_var(var_str: str) -> Tuple[str, str]:
|
@@ -1315,7 +1193,8 @@ def parse_json_var(var_str: str) -> Tuple[str, Any]:
|
|
1315
1193
|
value = json.loads(json_str)
|
1316
1194
|
except json.JSONDecodeError as e:
|
1317
1195
|
raise InvalidJSONError(
|
1318
|
-
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
|
1196
|
+
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}",
|
1197
|
+
context={"variable_name": name},
|
1319
1198
|
)
|
1320
1199
|
|
1321
1200
|
return name, value
|
@@ -1364,41 +1243,96 @@ def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
|
|
1364
1243
|
|
1365
1244
|
|
1366
1245
|
def handle_error(e: Exception) -> None:
|
1367
|
-
"""Handle errors
|
1246
|
+
"""Handle CLI errors and display appropriate messages.
|
1247
|
+
|
1248
|
+
Maintains specific error type handling while reducing duplication.
|
1249
|
+
Provides enhanced debug logging for CLI errors.
|
1250
|
+
"""
|
1251
|
+
# 1. Determine error type and message
|
1368
1252
|
if isinstance(e, click.UsageError):
|
1369
|
-
|
1370
|
-
|
1371
|
-
|
1372
|
-
|
1373
|
-
|
1374
|
-
|
1375
|
-
|
1376
|
-
|
1377
|
-
|
1378
|
-
elif isinstance(e,
|
1379
|
-
|
1380
|
-
|
1381
|
-
click.secho(msg, fg="red", err=True)
|
1382
|
-
sys.exit(ExitCode.DATA_ERROR)
|
1383
|
-
elif isinstance(e, FileNotFoundError):
|
1384
|
-
# Use the original error message if available
|
1385
|
-
msg = str(e) if str(e) != "None" else "File not found"
|
1386
|
-
click.secho(msg, fg="red", err=True)
|
1387
|
-
sys.exit(ExitCode.SCHEMA_ERROR)
|
1388
|
-
elif isinstance(e, TaskTemplateSyntaxError):
|
1389
|
-
# Use the original error message if available
|
1390
|
-
msg = str(e) if str(e) != "None" else "Template syntax error"
|
1391
|
-
click.secho(msg, fg="red", err=True)
|
1392
|
-
sys.exit(ExitCode.INTERNAL_ERROR)
|
1253
|
+
msg = f"Usage error: {str(e)}"
|
1254
|
+
exit_code = ExitCode.USAGE_ERROR
|
1255
|
+
elif isinstance(e, SchemaFileError):
|
1256
|
+
# Preserve specific schema error handling
|
1257
|
+
msg = str(e) # Use existing __str__ formatting
|
1258
|
+
exit_code = ExitCode.SCHEMA_ERROR
|
1259
|
+
elif isinstance(e, (InvalidJSONError, json.JSONDecodeError)):
|
1260
|
+
msg = f"Invalid JSON error: {str(e)}"
|
1261
|
+
exit_code = ExitCode.DATA_ERROR
|
1262
|
+
elif isinstance(e, SchemaValidationError):
|
1263
|
+
msg = f"Schema validation error: {str(e)}"
|
1264
|
+
exit_code = ExitCode.VALIDATION_ERROR
|
1393
1265
|
elif isinstance(e, CLIError):
|
1394
|
-
# Use
|
1395
|
-
e.
|
1396
|
-
|
1397
|
-
|
1266
|
+
msg = str(e) # Use existing __str__ formatting
|
1267
|
+
exit_code = ExitCode(e.exit_code) # Convert int to ExitCode
|
1268
|
+
else:
|
1269
|
+
msg = f"Unexpected error: {str(e)}"
|
1270
|
+
exit_code = ExitCode.INTERNAL_ERROR
|
1271
|
+
|
1272
|
+
# 2. Debug logging
|
1273
|
+
if isinstance(e, CLIError) and logger.isEnabledFor(logging.DEBUG):
|
1274
|
+
# Format context fields with lowercase keys and simple values
|
1275
|
+
context_str = ""
|
1276
|
+
if hasattr(e, "context"):
|
1277
|
+
for key, value in sorted(e.context.items()):
|
1278
|
+
if key not in {
|
1279
|
+
"timestamp",
|
1280
|
+
"host",
|
1281
|
+
"version",
|
1282
|
+
"python_version",
|
1283
|
+
}:
|
1284
|
+
context_str += f"{key.lower()}: {value}\n"
|
1285
|
+
|
1286
|
+
logger.debug(
|
1287
|
+
"Error details:\n"
|
1288
|
+
f"Type: {type(e).__name__}\n"
|
1289
|
+
f"{context_str.rstrip()}"
|
1398
1290
|
)
|
1291
|
+
elif not isinstance(e, click.UsageError):
|
1292
|
+
logger.error(msg, exc_info=True)
|
1399
1293
|
else:
|
1400
|
-
|
1401
|
-
|
1294
|
+
logger.error(msg)
|
1295
|
+
|
1296
|
+
# 3. User output
|
1297
|
+
click.secho(msg, fg="red", err=True)
|
1298
|
+
sys.exit(exit_code)
|
1299
|
+
|
1300
|
+
|
1301
|
+
def validate_model_parameters(model: str, params: Dict[str, Any]) -> None:
|
1302
|
+
"""Validate model parameters against model capabilities.
|
1303
|
+
|
1304
|
+
Args:
|
1305
|
+
model: The model name to validate parameters for
|
1306
|
+
params: Dictionary of parameter names and values to validate
|
1307
|
+
|
1308
|
+
Raises:
|
1309
|
+
CLIError: If any parameters are not supported by the model
|
1310
|
+
"""
|
1311
|
+
try:
|
1312
|
+
capabilities = ModelRegistry().get_capabilities(model)
|
1313
|
+
for param_name, value in params.items():
|
1314
|
+
try:
|
1315
|
+
capabilities.validate_parameter(param_name, value)
|
1316
|
+
except OpenAIClientError as e:
|
1317
|
+
logger.error(
|
1318
|
+
"Validation failed for model %s: %s", model, str(e)
|
1319
|
+
)
|
1320
|
+
raise CLIError(
|
1321
|
+
str(e),
|
1322
|
+
exit_code=ExitCode.VALIDATION_ERROR,
|
1323
|
+
context={
|
1324
|
+
"model": model,
|
1325
|
+
"param": param_name,
|
1326
|
+
"value": value,
|
1327
|
+
},
|
1328
|
+
)
|
1329
|
+
except (ModelNotSupportedError, ModelVersionError) as e:
|
1330
|
+
logger.error("Model validation failed: %s", str(e))
|
1331
|
+
raise CLIError(
|
1332
|
+
str(e),
|
1333
|
+
exit_code=ExitCode.VALIDATION_ERROR,
|
1334
|
+
context={"model": model},
|
1335
|
+
)
|
1402
1336
|
|
1403
1337
|
|
1404
1338
|
async def stream_structured_output(
|
@@ -1409,36 +1343,75 @@ async def stream_structured_output(
|
|
1409
1343
|
output_schema: Type[BaseModel],
|
1410
1344
|
output_file: Optional[str] = None,
|
1411
1345
|
**kwargs: Any,
|
1412
|
-
) -> None:
|
1346
|
+
) -> AsyncGenerator[BaseModel, None]:
|
1413
1347
|
"""Stream structured output from OpenAI API.
|
1414
1348
|
|
1415
1349
|
This function follows the guide's recommendation for a focused async streaming function.
|
1416
1350
|
It handles the core streaming logic and resource cleanup.
|
1351
|
+
|
1352
|
+
Args:
|
1353
|
+
client: The OpenAI client to use
|
1354
|
+
model: The model to use
|
1355
|
+
system_prompt: The system prompt to use
|
1356
|
+
user_prompt: The user prompt to use
|
1357
|
+
output_schema: The Pydantic model to validate responses against
|
1358
|
+
output_file: Optional file to write output to
|
1359
|
+
**kwargs: Additional parameters to pass to the API
|
1360
|
+
|
1361
|
+
Returns:
|
1362
|
+
An async generator yielding validated model instances
|
1363
|
+
|
1364
|
+
Raises:
|
1365
|
+
ValueError: If the model does not support structured output or parameters are invalid
|
1366
|
+
StreamInterruptedError: If the stream is interrupted
|
1367
|
+
APIResponseError: If there is an API error
|
1417
1368
|
"""
|
1418
1369
|
try:
|
1370
|
+
# Check if model supports structured output using openai_structured's function
|
1371
|
+
if not supports_structured_output(model):
|
1372
|
+
raise ValueError(
|
1373
|
+
f"Model {model} does not support structured output with json_schema response format. "
|
1374
|
+
"Please use a model that supports structured output."
|
1375
|
+
)
|
1376
|
+
|
1377
|
+
# Extract non-model parameters
|
1378
|
+
on_log = kwargs.pop("on_log", None)
|
1379
|
+
|
1380
|
+
# Handle model-specific parameters
|
1381
|
+
stream_kwargs = {}
|
1382
|
+
registry = ModelRegistry()
|
1383
|
+
capabilities = registry.get_capabilities(model)
|
1384
|
+
|
1385
|
+
# Validate and include supported parameters
|
1386
|
+
for param_name, value in kwargs.items():
|
1387
|
+
if param_name in capabilities.supported_parameters:
|
1388
|
+
# Validate the parameter value
|
1389
|
+
capabilities.validate_parameter(param_name, value)
|
1390
|
+
stream_kwargs[param_name] = value
|
1391
|
+
else:
|
1392
|
+
logger.warning(
|
1393
|
+
f"Parameter {param_name} is not supported by model {model} and will be ignored"
|
1394
|
+
)
|
1395
|
+
|
1396
|
+
# Log the API request details
|
1397
|
+
logger.debug("Making OpenAI API request with:")
|
1398
|
+
logger.debug("Model: %s", model)
|
1399
|
+
logger.debug("System prompt: %s", system_prompt)
|
1400
|
+
logger.debug("User prompt: %s", user_prompt)
|
1401
|
+
logger.debug("Parameters: %s", json.dumps(stream_kwargs, indent=2))
|
1402
|
+
logger.debug("Schema: %s", output_schema.model_json_schema())
|
1403
|
+
|
1404
|
+
# Use the async generator from openai_structured directly
|
1419
1405
|
async for chunk in async_openai_structured_stream(
|
1420
1406
|
client=client,
|
1421
1407
|
model=model,
|
1422
|
-
output_schema=output_schema,
|
1423
1408
|
system_prompt=system_prompt,
|
1424
1409
|
user_prompt=user_prompt,
|
1425
|
-
|
1410
|
+
output_schema=output_schema,
|
1411
|
+
on_log=on_log, # Pass non-model parameters directly to the function
|
1412
|
+
**stream_kwargs, # Pass only validated model parameters
|
1426
1413
|
):
|
1427
|
-
|
1428
|
-
continue
|
1429
|
-
|
1430
|
-
# Process and output the chunk
|
1431
|
-
dumped = chunk.model_dump(mode="json")
|
1432
|
-
json_str = json.dumps(dumped, indent=2)
|
1433
|
-
|
1434
|
-
if output_file:
|
1435
|
-
with open(output_file, "a", encoding="utf-8") as f:
|
1436
|
-
f.write(json_str)
|
1437
|
-
f.write("\n")
|
1438
|
-
f.flush() # Ensure immediate flush to file
|
1439
|
-
else:
|
1440
|
-
# Print directly to stdout with immediate flush
|
1441
|
-
print(json_str, flush=True)
|
1414
|
+
yield chunk
|
1442
1415
|
|
1443
1416
|
except (
|
1444
1417
|
StreamInterruptedError,
|
@@ -1455,149 +1428,458 @@ async def stream_structured_output(
|
|
1455
1428
|
await client.close()
|
1456
1429
|
|
1457
1430
|
|
1458
|
-
|
1459
|
-
|
1431
|
+
@click.group()
|
1432
|
+
@click.version_option(version=__version__)
|
1433
|
+
def cli() -> None:
|
1434
|
+
"""ostruct CLI - Make structured OpenAI API calls.
|
1435
|
+
|
1436
|
+
ostruct allows you to invoke OpenAI Structured Output to produce structured JSON
|
1437
|
+
output using templates and JSON schemas. It provides support for file handling, variable
|
1438
|
+
substitution, and output validation.
|
1439
|
+
|
1440
|
+
For detailed documentation, visit: https://ostruct.readthedocs.io
|
1441
|
+
|
1442
|
+
Examples:
|
1443
|
+
|
1444
|
+
# Basic usage with a template and schema
|
1445
|
+
|
1446
|
+
ostruct run task.j2 schema.json -V name=value
|
1460
1447
|
|
1461
|
-
|
1462
|
-
|
1448
|
+
# Process files with recursive directory scanning
|
1449
|
+
|
1450
|
+
ostruct run template.j2 schema.json -f code main.py -d src ./src -R
|
1451
|
+
|
1452
|
+
# Use JSON variables and custom model parameters
|
1453
|
+
|
1454
|
+
ostruct run task.j2 schema.json -J config='{"env":"prod"}' -m o3-mini
|
1455
|
+
"""
|
1456
|
+
pass
|
1457
|
+
|
1458
|
+
|
1459
|
+
@cli.command()
|
1460
|
+
@click.argument("task_template", type=click.Path(exists=True))
|
1461
|
+
@click.argument("schema_file", type=click.Path(exists=True))
|
1462
|
+
@all_options
|
1463
|
+
@click.pass_context
|
1464
|
+
def run(
|
1465
|
+
ctx: click.Context,
|
1466
|
+
task_template: str,
|
1467
|
+
schema_file: str,
|
1468
|
+
**kwargs: Any,
|
1469
|
+
) -> None:
|
1470
|
+
"""Run a structured task with template and schema.
|
1471
|
+
|
1472
|
+
TASK_TEMPLATE is the path to your Jinja2 template file that defines the task.
|
1473
|
+
SCHEMA_FILE is the path to your JSON schema file that defines the expected output structure.
|
1474
|
+
|
1475
|
+
The command supports various options for file handling, variable definition,
|
1476
|
+
model configuration, and output control. Use --help to see all available options.
|
1477
|
+
|
1478
|
+
Examples:
|
1479
|
+
# Basic usage
|
1480
|
+
ostruct run task.j2 schema.json
|
1481
|
+
|
1482
|
+
# Process multiple files
|
1483
|
+
ostruct run task.j2 schema.json -f code main.py -f test tests/test_main.py
|
1484
|
+
|
1485
|
+
# Scan directories recursively
|
1486
|
+
ostruct run task.j2 schema.json -d src ./src -R
|
1487
|
+
|
1488
|
+
# Define variables
|
1489
|
+
ostruct run task.j2 schema.json -V debug=true -J config='{"env":"prod"}'
|
1490
|
+
|
1491
|
+
# Configure model
|
1492
|
+
ostruct run task.j2 schema.json -m gpt-4 --temperature 0.7 --max-output-tokens 1000
|
1493
|
+
|
1494
|
+
# Control output
|
1495
|
+
ostruct run task.j2 schema.json --output-file result.json --verbose
|
1463
1496
|
"""
|
1464
1497
|
try:
|
1465
|
-
#
|
1466
|
-
|
1467
|
-
|
1468
|
-
|
1469
|
-
|
1498
|
+
# Convert Click parameters to typed dict
|
1499
|
+
params: CLIParams = {
|
1500
|
+
"task_file": task_template,
|
1501
|
+
"task": None,
|
1502
|
+
"schema_file": schema_file,
|
1503
|
+
}
|
1504
|
+
# Add only valid keys from kwargs
|
1505
|
+
valid_keys = set(CLIParams.__annotations__.keys())
|
1506
|
+
for k, v in kwargs.items():
|
1507
|
+
if k in valid_keys:
|
1508
|
+
params[k] = v # type: ignore[literal-required]
|
1509
|
+
|
1510
|
+
# Run the async function synchronously
|
1511
|
+
loop = asyncio.new_event_loop()
|
1512
|
+
asyncio.set_event_loop(loop)
|
1513
|
+
try:
|
1514
|
+
exit_code = loop.run_until_complete(run_cli_async(params))
|
1515
|
+
sys.exit(int(exit_code))
|
1516
|
+
finally:
|
1517
|
+
loop.close()
|
1518
|
+
|
1519
|
+
except (
|
1520
|
+
CLIError,
|
1521
|
+
InvalidJSONError,
|
1522
|
+
SchemaFileError,
|
1523
|
+
SchemaValidationError,
|
1524
|
+
) as e:
|
1525
|
+
handle_error(e)
|
1526
|
+
sys.exit(
|
1527
|
+
e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
|
1470
1528
|
)
|
1529
|
+
except click.UsageError as e:
|
1530
|
+
handle_error(e)
|
1531
|
+
sys.exit(ExitCode.USAGE_ERROR)
|
1532
|
+
except Exception as e:
|
1533
|
+
handle_error(e)
|
1534
|
+
sys.exit(ExitCode.INTERNAL_ERROR)
|
1535
|
+
|
1536
|
+
|
1537
|
+
# Remove the old @create_click_command() decorator and cli function definition
|
1538
|
+
# Keep all the other functions and code below this point
|
1539
|
+
|
1540
|
+
|
1541
|
+
async def validate_model_params(args: CLIParams) -> Dict[str, Any]:
|
1542
|
+
"""Validate model parameters and return a dictionary of valid parameters.
|
1543
|
+
|
1544
|
+
Args:
|
1545
|
+
args: Command line arguments
|
1546
|
+
|
1547
|
+
Returns:
|
1548
|
+
Dictionary of validated model parameters
|
1549
|
+
|
1550
|
+
Raises:
|
1551
|
+
CLIError: If model parameters are invalid
|
1552
|
+
"""
|
1553
|
+
params = {
|
1554
|
+
"temperature": args.get("temperature"),
|
1555
|
+
"max_output_tokens": args.get("max_output_tokens"),
|
1556
|
+
"top_p": args.get("top_p"),
|
1557
|
+
"frequency_penalty": args.get("frequency_penalty"),
|
1558
|
+
"presence_penalty": args.get("presence_penalty"),
|
1559
|
+
"reasoning_effort": args.get("reasoning_effort"),
|
1560
|
+
}
|
1561
|
+
# Remove None values
|
1562
|
+
params = {k: v for k, v in params.items() if v is not None}
|
1563
|
+
validate_model_parameters(args["model"], params)
|
1564
|
+
return params
|
1565
|
+
|
1566
|
+
|
1567
|
+
async def validate_inputs(
|
1568
|
+
args: CLIParams,
|
1569
|
+
) -> Tuple[
|
1570
|
+
SecurityManager, str, Dict[str, Any], Dict[str, Any], jinja2.Environment
|
1571
|
+
]:
|
1572
|
+
"""Validate all input parameters and return validated components.
|
1573
|
+
|
1574
|
+
Args:
|
1575
|
+
args: Command line arguments
|
1576
|
+
|
1577
|
+
Returns:
|
1578
|
+
Tuple containing:
|
1579
|
+
- SecurityManager instance
|
1580
|
+
- Task template string
|
1581
|
+
- Schema dictionary
|
1582
|
+
- Template context dictionary
|
1583
|
+
- Jinja2 environment
|
1584
|
+
|
1585
|
+
Raises:
|
1586
|
+
CLIError: For various validation errors
|
1587
|
+
"""
|
1588
|
+
logger.debug("=== Input Validation Phase ===")
|
1589
|
+
security_manager = validate_security_manager(
|
1590
|
+
base_dir=args.get("base_dir"),
|
1591
|
+
allowed_dirs=args.get("allowed_dirs"),
|
1592
|
+
allowed_dir_file=args.get("allowed_dir_file"),
|
1593
|
+
)
|
1594
|
+
|
1595
|
+
task_template = validate_task_template(
|
1596
|
+
args.get("task"), args.get("task_file")
|
1597
|
+
)
|
1598
|
+
logger.debug("Validating schema from %s", args["schema_file"])
|
1599
|
+
schema = validate_schema_file(
|
1600
|
+
args["schema_file"], args.get("verbose", False)
|
1601
|
+
)
|
1602
|
+
template_context = await create_template_context_from_args(
|
1603
|
+
args, security_manager
|
1604
|
+
)
|
1605
|
+
env = create_jinja_env()
|
1471
1606
|
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1607
|
+
return security_manager, task_template, schema, template_context, env
|
1608
|
+
|
1609
|
+
|
1610
|
+
async def process_templates(
|
1611
|
+
args: CLIParams,
|
1612
|
+
task_template: str,
|
1613
|
+
template_context: Dict[str, Any],
|
1614
|
+
env: jinja2.Environment,
|
1615
|
+
) -> Tuple[str, str]:
|
1616
|
+
"""Process system prompt and user prompt templates.
|
1617
|
+
|
1618
|
+
Args:
|
1619
|
+
args: Command line arguments
|
1620
|
+
task_template: Validated task template
|
1621
|
+
template_context: Template context dictionary
|
1622
|
+
env: Jinja2 environment
|
1623
|
+
|
1624
|
+
Returns:
|
1625
|
+
Tuple of (system_prompt, user_prompt)
|
1626
|
+
|
1627
|
+
Raises:
|
1628
|
+
CLIError: For template processing errors
|
1629
|
+
"""
|
1630
|
+
logger.debug("=== Template Processing Phase ===")
|
1631
|
+
system_prompt = process_system_prompt(
|
1632
|
+
task_template,
|
1633
|
+
args.get("system_prompt"),
|
1634
|
+
args.get("system_prompt_file"),
|
1635
|
+
template_context,
|
1636
|
+
env,
|
1637
|
+
args.get("ignore_task_sysprompt", False),
|
1638
|
+
)
|
1639
|
+
user_prompt = render_template(task_template, template_context, env)
|
1640
|
+
return system_prompt, user_prompt
|
1641
|
+
|
1642
|
+
|
1643
|
+
async def validate_model_and_schema(
|
1644
|
+
args: CLIParams,
|
1645
|
+
schema: Dict[str, Any],
|
1646
|
+
system_prompt: str,
|
1647
|
+
user_prompt: str,
|
1648
|
+
) -> Tuple[Type[BaseModel], List[Dict[str, str]], int, ModelRegistry]:
|
1649
|
+
"""Validate model compatibility and schema, and check token limits.
|
1650
|
+
|
1651
|
+
Args:
|
1652
|
+
args: Command line arguments
|
1653
|
+
schema: Schema dictionary
|
1654
|
+
system_prompt: Processed system prompt
|
1655
|
+
user_prompt: Processed user prompt
|
1656
|
+
|
1657
|
+
Returns:
|
1658
|
+
Tuple of (output_model, messages, total_tokens, registry)
|
1659
|
+
|
1660
|
+
Raises:
|
1661
|
+
CLIError: For validation errors
|
1662
|
+
ModelCreationError: When model creation fails
|
1663
|
+
SchemaValidationError: When schema is invalid
|
1664
|
+
"""
|
1665
|
+
logger.debug("=== Model & Schema Validation Phase ===")
|
1666
|
+
try:
|
1667
|
+
output_model = create_dynamic_model(
|
1668
|
+
schema,
|
1669
|
+
show_schema=args.get("show_model_schema", False),
|
1670
|
+
debug_validation=args.get("debug_validation", False),
|
1477
1671
|
)
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1672
|
+
logger.debug("Successfully created output model")
|
1673
|
+
except (
|
1674
|
+
SchemaFileError,
|
1675
|
+
InvalidJSONError,
|
1676
|
+
SchemaValidationError,
|
1677
|
+
ModelCreationError,
|
1678
|
+
) as e:
|
1679
|
+
logger.error("Schema error: %s", str(e))
|
1680
|
+
raise
|
1681
|
+
|
1682
|
+
if not supports_structured_output(args["model"]):
|
1683
|
+
msg = f"Model {args['model']} does not support structured output"
|
1684
|
+
logger.error(msg)
|
1685
|
+
raise ModelNotSupportedError(msg)
|
1686
|
+
|
1687
|
+
messages = [
|
1688
|
+
{"role": "system", "content": system_prompt},
|
1689
|
+
{"role": "user", "content": user_prompt},
|
1690
|
+
]
|
1691
|
+
|
1692
|
+
total_tokens = estimate_tokens_with_encoding(messages, args["model"])
|
1693
|
+
registry = ModelRegistry()
|
1694
|
+
capabilities = registry.get_capabilities(args["model"])
|
1695
|
+
context_limit = capabilities.context_window
|
1696
|
+
|
1697
|
+
if total_tokens > context_limit:
|
1698
|
+
msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
|
1699
|
+
logger.error(msg)
|
1700
|
+
raise CLIError(
|
1701
|
+
msg,
|
1702
|
+
context={
|
1703
|
+
"total_tokens": total_tokens,
|
1704
|
+
"context_limit": context_limit,
|
1705
|
+
},
|
1488
1706
|
)
|
1489
|
-
rendered_task = render_template(task_template, template_context, env)
|
1490
|
-
logger.info("Rendered task template: %s", rendered_task)
|
1491
1707
|
|
1492
|
-
|
1493
|
-
logger.info("DRY RUN MODE")
|
1494
|
-
return ExitCode.SUCCESS
|
1708
|
+
return output_model, messages, total_tokens, registry
|
1495
1709
|
|
1496
|
-
# Create output model
|
1497
|
-
logger.debug("Creating output model")
|
1498
|
-
try:
|
1499
|
-
output_model = create_dynamic_model(
|
1500
|
-
schema,
|
1501
|
-
base_name="OutputModel",
|
1502
|
-
show_schema=args.show_model_schema,
|
1503
|
-
debug_validation=args.debug_validation,
|
1504
|
-
)
|
1505
|
-
logger.debug("Successfully created output model")
|
1506
|
-
except (
|
1507
|
-
SchemaFileError,
|
1508
|
-
InvalidJSONError,
|
1509
|
-
SchemaValidationError,
|
1510
|
-
ModelCreationError,
|
1511
|
-
) as e:
|
1512
|
-
logger.error("Schema error: %s", str(e))
|
1513
|
-
raise # Let the error propagate with its context
|
1514
|
-
|
1515
|
-
# Validate model support and token usage
|
1516
|
-
try:
|
1517
|
-
supports_structured_output(args.model)
|
1518
|
-
except (ModelNotSupportedError, ModelVersionError) as e:
|
1519
|
-
logger.error("Model validation error: %s", str(e))
|
1520
|
-
raise # Let the error propagate with its context
|
1521
|
-
|
1522
|
-
messages = [
|
1523
|
-
{"role": "system", "content": system_prompt},
|
1524
|
-
{"role": "user", "content": rendered_task},
|
1525
|
-
]
|
1526
|
-
total_tokens = estimate_tokens_for_chat(messages, args.model)
|
1527
|
-
context_limit = get_context_window_limit(args.model)
|
1528
|
-
if total_tokens > context_limit:
|
1529
|
-
msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
|
1530
|
-
logger.error(msg)
|
1531
|
-
raise CLIError(
|
1532
|
-
msg,
|
1533
|
-
context={
|
1534
|
-
"total_tokens": total_tokens,
|
1535
|
-
"context_limit": context_limit,
|
1536
|
-
},
|
1537
|
-
)
|
1538
1710
|
|
1539
|
-
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1711
|
+
async def execute_model(
|
1712
|
+
args: CLIParams,
|
1713
|
+
params: Dict[str, Any],
|
1714
|
+
output_model: Type[BaseModel],
|
1715
|
+
system_prompt: str,
|
1716
|
+
user_prompt: str,
|
1717
|
+
) -> ExitCode:
|
1718
|
+
"""Execute the model and handle the response.
|
1545
1719
|
|
1546
|
-
|
1720
|
+
Args:
|
1721
|
+
args: Command line arguments
|
1722
|
+
params: Validated model parameters
|
1723
|
+
output_model: Generated Pydantic model
|
1724
|
+
system_prompt: Processed system prompt
|
1725
|
+
user_prompt: Processed user prompt
|
1547
1726
|
|
1548
|
-
|
1549
|
-
|
1550
|
-
level: int, message: str, extra: dict[str, Any]
|
1551
|
-
) -> None:
|
1552
|
-
if args.debug_openai_stream:
|
1553
|
-
if extra:
|
1554
|
-
extra_str = json.dumps(extra, indent=2)
|
1555
|
-
message = f"{message}\nDetails:\n{extra_str}"
|
1556
|
-
logger.log(level, message, extra=extra)
|
1727
|
+
Returns:
|
1728
|
+
Exit code indicating success or failure
|
1557
1729
|
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
|
1573
|
-
|
1730
|
+
Raises:
|
1731
|
+
CLIError: For execution errors
|
1732
|
+
"""
|
1733
|
+
logger.debug("=== Execution Phase ===")
|
1734
|
+
api_key = args.get("api_key") or os.getenv("OPENAI_API_KEY")
|
1735
|
+
if not api_key:
|
1736
|
+
msg = "No API key provided. Set OPENAI_API_KEY environment variable or use --api-key"
|
1737
|
+
logger.error(msg)
|
1738
|
+
raise CLIError(msg, exit_code=ExitCode.API_ERROR)
|
1739
|
+
|
1740
|
+
client = AsyncOpenAI(api_key=api_key, timeout=args.get("timeout", 60.0))
|
1741
|
+
|
1742
|
+
# Create detailed log callback
|
1743
|
+
def log_callback(level: int, message: str, extra: dict[str, Any]) -> None:
|
1744
|
+
if args.get("debug_openai_stream", False):
|
1745
|
+
if extra:
|
1746
|
+
extra_str = LogSerializer.serialize_log_extra(extra)
|
1747
|
+
if extra_str:
|
1748
|
+
logger.debug("%s\nExtra:\n%s", message, extra_str)
|
1749
|
+
else:
|
1750
|
+
logger.debug("%s\nExtra: Failed to serialize", message)
|
1751
|
+
else:
|
1752
|
+
logger.debug(message)
|
1753
|
+
|
1754
|
+
try:
|
1755
|
+
# Create output buffer
|
1756
|
+
output_buffer = []
|
1757
|
+
|
1758
|
+
# Stream the response
|
1759
|
+
async for response in stream_structured_output(
|
1760
|
+
client=client,
|
1761
|
+
model=args["model"],
|
1762
|
+
system_prompt=system_prompt,
|
1763
|
+
user_prompt=user_prompt,
|
1764
|
+
output_schema=output_model,
|
1765
|
+
output_file=args.get("output_file"),
|
1766
|
+
**params, # Only pass validated model parameters
|
1767
|
+
on_log=log_callback, # Pass logging callback separately
|
1768
|
+
):
|
1769
|
+
output_buffer.append(response)
|
1770
|
+
|
1771
|
+
# Handle final output
|
1772
|
+
output_file = args.get("output_file")
|
1773
|
+
if output_file:
|
1774
|
+
with open(output_file, "w") as f:
|
1775
|
+
if len(output_buffer) == 1:
|
1776
|
+
f.write(output_buffer[0].model_dump_json(indent=2))
|
1777
|
+
else:
|
1778
|
+
# Build complete JSON array as a single string
|
1779
|
+
json_output = "[\n"
|
1780
|
+
for i, response in enumerate(output_buffer):
|
1781
|
+
if i > 0:
|
1782
|
+
json_output += ",\n"
|
1783
|
+
json_output += " " + response.model_dump_json(
|
1784
|
+
indent=2
|
1785
|
+
).replace("\n", "\n ")
|
1786
|
+
json_output += "\n]"
|
1787
|
+
f.write(json_output)
|
1788
|
+
else:
|
1789
|
+
# Write to stdout when no output file is specified
|
1790
|
+
if len(output_buffer) == 1:
|
1791
|
+
print(output_buffer[0].model_dump_json(indent=2))
|
1792
|
+
else:
|
1793
|
+
# Build complete JSON array as a single string
|
1794
|
+
json_output = "[\n"
|
1795
|
+
for i, response in enumerate(output_buffer):
|
1796
|
+
if i > 0:
|
1797
|
+
json_output += ",\n"
|
1798
|
+
json_output += " " + response.model_dump_json(
|
1799
|
+
indent=2
|
1800
|
+
).replace("\n", "\n ")
|
1801
|
+
json_output += "\n]"
|
1802
|
+
print(json_output)
|
1803
|
+
|
1804
|
+
return ExitCode.SUCCESS
|
1805
|
+
|
1806
|
+
except (
|
1807
|
+
StreamInterruptedError,
|
1808
|
+
StreamBufferError,
|
1809
|
+
StreamParseError,
|
1810
|
+
APIResponseError,
|
1811
|
+
EmptyResponseError,
|
1812
|
+
InvalidResponseFormatError,
|
1813
|
+
) as e:
|
1814
|
+
logger.error("Stream error: %s", str(e))
|
1815
|
+
raise CLIError(str(e), exit_code=ExitCode.API_ERROR)
|
1816
|
+
except Exception as e:
|
1817
|
+
logger.exception("Unexpected error during streaming")
|
1818
|
+
raise CLIError(str(e), exit_code=ExitCode.UNKNOWN_ERROR)
|
1819
|
+
finally:
|
1820
|
+
await client.close()
|
1821
|
+
|
1822
|
+
|
1823
|
+
async def run_cli_async(args: CLIParams) -> ExitCode:
|
1824
|
+
"""Async wrapper for CLI operations.
|
1825
|
+
|
1826
|
+
Returns:
|
1827
|
+
Exit code to return from the CLI
|
1828
|
+
|
1829
|
+
Raises:
|
1830
|
+
CLIError: For various error conditions
|
1831
|
+
KeyboardInterrupt: When operation is cancelled by user
|
1832
|
+
"""
|
1833
|
+
try:
|
1834
|
+
# 0. Model Parameter Validation
|
1835
|
+
logger.debug("=== Model Parameter Validation ===")
|
1836
|
+
params = await validate_model_params(args)
|
1837
|
+
|
1838
|
+
# 1. Input Validation Phase
|
1839
|
+
security_manager, task_template, schema, template_context, env = (
|
1840
|
+
await validate_inputs(args)
|
1841
|
+
)
|
1842
|
+
|
1843
|
+
# 2. Template Processing Phase
|
1844
|
+
system_prompt, user_prompt = await process_templates(
|
1845
|
+
args, task_template, template_context, env
|
1846
|
+
)
|
1847
|
+
|
1848
|
+
# 3. Model & Schema Validation Phase
|
1849
|
+
output_model, messages, total_tokens, registry = (
|
1850
|
+
await validate_model_and_schema(
|
1851
|
+
args, schema, system_prompt, user_prompt
|
1574
1852
|
)
|
1853
|
+
)
|
1854
|
+
|
1855
|
+
# 4. Dry Run Output Phase
|
1856
|
+
if args.get("dry_run", False):
|
1857
|
+
logger.info("\n=== Dry Run Summary ===")
|
1858
|
+
logger.info("✓ Template rendered successfully")
|
1859
|
+
logger.info("✓ Schema validation passed")
|
1860
|
+
logger.info("✓ Model compatibility validated")
|
1861
|
+
logger.info(
|
1862
|
+
f"✓ Token count: {total_tokens}/{registry.get_capabilities(args['model']).context_window}"
|
1863
|
+
)
|
1864
|
+
|
1865
|
+
if args.get("verbose", False):
|
1866
|
+
logger.info("\nSystem Prompt:")
|
1867
|
+
logger.info("-" * 40)
|
1868
|
+
logger.info(system_prompt)
|
1869
|
+
logger.info("\nRendered Template:")
|
1870
|
+
logger.info("-" * 40)
|
1871
|
+
logger.info(user_prompt)
|
1872
|
+
|
1575
1873
|
return ExitCode.SUCCESS
|
1576
|
-
|
1577
|
-
|
1578
|
-
|
1579
|
-
|
1580
|
-
|
1581
|
-
EmptyResponseError,
|
1582
|
-
InvalidResponseFormatError,
|
1583
|
-
) as e:
|
1584
|
-
logger.error("Stream error: %s", str(e))
|
1585
|
-
raise # Let stream errors propagate
|
1586
|
-
except (APIConnectionError, InternalServerError) as e:
|
1587
|
-
logger.error("API connection error: %s", str(e))
|
1588
|
-
raise APIResponseError(str(e)) # Convert to our error type
|
1589
|
-
except RateLimitError as e:
|
1590
|
-
logger.error("Rate limit exceeded: %s", str(e))
|
1591
|
-
raise APIResponseError(str(e)) # Convert to our error type
|
1592
|
-
except (BadRequestError, AuthenticationError, OpenAIClientError) as e:
|
1593
|
-
logger.error("API client error: %s", str(e))
|
1594
|
-
raise APIResponseError(str(e)) # Convert to our error type
|
1595
|
-
finally:
|
1596
|
-
await client.close()
|
1874
|
+
|
1875
|
+
# 5. Execution Phase
|
1876
|
+
return await execute_model(
|
1877
|
+
args, params, output_model, system_prompt, user_prompt
|
1878
|
+
)
|
1597
1879
|
|
1598
1880
|
except KeyboardInterrupt:
|
1599
1881
|
logger.info("Operation cancelled by user")
|
1600
|
-
|
1882
|
+
raise
|
1601
1883
|
except Exception as e:
|
1602
1884
|
if isinstance(e, CLIError):
|
1603
1885
|
raise # Let our custom errors propagate
|
@@ -1611,68 +1893,35 @@ def create_cli() -> click.Command:
|
|
1611
1893
|
Returns:
|
1612
1894
|
click.Command: The CLI command object
|
1613
1895
|
"""
|
1614
|
-
|
1615
|
-
@create_click_command()
|
1616
|
-
def cli(**kwargs: Any) -> None:
|
1617
|
-
"""CLI entry point for structured OpenAI API calls."""
|
1618
|
-
try:
|
1619
|
-
args = Namespace(**kwargs)
|
1620
|
-
|
1621
|
-
# Validate required arguments first
|
1622
|
-
if not args.task and not args.task_file:
|
1623
|
-
raise click.UsageError(
|
1624
|
-
"Must specify either --task or --task-file"
|
1625
|
-
)
|
1626
|
-
if not args.schema_file:
|
1627
|
-
raise click.UsageError("Missing option '--schema-file'")
|
1628
|
-
if args.task and args.task_file:
|
1629
|
-
raise click.UsageError(
|
1630
|
-
"Cannot specify both --task and --task-file"
|
1631
|
-
)
|
1632
|
-
if args.system_prompt and args.system_prompt_file:
|
1633
|
-
raise click.UsageError(
|
1634
|
-
"Cannot specify both --system-prompt and --system-prompt-file"
|
1635
|
-
)
|
1636
|
-
|
1637
|
-
# Run the async function synchronously
|
1638
|
-
exit_code = asyncio.run(run_cli_async(args))
|
1639
|
-
|
1640
|
-
if exit_code != ExitCode.SUCCESS:
|
1641
|
-
error_msg = f"Command failed with exit code {exit_code}"
|
1642
|
-
if hasattr(ExitCode, exit_code.name):
|
1643
|
-
error_msg = f"{error_msg} ({exit_code.name})"
|
1644
|
-
raise CLIError(error_msg, context={"exit_code": exit_code})
|
1645
|
-
|
1646
|
-
except click.UsageError:
|
1647
|
-
# Let Click handle usage errors directly
|
1648
|
-
raise
|
1649
|
-
except InvalidJSONError:
|
1650
|
-
# Let InvalidJSONError propagate directly
|
1651
|
-
raise
|
1652
|
-
except CLIError:
|
1653
|
-
# Let our custom errors propagate with their context
|
1654
|
-
raise
|
1655
|
-
except Exception as e:
|
1656
|
-
# Convert other exceptions to CLIError
|
1657
|
-
logger.exception("Unexpected error")
|
1658
|
-
raise CLIError(str(e), context={"error_type": type(e).__name__})
|
1659
|
-
|
1660
|
-
# The decorated function is a Command, but mypy can't detect this
|
1661
|
-
return cast(click.Command, cli)
|
1896
|
+
return cli # The decorator already returns a Command
|
1662
1897
|
|
1663
1898
|
|
1664
1899
|
def main() -> None:
|
1665
1900
|
"""Main entry point for the CLI."""
|
1666
|
-
|
1667
|
-
|
1901
|
+
try:
|
1902
|
+
cli(standalone_mode=False)
|
1903
|
+
except (
|
1904
|
+
CLIError,
|
1905
|
+
InvalidJSONError,
|
1906
|
+
SchemaFileError,
|
1907
|
+
SchemaValidationError,
|
1908
|
+
) as e:
|
1909
|
+
handle_error(e)
|
1910
|
+
sys.exit(
|
1911
|
+
e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
|
1912
|
+
)
|
1913
|
+
except click.UsageError as e:
|
1914
|
+
handle_error(e)
|
1915
|
+
sys.exit(ExitCode.USAGE_ERROR)
|
1916
|
+
except Exception as e:
|
1917
|
+
handle_error(e)
|
1918
|
+
sys.exit(ExitCode.INTERNAL_ERROR)
|
1668
1919
|
|
1669
1920
|
|
1670
1921
|
# Export public API
|
1671
1922
|
__all__ = [
|
1672
1923
|
"ExitCode",
|
1673
|
-
"
|
1674
|
-
"get_context_window_limit",
|
1675
|
-
"get_default_token_limit",
|
1924
|
+
"estimate_tokens_with_encoding",
|
1676
1925
|
"parse_json_var",
|
1677
1926
|
"create_dynamic_model",
|
1678
1927
|
"validate_path_mapping",
|
@@ -1690,17 +1939,16 @@ def create_dynamic_model(
|
|
1690
1939
|
"""Create a Pydantic model from a JSON schema.
|
1691
1940
|
|
1692
1941
|
Args:
|
1693
|
-
schema: JSON schema
|
1694
|
-
base_name:
|
1695
|
-
show_schema: Whether to show the generated schema
|
1696
|
-
debug_validation: Whether to
|
1942
|
+
schema: JSON schema to create model from
|
1943
|
+
base_name: Name for the model class
|
1944
|
+
show_schema: Whether to show the generated model schema
|
1945
|
+
debug_validation: Whether to show detailed validation errors
|
1697
1946
|
|
1698
1947
|
Returns:
|
1699
|
-
|
1948
|
+
Type[BaseModel]: The generated Pydantic model class
|
1700
1949
|
|
1701
1950
|
Raises:
|
1702
|
-
|
1703
|
-
SchemaValidationError: When schema is invalid
|
1951
|
+
ModelValidationError: If the schema is invalid
|
1704
1952
|
"""
|
1705
1953
|
if debug_validation:
|
1706
1954
|
logger.info("Creating dynamic model from schema:")
|
@@ -1792,18 +2040,17 @@ def create_dynamic_model(
|
|
1792
2040
|
" JSON Schema Extra: %s", config.get("json_schema_extra")
|
1793
2041
|
)
|
1794
2042
|
|
1795
|
-
#
|
1796
|
-
field_definitions: Dict[str, FieldDefinition] = {}
|
2043
|
+
# Process schema properties into fields
|
1797
2044
|
properties = schema.get("properties", {})
|
2045
|
+
required = schema.get("required", [])
|
1798
2046
|
|
2047
|
+
field_definitions: Dict[str, Tuple[Type[Any], FieldInfoType]] = {}
|
1799
2048
|
for field_name, field_schema in properties.items():
|
1800
|
-
|
1801
|
-
|
1802
|
-
|
1803
|
-
logger.info(
|
1804
|
-
" Schema: %s", json.dumps(field_schema, indent=2)
|
1805
|
-
)
|
2049
|
+
if debug_validation:
|
2050
|
+
logger.info("Processing field %s:", field_name)
|
2051
|
+
logger.info(" Schema: %s", json.dumps(field_schema, indent=2))
|
1806
2052
|
|
2053
|
+
try:
|
1807
2054
|
python_type, field = _get_type_with_constraints(
|
1808
2055
|
field_schema, field_name, base_name
|
1809
2056
|
)
|
@@ -1838,22 +2085,24 @@ def create_dynamic_model(
|
|
1838
2085
|
raise ModelValidationError(base_name, [str(e)])
|
1839
2086
|
|
1840
2087
|
# Create the model with the fields
|
1841
|
-
|
1842
|
-
|
1843
|
-
|
1844
|
-
|
1845
|
-
|
1846
|
-
|
1847
|
-
|
1848
|
-
|
1849
|
-
|
1850
|
-
|
1851
|
-
|
1852
|
-
|
1853
|
-
|
1854
|
-
},
|
2088
|
+
field_defs: Dict[str, Any] = {
|
2089
|
+
name: (
|
2090
|
+
(
|
2091
|
+
cast(Type[Any], field_type)
|
2092
|
+
if is_container_type(field_type)
|
2093
|
+
else field_type
|
2094
|
+
),
|
2095
|
+
field,
|
2096
|
+
)
|
2097
|
+
for name, (field_type, field) in field_definitions.items()
|
2098
|
+
}
|
2099
|
+
model: Type[BaseModel] = create_model(
|
2100
|
+
base_name, __config__=config, **field_defs
|
1855
2101
|
)
|
1856
2102
|
|
2103
|
+
# Set the model config after creation
|
2104
|
+
model.model_config = config
|
2105
|
+
|
1857
2106
|
if debug_validation:
|
1858
2107
|
logger.info("Successfully created model: %s", model.__name__)
|
1859
2108
|
logger.info("Model config: %s", dict(model.model_config))
|
@@ -1870,10 +2119,6 @@ def create_dynamic_model(
|
|
1870
2119
|
logger.error("Schema validation failed:")
|
1871
2120
|
logger.error(" Error type: %s", type(e).__name__)
|
1872
2121
|
logger.error(" Error message: %s", str(e))
|
1873
|
-
if hasattr(e, "errors"):
|
1874
|
-
logger.error(" Validation errors:")
|
1875
|
-
for error in e.errors():
|
1876
|
-
logger.error(" - %s", error)
|
1877
2122
|
validation_errors = (
|
1878
2123
|
[str(err) for err in e.errors()]
|
1879
2124
|
if hasattr(e, "errors")
|
@@ -1881,7 +2126,7 @@ def create_dynamic_model(
|
|
1881
2126
|
)
|
1882
2127
|
raise ModelValidationError(base_name, validation_errors)
|
1883
2128
|
|
1884
|
-
return
|
2129
|
+
return model
|
1885
2130
|
|
1886
2131
|
except Exception as e:
|
1887
2132
|
if debug_validation:
|