ostruct-cli 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ostruct/cli/__init__.py +2 -2
- ostruct/cli/cli.py +466 -604
- ostruct/cli/click_options.py +257 -0
- ostruct/cli/errors.py +234 -183
- ostruct/cli/file_info.py +154 -50
- ostruct/cli/file_list.py +189 -64
- ostruct/cli/file_utils.py +95 -67
- ostruct/cli/path_utils.py +58 -77
- ostruct/cli/security/__init__.py +32 -0
- ostruct/cli/security/allowed_checker.py +47 -0
- ostruct/cli/security/case_manager.py +75 -0
- ostruct/cli/security/errors.py +184 -0
- ostruct/cli/security/normalization.py +161 -0
- ostruct/cli/security/safe_joiner.py +211 -0
- ostruct/cli/security/security_manager.py +353 -0
- ostruct/cli/security/symlink_resolver.py +483 -0
- ostruct/cli/security/types.py +108 -0
- ostruct/cli/security/windows_paths.py +404 -0
- ostruct/cli/template_filters.py +8 -5
- ostruct/cli/template_io.py +4 -2
- {ostruct_cli-0.2.0.dist-info → ostruct_cli-0.4.0.dist-info}/METADATA +9 -6
- ostruct_cli-0.4.0.dist-info/RECORD +36 -0
- ostruct/cli/security.py +0 -323
- ostruct/cli/security_types.py +0 -49
- ostruct_cli-0.2.0.dist-info/RECORD +0 -27
- {ostruct_cli-0.2.0.dist-info → ostruct_cli-0.4.0.dist-info}/LICENSE +0 -0
- {ostruct_cli-0.2.0.dist-info → ostruct_cli-0.4.0.dist-info}/WHEEL +0 -0
- {ostruct_cli-0.2.0.dist-info → ostruct_cli-0.4.0.dist-info}/entry_points.txt +0 -0
ostruct/cli/cli.py
CHANGED
@@ -1,18 +1,12 @@
|
|
1
1
|
"""Command-line interface for making structured OpenAI API calls."""
|
2
2
|
|
3
|
-
import argparse
|
4
3
|
import asyncio
|
5
4
|
import json
|
6
5
|
import logging
|
7
6
|
import os
|
8
7
|
import sys
|
8
|
+
from dataclasses import dataclass
|
9
9
|
from enum import Enum, IntEnum
|
10
|
-
|
11
|
-
if sys.version_info >= (3, 11):
|
12
|
-
from enum import StrEnum
|
13
|
-
|
14
|
-
from datetime import date, datetime, time
|
15
|
-
from pathlib import Path
|
16
10
|
from typing import (
|
17
11
|
Any,
|
18
12
|
Dict,
|
@@ -29,6 +23,13 @@ from typing import (
|
|
29
23
|
overload,
|
30
24
|
)
|
31
25
|
|
26
|
+
if sys.version_info >= (3, 11):
|
27
|
+
from enum import StrEnum
|
28
|
+
|
29
|
+
from datetime import date, datetime, time
|
30
|
+
from pathlib import Path
|
31
|
+
|
32
|
+
import click
|
32
33
|
import jinja2
|
33
34
|
import tiktoken
|
34
35
|
import yaml
|
@@ -42,6 +43,8 @@ from openai import (
|
|
42
43
|
)
|
43
44
|
from openai_structured.client import (
|
44
45
|
async_openai_structured_stream,
|
46
|
+
get_context_window_limit,
|
47
|
+
get_default_token_limit,
|
45
48
|
supports_structured_output,
|
46
49
|
)
|
47
50
|
from openai_structured.errors import (
|
@@ -71,8 +74,11 @@ from pydantic.functional_validators import BeforeValidator
|
|
71
74
|
from pydantic.types import constr
|
72
75
|
from typing_extensions import TypeAlias
|
73
76
|
|
74
|
-
from
|
77
|
+
from ostruct.cli.click_options import create_click_command
|
78
|
+
|
79
|
+
from .. import __version__ # noqa: F401 - Used in package metadata
|
75
80
|
from .errors import (
|
81
|
+
CLIError,
|
76
82
|
DirectoryNotFoundError,
|
77
83
|
FieldDefinitionError,
|
78
84
|
FileNotFoundError,
|
@@ -89,7 +95,6 @@ from .errors import (
|
|
89
95
|
)
|
90
96
|
from .file_utils import FileInfoList, TemplateValue, collect_files
|
91
97
|
from .path_utils import validate_path_mapping
|
92
|
-
from .progress import ProgressContext
|
93
98
|
from .security import SecurityManager
|
94
99
|
from .template_env import create_jinja_env
|
95
100
|
from .template_utils import SystemPromptError, render_template
|
@@ -97,6 +102,45 @@ from .template_utils import SystemPromptError, render_template
|
|
97
102
|
# Constants
|
98
103
|
DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
|
99
104
|
|
105
|
+
|
106
|
+
@dataclass
|
107
|
+
class Namespace:
|
108
|
+
"""Compatibility class to mimic argparse.Namespace for existing code."""
|
109
|
+
|
110
|
+
task: Optional[str]
|
111
|
+
task_file: Optional[str]
|
112
|
+
file: List[str]
|
113
|
+
files: List[str]
|
114
|
+
dir: List[str]
|
115
|
+
allowed_dir: List[str]
|
116
|
+
base_dir: str
|
117
|
+
allowed_dir_file: Optional[str]
|
118
|
+
dir_recursive: bool
|
119
|
+
dir_ext: Optional[str]
|
120
|
+
var: List[str]
|
121
|
+
json_var: List[str]
|
122
|
+
system_prompt: Optional[str]
|
123
|
+
system_prompt_file: Optional[str]
|
124
|
+
ignore_task_sysprompt: bool
|
125
|
+
schema_file: str
|
126
|
+
model: str
|
127
|
+
temperature: float
|
128
|
+
max_tokens: Optional[int]
|
129
|
+
top_p: float
|
130
|
+
frequency_penalty: float
|
131
|
+
presence_penalty: float
|
132
|
+
timeout: float
|
133
|
+
output_file: Optional[str]
|
134
|
+
dry_run: bool
|
135
|
+
no_progress: bool
|
136
|
+
api_key: Optional[str]
|
137
|
+
verbose: bool
|
138
|
+
debug_openai_stream: bool
|
139
|
+
show_model_schema: bool
|
140
|
+
debug_validation: bool
|
141
|
+
progress_level: str = "basic" # Default to 'basic' if not specified
|
142
|
+
|
143
|
+
|
100
144
|
# Set up logging
|
101
145
|
logger = logging.getLogger(__name__)
|
102
146
|
|
@@ -360,65 +404,43 @@ V = TypeVar("V")
|
|
360
404
|
|
361
405
|
|
362
406
|
def estimate_tokens_for_chat(
|
363
|
-
messages: List[Dict[str, str]],
|
407
|
+
messages: List[Dict[str, str]],
|
408
|
+
model: str,
|
409
|
+
encoder: Any = None,
|
364
410
|
) -> int:
|
365
|
-
"""Estimate the number of tokens in a chat completion.
|
366
|
-
try:
|
367
|
-
encoding = tiktoken.encoding_for_model(model)
|
368
|
-
except KeyError:
|
369
|
-
# Fall back to cl100k_base for unknown models
|
370
|
-
encoding = tiktoken.get_encoding("cl100k_base")
|
371
|
-
|
372
|
-
num_tokens = 0
|
373
|
-
for message in messages:
|
374
|
-
# Add message overhead
|
375
|
-
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
376
|
-
for key, value in message.items():
|
377
|
-
num_tokens += len(encoding.encode(str(value)))
|
378
|
-
if key == "name": # if there's a name, the role is omitted
|
379
|
-
num_tokens += -1 # role is always required and always 1 token
|
380
|
-
num_tokens += 2 # every reply is primed with <im_start>assistant
|
381
|
-
return num_tokens
|
382
|
-
|
383
|
-
|
384
|
-
def get_default_token_limit(model: str) -> int:
|
385
|
-
"""Get the default token limit for a given model.
|
386
|
-
|
387
|
-
Note: These limits are based on current OpenAI model specifications as of 2024 and may
|
388
|
-
need to be updated if OpenAI changes the models' capabilities.
|
411
|
+
"""Estimate the number of tokens in a chat completion.
|
389
412
|
|
390
413
|
Args:
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
The default token limit for the model
|
414
|
+
messages: List of chat messages
|
415
|
+
model: Model name
|
416
|
+
encoder: Optional tiktoken encoder for testing. If provided, only uses encoder.encode() results.
|
395
417
|
"""
|
396
|
-
if
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
"""
|
416
|
-
if "o1" in model:
|
417
|
-
return 200_000 # o1 supports 200K total context window
|
418
|
-
elif "gpt-4o" in model:
|
419
|
-
return 128_000 # gpt-4o and gpt-4o-mini support 128K context window
|
418
|
+
if encoder is None:
|
419
|
+
try:
|
420
|
+
# Try to get the encoding for the specific model
|
421
|
+
encoder = tiktoken.get_encoding("o200k_base")
|
422
|
+
except KeyError:
|
423
|
+
# Fall back to cl100k_base for unknown models
|
424
|
+
encoder = tiktoken.get_encoding("cl100k_base")
|
425
|
+
|
426
|
+
# Use standard token counting logic for real tiktoken encoders
|
427
|
+
num_tokens = 0
|
428
|
+
for message in messages:
|
429
|
+
# Add message overhead
|
430
|
+
num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
|
431
|
+
for key, value in message.items():
|
432
|
+
num_tokens += len(encoder.encode(str(value)))
|
433
|
+
if key == "name": # if there's a name, the role is omitted
|
434
|
+
num_tokens -= 1 # role is omitted
|
435
|
+
num_tokens += 2 # every reply is primed with <im_start>assistant
|
436
|
+
return num_tokens
|
420
437
|
else:
|
421
|
-
return
|
438
|
+
# For mock encoders in tests, just return the length of encoded content
|
439
|
+
num_tokens = 0
|
440
|
+
for message in messages:
|
441
|
+
for value in message.values():
|
442
|
+
num_tokens += len(encoder.encode(str(value)))
|
443
|
+
return num_tokens
|
422
444
|
|
423
445
|
|
424
446
|
def validate_token_limits(
|
@@ -460,6 +482,7 @@ def validate_token_limits(
|
|
460
482
|
def process_system_prompt(
|
461
483
|
task_template: str,
|
462
484
|
system_prompt: Optional[str],
|
485
|
+
system_prompt_file: Optional[str],
|
463
486
|
template_context: Dict[str, Any],
|
464
487
|
env: jinja2.Environment,
|
465
488
|
ignore_task_sysprompt: bool = False,
|
@@ -468,7 +491,8 @@ def process_system_prompt(
|
|
468
491
|
|
469
492
|
Args:
|
470
493
|
task_template: The task template string
|
471
|
-
system_prompt: Optional system prompt string
|
494
|
+
system_prompt: Optional system prompt string
|
495
|
+
system_prompt_file: Optional path to system prompt file
|
472
496
|
template_context: Template context for rendering
|
473
497
|
env: Jinja2 environment
|
474
498
|
ignore_task_sysprompt: Whether to ignore system prompt in task template
|
@@ -484,18 +508,24 @@ def process_system_prompt(
|
|
484
508
|
# Default system prompt
|
485
509
|
default_prompt = "You are a helpful assistant."
|
486
510
|
|
511
|
+
# Check for conflicting arguments
|
512
|
+
if system_prompt is not None and system_prompt_file is not None:
|
513
|
+
raise SystemPromptError(
|
514
|
+
"Cannot specify both --system-prompt and --system-prompt-file"
|
515
|
+
)
|
516
|
+
|
487
517
|
# Try to get system prompt from CLI argument first
|
488
|
-
if
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
raise SystemPromptError(f"Invalid system prompt file: {e}")
|
518
|
+
if system_prompt_file is not None:
|
519
|
+
try:
|
520
|
+
name, path = validate_path_mapping(
|
521
|
+
f"system_prompt={system_prompt_file}"
|
522
|
+
)
|
523
|
+
with open(path, "r", encoding="utf-8") as f:
|
524
|
+
system_prompt = f.read().strip()
|
525
|
+
except (FileNotFoundError, PathSecurityError) as e:
|
526
|
+
raise SystemPromptError(f"Invalid system prompt file: {e}")
|
498
527
|
|
528
|
+
if system_prompt is not None:
|
499
529
|
# Render system prompt with template context
|
500
530
|
try:
|
501
531
|
template = env.from_string(system_prompt)
|
@@ -618,30 +648,45 @@ def _validate_path_mapping_internal(
|
|
618
648
|
ValueError: If the format is invalid (missing "=").
|
619
649
|
OSError: If there is an underlying OS error (permissions, etc.).
|
620
650
|
"""
|
651
|
+
logger = logging.getLogger(__name__)
|
652
|
+
logger.debug("Starting path validation for mapping: %r", mapping)
|
653
|
+
logger.debug("Parameters - is_dir: %r, base_dir: %r", is_dir, base_dir)
|
654
|
+
|
621
655
|
try:
|
622
656
|
if not mapping or "=" not in mapping:
|
657
|
+
logger.debug("Invalid mapping format: %r", mapping)
|
623
658
|
raise ValueError(
|
624
659
|
"Invalid path mapping format. Expected format: name=path"
|
625
660
|
)
|
626
661
|
|
627
662
|
name, path = mapping.split("=", 1)
|
663
|
+
logger.debug("Split mapping - name: %r, path: %r", name, path)
|
664
|
+
|
628
665
|
if not name:
|
666
|
+
logger.debug("Empty name in mapping")
|
629
667
|
raise VariableNameError(
|
630
668
|
f"Empty name in {'directory' if is_dir else 'file'} mapping"
|
631
669
|
)
|
632
670
|
|
633
671
|
if not path:
|
672
|
+
logger.debug("Empty path in mapping")
|
634
673
|
raise VariableValueError("Path cannot be empty")
|
635
674
|
|
636
675
|
# Convert to Path object and resolve against base_dir if provided
|
676
|
+
logger.debug("Creating Path object for: %r", path)
|
637
677
|
path_obj = Path(path)
|
638
678
|
if base_dir:
|
679
|
+
logger.debug("Resolving against base_dir: %r", base_dir)
|
639
680
|
path_obj = Path(base_dir) / path_obj
|
681
|
+
logger.debug("Path object created: %r", path_obj)
|
640
682
|
|
641
683
|
# Resolve the path to catch directory traversal attempts
|
642
684
|
try:
|
685
|
+
logger.debug("Attempting to resolve path: %r", path_obj)
|
643
686
|
resolved_path = path_obj.resolve()
|
687
|
+
logger.debug("Resolved path: %r", resolved_path)
|
644
688
|
except OSError as e:
|
689
|
+
logger.error("Failed to resolve path: %s", e)
|
645
690
|
raise OSError(f"Failed to resolve path: {e}")
|
646
691
|
|
647
692
|
# Check for directory traversal
|
@@ -686,7 +731,9 @@ def _validate_path_mapping_internal(
|
|
686
731
|
raise
|
687
732
|
|
688
733
|
if security_manager:
|
689
|
-
|
734
|
+
try:
|
735
|
+
security_manager.validate_path(str(resolved_path))
|
736
|
+
except PathSecurityError:
|
690
737
|
raise PathSecurityError.from_expanded_paths(
|
691
738
|
original_path=str(path),
|
692
739
|
expanded_path=str(resolved_path),
|
@@ -709,34 +756,45 @@ def _validate_path_mapping_internal(
|
|
709
756
|
raise
|
710
757
|
|
711
758
|
|
712
|
-
def validate_task_template(
|
759
|
+
def validate_task_template(
|
760
|
+
task: Optional[str], task_file: Optional[str]
|
761
|
+
) -> str:
|
713
762
|
"""Validate and load a task template.
|
714
763
|
|
715
764
|
Args:
|
716
|
-
task: The task template string
|
765
|
+
task: The task template string
|
766
|
+
task_file: Path to task template file
|
717
767
|
|
718
768
|
Returns:
|
719
769
|
The task template string
|
720
770
|
|
721
771
|
Raises:
|
722
|
-
TaskTemplateVariableError: If
|
772
|
+
TaskTemplateVariableError: If neither task nor task_file is provided, or if both are provided
|
723
773
|
TaskTemplateSyntaxError: If the template has invalid syntax
|
724
774
|
FileNotFoundError: If the template file does not exist
|
725
775
|
PathSecurityError: If the template file path violates security constraints
|
726
776
|
"""
|
727
|
-
|
777
|
+
if task is not None and task_file is not None:
|
778
|
+
raise TaskTemplateVariableError(
|
779
|
+
"Cannot specify both --task and --task-file"
|
780
|
+
)
|
781
|
+
|
782
|
+
if task is None and task_file is None:
|
783
|
+
raise TaskTemplateVariableError(
|
784
|
+
"Must specify either --task or --task-file"
|
785
|
+
)
|
728
786
|
|
729
|
-
|
730
|
-
if
|
731
|
-
path = task[1:]
|
787
|
+
template_content: str
|
788
|
+
if task_file is not None:
|
732
789
|
try:
|
733
|
-
name, path = validate_path_mapping(f"task={
|
790
|
+
name, path = validate_path_mapping(f"task={task_file}")
|
734
791
|
with open(path, "r", encoding="utf-8") as f:
|
735
792
|
template_content = f.read()
|
736
793
|
except (FileNotFoundError, PathSecurityError) as e:
|
737
|
-
raise TaskTemplateVariableError(
|
794
|
+
raise TaskTemplateVariableError(str(e))
|
795
|
+
else:
|
796
|
+
template_content = task # type: ignore # We know task is str here due to the checks above
|
738
797
|
|
739
|
-
# Validate template syntax
|
740
798
|
try:
|
741
799
|
env = jinja2.Environment(undefined=jinja2.StrictUndefined)
|
742
800
|
env.parse(template_content)
|
@@ -813,7 +871,7 @@ def validate_schema_file(
|
|
813
871
|
|
814
872
|
|
815
873
|
def collect_template_files(
|
816
|
-
args:
|
874
|
+
args: Namespace,
|
817
875
|
security_manager: SecurityManager,
|
818
876
|
) -> Dict[str, TemplateValue]:
|
819
877
|
"""Collect files from command line arguments.
|
@@ -846,14 +904,17 @@ def collect_template_files(
|
|
846
904
|
# Wrap file-related errors
|
847
905
|
raise ValueError(f"File access error: {e}")
|
848
906
|
except Exception as e:
|
907
|
+
# Don't wrap InvalidJSONError
|
908
|
+
if isinstance(e, InvalidJSONError):
|
909
|
+
raise
|
849
910
|
# Check if this is a wrapped security error
|
850
911
|
if isinstance(e.__cause__, PathSecurityError):
|
851
912
|
raise e.__cause__
|
852
|
-
# Wrap
|
913
|
+
# Wrap other errors
|
853
914
|
raise ValueError(f"Error collecting files: {e}")
|
854
915
|
|
855
916
|
|
856
|
-
def collect_simple_variables(args:
|
917
|
+
def collect_simple_variables(args: Namespace) -> Dict[str, str]:
|
857
918
|
"""Collect simple string variables from --var arguments.
|
858
919
|
|
859
920
|
Args:
|
@@ -886,7 +947,7 @@ def collect_simple_variables(args: argparse.Namespace) -> Dict[str, str]:
|
|
886
947
|
return variables
|
887
948
|
|
888
949
|
|
889
|
-
def collect_json_variables(args:
|
950
|
+
def collect_json_variables(args: Namespace) -> Dict[str, Any]:
|
890
951
|
"""Collect JSON variables from --json-var arguments.
|
891
952
|
|
892
953
|
Args:
|
@@ -916,7 +977,7 @@ def collect_json_variables(args: argparse.Namespace) -> Dict[str, Any]:
|
|
916
977
|
all_names.add(name)
|
917
978
|
except json.JSONDecodeError as e:
|
918
979
|
raise InvalidJSONError(
|
919
|
-
f"
|
980
|
+
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
|
920
981
|
)
|
921
982
|
except ValueError:
|
922
983
|
raise VariableNameError(
|
@@ -972,7 +1033,7 @@ def create_template_context(
|
|
972
1033
|
|
973
1034
|
|
974
1035
|
def create_template_context_from_args(
|
975
|
-
args:
|
1036
|
+
args: "Namespace",
|
976
1037
|
security_manager: SecurityManager,
|
977
1038
|
) -> Dict[str, Any]:
|
978
1039
|
"""Create template context from command line arguments.
|
@@ -1024,7 +1085,7 @@ def create_template_context_from_args(
|
|
1024
1085
|
json_value = json.loads(value)
|
1025
1086
|
except json.JSONDecodeError as e:
|
1026
1087
|
raise InvalidJSONError(
|
1027
|
-
f"
|
1088
|
+
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {value}"
|
1028
1089
|
)
|
1029
1090
|
if name in json_variables:
|
1030
1091
|
raise VariableNameError(
|
@@ -1060,42 +1121,59 @@ def create_template_context_from_args(
|
|
1060
1121
|
# Wrap file-related errors
|
1061
1122
|
raise ValueError(f"File access error: {e}")
|
1062
1123
|
except Exception as e:
|
1124
|
+
# Don't wrap InvalidJSONError
|
1125
|
+
if isinstance(e, InvalidJSONError):
|
1126
|
+
raise
|
1063
1127
|
# Check if this is a wrapped security error
|
1064
1128
|
if isinstance(e.__cause__, PathSecurityError):
|
1065
1129
|
raise e.__cause__
|
1066
|
-
# Wrap
|
1130
|
+
# Wrap other errors
|
1067
1131
|
raise ValueError(f"Error collecting files: {e}")
|
1068
1132
|
|
1069
1133
|
|
1070
1134
|
def validate_security_manager(
|
1071
1135
|
base_dir: Optional[str] = None,
|
1072
1136
|
allowed_dirs: Optional[List[str]] = None,
|
1073
|
-
|
1137
|
+
allowed_dir_file: Optional[str] = None,
|
1074
1138
|
) -> SecurityManager:
|
1075
|
-
"""
|
1139
|
+
"""Validate and create security manager.
|
1076
1140
|
|
1077
1141
|
Args:
|
1078
|
-
base_dir:
|
1079
|
-
allowed_dirs: Optional list of allowed
|
1080
|
-
|
1142
|
+
base_dir: Base directory for file access. Defaults to current working directory.
|
1143
|
+
allowed_dirs: Optional list of additional allowed directories
|
1144
|
+
allowed_dir_file: Optional file containing allowed directories
|
1081
1145
|
|
1082
1146
|
Returns:
|
1083
1147
|
Configured SecurityManager instance
|
1084
1148
|
|
1085
1149
|
Raises:
|
1086
|
-
|
1087
|
-
|
1150
|
+
PathSecurityError: If any paths violate security constraints
|
1151
|
+
DirectoryNotFoundError: If any directories do not exist
|
1088
1152
|
"""
|
1089
|
-
#
|
1090
|
-
|
1091
|
-
|
1153
|
+
# Use current working directory if base_dir is None
|
1154
|
+
if base_dir is None:
|
1155
|
+
base_dir = os.getcwd()
|
1092
1156
|
|
1093
|
-
|
1094
|
-
|
1157
|
+
# Create security manager with base directory
|
1158
|
+
security_manager = SecurityManager(base_dir)
|
1095
1159
|
|
1160
|
+
# Add explicitly allowed directories
|
1096
1161
|
if allowed_dirs:
|
1097
|
-
for
|
1098
|
-
security_manager.
|
1162
|
+
for dir_path in allowed_dirs:
|
1163
|
+
security_manager.add_allowed_directory(dir_path)
|
1164
|
+
|
1165
|
+
# Add directories from file if specified
|
1166
|
+
if allowed_dir_file:
|
1167
|
+
try:
|
1168
|
+
with open(allowed_dir_file, "r", encoding="utf-8") as f:
|
1169
|
+
for line in f:
|
1170
|
+
line = line.strip()
|
1171
|
+
if line and not line.startswith("#"):
|
1172
|
+
security_manager.add_allowed_directory(line)
|
1173
|
+
except OSError as e:
|
1174
|
+
raise DirectoryNotFoundError(
|
1175
|
+
f"Failed to read allowed directories file: {e}"
|
1176
|
+
)
|
1099
1177
|
|
1100
1178
|
return security_manager
|
1101
1179
|
|
@@ -1157,8 +1235,8 @@ def parse_json_var(var_str: str) -> Tuple[str, Any]:
|
|
1157
1235
|
value = json.loads(json_str)
|
1158
1236
|
except json.JSONDecodeError as e:
|
1159
1237
|
raise InvalidJSONError(
|
1160
|
-
f"
|
1161
|
-
)
|
1238
|
+
f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
|
1239
|
+
)
|
1162
1240
|
|
1163
1241
|
return name, value
|
1164
1242
|
|
@@ -1205,582 +1283,366 @@ def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
|
|
1205
1283
|
return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
|
1206
1284
|
|
1207
1285
|
|
1208
|
-
def
|
1209
|
-
"""
|
1210
|
-
|
1211
|
-
|
1212
|
-
|
1213
|
-
|
1214
|
-
|
1215
|
-
|
1216
|
-
|
1217
|
-
|
1218
|
-
|
1219
|
-
|
1220
|
-
|
1221
|
-
|
1222
|
-
|
1223
|
-
"
|
1224
|
-
|
1225
|
-
|
1226
|
-
|
1227
|
-
|
1228
|
-
"
|
1229
|
-
|
1230
|
-
|
1231
|
-
|
1232
|
-
|
1233
|
-
"
|
1234
|
-
|
1235
|
-
|
1236
|
-
|
1237
|
-
|
1238
|
-
|
1239
|
-
|
1240
|
-
|
1241
|
-
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
|
1247
|
-
|
1248
|
-
|
1249
|
-
|
1250
|
-
|
1251
|
-
|
1252
|
-
|
1253
|
-
|
1254
|
-
|
1255
|
-
|
1256
|
-
action="append",
|
1257
|
-
default=[],
|
1258
|
-
help="Map file pattern to variable (name=pattern)",
|
1259
|
-
metavar="NAME=PATTERN",
|
1260
|
-
)
|
1261
|
-
parser.add_argument(
|
1262
|
-
"--dir",
|
1263
|
-
action="append",
|
1264
|
-
default=[],
|
1265
|
-
help="Map directory to variable (name=path)",
|
1266
|
-
metavar="NAME=PATH",
|
1267
|
-
)
|
1268
|
-
parser.add_argument(
|
1269
|
-
"--allowed-dir",
|
1270
|
-
action="append",
|
1271
|
-
default=[],
|
1272
|
-
help="Additional allowed directory or @file",
|
1273
|
-
metavar="PATH",
|
1274
|
-
)
|
1275
|
-
parser.add_argument(
|
1276
|
-
"--base-dir",
|
1277
|
-
help="Base directory for file access (defaults to current directory)",
|
1278
|
-
default=os.getcwd(),
|
1279
|
-
)
|
1280
|
-
parser.add_argument(
|
1281
|
-
"--allowed-dirs-file",
|
1282
|
-
help="File containing list of allowed directories",
|
1283
|
-
)
|
1284
|
-
parser.add_argument(
|
1285
|
-
"--dir-recursive",
|
1286
|
-
action="store_true",
|
1287
|
-
help="Process directories recursively",
|
1288
|
-
)
|
1289
|
-
parser.add_argument(
|
1290
|
-
"--dir-ext",
|
1291
|
-
help="Comma-separated list of file extensions to include in directory processing",
|
1292
|
-
)
|
1293
|
-
|
1294
|
-
# Variable arguments
|
1295
|
-
parser.add_argument(
|
1296
|
-
"--var",
|
1297
|
-
action="append",
|
1298
|
-
default=[],
|
1299
|
-
help="Pass simple variables (name=value)",
|
1300
|
-
metavar="NAME=VALUE",
|
1301
|
-
)
|
1302
|
-
parser.add_argument(
|
1303
|
-
"--json-var",
|
1304
|
-
action="append",
|
1305
|
-
default=[],
|
1306
|
-
help="Pass JSON variables (name=json)",
|
1307
|
-
metavar="NAME=JSON",
|
1308
|
-
)
|
1309
|
-
|
1310
|
-
# System prompt options
|
1311
|
-
parser.add_argument(
|
1312
|
-
"--system-prompt",
|
1313
|
-
help=(
|
1314
|
-
"System prompt for the model (use @file to load from file, "
|
1315
|
-
"can also be specified in task template YAML frontmatter)"
|
1316
|
-
),
|
1317
|
-
default=DEFAULT_SYSTEM_PROMPT,
|
1318
|
-
)
|
1319
|
-
parser.add_argument(
|
1320
|
-
"--ignore-task-sysprompt",
|
1321
|
-
action="store_true",
|
1322
|
-
help="Ignore system prompt from task template YAML frontmatter",
|
1323
|
-
)
|
1286
|
+
def handle_error(e: Exception) -> None:
|
1287
|
+
"""Handle errors by printing appropriate message and exiting with status code."""
|
1288
|
+
if isinstance(e, click.UsageError):
|
1289
|
+
# For UsageError, preserve the original message format
|
1290
|
+
if hasattr(e, "param") and e.param:
|
1291
|
+
# Missing parameter error
|
1292
|
+
msg = f"Missing option '--{e.param.name}'"
|
1293
|
+
click.echo(msg, err=True)
|
1294
|
+
else:
|
1295
|
+
# Other usage errors (like conflicting options)
|
1296
|
+
click.echo(str(e), err=True)
|
1297
|
+
sys.exit(ExitCode.USAGE_ERROR)
|
1298
|
+
elif isinstance(e, InvalidJSONError):
|
1299
|
+
# Use the original error message if available
|
1300
|
+
msg = str(e) if str(e) != "None" else "Invalid JSON"
|
1301
|
+
click.secho(msg, fg="red", err=True)
|
1302
|
+
sys.exit(ExitCode.DATA_ERROR)
|
1303
|
+
elif isinstance(e, FileNotFoundError):
|
1304
|
+
# Use the original error message if available
|
1305
|
+
msg = str(e) if str(e) != "None" else "File not found"
|
1306
|
+
click.secho(msg, fg="red", err=True)
|
1307
|
+
sys.exit(ExitCode.SCHEMA_ERROR)
|
1308
|
+
elif isinstance(e, TaskTemplateSyntaxError):
|
1309
|
+
# Use the original error message if available
|
1310
|
+
msg = str(e) if str(e) != "None" else "Template syntax error"
|
1311
|
+
click.secho(msg, fg="red", err=True)
|
1312
|
+
sys.exit(ExitCode.INTERNAL_ERROR)
|
1313
|
+
elif isinstance(e, CLIError):
|
1314
|
+
# Use the show method for CLIError and its subclasses
|
1315
|
+
e.show()
|
1316
|
+
sys.exit(
|
1317
|
+
e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
|
1318
|
+
)
|
1319
|
+
else:
|
1320
|
+
click.secho(f"Unexpected error: {str(e)}", fg="red", err=True)
|
1321
|
+
sys.exit(ExitCode.INTERNAL_ERROR)
|
1322
|
+
|
1323
|
+
|
1324
|
+
async def stream_structured_output(
|
1325
|
+
client: AsyncOpenAI,
|
1326
|
+
model: str,
|
1327
|
+
system_prompt: str,
|
1328
|
+
user_prompt: str,
|
1329
|
+
output_schema: Type[BaseModel],
|
1330
|
+
output_file: Optional[str] = None,
|
1331
|
+
**kwargs: Any,
|
1332
|
+
) -> None:
|
1333
|
+
"""Stream structured output from OpenAI API.
|
1324
1334
|
|
1325
|
-
|
1326
|
-
|
1327
|
-
|
1328
|
-
|
1329
|
-
|
1330
|
-
|
1331
|
-
)
|
1332
|
-
parser.add_argument(
|
1333
|
-
"--validate-schema",
|
1334
|
-
action="store_true",
|
1335
|
-
help="Validate schema and response",
|
1336
|
-
)
|
1335
|
+
This function follows the guide's recommendation for a focused async streaming function.
|
1336
|
+
It handles the core streaming logic and resource cleanup.
|
1337
|
+
"""
|
1338
|
+
try:
|
1339
|
+
# Base models that don't support streaming
|
1340
|
+
non_streaming_models = {"o1", "o3"}
|
1337
1341
|
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
)
|
1344
|
-
parser.add_argument(
|
1345
|
-
"--temperature",
|
1346
|
-
type=float,
|
1347
|
-
default=0.0,
|
1348
|
-
help="Temperature (0.0-2.0)",
|
1349
|
-
)
|
1350
|
-
parser.add_argument(
|
1351
|
-
"--max-tokens",
|
1352
|
-
type=int,
|
1353
|
-
help="Maximum tokens to generate",
|
1354
|
-
)
|
1355
|
-
parser.add_argument(
|
1356
|
-
"--top-p",
|
1357
|
-
type=float,
|
1358
|
-
default=1.0,
|
1359
|
-
help="Top-p sampling (0.0-1.0)",
|
1360
|
-
)
|
1361
|
-
parser.add_argument(
|
1362
|
-
"--frequency-penalty",
|
1363
|
-
type=float,
|
1364
|
-
default=0.0,
|
1365
|
-
help="Frequency penalty (-2.0-2.0)",
|
1366
|
-
)
|
1367
|
-
parser.add_argument(
|
1368
|
-
"--presence-penalty",
|
1369
|
-
type=float,
|
1370
|
-
default=0.0,
|
1371
|
-
help="Presence penalty (-2.0-2.0)",
|
1372
|
-
)
|
1373
|
-
parser.add_argument(
|
1374
|
-
"--timeout",
|
1375
|
-
type=float,
|
1376
|
-
default=60.0,
|
1377
|
-
help="API timeout in seconds",
|
1378
|
-
)
|
1342
|
+
# Check if model supports streaming
|
1343
|
+
# o3-mini and o3-mini-high support streaming, base o3 does not
|
1344
|
+
use_streaming = model not in non_streaming_models and (
|
1345
|
+
not model.startswith("o3") or model.startswith("o3-mini")
|
1346
|
+
)
|
1379
1347
|
|
1380
|
-
|
1381
|
-
|
1382
|
-
"
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
|
1391
|
-
|
1392
|
-
|
1393
|
-
|
1394
|
-
|
1348
|
+
# All o1 and o3 models (base and variants) have fixed settings
|
1349
|
+
stream_kwargs = {}
|
1350
|
+
if not (model.startswith("o1") or model.startswith("o3")):
|
1351
|
+
stream_kwargs = kwargs
|
1352
|
+
|
1353
|
+
if use_streaming:
|
1354
|
+
async for chunk in async_openai_structured_stream(
|
1355
|
+
client=client,
|
1356
|
+
model=model,
|
1357
|
+
output_schema=output_schema,
|
1358
|
+
system_prompt=system_prompt,
|
1359
|
+
user_prompt=user_prompt,
|
1360
|
+
**stream_kwargs,
|
1361
|
+
):
|
1362
|
+
if not chunk:
|
1363
|
+
continue
|
1364
|
+
|
1365
|
+
# Process and output the chunk
|
1366
|
+
dumped = chunk.model_dump(mode="json")
|
1367
|
+
json_str = json.dumps(dumped, indent=2)
|
1368
|
+
|
1369
|
+
if output_file:
|
1370
|
+
with open(output_file, "a", encoding="utf-8") as f:
|
1371
|
+
f.write(json_str)
|
1372
|
+
f.write("\n")
|
1373
|
+
f.flush() # Ensure immediate flush to file
|
1374
|
+
else:
|
1375
|
+
# Print directly to stdout with immediate flush
|
1376
|
+
print(json_str, flush=True)
|
1377
|
+
else:
|
1378
|
+
# For non-streaming models, use regular completion
|
1379
|
+
response = await client.chat.completions.create(
|
1380
|
+
model=model,
|
1381
|
+
messages=[
|
1382
|
+
{"role": "system", "content": system_prompt},
|
1383
|
+
{"role": "user", "content": user_prompt},
|
1384
|
+
],
|
1385
|
+
stream=False,
|
1386
|
+
**stream_kwargs,
|
1387
|
+
)
|
1395
1388
|
|
1396
|
-
|
1397
|
-
|
1398
|
-
|
1399
|
-
|
1400
|
-
|
1401
|
-
|
1402
|
-
|
1403
|
-
|
1404
|
-
|
1405
|
-
)
|
1406
|
-
parser.add_argument(
|
1407
|
-
"--debug-openai-stream",
|
1408
|
-
action="store_true",
|
1409
|
-
help="Enable low-level debug output for OpenAI streaming (very verbose)",
|
1410
|
-
)
|
1411
|
-
parser.add_argument(
|
1412
|
-
"--version",
|
1413
|
-
action="version",
|
1414
|
-
version=f"%(prog)s {__version__}",
|
1415
|
-
)
|
1389
|
+
# Process the single response
|
1390
|
+
content = response.choices[0].message.content
|
1391
|
+
if content:
|
1392
|
+
try:
|
1393
|
+
# Parse and validate against schema
|
1394
|
+
result = output_schema.model_validate_json(content)
|
1395
|
+
json_str = json.dumps(
|
1396
|
+
result.model_dump(mode="json"), indent=2
|
1397
|
+
)
|
1416
1398
|
|
1417
|
-
|
1399
|
+
if output_file:
|
1400
|
+
with open(output_file, "w", encoding="utf-8") as f:
|
1401
|
+
f.write(json_str)
|
1402
|
+
f.write("\n")
|
1403
|
+
else:
|
1404
|
+
print(json_str, flush=True)
|
1405
|
+
except ValidationError as e:
|
1406
|
+
raise InvalidResponseFormatError(
|
1407
|
+
f"Response validation failed: {e}"
|
1408
|
+
)
|
1409
|
+
else:
|
1410
|
+
raise EmptyResponseError("Model returned empty response")
|
1411
|
+
|
1412
|
+
except (
|
1413
|
+
StreamInterruptedError,
|
1414
|
+
StreamBufferError,
|
1415
|
+
StreamParseError,
|
1416
|
+
APIResponseError,
|
1417
|
+
EmptyResponseError,
|
1418
|
+
InvalidResponseFormatError,
|
1419
|
+
) as e:
|
1420
|
+
logger.error(f"Stream error: {e}")
|
1421
|
+
raise
|
1422
|
+
finally:
|
1423
|
+
# Always ensure client is properly closed
|
1424
|
+
await client.close()
|
1418
1425
|
|
1419
1426
|
|
1420
|
-
async def
|
1421
|
-
"""
|
1427
|
+
async def run_cli_async(args: Namespace) -> ExitCode:
|
1428
|
+
"""Async wrapper for CLI operations.
|
1422
1429
|
|
1423
|
-
|
1424
|
-
|
1430
|
+
This function prepares everything needed for streaming and then calls
|
1431
|
+
the focused streaming function.
|
1425
1432
|
"""
|
1426
1433
|
try:
|
1427
|
-
|
1428
|
-
args = parser.parse_args()
|
1429
|
-
|
1430
|
-
# Configure logging
|
1431
|
-
log_level = logging.DEBUG if args.verbose else logging.INFO
|
1432
|
-
logger.setLevel(log_level)
|
1433
|
-
|
1434
|
-
# Create security manager
|
1434
|
+
# Validate and prepare all inputs
|
1435
1435
|
security_manager = validate_security_manager(
|
1436
1436
|
base_dir=args.base_dir,
|
1437
1437
|
allowed_dirs=args.allowed_dir,
|
1438
|
-
|
1438
|
+
allowed_dir_file=args.allowed_dir_file,
|
1439
1439
|
)
|
1440
1440
|
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
# Validate schema file
|
1441
|
+
task_template = validate_task_template(args.task, args.task_file)
|
1442
|
+
logger.debug("Validating schema from %s", args.schema_file)
|
1445
1443
|
schema = validate_schema_file(args.schema_file, args.verbose)
|
1446
|
-
|
1447
|
-
# Create template context
|
1448
1444
|
template_context = create_template_context_from_args(
|
1449
1445
|
args, security_manager
|
1450
1446
|
)
|
1451
|
-
|
1452
|
-
# Create Jinja environment
|
1453
1447
|
env = create_jinja_env()
|
1454
1448
|
|
1455
|
-
# Process system prompt
|
1456
|
-
|
1449
|
+
# Process system prompt and render task
|
1450
|
+
system_prompt = process_system_prompt(
|
1457
1451
|
task_template,
|
1458
1452
|
args.system_prompt,
|
1453
|
+
args.system_prompt_file,
|
1459
1454
|
template_context,
|
1460
1455
|
env,
|
1461
1456
|
args.ignore_task_sysprompt,
|
1462
1457
|
)
|
1463
|
-
|
1464
|
-
# Render task template
|
1465
1458
|
rendered_task = render_template(task_template, template_context, env)
|
1466
|
-
logger.info(
|
1459
|
+
logger.info("Rendered task template: %s", rendered_task)
|
1467
1460
|
|
1468
|
-
# If dry run, exit here
|
1469
1461
|
if args.dry_run:
|
1470
1462
|
logger.info("DRY RUN MODE")
|
1471
1463
|
return ExitCode.SUCCESS
|
1472
1464
|
|
1473
|
-
#
|
1465
|
+
# Create output model
|
1466
|
+
logger.debug("Creating output model")
|
1474
1467
|
try:
|
1475
|
-
logger.debug("[_main] Loading schema from %s", args.schema_file)
|
1476
|
-
schema = validate_schema_file(
|
1477
|
-
args.schema_file, verbose=args.verbose_schema
|
1478
|
-
)
|
1479
|
-
logger.debug("[_main] Creating output model")
|
1480
1468
|
output_model = create_dynamic_model(
|
1481
1469
|
schema,
|
1482
1470
|
base_name="OutputModel",
|
1483
1471
|
show_schema=args.show_model_schema,
|
1484
1472
|
debug_validation=args.debug_validation,
|
1485
1473
|
)
|
1486
|
-
logger.debug("
|
1487
|
-
except (
|
1488
|
-
|
1489
|
-
|
1490
|
-
|
1491
|
-
|
1492
|
-
|
1493
|
-
|
1494
|
-
|
1495
|
-
|
1496
|
-
|
1497
|
-
# Validate model support
|
1474
|
+
logger.debug("Successfully created output model")
|
1475
|
+
except (
|
1476
|
+
SchemaFileError,
|
1477
|
+
InvalidJSONError,
|
1478
|
+
SchemaValidationError,
|
1479
|
+
ModelCreationError,
|
1480
|
+
) as e:
|
1481
|
+
logger.error("Schema error: %s", str(e))
|
1482
|
+
raise # Let the error propagate with its context
|
1483
|
+
|
1484
|
+
# Validate model support and token usage
|
1498
1485
|
try:
|
1499
1486
|
supports_structured_output(args.model)
|
1500
|
-
except ModelNotSupportedError as e:
|
1501
|
-
logger.error(str(e))
|
1502
|
-
|
1503
|
-
|
1504
|
-
logger.error(str(e))
|
1505
|
-
return ExitCode.DATA_ERROR
|
1506
|
-
|
1507
|
-
# Estimate token usage
|
1487
|
+
except (ModelNotSupportedError, ModelVersionError) as e:
|
1488
|
+
logger.error("Model validation error: %s", str(e))
|
1489
|
+
raise # Let the error propagate with its context
|
1490
|
+
|
1508
1491
|
messages = [
|
1509
|
-
{"role": "system", "content":
|
1492
|
+
{"role": "system", "content": system_prompt},
|
1510
1493
|
{"role": "user", "content": rendered_task},
|
1511
1494
|
]
|
1512
1495
|
total_tokens = estimate_tokens_for_chat(messages, args.model)
|
1513
1496
|
context_limit = get_context_window_limit(args.model)
|
1514
|
-
|
1515
1497
|
if total_tokens > context_limit:
|
1516
|
-
|
1517
|
-
|
1498
|
+
msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
|
1499
|
+
logger.error(msg)
|
1500
|
+
raise CLIError(
|
1501
|
+
msg,
|
1502
|
+
context={
|
1503
|
+
"total_tokens": total_tokens,
|
1504
|
+
"context_limit": context_limit,
|
1505
|
+
},
|
1518
1506
|
)
|
1519
|
-
return ExitCode.DATA_ERROR
|
1520
1507
|
|
1521
|
-
# Get API key
|
1508
|
+
# Get API key and create client
|
1522
1509
|
api_key = args.api_key or os.getenv("OPENAI_API_KEY")
|
1523
1510
|
if not api_key:
|
1524
|
-
|
1525
|
-
|
1526
|
-
)
|
1527
|
-
return ExitCode.USAGE_ERROR
|
1511
|
+
msg = "No OpenAI API key provided (--api-key or OPENAI_API_KEY env var)"
|
1512
|
+
logger.error(msg)
|
1513
|
+
raise CLIError(msg)
|
1528
1514
|
|
1529
|
-
# Create OpenAI client
|
1530
1515
|
client = AsyncOpenAI(api_key=api_key, timeout=args.timeout)
|
1531
1516
|
|
1532
|
-
# Create log callback
|
1517
|
+
# Create detailed log callback
|
1533
1518
|
def log_callback(
|
1534
1519
|
level: int, message: str, extra: dict[str, Any]
|
1535
1520
|
) -> None:
|
1536
|
-
# Only log if debug_openai_stream is enabled
|
1537
1521
|
if args.debug_openai_stream:
|
1538
|
-
|
1539
|
-
if extra: # Only add if there's actually extra data
|
1522
|
+
if extra:
|
1540
1523
|
extra_str = json.dumps(extra, indent=2)
|
1541
1524
|
message = f"{message}\nDetails:\n{extra_str}"
|
1542
|
-
|
1525
|
+
logger.log(level, message, extra=extra)
|
1543
1526
|
|
1544
|
-
#
|
1527
|
+
# Stream the output
|
1545
1528
|
try:
|
1546
|
-
|
1547
|
-
|
1548
|
-
|
1549
|
-
|
1550
|
-
|
1551
|
-
|
1552
|
-
|
1553
|
-
|
1554
|
-
|
1555
|
-
|
1556
|
-
|
1557
|
-
|
1558
|
-
|
1559
|
-
|
1560
|
-
|
1561
|
-
|
1562
|
-
|
1563
|
-
|
1564
|
-
|
1565
|
-
|
1566
|
-
|
1567
|
-
|
1568
|
-
|
1569
|
-
|
1570
|
-
|
1571
|
-
|
1572
|
-
logger.debug("Empty chunk received, skipping")
|
1573
|
-
continue
|
1574
|
-
|
1575
|
-
# Write output
|
1576
|
-
try:
|
1577
|
-
logger.debug("Starting to process output chunk")
|
1578
|
-
dumped = chunk.model_dump(mode="json")
|
1579
|
-
logger.debug("Successfully dumped chunk to JSON")
|
1580
|
-
logger.debug("Dumped chunk: %s", dumped)
|
1581
|
-
logger.debug(
|
1582
|
-
"Chunk type: %s, length: %d",
|
1583
|
-
type(dumped),
|
1584
|
-
len(json.dumps(dumped)),
|
1585
|
-
)
|
1586
|
-
|
1587
|
-
if args.output_file:
|
1588
|
-
logger.debug(
|
1589
|
-
"Writing to output file: %s", args.output_file
|
1590
|
-
)
|
1591
|
-
try:
|
1592
|
-
with open(
|
1593
|
-
args.output_file, "a", encoding="utf-8"
|
1594
|
-
) as f:
|
1595
|
-
json_str = json.dumps(dumped, indent=2)
|
1596
|
-
logger.debug(
|
1597
|
-
"Writing JSON string of length %d",
|
1598
|
-
len(json_str),
|
1599
|
-
)
|
1600
|
-
f.write(json_str)
|
1601
|
-
f.write("\n")
|
1602
|
-
logger.debug("Successfully wrote to file")
|
1603
|
-
except Exception as e:
|
1604
|
-
logger.error(
|
1605
|
-
"Failed to write to output file: %s", e
|
1606
|
-
)
|
1607
|
-
else:
|
1608
|
-
logger.debug(
|
1609
|
-
"About to call progress.print_output with JSON string"
|
1610
|
-
)
|
1611
|
-
json_str = json.dumps(dumped, indent=2)
|
1612
|
-
logger.debug(
|
1613
|
-
"JSON string length before print_output: %d",
|
1614
|
-
len(json_str),
|
1615
|
-
)
|
1616
|
-
logger.debug(
|
1617
|
-
"First 100 chars of JSON string: %s",
|
1618
|
-
json_str[:100] if json_str else "",
|
1619
|
-
)
|
1620
|
-
progress.print_output(json_str)
|
1621
|
-
logger.debug(
|
1622
|
-
"Completed print_output call for JSON string"
|
1623
|
-
)
|
1624
|
-
|
1625
|
-
logger.debug("Starting progress update")
|
1626
|
-
progress.update()
|
1627
|
-
logger.debug("Completed progress update")
|
1628
|
-
except Exception as e:
|
1629
|
-
logger.error("Failed to process chunk: %s", e)
|
1630
|
-
logger.error("Chunk: %s", chunk)
|
1631
|
-
continue
|
1632
|
-
|
1633
|
-
logger.debug("Finished processing API response stream")
|
1634
|
-
|
1635
|
-
except StreamInterruptedError as e:
|
1636
|
-
logger.error(f"Stream interrupted: {e}")
|
1637
|
-
return ExitCode.API_ERROR
|
1638
|
-
except StreamBufferError as e:
|
1639
|
-
logger.error(f"Stream buffer error: {e}")
|
1640
|
-
return ExitCode.API_ERROR
|
1641
|
-
except StreamParseError as e:
|
1642
|
-
logger.error(f"Stream parse error: {e}")
|
1643
|
-
return ExitCode.API_ERROR
|
1644
|
-
except APIResponseError as e:
|
1645
|
-
logger.error(f"API response error: {e}")
|
1646
|
-
return ExitCode.API_ERROR
|
1647
|
-
except EmptyResponseError as e:
|
1648
|
-
logger.error(f"Empty response error: {e}")
|
1649
|
-
return ExitCode.API_ERROR
|
1650
|
-
except InvalidResponseFormatError as e:
|
1651
|
-
logger.error(f"Invalid response format: {e}")
|
1652
|
-
return ExitCode.API_ERROR
|
1529
|
+
await stream_structured_output(
|
1530
|
+
client=client,
|
1531
|
+
model=args.model,
|
1532
|
+
system_prompt=system_prompt,
|
1533
|
+
user_prompt=rendered_task,
|
1534
|
+
output_schema=output_model,
|
1535
|
+
output_file=args.output_file,
|
1536
|
+
temperature=args.temperature,
|
1537
|
+
max_tokens=args.max_tokens,
|
1538
|
+
top_p=args.top_p,
|
1539
|
+
frequency_penalty=args.frequency_penalty,
|
1540
|
+
presence_penalty=args.presence_penalty,
|
1541
|
+
timeout=args.timeout,
|
1542
|
+
on_log=log_callback,
|
1543
|
+
)
|
1544
|
+
return ExitCode.SUCCESS
|
1545
|
+
except (
|
1546
|
+
StreamInterruptedError,
|
1547
|
+
StreamBufferError,
|
1548
|
+
StreamParseError,
|
1549
|
+
APIResponseError,
|
1550
|
+
EmptyResponseError,
|
1551
|
+
InvalidResponseFormatError,
|
1552
|
+
) as e:
|
1553
|
+
logger.error("Stream error: %s", str(e))
|
1554
|
+
raise # Let stream errors propagate
|
1653
1555
|
except (APIConnectionError, InternalServerError) as e:
|
1654
|
-
logger.error(
|
1655
|
-
|
1556
|
+
logger.error("API connection error: %s", str(e))
|
1557
|
+
raise APIResponseError(str(e)) # Convert to our error type
|
1656
1558
|
except RateLimitError as e:
|
1657
|
-
logger.error(
|
1658
|
-
|
1659
|
-
except BadRequestError as e:
|
1660
|
-
logger.error(
|
1661
|
-
|
1662
|
-
|
1663
|
-
|
1664
|
-
return ExitCode.API_ERROR
|
1665
|
-
except OpenAIClientError as e:
|
1666
|
-
logger.error(f"OpenAI client error: {e}")
|
1667
|
-
return ExitCode.API_ERROR
|
1668
|
-
except Exception as e:
|
1669
|
-
logger.error(f"Unexpected error: {e}")
|
1670
|
-
return ExitCode.INTERNAL_ERROR
|
1671
|
-
|
1672
|
-
return ExitCode.SUCCESS
|
1559
|
+
logger.error("Rate limit exceeded: %s", str(e))
|
1560
|
+
raise APIResponseError(str(e)) # Convert to our error type
|
1561
|
+
except (BadRequestError, AuthenticationError, OpenAIClientError) as e:
|
1562
|
+
logger.error("API client error: %s", str(e))
|
1563
|
+
raise APIResponseError(str(e)) # Convert to our error type
|
1564
|
+
finally:
|
1565
|
+
await client.close()
|
1673
1566
|
|
1674
1567
|
except KeyboardInterrupt:
|
1675
|
-
logger.
|
1568
|
+
logger.info("Operation cancelled by user")
|
1676
1569
|
return ExitCode.INTERRUPTED
|
1677
|
-
except PathSecurityError as e:
|
1678
|
-
# Only log security errors if they haven't been logged already
|
1679
|
-
logger.debug(
|
1680
|
-
"[_main] Caught PathSecurityError: %s (logged=%s)",
|
1681
|
-
str(e),
|
1682
|
-
getattr(e, "has_been_logged", False),
|
1683
|
-
)
|
1684
|
-
if not getattr(e, "has_been_logged", False):
|
1685
|
-
logger.error(str(e))
|
1686
|
-
return ExitCode.SECURITY_ERROR
|
1687
|
-
except ValueError as e:
|
1688
|
-
# Get the original cause of the error
|
1689
|
-
cause = e.__cause__ or e.__context__
|
1690
|
-
if isinstance(cause, PathSecurityError):
|
1691
|
-
logger.debug(
|
1692
|
-
"[_main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
|
1693
|
-
str(cause),
|
1694
|
-
getattr(cause, "has_been_logged", False),
|
1695
|
-
)
|
1696
|
-
# Only log security errors if they haven't been logged already
|
1697
|
-
if not getattr(cause, "has_been_logged", False):
|
1698
|
-
logger.error(str(cause))
|
1699
|
-
return ExitCode.SECURITY_ERROR
|
1700
|
-
else:
|
1701
|
-
logger.debug("[_main] Caught ValueError: %s", str(e))
|
1702
|
-
logger.error(f"Invalid input: {e}")
|
1703
|
-
return ExitCode.DATA_ERROR
|
1704
1570
|
except Exception as e:
|
1705
|
-
|
1706
|
-
|
1707
|
-
|
1708
|
-
|
1709
|
-
|
1710
|
-
|
1711
|
-
|
1712
|
-
|
1713
|
-
|
1714
|
-
|
1715
|
-
|
1716
|
-
|
1717
|
-
|
1718
|
-
|
1571
|
+
if isinstance(e, CLIError):
|
1572
|
+
raise # Let our custom errors propagate
|
1573
|
+
logger.exception("Unexpected error")
|
1574
|
+
raise CLIError(str(e), context={"error_type": type(e).__name__})
|
1575
|
+
|
1576
|
+
|
1577
|
+
def create_cli() -> click.Command:
|
1578
|
+
"""Create the CLI command.
|
1579
|
+
|
1580
|
+
Returns:
|
1581
|
+
click.Command: The CLI command object
|
1582
|
+
"""
|
1583
|
+
|
1584
|
+
@create_click_command()
|
1585
|
+
def cli(**kwargs: Any) -> None:
|
1586
|
+
"""CLI entry point for structured OpenAI API calls."""
|
1587
|
+
try:
|
1588
|
+
args = Namespace(**kwargs)
|
1589
|
+
|
1590
|
+
# Validate required arguments first
|
1591
|
+
if not args.task and not args.task_file:
|
1592
|
+
raise click.UsageError(
|
1593
|
+
"Must specify either --task or --task-file"
|
1594
|
+
)
|
1595
|
+
if not args.schema_file:
|
1596
|
+
raise click.UsageError("Missing option '--schema-file'")
|
1597
|
+
if args.task and args.task_file:
|
1598
|
+
raise click.UsageError(
|
1599
|
+
"Cannot specify both --task and --task-file"
|
1600
|
+
)
|
1601
|
+
if args.system_prompt and args.system_prompt_file:
|
1602
|
+
raise click.UsageError(
|
1603
|
+
"Cannot specify both --system-prompt and --system-prompt-file"
|
1604
|
+
)
|
1605
|
+
|
1606
|
+
# Run the async function synchronously
|
1607
|
+
exit_code = asyncio.run(run_cli_async(args))
|
1608
|
+
|
1609
|
+
if exit_code != ExitCode.SUCCESS:
|
1610
|
+
error_msg = f"Command failed with exit code {exit_code}"
|
1611
|
+
if hasattr(ExitCode, exit_code.name):
|
1612
|
+
error_msg = f"{error_msg} ({exit_code.name})"
|
1613
|
+
raise CLIError(error_msg, context={"exit_code": exit_code})
|
1614
|
+
|
1615
|
+
except click.UsageError:
|
1616
|
+
# Let Click handle usage errors directly
|
1617
|
+
raise
|
1618
|
+
except InvalidJSONError:
|
1619
|
+
# Let InvalidJSONError propagate directly
|
1620
|
+
raise
|
1621
|
+
except CLIError:
|
1622
|
+
# Let our custom errors propagate with their context
|
1623
|
+
raise
|
1624
|
+
except Exception as e:
|
1625
|
+
# Convert other exceptions to CLIError
|
1626
|
+
logger.exception("Unexpected error")
|
1627
|
+
raise CLIError(str(e), context={"error_type": type(e).__name__})
|
1628
|
+
|
1629
|
+
return cli
|
1719
1630
|
|
1720
1631
|
|
1721
1632
|
def main() -> None:
|
1722
|
-
"""
|
1723
|
-
|
1724
|
-
|
1725
|
-
exit_code = asyncio.run(_main())
|
1726
|
-
sys.exit(exit_code.value)
|
1727
|
-
except KeyboardInterrupt:
|
1728
|
-
logger.error("Operation cancelled by user")
|
1729
|
-
sys.exit(ExitCode.INTERRUPTED.value)
|
1730
|
-
except PathSecurityError as e:
|
1731
|
-
# Only log security errors if they haven't been logged already
|
1732
|
-
logger.debug(
|
1733
|
-
"[main] Caught PathSecurityError: %s (logged=%s)",
|
1734
|
-
str(e),
|
1735
|
-
getattr(e, "has_been_logged", False),
|
1736
|
-
)
|
1737
|
-
if not getattr(e, "has_been_logged", False):
|
1738
|
-
logger.error(str(e))
|
1739
|
-
sys.exit(ExitCode.SECURITY_ERROR.value)
|
1740
|
-
except ValueError as e:
|
1741
|
-
# Get the original cause of the error
|
1742
|
-
cause = e.__cause__ or e.__context__
|
1743
|
-
if isinstance(cause, PathSecurityError):
|
1744
|
-
logger.debug(
|
1745
|
-
"[main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
|
1746
|
-
str(cause),
|
1747
|
-
getattr(cause, "has_been_logged", False),
|
1748
|
-
)
|
1749
|
-
# Only log security errors if they haven't been logged already
|
1750
|
-
if not getattr(cause, "has_been_logged", False):
|
1751
|
-
logger.error(str(cause))
|
1752
|
-
sys.exit(ExitCode.SECURITY_ERROR.value)
|
1753
|
-
else:
|
1754
|
-
logger.debug("[main] Caught ValueError: %s", str(e))
|
1755
|
-
logger.error(f"Invalid input: {e}")
|
1756
|
-
sys.exit(ExitCode.DATA_ERROR.value)
|
1757
|
-
except Exception as e:
|
1758
|
-
# Check if this is a wrapped security error
|
1759
|
-
if isinstance(e.__cause__, PathSecurityError):
|
1760
|
-
logger.debug(
|
1761
|
-
"[main] Caught wrapped PathSecurityError in Exception: %s (logged=%s)",
|
1762
|
-
str(e.__cause__),
|
1763
|
-
getattr(e.__cause__, "has_been_logged", False),
|
1764
|
-
)
|
1765
|
-
# Only log security errors if they haven't been logged already
|
1766
|
-
if not getattr(e.__cause__, "has_been_logged", False):
|
1767
|
-
logger.error(str(e.__cause__))
|
1768
|
-
sys.exit(ExitCode.SECURITY_ERROR.value)
|
1769
|
-
logger.debug("[main] Caught unexpected error: %s", str(e))
|
1770
|
-
logger.error(f"Unexpected error: {e}")
|
1771
|
-
sys.exit(ExitCode.INTERNAL_ERROR.value)
|
1633
|
+
"""Main entry point for the CLI."""
|
1634
|
+
cli = create_cli()
|
1635
|
+
cli(standalone_mode=False)
|
1772
1636
|
|
1773
1637
|
|
1774
1638
|
# Export public API
|
1775
1639
|
__all__ = [
|
1776
1640
|
"ExitCode",
|
1777
1641
|
"estimate_tokens_for_chat",
|
1778
|
-
"get_context_window_limit",
|
1779
|
-
"get_default_token_limit",
|
1780
1642
|
"parse_json_var",
|
1781
1643
|
"create_dynamic_model",
|
1782
1644
|
"validate_path_mapping",
|
1783
|
-
"
|
1645
|
+
"create_cli",
|
1784
1646
|
"main",
|
1785
1647
|
]
|
1786
1648
|
|