ostruct-cli 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ostruct/cli/cli.py CHANGED
@@ -1,18 +1,12 @@
1
1
  """Command-line interface for making structured OpenAI API calls."""
2
2
 
3
- import argparse
4
3
  import asyncio
5
4
  import json
6
5
  import logging
7
6
  import os
8
7
  import sys
8
+ from dataclasses import dataclass
9
9
  from enum import Enum, IntEnum
10
-
11
- if sys.version_info >= (3, 11):
12
- from enum import StrEnum
13
-
14
- from datetime import date, datetime, time
15
- from pathlib import Path
16
10
  from typing import (
17
11
  Any,
18
12
  Dict,
@@ -29,6 +23,13 @@ from typing import (
29
23
  overload,
30
24
  )
31
25
 
26
+ if sys.version_info >= (3, 11):
27
+ from enum import StrEnum
28
+
29
+ from datetime import date, datetime, time
30
+ from pathlib import Path
31
+
32
+ import click
32
33
  import jinja2
33
34
  import tiktoken
34
35
  import yaml
@@ -71,8 +72,11 @@ from pydantic.functional_validators import BeforeValidator
71
72
  from pydantic.types import constr
72
73
  from typing_extensions import TypeAlias
73
74
 
74
- from .. import __version__
75
+ from ostruct.cli.click_options import create_click_command
76
+
77
+ from .. import __version__ # noqa: F401 - Used in package metadata
75
78
  from .errors import (
79
+ CLIError,
76
80
  DirectoryNotFoundError,
77
81
  FieldDefinitionError,
78
82
  FileNotFoundError,
@@ -89,7 +93,6 @@ from .errors import (
89
93
  )
90
94
  from .file_utils import FileInfoList, TemplateValue, collect_files
91
95
  from .path_utils import validate_path_mapping
92
- from .progress import ProgressContext
93
96
  from .security import SecurityManager
94
97
  from .template_env import create_jinja_env
95
98
  from .template_utils import SystemPromptError, render_template
@@ -97,6 +100,45 @@ from .template_utils import SystemPromptError, render_template
97
100
  # Constants
98
101
  DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
99
102
 
103
+
104
+ @dataclass
105
+ class Namespace:
106
+ """Compatibility class to mimic argparse.Namespace for existing code."""
107
+
108
+ task: Optional[str]
109
+ task_file: Optional[str]
110
+ file: List[str]
111
+ files: List[str]
112
+ dir: List[str]
113
+ allowed_dir: List[str]
114
+ base_dir: str
115
+ allowed_dir_file: Optional[str]
116
+ dir_recursive: bool
117
+ dir_ext: Optional[str]
118
+ var: List[str]
119
+ json_var: List[str]
120
+ system_prompt: Optional[str]
121
+ system_prompt_file: Optional[str]
122
+ ignore_task_sysprompt: bool
123
+ schema_file: str
124
+ model: str
125
+ temperature: float
126
+ max_tokens: Optional[int]
127
+ top_p: float
128
+ frequency_penalty: float
129
+ presence_penalty: float
130
+ timeout: float
131
+ output_file: Optional[str]
132
+ dry_run: bool
133
+ no_progress: bool
134
+ api_key: Optional[str]
135
+ verbose: bool
136
+ debug_openai_stream: bool
137
+ show_model_schema: bool
138
+ debug_validation: bool
139
+ progress_level: str = "basic" # Default to 'basic' if not specified
140
+
141
+
100
142
  # Set up logging
101
143
  logger = logging.getLogger(__name__)
102
144
 
@@ -360,25 +402,43 @@ V = TypeVar("V")
360
402
 
361
403
 
362
404
  def estimate_tokens_for_chat(
363
- messages: List[Dict[str, str]], model: str
405
+ messages: List[Dict[str, str]],
406
+ model: str,
407
+ encoder: Any = None,
364
408
  ) -> int:
365
- """Estimate the number of tokens in a chat completion."""
366
- try:
367
- encoding = tiktoken.encoding_for_model(model)
368
- except KeyError:
369
- # Fall back to cl100k_base for unknown models
370
- encoding = tiktoken.get_encoding("cl100k_base")
371
-
372
- num_tokens = 0
373
- for message in messages:
374
- # Add message overhead
375
- num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
376
- for key, value in message.items():
377
- num_tokens += len(encoding.encode(str(value)))
378
- if key == "name": # if there's a name, the role is omitted
379
- num_tokens += -1 # role is always required and always 1 token
380
- num_tokens += 2 # every reply is primed with <im_start>assistant
381
- return num_tokens
409
+ """Estimate the number of tokens in a chat completion.
410
+
411
+ Args:
412
+ messages: List of chat messages
413
+ model: Model name
414
+ encoder: Optional tiktoken encoder for testing. If provided, only uses encoder.encode() results.
415
+ """
416
+ if encoder is None:
417
+ try:
418
+ # Try to get the encoding for the specific model
419
+ encoder = tiktoken.get_encoding("o200k_base")
420
+ except KeyError:
421
+ # Fall back to cl100k_base for unknown models
422
+ encoder = tiktoken.get_encoding("cl100k_base")
423
+
424
+ # Use standard token counting logic for real tiktoken encoders
425
+ num_tokens = 0
426
+ for message in messages:
427
+ # Add message overhead
428
+ num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
429
+ for key, value in message.items():
430
+ num_tokens += len(encoder.encode(str(value)))
431
+ if key == "name": # if there's a name, the role is omitted
432
+ num_tokens -= 1 # role is omitted
433
+ num_tokens += 2 # every reply is primed with <im_start>assistant
434
+ return num_tokens
435
+ else:
436
+ # For mock encoders in tests, just return the length of encoded content
437
+ num_tokens = 0
438
+ for message in messages:
439
+ for value in message.values():
440
+ num_tokens += len(encoder.encode(str(value)))
441
+ return num_tokens
382
442
 
383
443
 
384
444
  def get_default_token_limit(model: str) -> int:
@@ -388,15 +448,17 @@ def get_default_token_limit(model: str) -> int:
388
448
  need to be updated if OpenAI changes the models' capabilities.
389
449
 
390
450
  Args:
391
- model: The model name (e.g., 'gpt-4o', 'gpt-4o-mini', 'o1')
451
+ model: The model name (e.g., 'gpt-4o', 'o1-mini', 'o3-mini')
392
452
 
393
453
  Returns:
394
454
  The default token limit for the model
395
455
  """
396
- if "o1" in model:
397
- return 100_000 # o1 supports up to 100K output tokens
456
+ if "o1-" in model:
457
+ return 100_000 # o1-mini supports up to 100K output tokens
398
458
  elif "gpt-4o" in model:
399
- return 16_384 # gpt-4o and gpt-4o-mini support up to 16K output tokens
459
+ return 16_384 # gpt-4o supports up to 16K output tokens
460
+ elif "o3-" in model:
461
+ return 16_384 # o3-mini supports up to 16K output tokens
400
462
  else:
401
463
  return 4_096 # default fallback
402
464
 
@@ -408,15 +470,15 @@ def get_context_window_limit(model: str) -> int:
408
470
  need to be updated if OpenAI changes the models' capabilities.
409
471
 
410
472
  Args:
411
- model: The model name (e.g., 'gpt-4o', 'gpt-4o-mini', 'o1')
473
+ model: The model name (e.g., 'gpt-4o', 'o1-mini', 'o3-mini')
412
474
 
413
475
  Returns:
414
476
  The context window limit for the model
415
477
  """
416
- if "o1" in model:
417
- return 200_000 # o1 supports 200K total context window
418
- elif "gpt-4o" in model:
419
- return 128_000 # gpt-4o and gpt-4o-mini support 128K context window
478
+ if "o1-" in model:
479
+ return 200_000 # o1-mini supports 200K total context window
480
+ elif "gpt-4o" in model or "o3-" in model:
481
+ return 128_000 # gpt-4o and o3-mini support 128K context window
420
482
  else:
421
483
  return 8_192 # default fallback
422
484
 
@@ -460,6 +522,7 @@ def validate_token_limits(
460
522
  def process_system_prompt(
461
523
  task_template: str,
462
524
  system_prompt: Optional[str],
525
+ system_prompt_file: Optional[str],
463
526
  template_context: Dict[str, Any],
464
527
  env: jinja2.Environment,
465
528
  ignore_task_sysprompt: bool = False,
@@ -468,7 +531,8 @@ def process_system_prompt(
468
531
 
469
532
  Args:
470
533
  task_template: The task template string
471
- system_prompt: Optional system prompt string or file path (with @ prefix)
534
+ system_prompt: Optional system prompt string
535
+ system_prompt_file: Optional path to system prompt file
472
536
  template_context: Template context for rendering
473
537
  env: Jinja2 environment
474
538
  ignore_task_sysprompt: Whether to ignore system prompt in task template
@@ -484,18 +548,24 @@ def process_system_prompt(
484
548
  # Default system prompt
485
549
  default_prompt = "You are a helpful assistant."
486
550
 
551
+ # Check for conflicting arguments
552
+ if system_prompt is not None and system_prompt_file is not None:
553
+ raise SystemPromptError(
554
+ "Cannot specify both --system-prompt and --system-prompt-file"
555
+ )
556
+
487
557
  # Try to get system prompt from CLI argument first
488
- if system_prompt:
489
- if system_prompt.startswith("@"):
490
- # Load from file
491
- path = system_prompt[1:]
492
- try:
493
- name, path = validate_path_mapping(f"system_prompt={path}")
494
- with open(path, "r", encoding="utf-8") as f:
495
- system_prompt = f.read().strip()
496
- except (FileNotFoundError, PathSecurityError) as e:
497
- raise SystemPromptError(f"Invalid system prompt file: {e}")
558
+ if system_prompt_file is not None:
559
+ try:
560
+ name, path = validate_path_mapping(
561
+ f"system_prompt={system_prompt_file}"
562
+ )
563
+ with open(path, "r", encoding="utf-8") as f:
564
+ system_prompt = f.read().strip()
565
+ except (FileNotFoundError, PathSecurityError) as e:
566
+ raise SystemPromptError(f"Invalid system prompt file: {e}")
498
567
 
568
+ if system_prompt is not None:
499
569
  # Render system prompt with template context
500
570
  try:
501
571
  template = env.from_string(system_prompt)
@@ -618,30 +688,45 @@ def _validate_path_mapping_internal(
618
688
  ValueError: If the format is invalid (missing "=").
619
689
  OSError: If there is an underlying OS error (permissions, etc.).
620
690
  """
691
+ logger = logging.getLogger(__name__)
692
+ logger.debug("Starting path validation for mapping: %r", mapping)
693
+ logger.debug("Parameters - is_dir: %r, base_dir: %r", is_dir, base_dir)
694
+
621
695
  try:
622
696
  if not mapping or "=" not in mapping:
697
+ logger.debug("Invalid mapping format: %r", mapping)
623
698
  raise ValueError(
624
699
  "Invalid path mapping format. Expected format: name=path"
625
700
  )
626
701
 
627
702
  name, path = mapping.split("=", 1)
703
+ logger.debug("Split mapping - name: %r, path: %r", name, path)
704
+
628
705
  if not name:
706
+ logger.debug("Empty name in mapping")
629
707
  raise VariableNameError(
630
708
  f"Empty name in {'directory' if is_dir else 'file'} mapping"
631
709
  )
632
710
 
633
711
  if not path:
712
+ logger.debug("Empty path in mapping")
634
713
  raise VariableValueError("Path cannot be empty")
635
714
 
636
715
  # Convert to Path object and resolve against base_dir if provided
716
+ logger.debug("Creating Path object for: %r", path)
637
717
  path_obj = Path(path)
638
718
  if base_dir:
719
+ logger.debug("Resolving against base_dir: %r", base_dir)
639
720
  path_obj = Path(base_dir) / path_obj
721
+ logger.debug("Path object created: %r", path_obj)
640
722
 
641
723
  # Resolve the path to catch directory traversal attempts
642
724
  try:
725
+ logger.debug("Attempting to resolve path: %r", path_obj)
643
726
  resolved_path = path_obj.resolve()
727
+ logger.debug("Resolved path: %r", resolved_path)
644
728
  except OSError as e:
729
+ logger.error("Failed to resolve path: %s", e)
645
730
  raise OSError(f"Failed to resolve path: {e}")
646
731
 
647
732
  # Check for directory traversal
@@ -709,34 +794,45 @@ def _validate_path_mapping_internal(
709
794
  raise
710
795
 
711
796
 
712
- def validate_task_template(task: str) -> str:
797
+ def validate_task_template(
798
+ task: Optional[str], task_file: Optional[str]
799
+ ) -> str:
713
800
  """Validate and load a task template.
714
801
 
715
802
  Args:
716
- task: The task template string or path to task template file (with @ prefix)
803
+ task: The task template string
804
+ task_file: Path to task template file
717
805
 
718
806
  Returns:
719
807
  The task template string
720
808
 
721
809
  Raises:
722
- TaskTemplateVariableError: If the template file cannot be read or is invalid
810
+ TaskTemplateVariableError: If neither task nor task_file is provided, or if both are provided
723
811
  TaskTemplateSyntaxError: If the template has invalid syntax
724
812
  FileNotFoundError: If the template file does not exist
725
813
  PathSecurityError: If the template file path violates security constraints
726
814
  """
727
- template_content = task
815
+ if task is not None and task_file is not None:
816
+ raise TaskTemplateVariableError(
817
+ "Cannot specify both --task and --task-file"
818
+ )
819
+
820
+ if task is None and task_file is None:
821
+ raise TaskTemplateVariableError(
822
+ "Must specify either --task or --task-file"
823
+ )
728
824
 
729
- # Check if task is a file path
730
- if task.startswith("@"):
731
- path = task[1:]
825
+ template_content: str
826
+ if task_file is not None:
732
827
  try:
733
- name, path = validate_path_mapping(f"task={path}")
828
+ name, path = validate_path_mapping(f"task={task_file}")
734
829
  with open(path, "r", encoding="utf-8") as f:
735
830
  template_content = f.read()
736
831
  except (FileNotFoundError, PathSecurityError) as e:
737
- raise TaskTemplateVariableError(f"Invalid task template file: {e}")
832
+ raise TaskTemplateVariableError(str(e))
833
+ else:
834
+ template_content = task # type: ignore # We know task is str here due to the checks above
738
835
 
739
- # Validate template syntax
740
836
  try:
741
837
  env = jinja2.Environment(undefined=jinja2.StrictUndefined)
742
838
  env.parse(template_content)
@@ -813,7 +909,7 @@ def validate_schema_file(
813
909
 
814
910
 
815
911
  def collect_template_files(
816
- args: argparse.Namespace,
912
+ args: Namespace,
817
913
  security_manager: SecurityManager,
818
914
  ) -> Dict[str, TemplateValue]:
819
915
  """Collect files from command line arguments.
@@ -846,14 +942,17 @@ def collect_template_files(
846
942
  # Wrap file-related errors
847
943
  raise ValueError(f"File access error: {e}")
848
944
  except Exception as e:
945
+ # Don't wrap InvalidJSONError
946
+ if isinstance(e, InvalidJSONError):
947
+ raise
849
948
  # Check if this is a wrapped security error
850
949
  if isinstance(e.__cause__, PathSecurityError):
851
950
  raise e.__cause__
852
- # Wrap unexpected errors
951
+ # Wrap other errors
853
952
  raise ValueError(f"Error collecting files: {e}")
854
953
 
855
954
 
856
- def collect_simple_variables(args: argparse.Namespace) -> Dict[str, str]:
955
+ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
857
956
  """Collect simple string variables from --var arguments.
858
957
 
859
958
  Args:
@@ -886,7 +985,7 @@ def collect_simple_variables(args: argparse.Namespace) -> Dict[str, str]:
886
985
  return variables
887
986
 
888
987
 
889
- def collect_json_variables(args: argparse.Namespace) -> Dict[str, Any]:
988
+ def collect_json_variables(args: Namespace) -> Dict[str, Any]:
890
989
  """Collect JSON variables from --json-var arguments.
891
990
 
892
991
  Args:
@@ -916,7 +1015,7 @@ def collect_json_variables(args: argparse.Namespace) -> Dict[str, Any]:
916
1015
  all_names.add(name)
917
1016
  except json.JSONDecodeError as e:
918
1017
  raise InvalidJSONError(
919
- f"Invalid JSON value for {name}: {str(e)}"
1018
+ f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
920
1019
  )
921
1020
  except ValueError:
922
1021
  raise VariableNameError(
@@ -972,7 +1071,7 @@ def create_template_context(
972
1071
 
973
1072
 
974
1073
  def create_template_context_from_args(
975
- args: argparse.Namespace,
1074
+ args: "Namespace",
976
1075
  security_manager: SecurityManager,
977
1076
  ) -> Dict[str, Any]:
978
1077
  """Create template context from command line arguments.
@@ -1024,7 +1123,7 @@ def create_template_context_from_args(
1024
1123
  json_value = json.loads(value)
1025
1124
  except json.JSONDecodeError as e:
1026
1125
  raise InvalidJSONError(
1027
- f"Invalid JSON value for {name} ({value!r}): {str(e)}"
1126
+ f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {value}"
1028
1127
  )
1029
1128
  if name in json_variables:
1030
1129
  raise VariableNameError(
@@ -1060,44 +1159,103 @@ def create_template_context_from_args(
1060
1159
  # Wrap file-related errors
1061
1160
  raise ValueError(f"File access error: {e}")
1062
1161
  except Exception as e:
1162
+ # Don't wrap InvalidJSONError
1163
+ if isinstance(e, InvalidJSONError):
1164
+ raise
1063
1165
  # Check if this is a wrapped security error
1064
1166
  if isinstance(e.__cause__, PathSecurityError):
1065
1167
  raise e.__cause__
1066
- # Wrap unexpected errors
1168
+ # Wrap other errors
1067
1169
  raise ValueError(f"Error collecting files: {e}")
1068
1170
 
1069
1171
 
1070
1172
  def validate_security_manager(
1071
1173
  base_dir: Optional[str] = None,
1072
1174
  allowed_dirs: Optional[List[str]] = None,
1073
- allowed_dirs_file: Optional[str] = None,
1175
+ allowed_dir_file: Optional[str] = None,
1074
1176
  ) -> SecurityManager:
1075
- """Create and validate a security manager.
1177
+ """Validate and create security manager.
1076
1178
 
1077
1179
  Args:
1078
- base_dir: Optional base directory to resolve paths against
1079
- allowed_dirs: Optional list of allowed directory paths
1080
- allowed_dirs_file: Optional path to file containing allowed directories
1180
+ base_dir: Base directory for file access. Defaults to current working directory.
1181
+ allowed_dirs: Optional list of additional allowed directories
1182
+ allowed_dir_file: Optional file containing allowed directories
1081
1183
 
1082
1184
  Returns:
1083
1185
  Configured SecurityManager instance
1084
1186
 
1085
1187
  Raises:
1086
- FileNotFoundError: If allowed_dirs_file does not exist
1087
- PathSecurityError: If any paths are outside base directory
1188
+ PathSecurityError: If any paths violate security constraints
1189
+ DirectoryNotFoundError: If any directories do not exist
1088
1190
  """
1089
- # Convert base_dir to string if it's a Path
1090
- base_dir_str = str(base_dir) if base_dir else None
1091
- security_manager = SecurityManager(base_dir_str)
1191
+ # Use current working directory if base_dir is None
1192
+ if base_dir is None:
1193
+ base_dir = os.getcwd()
1092
1194
 
1093
- if allowed_dirs_file:
1094
- security_manager.add_allowed_dirs_from_file(str(allowed_dirs_file))
1195
+ # Default to empty list if allowed_dirs is None
1196
+ if allowed_dirs is None:
1197
+ allowed_dirs = []
1095
1198
 
1096
- if allowed_dirs:
1097
- for allowed_dir in allowed_dirs:
1098
- security_manager.add_allowed_dir(str(allowed_dir))
1199
+ # Add base directory if it exists
1200
+ try:
1201
+ base_dir_path = Path(base_dir).resolve()
1202
+ if not base_dir_path.exists():
1203
+ raise DirectoryNotFoundError(
1204
+ f"Base directory not found: {base_dir}"
1205
+ )
1206
+ if not base_dir_path.is_dir():
1207
+ raise DirectoryNotFoundError(
1208
+ f"Base directory is not a directory: {base_dir}"
1209
+ )
1210
+ all_allowed_dirs = [str(base_dir_path)]
1211
+ except OSError as e:
1212
+ raise DirectoryNotFoundError(f"Invalid base directory: {e}")
1099
1213
 
1100
- return security_manager
1214
+ # Add explicitly allowed directories
1215
+ for dir_path in allowed_dirs:
1216
+ try:
1217
+ resolved_path = Path(dir_path).resolve()
1218
+ if not resolved_path.exists():
1219
+ raise DirectoryNotFoundError(
1220
+ f"Directory not found: {dir_path}"
1221
+ )
1222
+ if not resolved_path.is_dir():
1223
+ raise DirectoryNotFoundError(
1224
+ f"Path is not a directory: {dir_path}"
1225
+ )
1226
+ all_allowed_dirs.append(str(resolved_path))
1227
+ except OSError as e:
1228
+ raise DirectoryNotFoundError(f"Invalid directory path: {e}")
1229
+
1230
+ # Add directories from file if specified
1231
+ if allowed_dir_file:
1232
+ try:
1233
+ with open(allowed_dir_file, "r", encoding="utf-8") as f:
1234
+ for line in f:
1235
+ line = line.strip()
1236
+ if line and not line.startswith("#"):
1237
+ try:
1238
+ resolved_path = Path(line).resolve()
1239
+ if not resolved_path.exists():
1240
+ raise DirectoryNotFoundError(
1241
+ f"Directory not found: {line}"
1242
+ )
1243
+ if not resolved_path.is_dir():
1244
+ raise DirectoryNotFoundError(
1245
+ f"Path is not a directory: {line}"
1246
+ )
1247
+ all_allowed_dirs.append(str(resolved_path))
1248
+ except OSError as e:
1249
+ raise DirectoryNotFoundError(
1250
+ f"Invalid directory path in {allowed_dir_file}: {e}"
1251
+ )
1252
+ except OSError as e:
1253
+ raise DirectoryNotFoundError(
1254
+ f"Failed to read allowed directories file: {e}"
1255
+ )
1256
+
1257
+ # Create security manager with all allowed directories
1258
+ return SecurityManager(base_dir=base_dir, allowed_dirs=all_allowed_dirs)
1101
1259
 
1102
1260
 
1103
1261
  def parse_var(var_str: str) -> Tuple[str, str]:
@@ -1157,8 +1315,8 @@ def parse_json_var(var_str: str) -> Tuple[str, Any]:
1157
1315
  value = json.loads(json_str)
1158
1316
  except json.JSONDecodeError as e:
1159
1317
  raise InvalidJSONError(
1160
- f"Invalid JSON value for variable {name!r}: {json_str!r}"
1161
- ) from e
1318
+ f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
1319
+ )
1162
1320
 
1163
1321
  return name, value
1164
1322
 
@@ -1205,570 +1363,308 @@ def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
1205
1363
  return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
1206
1364
 
1207
1365
 
1208
- def create_argument_parser() -> argparse.ArgumentParser:
1209
- """Create argument parser for CLI."""
1210
- parser = argparse.ArgumentParser(
1211
- description="Make structured OpenAI API calls.",
1212
- formatter_class=argparse.RawDescriptionHelpFormatter,
1213
- )
1214
-
1215
- # Debug output options
1216
- debug_group = parser.add_argument_group("Debug Output Options")
1217
- debug_group.add_argument(
1218
- "--show-model-schema",
1219
- action="store_true",
1220
- help="Display the generated Pydantic model schema",
1221
- )
1222
- debug_group.add_argument(
1223
- "--debug-validation",
1224
- action="store_true",
1225
- help="Show detailed schema validation debugging information",
1226
- )
1227
- debug_group.add_argument(
1228
- "--verbose-schema",
1229
- action="store_true",
1230
- help="Enable verbose schema debugging output",
1231
- )
1232
- debug_group.add_argument(
1233
- "--progress-level",
1234
- choices=["none", "basic", "detailed"],
1235
- default="basic",
1236
- help="Set the level of progress reporting (default: basic)",
1237
- )
1238
-
1239
- # Required arguments
1240
- parser.add_argument(
1241
- "--task",
1242
- required=True,
1243
- help="Task template string or @file",
1244
- )
1245
-
1246
- # File access arguments
1247
- parser.add_argument(
1248
- "--file",
1249
- action="append",
1250
- default=[],
1251
- help="Map file to variable (name=path)",
1252
- metavar="NAME=PATH",
1253
- )
1254
- parser.add_argument(
1255
- "--files",
1256
- action="append",
1257
- default=[],
1258
- help="Map file pattern to variable (name=pattern)",
1259
- metavar="NAME=PATTERN",
1260
- )
1261
- parser.add_argument(
1262
- "--dir",
1263
- action="append",
1264
- default=[],
1265
- help="Map directory to variable (name=path)",
1266
- metavar="NAME=PATH",
1267
- )
1268
- parser.add_argument(
1269
- "--allowed-dir",
1270
- action="append",
1271
- default=[],
1272
- help="Additional allowed directory or @file",
1273
- metavar="PATH",
1274
- )
1275
- parser.add_argument(
1276
- "--base-dir",
1277
- help="Base directory for file access (defaults to current directory)",
1278
- default=os.getcwd(),
1279
- )
1280
- parser.add_argument(
1281
- "--allowed-dirs-file",
1282
- help="File containing list of allowed directories",
1283
- )
1284
- parser.add_argument(
1285
- "--dir-recursive",
1286
- action="store_true",
1287
- help="Process directories recursively",
1288
- )
1289
- parser.add_argument(
1290
- "--dir-ext",
1291
- help="Comma-separated list of file extensions to include in directory processing",
1292
- )
1293
-
1294
- # Variable arguments
1295
- parser.add_argument(
1296
- "--var",
1297
- action="append",
1298
- default=[],
1299
- help="Pass simple variables (name=value)",
1300
- metavar="NAME=VALUE",
1301
- )
1302
- parser.add_argument(
1303
- "--json-var",
1304
- action="append",
1305
- default=[],
1306
- help="Pass JSON variables (name=json)",
1307
- metavar="NAME=JSON",
1308
- )
1309
-
1310
- # System prompt options
1311
- parser.add_argument(
1312
- "--system-prompt",
1313
- help=(
1314
- "System prompt for the model (use @file to load from file, "
1315
- "can also be specified in task template YAML frontmatter)"
1316
- ),
1317
- default=DEFAULT_SYSTEM_PROMPT,
1318
- )
1319
- parser.add_argument(
1320
- "--ignore-task-sysprompt",
1321
- action="store_true",
1322
- help="Ignore system prompt from task template YAML frontmatter",
1323
- )
1324
-
1325
- # Schema validation
1326
- parser.add_argument(
1327
- "--schema",
1328
- dest="schema_file",
1329
- required=True,
1330
- help="JSON schema file for response validation",
1331
- )
1332
- parser.add_argument(
1333
- "--validate-schema",
1334
- action="store_true",
1335
- help="Validate schema and response",
1336
- )
1337
-
1338
- # Model configuration
1339
- parser.add_argument(
1340
- "--model",
1341
- default="gpt-4o-2024-08-06",
1342
- help="Model to use",
1343
- )
1344
- parser.add_argument(
1345
- "--temperature",
1346
- type=float,
1347
- default=0.0,
1348
- help="Temperature (0.0-2.0)",
1349
- )
1350
- parser.add_argument(
1351
- "--max-tokens",
1352
- type=int,
1353
- help="Maximum tokens to generate",
1354
- )
1355
- parser.add_argument(
1356
- "--top-p",
1357
- type=float,
1358
- default=1.0,
1359
- help="Top-p sampling (0.0-1.0)",
1360
- )
1361
- parser.add_argument(
1362
- "--frequency-penalty",
1363
- type=float,
1364
- default=0.0,
1365
- help="Frequency penalty (-2.0-2.0)",
1366
- )
1367
- parser.add_argument(
1368
- "--presence-penalty",
1369
- type=float,
1370
- default=0.0,
1371
- help="Presence penalty (-2.0-2.0)",
1372
- )
1373
- parser.add_argument(
1374
- "--timeout",
1375
- type=float,
1376
- default=60.0,
1377
- help="API timeout in seconds",
1378
- )
1379
-
1380
- # Output options
1381
- parser.add_argument(
1382
- "--output-file",
1383
- help="Write JSON output to file",
1384
- )
1385
- parser.add_argument(
1386
- "--dry-run",
1387
- action="store_true",
1388
- help="Simulate API call without making request",
1389
- )
1390
- parser.add_argument(
1391
- "--no-progress",
1392
- action="store_true",
1393
- help="Disable progress indicators",
1394
- )
1395
-
1396
- # Other options
1397
- parser.add_argument(
1398
- "--api-key",
1399
- help="OpenAI API key (overrides env var)",
1400
- )
1401
- parser.add_argument(
1402
- "--verbose",
1403
- action="store_true",
1404
- help="Enable verbose output",
1405
- )
1406
- parser.add_argument(
1407
- "--debug-openai-stream",
1408
- action="store_true",
1409
- help="Enable low-level debug output for OpenAI streaming (very verbose)",
1410
- )
1411
- parser.add_argument(
1412
- "--version",
1413
- action="version",
1414
- version=f"%(prog)s {__version__}",
1415
- )
1366
+ def handle_error(e: Exception) -> None:
1367
+ """Handle errors by printing appropriate message and exiting with status code."""
1368
+ if isinstance(e, click.UsageError):
1369
+ # For UsageError, preserve the original message format
1370
+ if hasattr(e, "param") and e.param:
1371
+ # Missing parameter error
1372
+ msg = f"Missing option '--{e.param.name}'"
1373
+ click.echo(msg, err=True)
1374
+ else:
1375
+ # Other usage errors (like conflicting options)
1376
+ click.echo(str(e), err=True)
1377
+ sys.exit(ExitCode.USAGE_ERROR)
1378
+ elif isinstance(e, InvalidJSONError):
1379
+ # Use the original error message if available
1380
+ msg = str(e) if str(e) != "None" else "Invalid JSON"
1381
+ click.secho(msg, fg="red", err=True)
1382
+ sys.exit(ExitCode.DATA_ERROR)
1383
+ elif isinstance(e, FileNotFoundError):
1384
+ # Use the original error message if available
1385
+ msg = str(e) if str(e) != "None" else "File not found"
1386
+ click.secho(msg, fg="red", err=True)
1387
+ sys.exit(ExitCode.SCHEMA_ERROR)
1388
+ elif isinstance(e, TaskTemplateSyntaxError):
1389
+ # Use the original error message if available
1390
+ msg = str(e) if str(e) != "None" else "Template syntax error"
1391
+ click.secho(msg, fg="red", err=True)
1392
+ sys.exit(ExitCode.INTERNAL_ERROR)
1393
+ elif isinstance(e, CLIError):
1394
+ # Use the show method for CLIError and its subclasses
1395
+ e.show()
1396
+ sys.exit(
1397
+ e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
1398
+ )
1399
+ else:
1400
+ click.secho(f"Unexpected error: {str(e)}", fg="red", err=True)
1401
+ sys.exit(ExitCode.INTERNAL_ERROR)
1402
+
1403
+
1404
+ async def stream_structured_output(
1405
+ client: AsyncOpenAI,
1406
+ model: str,
1407
+ system_prompt: str,
1408
+ user_prompt: str,
1409
+ output_schema: Type[BaseModel],
1410
+ output_file: Optional[str] = None,
1411
+ **kwargs: Any,
1412
+ ) -> None:
1413
+ """Stream structured output from OpenAI API.
1416
1414
 
1417
- return parser
1415
+ This function follows the guide's recommendation for a focused async streaming function.
1416
+ It handles the core streaming logic and resource cleanup.
1417
+ """
1418
+ try:
1419
+ async for chunk in async_openai_structured_stream(
1420
+ client=client,
1421
+ model=model,
1422
+ output_schema=output_schema,
1423
+ system_prompt=system_prompt,
1424
+ user_prompt=user_prompt,
1425
+ **kwargs,
1426
+ ):
1427
+ if not chunk:
1428
+ continue
1429
+
1430
+ # Process and output the chunk
1431
+ dumped = chunk.model_dump(mode="json")
1432
+ json_str = json.dumps(dumped, indent=2)
1433
+
1434
+ if output_file:
1435
+ with open(output_file, "a", encoding="utf-8") as f:
1436
+ f.write(json_str)
1437
+ f.write("\n")
1438
+ f.flush() # Ensure immediate flush to file
1439
+ else:
1440
+ # Print directly to stdout with immediate flush
1441
+ print(json_str, flush=True)
1442
+
1443
+ except (
1444
+ StreamInterruptedError,
1445
+ StreamBufferError,
1446
+ StreamParseError,
1447
+ APIResponseError,
1448
+ EmptyResponseError,
1449
+ InvalidResponseFormatError,
1450
+ ) as e:
1451
+ logger.error(f"Stream error: {e}")
1452
+ raise
1453
+ finally:
1454
+ # Always ensure client is properly closed
1455
+ await client.close()
1418
1456
 
1419
1457
 
1420
- async def _main() -> ExitCode:
1421
- """Main CLI function.
1458
+ async def run_cli_async(args: Namespace) -> ExitCode:
1459
+ """Async wrapper for CLI operations.
1422
1460
 
1423
- Returns:
1424
- ExitCode: Exit code indicating success or failure
1461
+ This function prepares everything needed for streaming and then calls
1462
+ the focused streaming function.
1425
1463
  """
1426
1464
  try:
1427
- parser = create_argument_parser()
1428
- args = parser.parse_args()
1429
-
1430
- # Configure logging
1431
- log_level = logging.DEBUG if args.verbose else logging.INFO
1432
- logger.setLevel(log_level)
1433
-
1434
- # Create security manager
1465
+ # Validate and prepare all inputs
1435
1466
  security_manager = validate_security_manager(
1436
1467
  base_dir=args.base_dir,
1437
1468
  allowed_dirs=args.allowed_dir,
1438
- allowed_dirs_file=args.allowed_dirs_file,
1469
+ allowed_dir_file=args.allowed_dir_file,
1439
1470
  )
1440
1471
 
1441
- # Validate task template
1442
- task_template = validate_task_template(args.task)
1443
-
1444
- # Validate schema file
1472
+ task_template = validate_task_template(args.task, args.task_file)
1473
+ logger.debug("Validating schema from %s", args.schema_file)
1445
1474
  schema = validate_schema_file(args.schema_file, args.verbose)
1446
-
1447
- # Create template context
1448
1475
  template_context = create_template_context_from_args(
1449
1476
  args, security_manager
1450
1477
  )
1451
-
1452
- # Create Jinja environment
1453
1478
  env = create_jinja_env()
1454
1479
 
1455
- # Process system prompt
1456
- args.system_prompt = process_system_prompt(
1480
+ # Process system prompt and render task
1481
+ system_prompt = process_system_prompt(
1457
1482
  task_template,
1458
1483
  args.system_prompt,
1484
+ args.system_prompt_file,
1459
1485
  template_context,
1460
1486
  env,
1461
1487
  args.ignore_task_sysprompt,
1462
1488
  )
1463
-
1464
- # Render task template
1465
1489
  rendered_task = render_template(task_template, template_context, env)
1466
- logger.info(rendered_task) # Log the rendered template
1490
+ logger.info("Rendered task template: %s", rendered_task)
1467
1491
 
1468
- # If dry run, exit here
1469
1492
  if args.dry_run:
1470
1493
  logger.info("DRY RUN MODE")
1471
1494
  return ExitCode.SUCCESS
1472
1495
 
1473
- # Load and validate schema
1496
+ # Create output model
1497
+ logger.debug("Creating output model")
1474
1498
  try:
1475
- logger.debug("[_main] Loading schema from %s", args.schema_file)
1476
- schema = validate_schema_file(
1477
- args.schema_file, verbose=args.verbose_schema
1478
- )
1479
- logger.debug("[_main] Creating output model")
1480
1499
  output_model = create_dynamic_model(
1481
1500
  schema,
1482
1501
  base_name="OutputModel",
1483
1502
  show_schema=args.show_model_schema,
1484
1503
  debug_validation=args.debug_validation,
1485
1504
  )
1486
- logger.debug("[_main] Successfully created output model")
1487
- except (SchemaFileError, InvalidJSONError, SchemaValidationError) as e:
1488
- logger.error(str(e))
1489
- return ExitCode.SCHEMA_ERROR
1490
- except ModelCreationError as e:
1491
- logger.error(f"Model creation error: {e}")
1492
- return ExitCode.SCHEMA_ERROR
1493
- except Exception as e:
1494
- logger.error(f"Unexpected error creating model: {e}")
1495
- return ExitCode.SCHEMA_ERROR
1496
-
1497
- # Validate model support
1505
+ logger.debug("Successfully created output model")
1506
+ except (
1507
+ SchemaFileError,
1508
+ InvalidJSONError,
1509
+ SchemaValidationError,
1510
+ ModelCreationError,
1511
+ ) as e:
1512
+ logger.error("Schema error: %s", str(e))
1513
+ raise # Let the error propagate with its context
1514
+
1515
+ # Validate model support and token usage
1498
1516
  try:
1499
1517
  supports_structured_output(args.model)
1500
- except ModelNotSupportedError as e:
1501
- logger.error(str(e))
1502
- return ExitCode.DATA_ERROR
1503
- except ModelVersionError as e:
1504
- logger.error(str(e))
1505
- return ExitCode.DATA_ERROR
1506
-
1507
- # Estimate token usage
1518
+ except (ModelNotSupportedError, ModelVersionError) as e:
1519
+ logger.error("Model validation error: %s", str(e))
1520
+ raise # Let the error propagate with its context
1521
+
1508
1522
  messages = [
1509
- {"role": "system", "content": args.system_prompt},
1523
+ {"role": "system", "content": system_prompt},
1510
1524
  {"role": "user", "content": rendered_task},
1511
1525
  ]
1512
1526
  total_tokens = estimate_tokens_for_chat(messages, args.model)
1513
1527
  context_limit = get_context_window_limit(args.model)
1514
-
1515
1528
  if total_tokens > context_limit:
1516
- logger.error(
1517
- f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
1529
+ msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
1530
+ logger.error(msg)
1531
+ raise CLIError(
1532
+ msg,
1533
+ context={
1534
+ "total_tokens": total_tokens,
1535
+ "context_limit": context_limit,
1536
+ },
1518
1537
  )
1519
- return ExitCode.DATA_ERROR
1520
1538
 
1521
- # Get API key
1539
+ # Get API key and create client
1522
1540
  api_key = args.api_key or os.getenv("OPENAI_API_KEY")
1523
1541
  if not api_key:
1524
- logger.error(
1525
- "No OpenAI API key provided (--api-key or OPENAI_API_KEY env var)"
1526
- )
1527
- return ExitCode.USAGE_ERROR
1542
+ msg = "No OpenAI API key provided (--api-key or OPENAI_API_KEY env var)"
1543
+ logger.error(msg)
1544
+ raise CLIError(msg)
1528
1545
 
1529
- # Create OpenAI client
1530
1546
  client = AsyncOpenAI(api_key=api_key, timeout=args.timeout)
1531
1547
 
1532
- # Create log callback that matches expected signature
1548
+ # Create detailed log callback
1533
1549
  def log_callback(
1534
1550
  level: int, message: str, extra: dict[str, Any]
1535
1551
  ) -> None:
1536
- # Only log if debug_openai_stream is enabled
1537
1552
  if args.debug_openai_stream:
1538
- # Include extra dictionary in the message for both DEBUG and ERROR
1539
- if extra: # Only add if there's actually extra data
1553
+ if extra:
1540
1554
  extra_str = json.dumps(extra, indent=2)
1541
1555
  message = f"{message}\nDetails:\n{extra_str}"
1542
- openai_logger.log(level, message, extra=extra)
1556
+ logger.log(level, message, extra=extra)
1543
1557
 
1544
- # Make API request
1558
+ # Stream the output
1545
1559
  try:
1546
- logger.debug("Creating ProgressContext for API response handling")
1547
- with ProgressContext(
1548
- description="Processing API response",
1549
- level=args.progress_level,
1550
- ) as progress:
1551
- logger.debug("Starting API response stream processing")
1552
- logger.debug("Debug flag status: %s", args.debug_openai_stream)
1553
- logger.debug("OpenAI logger level: %s", openai_logger.level)
1554
- for handler in openai_logger.handlers:
1555
- logger.debug("Handler level: %s", handler.level)
1556
- async for chunk in async_openai_structured_stream(
1557
- client=client,
1558
- model=args.model,
1559
- temperature=args.temperature,
1560
- max_tokens=args.max_tokens,
1561
- top_p=args.top_p,
1562
- frequency_penalty=args.frequency_penalty,
1563
- presence_penalty=args.presence_penalty,
1564
- system_prompt=args.system_prompt,
1565
- user_prompt=rendered_task,
1566
- output_schema=output_model,
1567
- timeout=args.timeout,
1568
- on_log=log_callback,
1569
- ):
1570
- logger.debug("Received API response chunk")
1571
- if not chunk:
1572
- logger.debug("Empty chunk received, skipping")
1573
- continue
1574
-
1575
- # Write output
1576
- try:
1577
- logger.debug("Starting to process output chunk")
1578
- dumped = chunk.model_dump(mode="json")
1579
- logger.debug("Successfully dumped chunk to JSON")
1580
- logger.debug("Dumped chunk: %s", dumped)
1581
- logger.debug(
1582
- "Chunk type: %s, length: %d",
1583
- type(dumped),
1584
- len(json.dumps(dumped)),
1585
- )
1586
-
1587
- if args.output_file:
1588
- logger.debug(
1589
- "Writing to output file: %s", args.output_file
1590
- )
1591
- try:
1592
- with open(
1593
- args.output_file, "a", encoding="utf-8"
1594
- ) as f:
1595
- json_str = json.dumps(dumped, indent=2)
1596
- logger.debug(
1597
- "Writing JSON string of length %d",
1598
- len(json_str),
1599
- )
1600
- f.write(json_str)
1601
- f.write("\n")
1602
- logger.debug("Successfully wrote to file")
1603
- except Exception as e:
1604
- logger.error(
1605
- "Failed to write to output file: %s", e
1606
- )
1607
- else:
1608
- logger.debug(
1609
- "About to call progress.print_output with JSON string"
1610
- )
1611
- json_str = json.dumps(dumped, indent=2)
1612
- logger.debug(
1613
- "JSON string length before print_output: %d",
1614
- len(json_str),
1615
- )
1616
- logger.debug(
1617
- "First 100 chars of JSON string: %s",
1618
- json_str[:100] if json_str else "",
1619
- )
1620
- progress.print_output(json_str)
1621
- logger.debug(
1622
- "Completed print_output call for JSON string"
1623
- )
1624
-
1625
- logger.debug("Starting progress update")
1626
- progress.update()
1627
- logger.debug("Completed progress update")
1628
- except Exception as e:
1629
- logger.error("Failed to process chunk: %s", e)
1630
- logger.error("Chunk: %s", chunk)
1631
- continue
1632
-
1633
- logger.debug("Finished processing API response stream")
1634
-
1635
- except StreamInterruptedError as e:
1636
- logger.error(f"Stream interrupted: {e}")
1637
- return ExitCode.API_ERROR
1638
- except StreamBufferError as e:
1639
- logger.error(f"Stream buffer error: {e}")
1640
- return ExitCode.API_ERROR
1641
- except StreamParseError as e:
1642
- logger.error(f"Stream parse error: {e}")
1643
- return ExitCode.API_ERROR
1644
- except APIResponseError as e:
1645
- logger.error(f"API response error: {e}")
1646
- return ExitCode.API_ERROR
1647
- except EmptyResponseError as e:
1648
- logger.error(f"Empty response error: {e}")
1649
- return ExitCode.API_ERROR
1650
- except InvalidResponseFormatError as e:
1651
- logger.error(f"Invalid response format: {e}")
1652
- return ExitCode.API_ERROR
1560
+ await stream_structured_output(
1561
+ client=client,
1562
+ model=args.model,
1563
+ system_prompt=system_prompt,
1564
+ user_prompt=rendered_task,
1565
+ output_schema=output_model,
1566
+ output_file=args.output_file,
1567
+ temperature=args.temperature,
1568
+ max_tokens=args.max_tokens,
1569
+ top_p=args.top_p,
1570
+ frequency_penalty=args.frequency_penalty,
1571
+ presence_penalty=args.presence_penalty,
1572
+ timeout=args.timeout,
1573
+ on_log=log_callback,
1574
+ )
1575
+ return ExitCode.SUCCESS
1576
+ except (
1577
+ StreamInterruptedError,
1578
+ StreamBufferError,
1579
+ StreamParseError,
1580
+ APIResponseError,
1581
+ EmptyResponseError,
1582
+ InvalidResponseFormatError,
1583
+ ) as e:
1584
+ logger.error("Stream error: %s", str(e))
1585
+ raise # Let stream errors propagate
1653
1586
  except (APIConnectionError, InternalServerError) as e:
1654
- logger.error(f"API connection error: {e}")
1655
- return ExitCode.API_ERROR
1587
+ logger.error("API connection error: %s", str(e))
1588
+ raise APIResponseError(str(e)) # Convert to our error type
1656
1589
  except RateLimitError as e:
1657
- logger.error(f"Rate limit exceeded: {e}")
1658
- return ExitCode.API_ERROR
1659
- except BadRequestError as e:
1660
- logger.error(f"Bad request: {e}")
1661
- return ExitCode.API_ERROR
1662
- except AuthenticationError as e:
1663
- logger.error(f"Authentication failed: {e}")
1664
- return ExitCode.API_ERROR
1665
- except OpenAIClientError as e:
1666
- logger.error(f"OpenAI client error: {e}")
1667
- return ExitCode.API_ERROR
1668
- except Exception as e:
1669
- logger.error(f"Unexpected error: {e}")
1670
- return ExitCode.INTERNAL_ERROR
1671
-
1672
- return ExitCode.SUCCESS
1590
+ logger.error("Rate limit exceeded: %s", str(e))
1591
+ raise APIResponseError(str(e)) # Convert to our error type
1592
+ except (BadRequestError, AuthenticationError, OpenAIClientError) as e:
1593
+ logger.error("API client error: %s", str(e))
1594
+ raise APIResponseError(str(e)) # Convert to our error type
1595
+ finally:
1596
+ await client.close()
1673
1597
 
1674
1598
  except KeyboardInterrupt:
1675
- logger.error("Operation cancelled by user")
1599
+ logger.info("Operation cancelled by user")
1676
1600
  return ExitCode.INTERRUPTED
1677
- except PathSecurityError as e:
1678
- # Only log security errors if they haven't been logged already
1679
- logger.debug(
1680
- "[_main] Caught PathSecurityError: %s (logged=%s)",
1681
- str(e),
1682
- getattr(e, "has_been_logged", False),
1683
- )
1684
- if not getattr(e, "has_been_logged", False):
1685
- logger.error(str(e))
1686
- return ExitCode.SECURITY_ERROR
1687
- except ValueError as e:
1688
- # Get the original cause of the error
1689
- cause = e.__cause__ or e.__context__
1690
- if isinstance(cause, PathSecurityError):
1691
- logger.debug(
1692
- "[_main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
1693
- str(cause),
1694
- getattr(cause, "has_been_logged", False),
1695
- )
1696
- # Only log security errors if they haven't been logged already
1697
- if not getattr(cause, "has_been_logged", False):
1698
- logger.error(str(cause))
1699
- return ExitCode.SECURITY_ERROR
1700
- else:
1701
- logger.debug("[_main] Caught ValueError: %s", str(e))
1702
- logger.error(f"Invalid input: {e}")
1703
- return ExitCode.DATA_ERROR
1704
1601
  except Exception as e:
1705
- # Check if this is a wrapped security error
1706
- if isinstance(e.__cause__, PathSecurityError):
1707
- logger.debug(
1708
- "[_main] Caught wrapped PathSecurityError in Exception: %s (logged=%s)",
1709
- str(e.__cause__),
1710
- getattr(e.__cause__, "has_been_logged", False),
1711
- )
1712
- # Only log security errors if they haven't been logged already
1713
- if not getattr(e.__cause__, "has_been_logged", False):
1714
- logger.error(str(e.__cause__))
1715
- return ExitCode.SECURITY_ERROR
1716
- logger.debug("[_main] Caught unexpected error: %s", str(e))
1717
- logger.error(f"Unexpected error: {e}")
1718
- return ExitCode.INTERNAL_ERROR
1602
+ if isinstance(e, CLIError):
1603
+ raise # Let our custom errors propagate
1604
+ logger.exception("Unexpected error")
1605
+ raise CLIError(str(e), context={"error_type": type(e).__name__})
1606
+
1607
+
1608
+ def create_cli() -> click.Command:
1609
+ """Create the CLI command.
1610
+
1611
+ Returns:
1612
+ click.Command: The CLI command object
1613
+ """
1614
+
1615
+ @create_click_command()
1616
+ def cli(**kwargs: Any) -> None:
1617
+ """CLI entry point for structured OpenAI API calls."""
1618
+ try:
1619
+ args = Namespace(**kwargs)
1620
+
1621
+ # Validate required arguments first
1622
+ if not args.task and not args.task_file:
1623
+ raise click.UsageError(
1624
+ "Must specify either --task or --task-file"
1625
+ )
1626
+ if not args.schema_file:
1627
+ raise click.UsageError("Missing option '--schema-file'")
1628
+ if args.task and args.task_file:
1629
+ raise click.UsageError(
1630
+ "Cannot specify both --task and --task-file"
1631
+ )
1632
+ if args.system_prompt and args.system_prompt_file:
1633
+ raise click.UsageError(
1634
+ "Cannot specify both --system-prompt and --system-prompt-file"
1635
+ )
1636
+
1637
+ # Run the async function synchronously
1638
+ exit_code = asyncio.run(run_cli_async(args))
1639
+
1640
+ if exit_code != ExitCode.SUCCESS:
1641
+ error_msg = f"Command failed with exit code {exit_code}"
1642
+ if hasattr(ExitCode, exit_code.name):
1643
+ error_msg = f"{error_msg} ({exit_code.name})"
1644
+ raise CLIError(error_msg, context={"exit_code": exit_code})
1645
+
1646
+ except click.UsageError:
1647
+ # Let Click handle usage errors directly
1648
+ raise
1649
+ except InvalidJSONError:
1650
+ # Let InvalidJSONError propagate directly
1651
+ raise
1652
+ except CLIError:
1653
+ # Let our custom errors propagate with their context
1654
+ raise
1655
+ except Exception as e:
1656
+ # Convert other exceptions to CLIError
1657
+ logger.exception("Unexpected error")
1658
+ raise CLIError(str(e), context={"error_type": type(e).__name__})
1659
+
1660
+ # The decorated function is a Command, but mypy can't detect this
1661
+ return cast(click.Command, cli)
1719
1662
 
1720
1663
 
1721
1664
  def main() -> None:
1722
- """CLI entry point that handles all errors."""
1723
- try:
1724
- logger.debug("[main] Starting main execution")
1725
- exit_code = asyncio.run(_main())
1726
- sys.exit(exit_code.value)
1727
- except KeyboardInterrupt:
1728
- logger.error("Operation cancelled by user")
1729
- sys.exit(ExitCode.INTERRUPTED.value)
1730
- except PathSecurityError as e:
1731
- # Only log security errors if they haven't been logged already
1732
- logger.debug(
1733
- "[main] Caught PathSecurityError: %s (logged=%s)",
1734
- str(e),
1735
- getattr(e, "has_been_logged", False),
1736
- )
1737
- if not getattr(e, "has_been_logged", False):
1738
- logger.error(str(e))
1739
- sys.exit(ExitCode.SECURITY_ERROR.value)
1740
- except ValueError as e:
1741
- # Get the original cause of the error
1742
- cause = e.__cause__ or e.__context__
1743
- if isinstance(cause, PathSecurityError):
1744
- logger.debug(
1745
- "[main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
1746
- str(cause),
1747
- getattr(cause, "has_been_logged", False),
1748
- )
1749
- # Only log security errors if they haven't been logged already
1750
- if not getattr(cause, "has_been_logged", False):
1751
- logger.error(str(cause))
1752
- sys.exit(ExitCode.SECURITY_ERROR.value)
1753
- else:
1754
- logger.debug("[main] Caught ValueError: %s", str(e))
1755
- logger.error(f"Invalid input: {e}")
1756
- sys.exit(ExitCode.DATA_ERROR.value)
1757
- except Exception as e:
1758
- # Check if this is a wrapped security error
1759
- if isinstance(e.__cause__, PathSecurityError):
1760
- logger.debug(
1761
- "[main] Caught wrapped PathSecurityError in Exception: %s (logged=%s)",
1762
- str(e.__cause__),
1763
- getattr(e.__cause__, "has_been_logged", False),
1764
- )
1765
- # Only log security errors if they haven't been logged already
1766
- if not getattr(e.__cause__, "has_been_logged", False):
1767
- logger.error(str(e.__cause__))
1768
- sys.exit(ExitCode.SECURITY_ERROR.value)
1769
- logger.debug("[main] Caught unexpected error: %s", str(e))
1770
- logger.error(f"Unexpected error: {e}")
1771
- sys.exit(ExitCode.INTERNAL_ERROR.value)
1665
+ """Main entry point for the CLI."""
1666
+ cli = create_cli()
1667
+ cli(standalone_mode=False)
1772
1668
 
1773
1669
 
1774
1670
  # Export public API
@@ -1780,7 +1676,7 @@ __all__ = [
1780
1676
  "parse_json_var",
1781
1677
  "create_dynamic_model",
1782
1678
  "validate_path_mapping",
1783
- "create_argument_parser",
1679
+ "create_cli",
1784
1680
  "main",
1785
1681
  ]
1786
1682