ostruct-cli 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ostruct/cli/cli.py CHANGED
@@ -1,18 +1,12 @@
1
1
  """Command-line interface for making structured OpenAI API calls."""
2
2
 
3
- import argparse
4
3
  import asyncio
5
4
  import json
6
5
  import logging
7
6
  import os
8
7
  import sys
8
+ from dataclasses import dataclass
9
9
  from enum import Enum, IntEnum
10
-
11
- if sys.version_info >= (3, 11):
12
- from enum import StrEnum
13
-
14
- from datetime import date, datetime, time
15
- from pathlib import Path
16
10
  from typing import (
17
11
  Any,
18
12
  Dict,
@@ -29,6 +23,13 @@ from typing import (
29
23
  overload,
30
24
  )
31
25
 
26
+ if sys.version_info >= (3, 11):
27
+ from enum import StrEnum
28
+
29
+ from datetime import date, datetime, time
30
+ from pathlib import Path
31
+
32
+ import click
32
33
  import jinja2
33
34
  import tiktoken
34
35
  import yaml
@@ -42,6 +43,8 @@ from openai import (
42
43
  )
43
44
  from openai_structured.client import (
44
45
  async_openai_structured_stream,
46
+ get_context_window_limit,
47
+ get_default_token_limit,
45
48
  supports_structured_output,
46
49
  )
47
50
  from openai_structured.errors import (
@@ -71,8 +74,11 @@ from pydantic.functional_validators import BeforeValidator
71
74
  from pydantic.types import constr
72
75
  from typing_extensions import TypeAlias
73
76
 
74
- from .. import __version__
77
+ from ostruct.cli.click_options import create_click_command
78
+
79
+ from .. import __version__ # noqa: F401 - Used in package metadata
75
80
  from .errors import (
81
+ CLIError,
76
82
  DirectoryNotFoundError,
77
83
  FieldDefinitionError,
78
84
  FileNotFoundError,
@@ -89,7 +95,6 @@ from .errors import (
89
95
  )
90
96
  from .file_utils import FileInfoList, TemplateValue, collect_files
91
97
  from .path_utils import validate_path_mapping
92
- from .progress import ProgressContext
93
98
  from .security import SecurityManager
94
99
  from .template_env import create_jinja_env
95
100
  from .template_utils import SystemPromptError, render_template
@@ -97,6 +102,45 @@ from .template_utils import SystemPromptError, render_template
97
102
  # Constants
98
103
  DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
99
104
 
105
+
106
+ @dataclass
107
+ class Namespace:
108
+ """Compatibility class to mimic argparse.Namespace for existing code."""
109
+
110
+ task: Optional[str]
111
+ task_file: Optional[str]
112
+ file: List[str]
113
+ files: List[str]
114
+ dir: List[str]
115
+ allowed_dir: List[str]
116
+ base_dir: str
117
+ allowed_dir_file: Optional[str]
118
+ dir_recursive: bool
119
+ dir_ext: Optional[str]
120
+ var: List[str]
121
+ json_var: List[str]
122
+ system_prompt: Optional[str]
123
+ system_prompt_file: Optional[str]
124
+ ignore_task_sysprompt: bool
125
+ schema_file: str
126
+ model: str
127
+ temperature: float
128
+ max_tokens: Optional[int]
129
+ top_p: float
130
+ frequency_penalty: float
131
+ presence_penalty: float
132
+ timeout: float
133
+ output_file: Optional[str]
134
+ dry_run: bool
135
+ no_progress: bool
136
+ api_key: Optional[str]
137
+ verbose: bool
138
+ debug_openai_stream: bool
139
+ show_model_schema: bool
140
+ debug_validation: bool
141
+ progress_level: str = "basic" # Default to 'basic' if not specified
142
+
143
+
100
144
  # Set up logging
101
145
  logger = logging.getLogger(__name__)
102
146
 
@@ -360,65 +404,43 @@ V = TypeVar("V")
360
404
 
361
405
 
362
406
  def estimate_tokens_for_chat(
363
- messages: List[Dict[str, str]], model: str
407
+ messages: List[Dict[str, str]],
408
+ model: str,
409
+ encoder: Any = None,
364
410
  ) -> int:
365
- """Estimate the number of tokens in a chat completion."""
366
- try:
367
- encoding = tiktoken.encoding_for_model(model)
368
- except KeyError:
369
- # Fall back to cl100k_base for unknown models
370
- encoding = tiktoken.get_encoding("cl100k_base")
371
-
372
- num_tokens = 0
373
- for message in messages:
374
- # Add message overhead
375
- num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
376
- for key, value in message.items():
377
- num_tokens += len(encoding.encode(str(value)))
378
- if key == "name": # if there's a name, the role is omitted
379
- num_tokens += -1 # role is always required and always 1 token
380
- num_tokens += 2 # every reply is primed with <im_start>assistant
381
- return num_tokens
382
-
383
-
384
- def get_default_token_limit(model: str) -> int:
385
- """Get the default token limit for a given model.
386
-
387
- Note: These limits are based on current OpenAI model specifications as of 2024 and may
388
- need to be updated if OpenAI changes the models' capabilities.
411
+ """Estimate the number of tokens in a chat completion.
389
412
 
390
413
  Args:
391
- model: The model name (e.g., 'gpt-4o', 'gpt-4o-mini', 'o1')
392
-
393
- Returns:
394
- The default token limit for the model
414
+ messages: List of chat messages
415
+ model: Model name
416
+ encoder: Optional tiktoken encoder for testing. If provided, only uses encoder.encode() results.
395
417
  """
396
- if "o1" in model:
397
- return 100_000 # o1 supports up to 100K output tokens
398
- elif "gpt-4o" in model:
399
- return 16_384 # gpt-4o and gpt-4o-mini support up to 16K output tokens
400
- else:
401
- return 4_096 # default fallback
402
-
403
-
404
- def get_context_window_limit(model: str) -> int:
405
- """Get the total context window limit for a given model.
406
-
407
- Note: These limits are based on current OpenAI model specifications as of 2024 and may
408
- need to be updated if OpenAI changes the models' capabilities.
409
-
410
- Args:
411
- model: The model name (e.g., 'gpt-4o', 'gpt-4o-mini', 'o1')
412
-
413
- Returns:
414
- The context window limit for the model
415
- """
416
- if "o1" in model:
417
- return 200_000 # o1 supports 200K total context window
418
- elif "gpt-4o" in model:
419
- return 128_000 # gpt-4o and gpt-4o-mini support 128K context window
418
+ if encoder is None:
419
+ try:
420
+ # Try to get the encoding for the specific model
421
+ encoder = tiktoken.get_encoding("o200k_base")
422
+ except KeyError:
423
+ # Fall back to cl100k_base for unknown models
424
+ encoder = tiktoken.get_encoding("cl100k_base")
425
+
426
+ # Use standard token counting logic for real tiktoken encoders
427
+ num_tokens = 0
428
+ for message in messages:
429
+ # Add message overhead
430
+ num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
431
+ for key, value in message.items():
432
+ num_tokens += len(encoder.encode(str(value)))
433
+ if key == "name": # if there's a name, the role is omitted
434
+ num_tokens -= 1 # role is omitted
435
+ num_tokens += 2 # every reply is primed with <im_start>assistant
436
+ return num_tokens
420
437
  else:
421
- return 8_192 # default fallback
438
+ # For mock encoders in tests, just return the length of encoded content
439
+ num_tokens = 0
440
+ for message in messages:
441
+ for value in message.values():
442
+ num_tokens += len(encoder.encode(str(value)))
443
+ return num_tokens
422
444
 
423
445
 
424
446
  def validate_token_limits(
@@ -460,6 +482,7 @@ def validate_token_limits(
460
482
  def process_system_prompt(
461
483
  task_template: str,
462
484
  system_prompt: Optional[str],
485
+ system_prompt_file: Optional[str],
463
486
  template_context: Dict[str, Any],
464
487
  env: jinja2.Environment,
465
488
  ignore_task_sysprompt: bool = False,
@@ -468,7 +491,8 @@ def process_system_prompt(
468
491
 
469
492
  Args:
470
493
  task_template: The task template string
471
- system_prompt: Optional system prompt string or file path (with @ prefix)
494
+ system_prompt: Optional system prompt string
495
+ system_prompt_file: Optional path to system prompt file
472
496
  template_context: Template context for rendering
473
497
  env: Jinja2 environment
474
498
  ignore_task_sysprompt: Whether to ignore system prompt in task template
@@ -484,18 +508,24 @@ def process_system_prompt(
484
508
  # Default system prompt
485
509
  default_prompt = "You are a helpful assistant."
486
510
 
511
+ # Check for conflicting arguments
512
+ if system_prompt is not None and system_prompt_file is not None:
513
+ raise SystemPromptError(
514
+ "Cannot specify both --system-prompt and --system-prompt-file"
515
+ )
516
+
487
517
  # Try to get system prompt from CLI argument first
488
- if system_prompt:
489
- if system_prompt.startswith("@"):
490
- # Load from file
491
- path = system_prompt[1:]
492
- try:
493
- name, path = validate_path_mapping(f"system_prompt={path}")
494
- with open(path, "r", encoding="utf-8") as f:
495
- system_prompt = f.read().strip()
496
- except (FileNotFoundError, PathSecurityError) as e:
497
- raise SystemPromptError(f"Invalid system prompt file: {e}")
518
+ if system_prompt_file is not None:
519
+ try:
520
+ name, path = validate_path_mapping(
521
+ f"system_prompt={system_prompt_file}"
522
+ )
523
+ with open(path, "r", encoding="utf-8") as f:
524
+ system_prompt = f.read().strip()
525
+ except (FileNotFoundError, PathSecurityError) as e:
526
+ raise SystemPromptError(f"Invalid system prompt file: {e}")
498
527
 
528
+ if system_prompt is not None:
499
529
  # Render system prompt with template context
500
530
  try:
501
531
  template = env.from_string(system_prompt)
@@ -618,30 +648,45 @@ def _validate_path_mapping_internal(
618
648
  ValueError: If the format is invalid (missing "=").
619
649
  OSError: If there is an underlying OS error (permissions, etc.).
620
650
  """
651
+ logger = logging.getLogger(__name__)
652
+ logger.debug("Starting path validation for mapping: %r", mapping)
653
+ logger.debug("Parameters - is_dir: %r, base_dir: %r", is_dir, base_dir)
654
+
621
655
  try:
622
656
  if not mapping or "=" not in mapping:
657
+ logger.debug("Invalid mapping format: %r", mapping)
623
658
  raise ValueError(
624
659
  "Invalid path mapping format. Expected format: name=path"
625
660
  )
626
661
 
627
662
  name, path = mapping.split("=", 1)
663
+ logger.debug("Split mapping - name: %r, path: %r", name, path)
664
+
628
665
  if not name:
666
+ logger.debug("Empty name in mapping")
629
667
  raise VariableNameError(
630
668
  f"Empty name in {'directory' if is_dir else 'file'} mapping"
631
669
  )
632
670
 
633
671
  if not path:
672
+ logger.debug("Empty path in mapping")
634
673
  raise VariableValueError("Path cannot be empty")
635
674
 
636
675
  # Convert to Path object and resolve against base_dir if provided
676
+ logger.debug("Creating Path object for: %r", path)
637
677
  path_obj = Path(path)
638
678
  if base_dir:
679
+ logger.debug("Resolving against base_dir: %r", base_dir)
639
680
  path_obj = Path(base_dir) / path_obj
681
+ logger.debug("Path object created: %r", path_obj)
640
682
 
641
683
  # Resolve the path to catch directory traversal attempts
642
684
  try:
685
+ logger.debug("Attempting to resolve path: %r", path_obj)
643
686
  resolved_path = path_obj.resolve()
687
+ logger.debug("Resolved path: %r", resolved_path)
644
688
  except OSError as e:
689
+ logger.error("Failed to resolve path: %s", e)
645
690
  raise OSError(f"Failed to resolve path: {e}")
646
691
 
647
692
  # Check for directory traversal
@@ -686,7 +731,9 @@ def _validate_path_mapping_internal(
686
731
  raise
687
732
 
688
733
  if security_manager:
689
- if not security_manager.is_allowed_file(str(resolved_path)):
734
+ try:
735
+ security_manager.validate_path(str(resolved_path))
736
+ except PathSecurityError:
690
737
  raise PathSecurityError.from_expanded_paths(
691
738
  original_path=str(path),
692
739
  expanded_path=str(resolved_path),
@@ -709,34 +756,45 @@ def _validate_path_mapping_internal(
709
756
  raise
710
757
 
711
758
 
712
- def validate_task_template(task: str) -> str:
759
+ def validate_task_template(
760
+ task: Optional[str], task_file: Optional[str]
761
+ ) -> str:
713
762
  """Validate and load a task template.
714
763
 
715
764
  Args:
716
- task: The task template string or path to task template file (with @ prefix)
765
+ task: The task template string
766
+ task_file: Path to task template file
717
767
 
718
768
  Returns:
719
769
  The task template string
720
770
 
721
771
  Raises:
722
- TaskTemplateVariableError: If the template file cannot be read or is invalid
772
+ TaskTemplateVariableError: If neither task nor task_file is provided, or if both are provided
723
773
  TaskTemplateSyntaxError: If the template has invalid syntax
724
774
  FileNotFoundError: If the template file does not exist
725
775
  PathSecurityError: If the template file path violates security constraints
726
776
  """
727
- template_content = task
777
+ if task is not None and task_file is not None:
778
+ raise TaskTemplateVariableError(
779
+ "Cannot specify both --task and --task-file"
780
+ )
781
+
782
+ if task is None and task_file is None:
783
+ raise TaskTemplateVariableError(
784
+ "Must specify either --task or --task-file"
785
+ )
728
786
 
729
- # Check if task is a file path
730
- if task.startswith("@"):
731
- path = task[1:]
787
+ template_content: str
788
+ if task_file is not None:
732
789
  try:
733
- name, path = validate_path_mapping(f"task={path}")
790
+ name, path = validate_path_mapping(f"task={task_file}")
734
791
  with open(path, "r", encoding="utf-8") as f:
735
792
  template_content = f.read()
736
793
  except (FileNotFoundError, PathSecurityError) as e:
737
- raise TaskTemplateVariableError(f"Invalid task template file: {e}")
794
+ raise TaskTemplateVariableError(str(e))
795
+ else:
796
+ template_content = task # type: ignore # We know task is str here due to the checks above
738
797
 
739
- # Validate template syntax
740
798
  try:
741
799
  env = jinja2.Environment(undefined=jinja2.StrictUndefined)
742
800
  env.parse(template_content)
@@ -813,7 +871,7 @@ def validate_schema_file(
813
871
 
814
872
 
815
873
  def collect_template_files(
816
- args: argparse.Namespace,
874
+ args: Namespace,
817
875
  security_manager: SecurityManager,
818
876
  ) -> Dict[str, TemplateValue]:
819
877
  """Collect files from command line arguments.
@@ -846,14 +904,17 @@ def collect_template_files(
846
904
  # Wrap file-related errors
847
905
  raise ValueError(f"File access error: {e}")
848
906
  except Exception as e:
907
+ # Don't wrap InvalidJSONError
908
+ if isinstance(e, InvalidJSONError):
909
+ raise
849
910
  # Check if this is a wrapped security error
850
911
  if isinstance(e.__cause__, PathSecurityError):
851
912
  raise e.__cause__
852
- # Wrap unexpected errors
913
+ # Wrap other errors
853
914
  raise ValueError(f"Error collecting files: {e}")
854
915
 
855
916
 
856
- def collect_simple_variables(args: argparse.Namespace) -> Dict[str, str]:
917
+ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
857
918
  """Collect simple string variables from --var arguments.
858
919
 
859
920
  Args:
@@ -886,7 +947,7 @@ def collect_simple_variables(args: argparse.Namespace) -> Dict[str, str]:
886
947
  return variables
887
948
 
888
949
 
889
- def collect_json_variables(args: argparse.Namespace) -> Dict[str, Any]:
950
+ def collect_json_variables(args: Namespace) -> Dict[str, Any]:
890
951
  """Collect JSON variables from --json-var arguments.
891
952
 
892
953
  Args:
@@ -916,7 +977,7 @@ def collect_json_variables(args: argparse.Namespace) -> Dict[str, Any]:
916
977
  all_names.add(name)
917
978
  except json.JSONDecodeError as e:
918
979
  raise InvalidJSONError(
919
- f"Invalid JSON value for {name}: {str(e)}"
980
+ f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
920
981
  )
921
982
  except ValueError:
922
983
  raise VariableNameError(
@@ -972,7 +1033,7 @@ def create_template_context(
972
1033
 
973
1034
 
974
1035
  def create_template_context_from_args(
975
- args: argparse.Namespace,
1036
+ args: "Namespace",
976
1037
  security_manager: SecurityManager,
977
1038
  ) -> Dict[str, Any]:
978
1039
  """Create template context from command line arguments.
@@ -1024,7 +1085,7 @@ def create_template_context_from_args(
1024
1085
  json_value = json.loads(value)
1025
1086
  except json.JSONDecodeError as e:
1026
1087
  raise InvalidJSONError(
1027
- f"Invalid JSON value for {name} ({value!r}): {str(e)}"
1088
+ f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {value}"
1028
1089
  )
1029
1090
  if name in json_variables:
1030
1091
  raise VariableNameError(
@@ -1060,42 +1121,59 @@ def create_template_context_from_args(
1060
1121
  # Wrap file-related errors
1061
1122
  raise ValueError(f"File access error: {e}")
1062
1123
  except Exception as e:
1124
+ # Don't wrap InvalidJSONError
1125
+ if isinstance(e, InvalidJSONError):
1126
+ raise
1063
1127
  # Check if this is a wrapped security error
1064
1128
  if isinstance(e.__cause__, PathSecurityError):
1065
1129
  raise e.__cause__
1066
- # Wrap unexpected errors
1130
+ # Wrap other errors
1067
1131
  raise ValueError(f"Error collecting files: {e}")
1068
1132
 
1069
1133
 
1070
1134
  def validate_security_manager(
1071
1135
  base_dir: Optional[str] = None,
1072
1136
  allowed_dirs: Optional[List[str]] = None,
1073
- allowed_dirs_file: Optional[str] = None,
1137
+ allowed_dir_file: Optional[str] = None,
1074
1138
  ) -> SecurityManager:
1075
- """Create and validate a security manager.
1139
+ """Validate and create security manager.
1076
1140
 
1077
1141
  Args:
1078
- base_dir: Optional base directory to resolve paths against
1079
- allowed_dirs: Optional list of allowed directory paths
1080
- allowed_dirs_file: Optional path to file containing allowed directories
1142
+ base_dir: Base directory for file access. Defaults to current working directory.
1143
+ allowed_dirs: Optional list of additional allowed directories
1144
+ allowed_dir_file: Optional file containing allowed directories
1081
1145
 
1082
1146
  Returns:
1083
1147
  Configured SecurityManager instance
1084
1148
 
1085
1149
  Raises:
1086
- FileNotFoundError: If allowed_dirs_file does not exist
1087
- PathSecurityError: If any paths are outside base directory
1150
+ PathSecurityError: If any paths violate security constraints
1151
+ DirectoryNotFoundError: If any directories do not exist
1088
1152
  """
1089
- # Convert base_dir to string if it's a Path
1090
- base_dir_str = str(base_dir) if base_dir else None
1091
- security_manager = SecurityManager(base_dir_str)
1153
+ # Use current working directory if base_dir is None
1154
+ if base_dir is None:
1155
+ base_dir = os.getcwd()
1092
1156
 
1093
- if allowed_dirs_file:
1094
- security_manager.add_allowed_dirs_from_file(str(allowed_dirs_file))
1157
+ # Create security manager with base directory
1158
+ security_manager = SecurityManager(base_dir)
1095
1159
 
1160
+ # Add explicitly allowed directories
1096
1161
  if allowed_dirs:
1097
- for allowed_dir in allowed_dirs:
1098
- security_manager.add_allowed_dir(str(allowed_dir))
1162
+ for dir_path in allowed_dirs:
1163
+ security_manager.add_allowed_directory(dir_path)
1164
+
1165
+ # Add directories from file if specified
1166
+ if allowed_dir_file:
1167
+ try:
1168
+ with open(allowed_dir_file, "r", encoding="utf-8") as f:
1169
+ for line in f:
1170
+ line = line.strip()
1171
+ if line and not line.startswith("#"):
1172
+ security_manager.add_allowed_directory(line)
1173
+ except OSError as e:
1174
+ raise DirectoryNotFoundError(
1175
+ f"Failed to read allowed directories file: {e}"
1176
+ )
1099
1177
 
1100
1178
  return security_manager
1101
1179
 
@@ -1157,8 +1235,8 @@ def parse_json_var(var_str: str) -> Tuple[str, Any]:
1157
1235
  value = json.loads(json_str)
1158
1236
  except json.JSONDecodeError as e:
1159
1237
  raise InvalidJSONError(
1160
- f"Invalid JSON value for variable {name!r}: {json_str!r}"
1161
- ) from e
1238
+ f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
1239
+ )
1162
1240
 
1163
1241
  return name, value
1164
1242
 
@@ -1205,582 +1283,366 @@ def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
1205
1283
  return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
1206
1284
 
1207
1285
 
1208
- def create_argument_parser() -> argparse.ArgumentParser:
1209
- """Create argument parser for CLI."""
1210
- parser = argparse.ArgumentParser(
1211
- description="Make structured OpenAI API calls.",
1212
- formatter_class=argparse.RawDescriptionHelpFormatter,
1213
- )
1214
-
1215
- # Debug output options
1216
- debug_group = parser.add_argument_group("Debug Output Options")
1217
- debug_group.add_argument(
1218
- "--show-model-schema",
1219
- action="store_true",
1220
- help="Display the generated Pydantic model schema",
1221
- )
1222
- debug_group.add_argument(
1223
- "--debug-validation",
1224
- action="store_true",
1225
- help="Show detailed schema validation debugging information",
1226
- )
1227
- debug_group.add_argument(
1228
- "--verbose-schema",
1229
- action="store_true",
1230
- help="Enable verbose schema debugging output",
1231
- )
1232
- debug_group.add_argument(
1233
- "--progress-level",
1234
- choices=["none", "basic", "detailed"],
1235
- default="basic",
1236
- help="Set the level of progress reporting (default: basic)",
1237
- )
1238
-
1239
- # Required arguments
1240
- parser.add_argument(
1241
- "--task",
1242
- required=True,
1243
- help="Task template string or @file",
1244
- )
1245
-
1246
- # File access arguments
1247
- parser.add_argument(
1248
- "--file",
1249
- action="append",
1250
- default=[],
1251
- help="Map file to variable (name=path)",
1252
- metavar="NAME=PATH",
1253
- )
1254
- parser.add_argument(
1255
- "--files",
1256
- action="append",
1257
- default=[],
1258
- help="Map file pattern to variable (name=pattern)",
1259
- metavar="NAME=PATTERN",
1260
- )
1261
- parser.add_argument(
1262
- "--dir",
1263
- action="append",
1264
- default=[],
1265
- help="Map directory to variable (name=path)",
1266
- metavar="NAME=PATH",
1267
- )
1268
- parser.add_argument(
1269
- "--allowed-dir",
1270
- action="append",
1271
- default=[],
1272
- help="Additional allowed directory or @file",
1273
- metavar="PATH",
1274
- )
1275
- parser.add_argument(
1276
- "--base-dir",
1277
- help="Base directory for file access (defaults to current directory)",
1278
- default=os.getcwd(),
1279
- )
1280
- parser.add_argument(
1281
- "--allowed-dirs-file",
1282
- help="File containing list of allowed directories",
1283
- )
1284
- parser.add_argument(
1285
- "--dir-recursive",
1286
- action="store_true",
1287
- help="Process directories recursively",
1288
- )
1289
- parser.add_argument(
1290
- "--dir-ext",
1291
- help="Comma-separated list of file extensions to include in directory processing",
1292
- )
1293
-
1294
- # Variable arguments
1295
- parser.add_argument(
1296
- "--var",
1297
- action="append",
1298
- default=[],
1299
- help="Pass simple variables (name=value)",
1300
- metavar="NAME=VALUE",
1301
- )
1302
- parser.add_argument(
1303
- "--json-var",
1304
- action="append",
1305
- default=[],
1306
- help="Pass JSON variables (name=json)",
1307
- metavar="NAME=JSON",
1308
- )
1309
-
1310
- # System prompt options
1311
- parser.add_argument(
1312
- "--system-prompt",
1313
- help=(
1314
- "System prompt for the model (use @file to load from file, "
1315
- "can also be specified in task template YAML frontmatter)"
1316
- ),
1317
- default=DEFAULT_SYSTEM_PROMPT,
1318
- )
1319
- parser.add_argument(
1320
- "--ignore-task-sysprompt",
1321
- action="store_true",
1322
- help="Ignore system prompt from task template YAML frontmatter",
1323
- )
1286
+ def handle_error(e: Exception) -> None:
1287
+ """Handle errors by printing appropriate message and exiting with status code."""
1288
+ if isinstance(e, click.UsageError):
1289
+ # For UsageError, preserve the original message format
1290
+ if hasattr(e, "param") and e.param:
1291
+ # Missing parameter error
1292
+ msg = f"Missing option '--{e.param.name}'"
1293
+ click.echo(msg, err=True)
1294
+ else:
1295
+ # Other usage errors (like conflicting options)
1296
+ click.echo(str(e), err=True)
1297
+ sys.exit(ExitCode.USAGE_ERROR)
1298
+ elif isinstance(e, InvalidJSONError):
1299
+ # Use the original error message if available
1300
+ msg = str(e) if str(e) != "None" else "Invalid JSON"
1301
+ click.secho(msg, fg="red", err=True)
1302
+ sys.exit(ExitCode.DATA_ERROR)
1303
+ elif isinstance(e, FileNotFoundError):
1304
+ # Use the original error message if available
1305
+ msg = str(e) if str(e) != "None" else "File not found"
1306
+ click.secho(msg, fg="red", err=True)
1307
+ sys.exit(ExitCode.SCHEMA_ERROR)
1308
+ elif isinstance(e, TaskTemplateSyntaxError):
1309
+ # Use the original error message if available
1310
+ msg = str(e) if str(e) != "None" else "Template syntax error"
1311
+ click.secho(msg, fg="red", err=True)
1312
+ sys.exit(ExitCode.INTERNAL_ERROR)
1313
+ elif isinstance(e, CLIError):
1314
+ # Use the show method for CLIError and its subclasses
1315
+ e.show()
1316
+ sys.exit(
1317
+ e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
1318
+ )
1319
+ else:
1320
+ click.secho(f"Unexpected error: {str(e)}", fg="red", err=True)
1321
+ sys.exit(ExitCode.INTERNAL_ERROR)
1322
+
1323
+
1324
+ async def stream_structured_output(
1325
+ client: AsyncOpenAI,
1326
+ model: str,
1327
+ system_prompt: str,
1328
+ user_prompt: str,
1329
+ output_schema: Type[BaseModel],
1330
+ output_file: Optional[str] = None,
1331
+ **kwargs: Any,
1332
+ ) -> None:
1333
+ """Stream structured output from OpenAI API.
1324
1334
 
1325
- # Schema validation
1326
- parser.add_argument(
1327
- "--schema",
1328
- dest="schema_file",
1329
- required=True,
1330
- help="JSON schema file for response validation",
1331
- )
1332
- parser.add_argument(
1333
- "--validate-schema",
1334
- action="store_true",
1335
- help="Validate schema and response",
1336
- )
1335
+ This function follows the guide's recommendation for a focused async streaming function.
1336
+ It handles the core streaming logic and resource cleanup.
1337
+ """
1338
+ try:
1339
+ # Base models that don't support streaming
1340
+ non_streaming_models = {"o1", "o3"}
1337
1341
 
1338
- # Model configuration
1339
- parser.add_argument(
1340
- "--model",
1341
- default="gpt-4o-2024-08-06",
1342
- help="Model to use",
1343
- )
1344
- parser.add_argument(
1345
- "--temperature",
1346
- type=float,
1347
- default=0.0,
1348
- help="Temperature (0.0-2.0)",
1349
- )
1350
- parser.add_argument(
1351
- "--max-tokens",
1352
- type=int,
1353
- help="Maximum tokens to generate",
1354
- )
1355
- parser.add_argument(
1356
- "--top-p",
1357
- type=float,
1358
- default=1.0,
1359
- help="Top-p sampling (0.0-1.0)",
1360
- )
1361
- parser.add_argument(
1362
- "--frequency-penalty",
1363
- type=float,
1364
- default=0.0,
1365
- help="Frequency penalty (-2.0-2.0)",
1366
- )
1367
- parser.add_argument(
1368
- "--presence-penalty",
1369
- type=float,
1370
- default=0.0,
1371
- help="Presence penalty (-2.0-2.0)",
1372
- )
1373
- parser.add_argument(
1374
- "--timeout",
1375
- type=float,
1376
- default=60.0,
1377
- help="API timeout in seconds",
1378
- )
1342
+ # Check if model supports streaming
1343
+ # o3-mini and o3-mini-high support streaming, base o3 does not
1344
+ use_streaming = model not in non_streaming_models and (
1345
+ not model.startswith("o3") or model.startswith("o3-mini")
1346
+ )
1379
1347
 
1380
- # Output options
1381
- parser.add_argument(
1382
- "--output-file",
1383
- help="Write JSON output to file",
1384
- )
1385
- parser.add_argument(
1386
- "--dry-run",
1387
- action="store_true",
1388
- help="Simulate API call without making request",
1389
- )
1390
- parser.add_argument(
1391
- "--no-progress",
1392
- action="store_true",
1393
- help="Disable progress indicators",
1394
- )
1348
+ # All o1 and o3 models (base and variants) have fixed settings
1349
+ stream_kwargs = {}
1350
+ if not (model.startswith("o1") or model.startswith("o3")):
1351
+ stream_kwargs = kwargs
1352
+
1353
+ if use_streaming:
1354
+ async for chunk in async_openai_structured_stream(
1355
+ client=client,
1356
+ model=model,
1357
+ output_schema=output_schema,
1358
+ system_prompt=system_prompt,
1359
+ user_prompt=user_prompt,
1360
+ **stream_kwargs,
1361
+ ):
1362
+ if not chunk:
1363
+ continue
1364
+
1365
+ # Process and output the chunk
1366
+ dumped = chunk.model_dump(mode="json")
1367
+ json_str = json.dumps(dumped, indent=2)
1368
+
1369
+ if output_file:
1370
+ with open(output_file, "a", encoding="utf-8") as f:
1371
+ f.write(json_str)
1372
+ f.write("\n")
1373
+ f.flush() # Ensure immediate flush to file
1374
+ else:
1375
+ # Print directly to stdout with immediate flush
1376
+ print(json_str, flush=True)
1377
+ else:
1378
+ # For non-streaming models, use regular completion
1379
+ response = await client.chat.completions.create(
1380
+ model=model,
1381
+ messages=[
1382
+ {"role": "system", "content": system_prompt},
1383
+ {"role": "user", "content": user_prompt},
1384
+ ],
1385
+ stream=False,
1386
+ **stream_kwargs,
1387
+ )
1395
1388
 
1396
- # Other options
1397
- parser.add_argument(
1398
- "--api-key",
1399
- help="OpenAI API key (overrides env var)",
1400
- )
1401
- parser.add_argument(
1402
- "--verbose",
1403
- action="store_true",
1404
- help="Enable verbose output",
1405
- )
1406
- parser.add_argument(
1407
- "--debug-openai-stream",
1408
- action="store_true",
1409
- help="Enable low-level debug output for OpenAI streaming (very verbose)",
1410
- )
1411
- parser.add_argument(
1412
- "--version",
1413
- action="version",
1414
- version=f"%(prog)s {__version__}",
1415
- )
1389
+ # Process the single response
1390
+ content = response.choices[0].message.content
1391
+ if content:
1392
+ try:
1393
+ # Parse and validate against schema
1394
+ result = output_schema.model_validate_json(content)
1395
+ json_str = json.dumps(
1396
+ result.model_dump(mode="json"), indent=2
1397
+ )
1416
1398
 
1417
- return parser
1399
+ if output_file:
1400
+ with open(output_file, "w", encoding="utf-8") as f:
1401
+ f.write(json_str)
1402
+ f.write("\n")
1403
+ else:
1404
+ print(json_str, flush=True)
1405
+ except ValidationError as e:
1406
+ raise InvalidResponseFormatError(
1407
+ f"Response validation failed: {e}"
1408
+ )
1409
+ else:
1410
+ raise EmptyResponseError("Model returned empty response")
1411
+
1412
+ except (
1413
+ StreamInterruptedError,
1414
+ StreamBufferError,
1415
+ StreamParseError,
1416
+ APIResponseError,
1417
+ EmptyResponseError,
1418
+ InvalidResponseFormatError,
1419
+ ) as e:
1420
+ logger.error(f"Stream error: {e}")
1421
+ raise
1422
+ finally:
1423
+ # Always ensure client is properly closed
1424
+ await client.close()
1418
1425
 
1419
1426
 
1420
- async def _main() -> ExitCode:
1421
- """Main CLI function.
1427
+ async def run_cli_async(args: Namespace) -> ExitCode:
1428
+ """Async wrapper for CLI operations.
1422
1429
 
1423
- Returns:
1424
- ExitCode: Exit code indicating success or failure
1430
+ This function prepares everything needed for streaming and then calls
1431
+ the focused streaming function.
1425
1432
  """
1426
1433
  try:
1427
- parser = create_argument_parser()
1428
- args = parser.parse_args()
1429
-
1430
- # Configure logging
1431
- log_level = logging.DEBUG if args.verbose else logging.INFO
1432
- logger.setLevel(log_level)
1433
-
1434
- # Create security manager
1434
+ # Validate and prepare all inputs
1435
1435
  security_manager = validate_security_manager(
1436
1436
  base_dir=args.base_dir,
1437
1437
  allowed_dirs=args.allowed_dir,
1438
- allowed_dirs_file=args.allowed_dirs_file,
1438
+ allowed_dir_file=args.allowed_dir_file,
1439
1439
  )
1440
1440
 
1441
- # Validate task template
1442
- task_template = validate_task_template(args.task)
1443
-
1444
- # Validate schema file
1441
+ task_template = validate_task_template(args.task, args.task_file)
1442
+ logger.debug("Validating schema from %s", args.schema_file)
1445
1443
  schema = validate_schema_file(args.schema_file, args.verbose)
1446
-
1447
- # Create template context
1448
1444
  template_context = create_template_context_from_args(
1449
1445
  args, security_manager
1450
1446
  )
1451
-
1452
- # Create Jinja environment
1453
1447
  env = create_jinja_env()
1454
1448
 
1455
- # Process system prompt
1456
- args.system_prompt = process_system_prompt(
1449
+ # Process system prompt and render task
1450
+ system_prompt = process_system_prompt(
1457
1451
  task_template,
1458
1452
  args.system_prompt,
1453
+ args.system_prompt_file,
1459
1454
  template_context,
1460
1455
  env,
1461
1456
  args.ignore_task_sysprompt,
1462
1457
  )
1463
-
1464
- # Render task template
1465
1458
  rendered_task = render_template(task_template, template_context, env)
1466
- logger.info(rendered_task) # Log the rendered template
1459
+ logger.info("Rendered task template: %s", rendered_task)
1467
1460
 
1468
- # If dry run, exit here
1469
1461
  if args.dry_run:
1470
1462
  logger.info("DRY RUN MODE")
1471
1463
  return ExitCode.SUCCESS
1472
1464
 
1473
- # Load and validate schema
1465
+ # Create output model
1466
+ logger.debug("Creating output model")
1474
1467
  try:
1475
- logger.debug("[_main] Loading schema from %s", args.schema_file)
1476
- schema = validate_schema_file(
1477
- args.schema_file, verbose=args.verbose_schema
1478
- )
1479
- logger.debug("[_main] Creating output model")
1480
1468
  output_model = create_dynamic_model(
1481
1469
  schema,
1482
1470
  base_name="OutputModel",
1483
1471
  show_schema=args.show_model_schema,
1484
1472
  debug_validation=args.debug_validation,
1485
1473
  )
1486
- logger.debug("[_main] Successfully created output model")
1487
- except (SchemaFileError, InvalidJSONError, SchemaValidationError) as e:
1488
- logger.error(str(e))
1489
- return ExitCode.SCHEMA_ERROR
1490
- except ModelCreationError as e:
1491
- logger.error(f"Model creation error: {e}")
1492
- return ExitCode.SCHEMA_ERROR
1493
- except Exception as e:
1494
- logger.error(f"Unexpected error creating model: {e}")
1495
- return ExitCode.SCHEMA_ERROR
1496
-
1497
- # Validate model support
1474
+ logger.debug("Successfully created output model")
1475
+ except (
1476
+ SchemaFileError,
1477
+ InvalidJSONError,
1478
+ SchemaValidationError,
1479
+ ModelCreationError,
1480
+ ) as e:
1481
+ logger.error("Schema error: %s", str(e))
1482
+ raise # Let the error propagate with its context
1483
+
1484
+ # Validate model support and token usage
1498
1485
  try:
1499
1486
  supports_structured_output(args.model)
1500
- except ModelNotSupportedError as e:
1501
- logger.error(str(e))
1502
- return ExitCode.DATA_ERROR
1503
- except ModelVersionError as e:
1504
- logger.error(str(e))
1505
- return ExitCode.DATA_ERROR
1506
-
1507
- # Estimate token usage
1487
+ except (ModelNotSupportedError, ModelVersionError) as e:
1488
+ logger.error("Model validation error: %s", str(e))
1489
+ raise # Let the error propagate with its context
1490
+
1508
1491
  messages = [
1509
- {"role": "system", "content": args.system_prompt},
1492
+ {"role": "system", "content": system_prompt},
1510
1493
  {"role": "user", "content": rendered_task},
1511
1494
  ]
1512
1495
  total_tokens = estimate_tokens_for_chat(messages, args.model)
1513
1496
  context_limit = get_context_window_limit(args.model)
1514
-
1515
1497
  if total_tokens > context_limit:
1516
- logger.error(
1517
- f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
1498
+ msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
1499
+ logger.error(msg)
1500
+ raise CLIError(
1501
+ msg,
1502
+ context={
1503
+ "total_tokens": total_tokens,
1504
+ "context_limit": context_limit,
1505
+ },
1518
1506
  )
1519
- return ExitCode.DATA_ERROR
1520
1507
 
1521
- # Get API key
1508
+ # Get API key and create client
1522
1509
  api_key = args.api_key or os.getenv("OPENAI_API_KEY")
1523
1510
  if not api_key:
1524
- logger.error(
1525
- "No OpenAI API key provided (--api-key or OPENAI_API_KEY env var)"
1526
- )
1527
- return ExitCode.USAGE_ERROR
1511
+ msg = "No OpenAI API key provided (--api-key or OPENAI_API_KEY env var)"
1512
+ logger.error(msg)
1513
+ raise CLIError(msg)
1528
1514
 
1529
- # Create OpenAI client
1530
1515
  client = AsyncOpenAI(api_key=api_key, timeout=args.timeout)
1531
1516
 
1532
- # Create log callback that matches expected signature
1517
+ # Create detailed log callback
1533
1518
  def log_callback(
1534
1519
  level: int, message: str, extra: dict[str, Any]
1535
1520
  ) -> None:
1536
- # Only log if debug_openai_stream is enabled
1537
1521
  if args.debug_openai_stream:
1538
- # Include extra dictionary in the message for both DEBUG and ERROR
1539
- if extra: # Only add if there's actually extra data
1522
+ if extra:
1540
1523
  extra_str = json.dumps(extra, indent=2)
1541
1524
  message = f"{message}\nDetails:\n{extra_str}"
1542
- openai_logger.log(level, message, extra=extra)
1525
+ logger.log(level, message, extra=extra)
1543
1526
 
1544
- # Make API request
1527
+ # Stream the output
1545
1528
  try:
1546
- logger.debug("Creating ProgressContext for API response handling")
1547
- with ProgressContext(
1548
- description="Processing API response",
1549
- level=args.progress_level,
1550
- ) as progress:
1551
- logger.debug("Starting API response stream processing")
1552
- logger.debug("Debug flag status: %s", args.debug_openai_stream)
1553
- logger.debug("OpenAI logger level: %s", openai_logger.level)
1554
- for handler in openai_logger.handlers:
1555
- logger.debug("Handler level: %s", handler.level)
1556
- async for chunk in async_openai_structured_stream(
1557
- client=client,
1558
- model=args.model,
1559
- temperature=args.temperature,
1560
- max_tokens=args.max_tokens,
1561
- top_p=args.top_p,
1562
- frequency_penalty=args.frequency_penalty,
1563
- presence_penalty=args.presence_penalty,
1564
- system_prompt=args.system_prompt,
1565
- user_prompt=rendered_task,
1566
- output_schema=output_model,
1567
- timeout=args.timeout,
1568
- on_log=log_callback,
1569
- ):
1570
- logger.debug("Received API response chunk")
1571
- if not chunk:
1572
- logger.debug("Empty chunk received, skipping")
1573
- continue
1574
-
1575
- # Write output
1576
- try:
1577
- logger.debug("Starting to process output chunk")
1578
- dumped = chunk.model_dump(mode="json")
1579
- logger.debug("Successfully dumped chunk to JSON")
1580
- logger.debug("Dumped chunk: %s", dumped)
1581
- logger.debug(
1582
- "Chunk type: %s, length: %d",
1583
- type(dumped),
1584
- len(json.dumps(dumped)),
1585
- )
1586
-
1587
- if args.output_file:
1588
- logger.debug(
1589
- "Writing to output file: %s", args.output_file
1590
- )
1591
- try:
1592
- with open(
1593
- args.output_file, "a", encoding="utf-8"
1594
- ) as f:
1595
- json_str = json.dumps(dumped, indent=2)
1596
- logger.debug(
1597
- "Writing JSON string of length %d",
1598
- len(json_str),
1599
- )
1600
- f.write(json_str)
1601
- f.write("\n")
1602
- logger.debug("Successfully wrote to file")
1603
- except Exception as e:
1604
- logger.error(
1605
- "Failed to write to output file: %s", e
1606
- )
1607
- else:
1608
- logger.debug(
1609
- "About to call progress.print_output with JSON string"
1610
- )
1611
- json_str = json.dumps(dumped, indent=2)
1612
- logger.debug(
1613
- "JSON string length before print_output: %d",
1614
- len(json_str),
1615
- )
1616
- logger.debug(
1617
- "First 100 chars of JSON string: %s",
1618
- json_str[:100] if json_str else "",
1619
- )
1620
- progress.print_output(json_str)
1621
- logger.debug(
1622
- "Completed print_output call for JSON string"
1623
- )
1624
-
1625
- logger.debug("Starting progress update")
1626
- progress.update()
1627
- logger.debug("Completed progress update")
1628
- except Exception as e:
1629
- logger.error("Failed to process chunk: %s", e)
1630
- logger.error("Chunk: %s", chunk)
1631
- continue
1632
-
1633
- logger.debug("Finished processing API response stream")
1634
-
1635
- except StreamInterruptedError as e:
1636
- logger.error(f"Stream interrupted: {e}")
1637
- return ExitCode.API_ERROR
1638
- except StreamBufferError as e:
1639
- logger.error(f"Stream buffer error: {e}")
1640
- return ExitCode.API_ERROR
1641
- except StreamParseError as e:
1642
- logger.error(f"Stream parse error: {e}")
1643
- return ExitCode.API_ERROR
1644
- except APIResponseError as e:
1645
- logger.error(f"API response error: {e}")
1646
- return ExitCode.API_ERROR
1647
- except EmptyResponseError as e:
1648
- logger.error(f"Empty response error: {e}")
1649
- return ExitCode.API_ERROR
1650
- except InvalidResponseFormatError as e:
1651
- logger.error(f"Invalid response format: {e}")
1652
- return ExitCode.API_ERROR
1529
+ await stream_structured_output(
1530
+ client=client,
1531
+ model=args.model,
1532
+ system_prompt=system_prompt,
1533
+ user_prompt=rendered_task,
1534
+ output_schema=output_model,
1535
+ output_file=args.output_file,
1536
+ temperature=args.temperature,
1537
+ max_tokens=args.max_tokens,
1538
+ top_p=args.top_p,
1539
+ frequency_penalty=args.frequency_penalty,
1540
+ presence_penalty=args.presence_penalty,
1541
+ timeout=args.timeout,
1542
+ on_log=log_callback,
1543
+ )
1544
+ return ExitCode.SUCCESS
1545
+ except (
1546
+ StreamInterruptedError,
1547
+ StreamBufferError,
1548
+ StreamParseError,
1549
+ APIResponseError,
1550
+ EmptyResponseError,
1551
+ InvalidResponseFormatError,
1552
+ ) as e:
1553
+ logger.error("Stream error: %s", str(e))
1554
+ raise # Let stream errors propagate
1653
1555
  except (APIConnectionError, InternalServerError) as e:
1654
- logger.error(f"API connection error: {e}")
1655
- return ExitCode.API_ERROR
1556
+ logger.error("API connection error: %s", str(e))
1557
+ raise APIResponseError(str(e)) # Convert to our error type
1656
1558
  except RateLimitError as e:
1657
- logger.error(f"Rate limit exceeded: {e}")
1658
- return ExitCode.API_ERROR
1659
- except BadRequestError as e:
1660
- logger.error(f"Bad request: {e}")
1661
- return ExitCode.API_ERROR
1662
- except AuthenticationError as e:
1663
- logger.error(f"Authentication failed: {e}")
1664
- return ExitCode.API_ERROR
1665
- except OpenAIClientError as e:
1666
- logger.error(f"OpenAI client error: {e}")
1667
- return ExitCode.API_ERROR
1668
- except Exception as e:
1669
- logger.error(f"Unexpected error: {e}")
1670
- return ExitCode.INTERNAL_ERROR
1671
-
1672
- return ExitCode.SUCCESS
1559
+ logger.error("Rate limit exceeded: %s", str(e))
1560
+ raise APIResponseError(str(e)) # Convert to our error type
1561
+ except (BadRequestError, AuthenticationError, OpenAIClientError) as e:
1562
+ logger.error("API client error: %s", str(e))
1563
+ raise APIResponseError(str(e)) # Convert to our error type
1564
+ finally:
1565
+ await client.close()
1673
1566
 
1674
1567
  except KeyboardInterrupt:
1675
- logger.error("Operation cancelled by user")
1568
+ logger.info("Operation cancelled by user")
1676
1569
  return ExitCode.INTERRUPTED
1677
- except PathSecurityError as e:
1678
- # Only log security errors if they haven't been logged already
1679
- logger.debug(
1680
- "[_main] Caught PathSecurityError: %s (logged=%s)",
1681
- str(e),
1682
- getattr(e, "has_been_logged", False),
1683
- )
1684
- if not getattr(e, "has_been_logged", False):
1685
- logger.error(str(e))
1686
- return ExitCode.SECURITY_ERROR
1687
- except ValueError as e:
1688
- # Get the original cause of the error
1689
- cause = e.__cause__ or e.__context__
1690
- if isinstance(cause, PathSecurityError):
1691
- logger.debug(
1692
- "[_main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
1693
- str(cause),
1694
- getattr(cause, "has_been_logged", False),
1695
- )
1696
- # Only log security errors if they haven't been logged already
1697
- if not getattr(cause, "has_been_logged", False):
1698
- logger.error(str(cause))
1699
- return ExitCode.SECURITY_ERROR
1700
- else:
1701
- logger.debug("[_main] Caught ValueError: %s", str(e))
1702
- logger.error(f"Invalid input: {e}")
1703
- return ExitCode.DATA_ERROR
1704
1570
  except Exception as e:
1705
- # Check if this is a wrapped security error
1706
- if isinstance(e.__cause__, PathSecurityError):
1707
- logger.debug(
1708
- "[_main] Caught wrapped PathSecurityError in Exception: %s (logged=%s)",
1709
- str(e.__cause__),
1710
- getattr(e.__cause__, "has_been_logged", False),
1711
- )
1712
- # Only log security errors if they haven't been logged already
1713
- if not getattr(e.__cause__, "has_been_logged", False):
1714
- logger.error(str(e.__cause__))
1715
- return ExitCode.SECURITY_ERROR
1716
- logger.debug("[_main] Caught unexpected error: %s", str(e))
1717
- logger.error(f"Unexpected error: {e}")
1718
- return ExitCode.INTERNAL_ERROR
1571
+ if isinstance(e, CLIError):
1572
+ raise # Let our custom errors propagate
1573
+ logger.exception("Unexpected error")
1574
+ raise CLIError(str(e), context={"error_type": type(e).__name__})
1575
+
1576
+
1577
+ def create_cli() -> click.Command:
1578
+ """Create the CLI command.
1579
+
1580
+ Returns:
1581
+ click.Command: The CLI command object
1582
+ """
1583
+
1584
+ @create_click_command()
1585
+ def cli(**kwargs: Any) -> None:
1586
+ """CLI entry point for structured OpenAI API calls."""
1587
+ try:
1588
+ args = Namespace(**kwargs)
1589
+
1590
+ # Validate required arguments first
1591
+ if not args.task and not args.task_file:
1592
+ raise click.UsageError(
1593
+ "Must specify either --task or --task-file"
1594
+ )
1595
+ if not args.schema_file:
1596
+ raise click.UsageError("Missing option '--schema-file'")
1597
+ if args.task and args.task_file:
1598
+ raise click.UsageError(
1599
+ "Cannot specify both --task and --task-file"
1600
+ )
1601
+ if args.system_prompt and args.system_prompt_file:
1602
+ raise click.UsageError(
1603
+ "Cannot specify both --system-prompt and --system-prompt-file"
1604
+ )
1605
+
1606
+ # Run the async function synchronously
1607
+ exit_code = asyncio.run(run_cli_async(args))
1608
+
1609
+ if exit_code != ExitCode.SUCCESS:
1610
+ error_msg = f"Command failed with exit code {exit_code}"
1611
+ if hasattr(ExitCode, exit_code.name):
1612
+ error_msg = f"{error_msg} ({exit_code.name})"
1613
+ raise CLIError(error_msg, context={"exit_code": exit_code})
1614
+
1615
+ except click.UsageError:
1616
+ # Let Click handle usage errors directly
1617
+ raise
1618
+ except InvalidJSONError:
1619
+ # Let InvalidJSONError propagate directly
1620
+ raise
1621
+ except CLIError:
1622
+ # Let our custom errors propagate with their context
1623
+ raise
1624
+ except Exception as e:
1625
+ # Convert other exceptions to CLIError
1626
+ logger.exception("Unexpected error")
1627
+ raise CLIError(str(e), context={"error_type": type(e).__name__})
1628
+
1629
+ return cli
1719
1630
 
1720
1631
 
1721
1632
  def main() -> None:
1722
- """CLI entry point that handles all errors."""
1723
- try:
1724
- logger.debug("[main] Starting main execution")
1725
- exit_code = asyncio.run(_main())
1726
- sys.exit(exit_code.value)
1727
- except KeyboardInterrupt:
1728
- logger.error("Operation cancelled by user")
1729
- sys.exit(ExitCode.INTERRUPTED.value)
1730
- except PathSecurityError as e:
1731
- # Only log security errors if they haven't been logged already
1732
- logger.debug(
1733
- "[main] Caught PathSecurityError: %s (logged=%s)",
1734
- str(e),
1735
- getattr(e, "has_been_logged", False),
1736
- )
1737
- if not getattr(e, "has_been_logged", False):
1738
- logger.error(str(e))
1739
- sys.exit(ExitCode.SECURITY_ERROR.value)
1740
- except ValueError as e:
1741
- # Get the original cause of the error
1742
- cause = e.__cause__ or e.__context__
1743
- if isinstance(cause, PathSecurityError):
1744
- logger.debug(
1745
- "[main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
1746
- str(cause),
1747
- getattr(cause, "has_been_logged", False),
1748
- )
1749
- # Only log security errors if they haven't been logged already
1750
- if not getattr(cause, "has_been_logged", False):
1751
- logger.error(str(cause))
1752
- sys.exit(ExitCode.SECURITY_ERROR.value)
1753
- else:
1754
- logger.debug("[main] Caught ValueError: %s", str(e))
1755
- logger.error(f"Invalid input: {e}")
1756
- sys.exit(ExitCode.DATA_ERROR.value)
1757
- except Exception as e:
1758
- # Check if this is a wrapped security error
1759
- if isinstance(e.__cause__, PathSecurityError):
1760
- logger.debug(
1761
- "[main] Caught wrapped PathSecurityError in Exception: %s (logged=%s)",
1762
- str(e.__cause__),
1763
- getattr(e.__cause__, "has_been_logged", False),
1764
- )
1765
- # Only log security errors if they haven't been logged already
1766
- if not getattr(e.__cause__, "has_been_logged", False):
1767
- logger.error(str(e.__cause__))
1768
- sys.exit(ExitCode.SECURITY_ERROR.value)
1769
- logger.debug("[main] Caught unexpected error: %s", str(e))
1770
- logger.error(f"Unexpected error: {e}")
1771
- sys.exit(ExitCode.INTERNAL_ERROR.value)
1633
+ """Main entry point for the CLI."""
1634
+ cli = create_cli()
1635
+ cli(standalone_mode=False)
1772
1636
 
1773
1637
 
1774
1638
  # Export public API
1775
1639
  __all__ = [
1776
1640
  "ExitCode",
1777
1641
  "estimate_tokens_for_chat",
1778
- "get_context_window_limit",
1779
- "get_default_token_limit",
1780
1642
  "parse_json_var",
1781
1643
  "create_dynamic_model",
1782
1644
  "validate_path_mapping",
1783
- "create_argument_parser",
1645
+ "create_cli",
1784
1646
  "main",
1785
1647
  ]
1786
1648