ostruct-cli 0.4.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ostruct/cli/cli.py CHANGED
@@ -5,10 +5,10 @@ import json
5
5
  import logging
6
6
  import os
7
7
  import sys
8
- from dataclasses import dataclass
9
8
  from enum import Enum, IntEnum
10
9
  from typing import (
11
10
  Any,
11
+ AsyncGenerator,
12
12
  Dict,
13
13
  List,
14
14
  Literal,
@@ -16,6 +16,7 @@ from typing import (
16
16
  Set,
17
17
  Tuple,
18
18
  Type,
19
+ TypedDict,
19
20
  TypeVar,
20
21
  Union,
21
22
  cast,
@@ -31,20 +32,10 @@ from pathlib import Path
31
32
 
32
33
  import click
33
34
  import jinja2
34
- import tiktoken
35
35
  import yaml
36
- from openai import (
37
- APIConnectionError,
38
- AsyncOpenAI,
39
- AuthenticationError,
40
- BadRequestError,
41
- InternalServerError,
42
- RateLimitError,
43
- )
36
+ from openai import AsyncOpenAI
44
37
  from openai_structured.client import (
45
38
  async_openai_structured_stream,
46
- get_context_window_limit,
47
- get_default_token_limit,
48
39
  supports_structured_output,
49
40
  )
50
41
  from openai_structured.errors import (
@@ -54,12 +45,9 @@ from openai_structured.errors import (
54
45
  ModelNotSupportedError,
55
46
  ModelVersionError,
56
47
  OpenAIClientError,
57
- SchemaFileError,
58
- SchemaValidationError,
59
48
  StreamBufferError,
60
- StreamInterruptedError,
61
- StreamParseError,
62
49
  )
50
+ from openai_structured.model_registry import ModelRegistry
63
51
  from pydantic import (
64
52
  AnyUrl,
65
53
  BaseModel,
@@ -74,61 +62,60 @@ from pydantic.functional_validators import BeforeValidator
74
62
  from pydantic.types import constr
75
63
  from typing_extensions import TypeAlias
76
64
 
77
- from ostruct.cli.click_options import create_click_command
65
+ from ostruct.cli.click_options import all_options
66
+ from ostruct.cli.exit_codes import ExitCode
78
67
 
79
68
  from .. import __version__ # noqa: F401 - Used in package metadata
80
69
  from .errors import (
81
70
  CLIError,
82
71
  DirectoryNotFoundError,
83
72
  FieldDefinitionError,
84
- FileNotFoundError,
85
73
  InvalidJSONError,
86
74
  ModelCreationError,
87
75
  ModelValidationError,
88
76
  NestedModelError,
77
+ OstructFileNotFoundError,
89
78
  PathSecurityError,
79
+ SchemaFileError,
80
+ SchemaValidationError,
81
+ StreamInterruptedError,
82
+ StreamParseError,
90
83
  TaskTemplateSyntaxError,
91
84
  TaskTemplateVariableError,
92
- VariableError,
93
85
  VariableNameError,
94
86
  VariableValueError,
95
87
  )
96
- from .file_utils import FileInfoList, TemplateValue, collect_files
88
+ from .file_utils import FileInfoList, collect_files
97
89
  from .path_utils import validate_path_mapping
98
90
  from .security import SecurityManager
91
+ from .serialization import LogSerializer
99
92
  from .template_env import create_jinja_env
100
93
  from .template_utils import SystemPromptError, render_template
94
+ from .token_utils import estimate_tokens_with_encoding
101
95
 
102
96
  # Constants
103
97
  DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
104
98
 
105
99
 
106
- @dataclass
107
- class Namespace:
108
- """Compatibility class to mimic argparse.Namespace for existing code."""
100
+ class CLIParams(TypedDict, total=False):
101
+ """Type-safe CLI parameters."""
109
102
 
110
- task: Optional[str]
111
- task_file: Optional[str]
112
- file: List[str]
113
- files: List[str]
114
- dir: List[str]
115
- allowed_dir: List[str]
103
+ files: List[
104
+ Tuple[str, str]
105
+ ] # List of (name, path) tuples from Click's nargs=2
106
+ dir: List[
107
+ Tuple[str, str]
108
+ ] # List of (name, dir) tuples from Click's nargs=2
109
+ allowed_dirs: List[str]
116
110
  base_dir: str
117
111
  allowed_dir_file: Optional[str]
118
- dir_recursive: bool
119
- dir_ext: Optional[str]
112
+ recursive: bool
120
113
  var: List[str]
121
114
  json_var: List[str]
122
115
  system_prompt: Optional[str]
123
116
  system_prompt_file: Optional[str]
124
117
  ignore_task_sysprompt: bool
125
- schema_file: str
126
118
  model: str
127
- temperature: float
128
- max_tokens: Optional[int]
129
- top_p: float
130
- frequency_penalty: float
131
- presence_penalty: float
132
119
  timeout: float
133
120
  output_file: Optional[str]
134
121
  dry_run: bool
@@ -138,7 +125,16 @@ class Namespace:
138
125
  debug_openai_stream: bool
139
126
  show_model_schema: bool
140
127
  debug_validation: bool
141
- progress_level: str = "basic" # Default to 'basic' if not specified
128
+ temperature: Optional[float]
129
+ max_output_tokens: Optional[int]
130
+ top_p: Optional[float]
131
+ frequency_penalty: Optional[float]
132
+ presence_penalty: Optional[float]
133
+ reasoning_effort: Optional[str]
134
+ progress_level: str
135
+ task_file: Optional[str]
136
+ task: Optional[str]
137
+ schema_file: str
142
138
 
143
139
 
144
140
  # Set up logging
@@ -176,45 +172,6 @@ ostruct_file_handler.setFormatter(
176
172
  logger.addHandler(ostruct_file_handler)
177
173
 
178
174
 
179
- class ExitCode(IntEnum):
180
- """Exit codes for the CLI following standard Unix conventions.
181
-
182
- Categories:
183
- - Success (0-1)
184
- - User Interruption (2-3)
185
- - Input/Validation (64-69)
186
- - I/O and File Access (70-79)
187
- - API and External Services (80-89)
188
- - Internal Errors (90-99)
189
- """
190
-
191
- # Success codes
192
- SUCCESS = 0
193
-
194
- # User interruption
195
- INTERRUPTED = 2
196
-
197
- # Input/Validation errors (64-69)
198
- USAGE_ERROR = 64
199
- DATA_ERROR = 65
200
- SCHEMA_ERROR = 66
201
- VALIDATION_ERROR = 67
202
-
203
- # I/O and File Access errors (70-79)
204
- IO_ERROR = 70
205
- FILE_NOT_FOUND = 71
206
- PERMISSION_ERROR = 72
207
- SECURITY_ERROR = 73
208
-
209
- # API and External Service errors (80-89)
210
- API_ERROR = 80
211
- API_TIMEOUT = 81
212
-
213
- # Internal errors (90-99)
214
- INTERNAL_ERROR = 90
215
- UNKNOWN_ERROR = 91
216
-
217
-
218
175
  # Type aliases
219
176
  FieldType = (
220
177
  Any # Changed from Type[Any] to allow both concrete types and generics
@@ -281,7 +238,7 @@ def _get_type_with_constraints(
281
238
  show_schema=False,
282
239
  debug_validation=False,
283
240
  )
284
- array_type: Type[List[Any]] = List[array_item_model] # type: ignore[valid-type]
241
+ array_type: Type[List[Any]] = List[array_item_model] # type: ignore
285
242
  return (array_type, Field(**field_kwargs))
286
243
 
287
244
  # For non-object items, use the type directly
@@ -403,64 +360,17 @@ K = TypeVar("K")
403
360
  V = TypeVar("V")
404
361
 
405
362
 
406
- def estimate_tokens_for_chat(
407
- messages: List[Dict[str, str]],
408
- model: str,
409
- encoder: Any = None,
410
- ) -> int:
411
- """Estimate the number of tokens in a chat completion.
412
-
413
- Args:
414
- messages: List of chat messages
415
- model: Model name
416
- encoder: Optional tiktoken encoder for testing. If provided, only uses encoder.encode() results.
417
- """
418
- if encoder is None:
419
- try:
420
- # Try to get the encoding for the specific model
421
- encoder = tiktoken.get_encoding("o200k_base")
422
- except KeyError:
423
- # Fall back to cl100k_base for unknown models
424
- encoder = tiktoken.get_encoding("cl100k_base")
425
-
426
- # Use standard token counting logic for real tiktoken encoders
427
- num_tokens = 0
428
- for message in messages:
429
- # Add message overhead
430
- num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
431
- for key, value in message.items():
432
- num_tokens += len(encoder.encode(str(value)))
433
- if key == "name": # if there's a name, the role is omitted
434
- num_tokens -= 1 # role is omitted
435
- num_tokens += 2 # every reply is primed with <im_start>assistant
436
- return num_tokens
437
- else:
438
- # For mock encoders in tests, just return the length of encoded content
439
- num_tokens = 0
440
- for message in messages:
441
- for value in message.values():
442
- num_tokens += len(encoder.encode(str(value)))
443
- return num_tokens
444
-
445
-
446
363
  def validate_token_limits(
447
364
  model: str, total_tokens: int, max_token_limit: Optional[int] = None
448
365
  ) -> None:
449
- """Validate token counts against model limits.
450
-
451
- Args:
452
- model: The model name
453
- total_tokens: Total number of tokens in the prompt
454
- max_token_limit: Optional user-specified token limit
455
-
456
- Raises:
457
- ValueError: If token limits are exceeded
458
- """
459
- context_limit = get_context_window_limit(model)
366
+ """Validate token counts against model limits."""
367
+ registry = ModelRegistry()
368
+ capabilities = registry.get_capabilities(model)
369
+ context_limit = capabilities.context_window
460
370
  output_limit = (
461
371
  max_token_limit
462
372
  if max_token_limit is not None
463
- else get_default_token_limit(model)
373
+ else capabilities.max_output_tokens
464
374
  )
465
375
 
466
376
  # Check if total tokens exceed context window
@@ -522,8 +432,12 @@ def process_system_prompt(
522
432
  )
523
433
  with open(path, "r", encoding="utf-8") as f:
524
434
  system_prompt = f.read().strip()
525
- except (FileNotFoundError, PathSecurityError) as e:
526
- raise SystemPromptError(f"Invalid system prompt file: {e}")
435
+ except OstructFileNotFoundError as e:
436
+ raise SystemPromptError(
437
+ f"Failed to load system prompt file: {e}"
438
+ ) from e
439
+ except PathSecurityError as e:
440
+ raise SystemPromptError(f"Invalid system prompt file: {e}") from e
527
441
 
528
442
  if system_prompt is not None:
529
443
  # Render system prompt with template context
@@ -591,7 +505,8 @@ def validate_variable_mapping(
591
505
  value = json.loads(value)
592
506
  except json.JSONDecodeError as e:
593
507
  raise InvalidJSONError(
594
- f"Invalid JSON value for variable {name!r}: {value!r}"
508
+ f"Invalid JSON value for variable {name!r}: {value!r}",
509
+ context={"variable_name": name},
595
510
  ) from e
596
511
 
597
512
  return name, value
@@ -787,11 +702,20 @@ def validate_task_template(
787
702
  template_content: str
788
703
  if task_file is not None:
789
704
  try:
790
- name, path = validate_path_mapping(f"task={task_file}")
791
- with open(path, "r", encoding="utf-8") as f:
705
+ with open(task_file, "r", encoding="utf-8") as f:
792
706
  template_content = f.read()
793
- except (FileNotFoundError, PathSecurityError) as e:
794
- raise TaskTemplateVariableError(str(e))
707
+ except FileNotFoundError:
708
+ raise TaskTemplateVariableError(
709
+ f"Task template file not found: {task_file}"
710
+ )
711
+ except PermissionError:
712
+ raise TaskTemplateVariableError(
713
+ f"Permission denied reading task template file: {task_file}"
714
+ )
715
+ except Exception as e:
716
+ raise TaskTemplateVariableError(
717
+ f"Error reading task template file: {e}"
718
+ )
795
719
  else:
796
720
  template_content = task # type: ignore # We know task is str here due to the checks above
797
721
 
@@ -809,10 +733,10 @@ def validate_schema_file(
809
733
  path: str,
810
734
  verbose: bool = False,
811
735
  ) -> Dict[str, Any]:
812
- """Validate a JSON schema file.
736
+ """Validate and load a JSON schema file.
813
737
 
814
738
  Args:
815
- path: Path to the schema file
739
+ path: Path to schema file
816
740
  verbose: Whether to enable verbose logging
817
741
 
818
742
  Returns:
@@ -827,14 +751,42 @@ def validate_schema_file(
827
751
  logger.info("Validating schema file: %s", path)
828
752
 
829
753
  try:
830
- with open(path) as f:
831
- schema = json.load(f)
754
+ logger.debug("Opening schema file: %s", path)
755
+ with open(path, "r", encoding="utf-8") as f:
756
+ logger.debug("Loading JSON from schema file")
757
+ try:
758
+ schema = json.load(f)
759
+ logger.debug(
760
+ "Successfully loaded JSON: %s",
761
+ json.dumps(schema, indent=2),
762
+ )
763
+ except json.JSONDecodeError as e:
764
+ logger.error("JSON decode error in %s: %s", path, str(e))
765
+ logger.debug(
766
+ "Error details - line: %d, col: %d, msg: %s",
767
+ e.lineno,
768
+ e.colno,
769
+ e.msg,
770
+ )
771
+ raise InvalidJSONError(
772
+ f"Invalid JSON in schema file {path}: {e}",
773
+ context={"schema_path": path},
774
+ ) from e
832
775
  except FileNotFoundError:
833
- raise SchemaFileError(f"Schema file not found: {path}")
834
- except json.JSONDecodeError as e:
835
- raise InvalidJSONError(f"Invalid JSON in schema file: {e}")
776
+ msg = f"Schema file not found: {path}"
777
+ logger.error(msg)
778
+ raise SchemaFileError(msg, schema_path=path)
779
+ except PermissionError:
780
+ msg = f"Permission denied reading schema file: {path}"
781
+ logger.error(msg)
782
+ raise SchemaFileError(msg, schema_path=path)
836
783
  except Exception as e:
837
- raise SchemaFileError(f"Failed to read schema file: {e}")
784
+ if isinstance(e, InvalidJSONError):
785
+ raise
786
+ msg = f"Failed to read schema file {path}: {e}"
787
+ logger.error(msg)
788
+ logger.debug("Unexpected error details: %s", str(e))
789
+ raise SchemaFileError(msg, schema_path=path) from e
838
790
 
839
791
  # Pre-validation structure checks
840
792
  if verbose:
@@ -842,11 +794,9 @@ def validate_schema_file(
842
794
  logger.debug("Loaded schema: %s", json.dumps(schema, indent=2))
843
795
 
844
796
  if not isinstance(schema, dict):
845
- if verbose:
846
- logger.error(
847
- "Schema is not a dictionary: %s", type(schema).__name__
848
- )
849
- raise SchemaValidationError("Schema must be a JSON object")
797
+ msg = f"Schema in {path} must be a JSON object"
798
+ logger.error(msg)
799
+ raise SchemaValidationError(msg, schema_path=path)
850
800
 
851
801
  # Validate schema structure
852
802
  if "schema" in schema:
@@ -854,30 +804,37 @@ def validate_schema_file(
854
804
  logger.debug("Found schema wrapper, validating inner schema")
855
805
  inner_schema = schema["schema"]
856
806
  if not isinstance(inner_schema, dict):
857
- if verbose:
858
- logger.error(
859
- "Inner schema is not a dictionary: %s",
860
- type(inner_schema).__name__,
861
- )
862
- raise SchemaValidationError("Inner schema must be a JSON object")
807
+ msg = f"Inner schema in {path} must be a JSON object"
808
+ logger.error(msg)
809
+ raise SchemaValidationError(msg, schema_path=path)
863
810
  if verbose:
864
811
  logger.debug("Inner schema validated successfully")
812
+ logger.debug(
813
+ "Inner schema: %s", json.dumps(inner_schema, indent=2)
814
+ )
865
815
  else:
866
816
  if verbose:
867
817
  logger.debug("No schema wrapper found, using schema as-is")
818
+ logger.debug("Schema: %s", json.dumps(schema, indent=2))
819
+
820
+ # Additional schema validation
821
+ if "type" not in schema.get("schema", schema):
822
+ msg = f"Schema in {path} must specify a type"
823
+ logger.error(msg)
824
+ raise SchemaValidationError(msg, schema_path=path)
868
825
 
869
826
  # Return the full schema including wrapper
870
827
  return schema
871
828
 
872
829
 
873
830
  def collect_template_files(
874
- args: Namespace,
831
+ args: CLIParams,
875
832
  security_manager: SecurityManager,
876
- ) -> Dict[str, TemplateValue]:
833
+ ) -> Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]:
877
834
  """Collect files from command line arguments.
878
835
 
879
836
  Args:
880
- args: Parsed command line arguments
837
+ args: Command line arguments
881
838
  security_manager: Security manager for path validation
882
839
 
883
840
  Returns:
@@ -888,15 +845,29 @@ def collect_template_files(
888
845
  ValueError: If file mappings are invalid or files cannot be accessed
889
846
  """
890
847
  try:
891
- result = collect_files(
892
- file_mappings=args.file,
893
- pattern_mappings=args.files,
894
- dir_mappings=args.dir,
895
- dir_recursive=args.dir_recursive,
896
- dir_extensions=args.dir_ext.split(",") if args.dir_ext else None,
848
+ # Get files and directories from args - they are already tuples from Click's nargs=2
849
+ files = list(
850
+ args.get("files", [])
851
+ ) # List of (name, path) tuples from Click
852
+ dirs = args.get("dir", []) # List of (name, dir) tuples from Click
853
+
854
+ # Collect files from directories
855
+ dir_files = collect_files(
856
+ file_mappings=cast(
857
+ List[Tuple[str, Union[str, Path]]], files
858
+ ), # Cast to correct type
859
+ dir_mappings=cast(
860
+ List[Tuple[str, Union[str, Path]]], dirs
861
+ ), # Cast to correct type
862
+ dir_recursive=args.get("recursive", False),
897
863
  security_manager=security_manager,
898
864
  )
899
- return cast(Dict[str, TemplateValue], result)
865
+
866
+ # Combine results
867
+ return cast(
868
+ Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]],
869
+ dir_files,
870
+ )
900
871
  except PathSecurityError:
901
872
  # Let PathSecurityError propagate without wrapping
902
873
  raise
@@ -914,11 +885,11 @@ def collect_template_files(
914
885
  raise ValueError(f"Error collecting files: {e}")
915
886
 
916
887
 
917
- def collect_simple_variables(args: Namespace) -> Dict[str, str]:
888
+ def collect_simple_variables(args: CLIParams) -> Dict[str, str]:
918
889
  """Collect simple string variables from --var arguments.
919
890
 
920
891
  Args:
921
- args: Parsed command line arguments
892
+ args: Command line arguments
922
893
 
923
894
  Returns:
924
895
  Dictionary mapping variable names to string values
@@ -929,10 +900,15 @@ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
929
900
  variables: Dict[str, str] = {}
930
901
  all_names: Set[str] = set()
931
902
 
932
- if args.var:
933
- for mapping in args.var:
903
+ if args.get("var"):
904
+ for mapping in args["var"]:
934
905
  try:
935
- name, value = mapping.split("=", 1)
906
+ # Handle both tuple format and string format
907
+ if isinstance(mapping, tuple):
908
+ name, value = mapping
909
+ else:
910
+ name, value = mapping.split("=", 1)
911
+
936
912
  if not name.isidentifier():
937
913
  raise VariableNameError(f"Invalid variable name: {name}")
938
914
  if name in all_names:
@@ -947,11 +923,11 @@ def collect_simple_variables(args: Namespace) -> Dict[str, str]:
947
923
  return variables
948
924
 
949
925
 
950
- def collect_json_variables(args: Namespace) -> Dict[str, Any]:
926
+ def collect_json_variables(args: CLIParams) -> Dict[str, Any]:
951
927
  """Collect JSON variables from --json-var arguments.
952
928
 
953
929
  Args:
954
- args: Parsed command line arguments
930
+ args: Command line arguments
955
931
 
956
932
  Returns:
957
933
  Dictionary mapping variable names to parsed JSON values
@@ -963,32 +939,46 @@ def collect_json_variables(args: Namespace) -> Dict[str, Any]:
963
939
  variables: Dict[str, Any] = {}
964
940
  all_names: Set[str] = set()
965
941
 
966
- if args.json_var:
967
- for mapping in args.json_var:
942
+ if args.get("json_var"):
943
+ for mapping in args["json_var"]:
968
944
  try:
969
- name, json_str = mapping.split("=", 1)
945
+ # Handle both tuple format and string format
946
+ if isinstance(mapping, tuple):
947
+ name, value = (
948
+ mapping # Value is already parsed by Click validator
949
+ )
950
+ else:
951
+ try:
952
+ name, json_str = mapping.split("=", 1)
953
+ except ValueError:
954
+ raise VariableNameError(
955
+ f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
956
+ )
957
+ try:
958
+ value = json.loads(json_str)
959
+ except json.JSONDecodeError as e:
960
+ raise InvalidJSONError(
961
+ f"Invalid JSON value for variable '{name}': {json_str}",
962
+ context={"variable_name": name},
963
+ ) from e
964
+
970
965
  if not name.isidentifier():
971
966
  raise VariableNameError(f"Invalid variable name: {name}")
972
967
  if name in all_names:
973
968
  raise VariableNameError(f"Duplicate variable name: {name}")
974
- try:
975
- value = json.loads(json_str)
976
- variables[name] = value
977
- all_names.add(name)
978
- except json.JSONDecodeError as e:
979
- raise InvalidJSONError(
980
- f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
981
- )
982
- except ValueError:
983
- raise VariableNameError(
984
- f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
985
- )
969
+
970
+ variables[name] = value
971
+ all_names.add(name)
972
+ except (VariableNameError, InvalidJSONError):
973
+ raise
986
974
 
987
975
  return variables
988
976
 
989
977
 
990
978
  def create_template_context(
991
- files: Optional[Dict[str, FileInfoList]] = None,
979
+ files: Optional[
980
+ Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]
981
+ ] = None,
992
982
  variables: Optional[Dict[str, str]] = None,
993
983
  json_variables: Optional[Dict[str, Any]] = None,
994
984
  security_manager: Optional[SecurityManager] = None,
@@ -1032,14 +1022,14 @@ def create_template_context(
1032
1022
  return context
1033
1023
 
1034
1024
 
1035
- def create_template_context_from_args(
1036
- args: "Namespace",
1025
+ async def create_template_context_from_args(
1026
+ args: CLIParams,
1037
1027
  security_manager: SecurityManager,
1038
1028
  ) -> Dict[str, Any]:
1039
1029
  """Create template context from command line arguments.
1040
1030
 
1041
1031
  Args:
1042
- args: Parsed command line arguments
1032
+ args: Command line arguments
1043
1033
  security_manager: Security manager for path validation
1044
1034
 
1045
1035
  Returns:
@@ -1052,50 +1042,13 @@ def create_template_context_from_args(
1052
1042
  """
1053
1043
  try:
1054
1044
  # Collect files from arguments
1055
- files = None
1056
- if any([args.file, args.files, args.dir]):
1057
- files = collect_files(
1058
- file_mappings=args.file,
1059
- pattern_mappings=args.files,
1060
- dir_mappings=args.dir,
1061
- dir_recursive=args.dir_recursive,
1062
- dir_extensions=(
1063
- args.dir_ext.split(",") if args.dir_ext else None
1064
- ),
1065
- security_manager=security_manager,
1066
- )
1045
+ files = collect_template_files(args, security_manager)
1067
1046
 
1068
1047
  # Collect simple variables
1069
- try:
1070
- variables = collect_simple_variables(args)
1071
- except VariableNameError as e:
1072
- raise VariableError(str(e))
1048
+ variables = collect_simple_variables(args)
1073
1049
 
1074
1050
  # Collect JSON variables
1075
- json_variables = {}
1076
- if args.json_var:
1077
- for mapping in args.json_var:
1078
- try:
1079
- name, value = mapping.split("=", 1)
1080
- if not name.isidentifier():
1081
- raise VariableNameError(
1082
- f"Invalid variable name: {name}"
1083
- )
1084
- try:
1085
- json_value = json.loads(value)
1086
- except json.JSONDecodeError as e:
1087
- raise InvalidJSONError(
1088
- f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {value}"
1089
- )
1090
- if name in json_variables:
1091
- raise VariableNameError(
1092
- f"Duplicate variable name: {name}"
1093
- )
1094
- json_variables[name] = json_value
1095
- except ValueError:
1096
- raise VariableNameError(
1097
- f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
1098
- )
1051
+ json_variables = collect_json_variables(args)
1099
1052
 
1100
1053
  # Get stdin content if available
1101
1054
  stdin_content = None
@@ -1106,7 +1059,7 @@ def create_template_context_from_args(
1106
1059
  # Skip stdin if it can't be read
1107
1060
  pass
1108
1061
 
1109
- return create_template_context(
1062
+ context = create_template_context(
1110
1063
  files=files,
1111
1064
  variables=variables,
1112
1065
  json_variables=json_variables,
@@ -1114,6 +1067,11 @@ def create_template_context_from_args(
1114
1067
  stdin_content=stdin_content,
1115
1068
  )
1116
1069
 
1070
+ # Add current model to context
1071
+ context["current_model"] = args["model"]
1072
+
1073
+ return context
1074
+
1117
1075
  except PathSecurityError:
1118
1076
  # Let PathSecurityError propagate without wrapping
1119
1077
  raise
@@ -1235,7 +1193,8 @@ def parse_json_var(var_str: str) -> Tuple[str, Any]:
1235
1193
  value = json.loads(json_str)
1236
1194
  except json.JSONDecodeError as e:
1237
1195
  raise InvalidJSONError(
1238
- f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}"
1196
+ f"Error parsing JSON for variable '{name}': {str(e)}. Input was: {json_str}",
1197
+ context={"variable_name": name},
1239
1198
  )
1240
1199
 
1241
1200
  return name, value
@@ -1284,41 +1243,96 @@ def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
1284
1243
 
1285
1244
 
1286
1245
  def handle_error(e: Exception) -> None:
1287
- """Handle errors by printing appropriate message and exiting with status code."""
1246
+ """Handle CLI errors and display appropriate messages.
1247
+
1248
+ Maintains specific error type handling while reducing duplication.
1249
+ Provides enhanced debug logging for CLI errors.
1250
+ """
1251
+ # 1. Determine error type and message
1288
1252
  if isinstance(e, click.UsageError):
1289
- # For UsageError, preserve the original message format
1290
- if hasattr(e, "param") and e.param:
1291
- # Missing parameter error
1292
- msg = f"Missing option '--{e.param.name}'"
1293
- click.echo(msg, err=True)
1294
- else:
1295
- # Other usage errors (like conflicting options)
1296
- click.echo(str(e), err=True)
1297
- sys.exit(ExitCode.USAGE_ERROR)
1298
- elif isinstance(e, InvalidJSONError):
1299
- # Use the original error message if available
1300
- msg = str(e) if str(e) != "None" else "Invalid JSON"
1301
- click.secho(msg, fg="red", err=True)
1302
- sys.exit(ExitCode.DATA_ERROR)
1303
- elif isinstance(e, FileNotFoundError):
1304
- # Use the original error message if available
1305
- msg = str(e) if str(e) != "None" else "File not found"
1306
- click.secho(msg, fg="red", err=True)
1307
- sys.exit(ExitCode.SCHEMA_ERROR)
1308
- elif isinstance(e, TaskTemplateSyntaxError):
1309
- # Use the original error message if available
1310
- msg = str(e) if str(e) != "None" else "Template syntax error"
1311
- click.secho(msg, fg="red", err=True)
1312
- sys.exit(ExitCode.INTERNAL_ERROR)
1253
+ msg = f"Usage error: {str(e)}"
1254
+ exit_code = ExitCode.USAGE_ERROR
1255
+ elif isinstance(e, SchemaFileError):
1256
+ # Preserve specific schema error handling
1257
+ msg = str(e) # Use existing __str__ formatting
1258
+ exit_code = ExitCode.SCHEMA_ERROR
1259
+ elif isinstance(e, (InvalidJSONError, json.JSONDecodeError)):
1260
+ msg = f"Invalid JSON error: {str(e)}"
1261
+ exit_code = ExitCode.DATA_ERROR
1262
+ elif isinstance(e, SchemaValidationError):
1263
+ msg = f"Schema validation error: {str(e)}"
1264
+ exit_code = ExitCode.VALIDATION_ERROR
1313
1265
  elif isinstance(e, CLIError):
1314
- # Use the show method for CLIError and its subclasses
1315
- e.show()
1316
- sys.exit(
1317
- e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
1266
+ msg = str(e) # Use existing __str__ formatting
1267
+ exit_code = ExitCode(e.exit_code) # Convert int to ExitCode
1268
+ else:
1269
+ msg = f"Unexpected error: {str(e)}"
1270
+ exit_code = ExitCode.INTERNAL_ERROR
1271
+
1272
+ # 2. Debug logging
1273
+ if isinstance(e, CLIError) and logger.isEnabledFor(logging.DEBUG):
1274
+ # Format context fields with lowercase keys and simple values
1275
+ context_str = ""
1276
+ if hasattr(e, "context"):
1277
+ for key, value in sorted(e.context.items()):
1278
+ if key not in {
1279
+ "timestamp",
1280
+ "host",
1281
+ "version",
1282
+ "python_version",
1283
+ }:
1284
+ context_str += f"{key.lower()}: {value}\n"
1285
+
1286
+ logger.debug(
1287
+ "Error details:\n"
1288
+ f"Type: {type(e).__name__}\n"
1289
+ f"{context_str.rstrip()}"
1318
1290
  )
1291
+ elif not isinstance(e, click.UsageError):
1292
+ logger.error(msg, exc_info=True)
1319
1293
  else:
1320
- click.secho(f"Unexpected error: {str(e)}", fg="red", err=True)
1321
- sys.exit(ExitCode.INTERNAL_ERROR)
1294
+ logger.error(msg)
1295
+
1296
+ # 3. User output
1297
+ click.secho(msg, fg="red", err=True)
1298
+ sys.exit(exit_code)
1299
+
1300
+
1301
+ def validate_model_parameters(model: str, params: Dict[str, Any]) -> None:
1302
+ """Validate model parameters against model capabilities.
1303
+
1304
+ Args:
1305
+ model: The model name to validate parameters for
1306
+ params: Dictionary of parameter names and values to validate
1307
+
1308
+ Raises:
1309
+ CLIError: If any parameters are not supported by the model
1310
+ """
1311
+ try:
1312
+ capabilities = ModelRegistry().get_capabilities(model)
1313
+ for param_name, value in params.items():
1314
+ try:
1315
+ capabilities.validate_parameter(param_name, value)
1316
+ except OpenAIClientError as e:
1317
+ logger.error(
1318
+ "Validation failed for model %s: %s", model, str(e)
1319
+ )
1320
+ raise CLIError(
1321
+ str(e),
1322
+ exit_code=ExitCode.VALIDATION_ERROR,
1323
+ context={
1324
+ "model": model,
1325
+ "param": param_name,
1326
+ "value": value,
1327
+ },
1328
+ )
1329
+ except (ModelNotSupportedError, ModelVersionError) as e:
1330
+ logger.error("Model validation failed: %s", str(e))
1331
+ raise CLIError(
1332
+ str(e),
1333
+ exit_code=ExitCode.VALIDATION_ERROR,
1334
+ context={"model": model},
1335
+ )
1322
1336
 
1323
1337
 
1324
1338
  async def stream_structured_output(
@@ -1329,85 +1343,75 @@ async def stream_structured_output(
1329
1343
  output_schema: Type[BaseModel],
1330
1344
  output_file: Optional[str] = None,
1331
1345
  **kwargs: Any,
1332
- ) -> None:
1346
+ ) -> AsyncGenerator[BaseModel, None]:
1333
1347
  """Stream structured output from OpenAI API.
1334
1348
 
1335
1349
  This function follows the guide's recommendation for a focused async streaming function.
1336
1350
  It handles the core streaming logic and resource cleanup.
1337
- """
1338
- try:
1339
- # Base models that don't support streaming
1340
- non_streaming_models = {"o1", "o3"}
1341
1351
 
1342
- # Check if model supports streaming
1343
- # o3-mini and o3-mini-high support streaming, base o3 does not
1344
- use_streaming = model not in non_streaming_models and (
1345
- not model.startswith("o3") or model.startswith("o3-mini")
1346
- )
1352
+ Args:
1353
+ client: The OpenAI client to use
1354
+ model: The model to use
1355
+ system_prompt: The system prompt to use
1356
+ user_prompt: The user prompt to use
1357
+ output_schema: The Pydantic model to validate responses against
1358
+ output_file: Optional file to write output to
1359
+ **kwargs: Additional parameters to pass to the API
1347
1360
 
1348
- # All o1 and o3 models (base and variants) have fixed settings
1349
- stream_kwargs = {}
1350
- if not (model.startswith("o1") or model.startswith("o3")):
1351
- stream_kwargs = kwargs
1352
-
1353
- if use_streaming:
1354
- async for chunk in async_openai_structured_stream(
1355
- client=client,
1356
- model=model,
1357
- output_schema=output_schema,
1358
- system_prompt=system_prompt,
1359
- user_prompt=user_prompt,
1360
- **stream_kwargs,
1361
- ):
1362
- if not chunk:
1363
- continue
1364
-
1365
- # Process and output the chunk
1366
- dumped = chunk.model_dump(mode="json")
1367
- json_str = json.dumps(dumped, indent=2)
1368
-
1369
- if output_file:
1370
- with open(output_file, "a", encoding="utf-8") as f:
1371
- f.write(json_str)
1372
- f.write("\n")
1373
- f.flush() # Ensure immediate flush to file
1374
- else:
1375
- # Print directly to stdout with immediate flush
1376
- print(json_str, flush=True)
1377
- else:
1378
- # For non-streaming models, use regular completion
1379
- response = await client.chat.completions.create(
1380
- model=model,
1381
- messages=[
1382
- {"role": "system", "content": system_prompt},
1383
- {"role": "user", "content": user_prompt},
1384
- ],
1385
- stream=False,
1386
- **stream_kwargs,
1361
+ Returns:
1362
+ An async generator yielding validated model instances
1363
+
1364
+ Raises:
1365
+ ValueError: If the model does not support structured output or parameters are invalid
1366
+ StreamInterruptedError: If the stream is interrupted
1367
+ APIResponseError: If there is an API error
1368
+ """
1369
+ try:
1370
+ # Check if model supports structured output using openai_structured's function
1371
+ if not supports_structured_output(model):
1372
+ raise ValueError(
1373
+ f"Model {model} does not support structured output with json_schema response format. "
1374
+ "Please use a model that supports structured output."
1387
1375
  )
1388
1376
 
1389
- # Process the single response
1390
- content = response.choices[0].message.content
1391
- if content:
1392
- try:
1393
- # Parse and validate against schema
1394
- result = output_schema.model_validate_json(content)
1395
- json_str = json.dumps(
1396
- result.model_dump(mode="json"), indent=2
1397
- )
1377
+ # Extract non-model parameters
1378
+ on_log = kwargs.pop("on_log", None)
1398
1379
 
1399
- if output_file:
1400
- with open(output_file, "w", encoding="utf-8") as f:
1401
- f.write(json_str)
1402
- f.write("\n")
1403
- else:
1404
- print(json_str, flush=True)
1405
- except ValidationError as e:
1406
- raise InvalidResponseFormatError(
1407
- f"Response validation failed: {e}"
1408
- )
1380
+ # Handle model-specific parameters
1381
+ stream_kwargs = {}
1382
+ registry = ModelRegistry()
1383
+ capabilities = registry.get_capabilities(model)
1384
+
1385
+ # Validate and include supported parameters
1386
+ for param_name, value in kwargs.items():
1387
+ if param_name in capabilities.supported_parameters:
1388
+ # Validate the parameter value
1389
+ capabilities.validate_parameter(param_name, value)
1390
+ stream_kwargs[param_name] = value
1409
1391
  else:
1410
- raise EmptyResponseError("Model returned empty response")
1392
+ logger.warning(
1393
+ f"Parameter {param_name} is not supported by model {model} and will be ignored"
1394
+ )
1395
+
1396
+ # Log the API request details
1397
+ logger.debug("Making OpenAI API request with:")
1398
+ logger.debug("Model: %s", model)
1399
+ logger.debug("System prompt: %s", system_prompt)
1400
+ logger.debug("User prompt: %s", user_prompt)
1401
+ logger.debug("Parameters: %s", json.dumps(stream_kwargs, indent=2))
1402
+ logger.debug("Schema: %s", output_schema.model_json_schema())
1403
+
1404
+ # Use the async generator from openai_structured directly
1405
+ async for chunk in async_openai_structured_stream(
1406
+ client=client,
1407
+ model=model,
1408
+ system_prompt=system_prompt,
1409
+ user_prompt=user_prompt,
1410
+ output_schema=output_schema,
1411
+ on_log=on_log, # Pass non-model parameters directly to the function
1412
+ **stream_kwargs, # Pass only validated model parameters
1413
+ ):
1414
+ yield chunk
1411
1415
 
1412
1416
  except (
1413
1417
  StreamInterruptedError,
@@ -1424,149 +1428,458 @@ async def stream_structured_output(
1424
1428
  await client.close()
1425
1429
 
1426
1430
 
1427
- async def run_cli_async(args: Namespace) -> ExitCode:
1428
- """Async wrapper for CLI operations.
1431
+ @click.group()
1432
+ @click.version_option(version=__version__)
1433
+ def cli() -> None:
1434
+ """ostruct CLI - Make structured OpenAI API calls.
1435
+
1436
+ ostruct allows you to invoke OpenAI Structured Output to produce structured JSON
1437
+ output using templates and JSON schemas. It provides support for file handling, variable
1438
+ substitution, and output validation.
1439
+
1440
+ For detailed documentation, visit: https://ostruct.readthedocs.io
1441
+
1442
+ Examples:
1429
1443
 
1430
- This function prepares everything needed for streaming and then calls
1431
- the focused streaming function.
1444
+ # Basic usage with a template and schema
1445
+
1446
+ ostruct run task.j2 schema.json -V name=value
1447
+
1448
+ # Process files with recursive directory scanning
1449
+
1450
+ ostruct run template.j2 schema.json -f code main.py -d src ./src -R
1451
+
1452
+ # Use JSON variables and custom model parameters
1453
+
1454
+ ostruct run task.j2 schema.json -J config='{"env":"prod"}' -m o3-mini
1455
+ """
1456
+ pass
1457
+
1458
+
1459
+ @cli.command()
1460
+ @click.argument("task_template", type=click.Path(exists=True))
1461
+ @click.argument("schema_file", type=click.Path(exists=True))
1462
+ @all_options
1463
+ @click.pass_context
1464
+ def run(
1465
+ ctx: click.Context,
1466
+ task_template: str,
1467
+ schema_file: str,
1468
+ **kwargs: Any,
1469
+ ) -> None:
1470
+ """Run a structured task with template and schema.
1471
+
1472
+ TASK_TEMPLATE is the path to your Jinja2 template file that defines the task.
1473
+ SCHEMA_FILE is the path to your JSON schema file that defines the expected output structure.
1474
+
1475
+ The command supports various options for file handling, variable definition,
1476
+ model configuration, and output control. Use --help to see all available options.
1477
+
1478
+ Examples:
1479
+ # Basic usage
1480
+ ostruct run task.j2 schema.json
1481
+
1482
+ # Process multiple files
1483
+ ostruct run task.j2 schema.json -f code main.py -f test tests/test_main.py
1484
+
1485
+ # Scan directories recursively
1486
+ ostruct run task.j2 schema.json -d src ./src -R
1487
+
1488
+ # Define variables
1489
+ ostruct run task.j2 schema.json -V debug=true -J config='{"env":"prod"}'
1490
+
1491
+ # Configure model
1492
+ ostruct run task.j2 schema.json -m gpt-4 --temperature 0.7 --max-output-tokens 1000
1493
+
1494
+ # Control output
1495
+ ostruct run task.j2 schema.json --output-file result.json --verbose
1432
1496
  """
1433
1497
  try:
1434
- # Validate and prepare all inputs
1435
- security_manager = validate_security_manager(
1436
- base_dir=args.base_dir,
1437
- allowed_dirs=args.allowed_dir,
1438
- allowed_dir_file=args.allowed_dir_file,
1498
+ # Convert Click parameters to typed dict
1499
+ params: CLIParams = {
1500
+ "task_file": task_template,
1501
+ "task": None,
1502
+ "schema_file": schema_file,
1503
+ }
1504
+ # Add only valid keys from kwargs
1505
+ valid_keys = set(CLIParams.__annotations__.keys())
1506
+ for k, v in kwargs.items():
1507
+ if k in valid_keys:
1508
+ params[k] = v # type: ignore[literal-required]
1509
+
1510
+ # Run the async function synchronously
1511
+ loop = asyncio.new_event_loop()
1512
+ asyncio.set_event_loop(loop)
1513
+ try:
1514
+ exit_code = loop.run_until_complete(run_cli_async(params))
1515
+ sys.exit(int(exit_code))
1516
+ finally:
1517
+ loop.close()
1518
+
1519
+ except (
1520
+ CLIError,
1521
+ InvalidJSONError,
1522
+ SchemaFileError,
1523
+ SchemaValidationError,
1524
+ ) as e:
1525
+ handle_error(e)
1526
+ sys.exit(
1527
+ e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
1439
1528
  )
1529
+ except click.UsageError as e:
1530
+ handle_error(e)
1531
+ sys.exit(ExitCode.USAGE_ERROR)
1532
+ except Exception as e:
1533
+ handle_error(e)
1534
+ sys.exit(ExitCode.INTERNAL_ERROR)
1535
+
1536
+
1537
+ # Remove the old @create_click_command() decorator and cli function definition
1538
+ # Keep all the other functions and code below this point
1440
1539
 
1441
- task_template = validate_task_template(args.task, args.task_file)
1442
- logger.debug("Validating schema from %s", args.schema_file)
1443
- schema = validate_schema_file(args.schema_file, args.verbose)
1444
- template_context = create_template_context_from_args(
1445
- args, security_manager
1540
+
1541
+ async def validate_model_params(args: CLIParams) -> Dict[str, Any]:
1542
+ """Validate model parameters and return a dictionary of valid parameters.
1543
+
1544
+ Args:
1545
+ args: Command line arguments
1546
+
1547
+ Returns:
1548
+ Dictionary of validated model parameters
1549
+
1550
+ Raises:
1551
+ CLIError: If model parameters are invalid
1552
+ """
1553
+ params = {
1554
+ "temperature": args.get("temperature"),
1555
+ "max_output_tokens": args.get("max_output_tokens"),
1556
+ "top_p": args.get("top_p"),
1557
+ "frequency_penalty": args.get("frequency_penalty"),
1558
+ "presence_penalty": args.get("presence_penalty"),
1559
+ "reasoning_effort": args.get("reasoning_effort"),
1560
+ }
1561
+ # Remove None values
1562
+ params = {k: v for k, v in params.items() if v is not None}
1563
+ validate_model_parameters(args["model"], params)
1564
+ return params
1565
+
1566
+
1567
+ async def validate_inputs(
1568
+ args: CLIParams,
1569
+ ) -> Tuple[
1570
+ SecurityManager, str, Dict[str, Any], Dict[str, Any], jinja2.Environment
1571
+ ]:
1572
+ """Validate all input parameters and return validated components.
1573
+
1574
+ Args:
1575
+ args: Command line arguments
1576
+
1577
+ Returns:
1578
+ Tuple containing:
1579
+ - SecurityManager instance
1580
+ - Task template string
1581
+ - Schema dictionary
1582
+ - Template context dictionary
1583
+ - Jinja2 environment
1584
+
1585
+ Raises:
1586
+ CLIError: For various validation errors
1587
+ """
1588
+ logger.debug("=== Input Validation Phase ===")
1589
+ security_manager = validate_security_manager(
1590
+ base_dir=args.get("base_dir"),
1591
+ allowed_dirs=args.get("allowed_dirs"),
1592
+ allowed_dir_file=args.get("allowed_dir_file"),
1593
+ )
1594
+
1595
+ task_template = validate_task_template(
1596
+ args.get("task"), args.get("task_file")
1597
+ )
1598
+ logger.debug("Validating schema from %s", args["schema_file"])
1599
+ schema = validate_schema_file(
1600
+ args["schema_file"], args.get("verbose", False)
1601
+ )
1602
+ template_context = await create_template_context_from_args(
1603
+ args, security_manager
1604
+ )
1605
+ env = create_jinja_env()
1606
+
1607
+ return security_manager, task_template, schema, template_context, env
1608
+
1609
+
1610
+ async def process_templates(
1611
+ args: CLIParams,
1612
+ task_template: str,
1613
+ template_context: Dict[str, Any],
1614
+ env: jinja2.Environment,
1615
+ ) -> Tuple[str, str]:
1616
+ """Process system prompt and user prompt templates.
1617
+
1618
+ Args:
1619
+ args: Command line arguments
1620
+ task_template: Validated task template
1621
+ template_context: Template context dictionary
1622
+ env: Jinja2 environment
1623
+
1624
+ Returns:
1625
+ Tuple of (system_prompt, user_prompt)
1626
+
1627
+ Raises:
1628
+ CLIError: For template processing errors
1629
+ """
1630
+ logger.debug("=== Template Processing Phase ===")
1631
+ system_prompt = process_system_prompt(
1632
+ task_template,
1633
+ args.get("system_prompt"),
1634
+ args.get("system_prompt_file"),
1635
+ template_context,
1636
+ env,
1637
+ args.get("ignore_task_sysprompt", False),
1638
+ )
1639
+ user_prompt = render_template(task_template, template_context, env)
1640
+ return system_prompt, user_prompt
1641
+
1642
+
1643
+ async def validate_model_and_schema(
1644
+ args: CLIParams,
1645
+ schema: Dict[str, Any],
1646
+ system_prompt: str,
1647
+ user_prompt: str,
1648
+ ) -> Tuple[Type[BaseModel], List[Dict[str, str]], int, ModelRegistry]:
1649
+ """Validate model compatibility and schema, and check token limits.
1650
+
1651
+ Args:
1652
+ args: Command line arguments
1653
+ schema: Schema dictionary
1654
+ system_prompt: Processed system prompt
1655
+ user_prompt: Processed user prompt
1656
+
1657
+ Returns:
1658
+ Tuple of (output_model, messages, total_tokens, registry)
1659
+
1660
+ Raises:
1661
+ CLIError: For validation errors
1662
+ ModelCreationError: When model creation fails
1663
+ SchemaValidationError: When schema is invalid
1664
+ """
1665
+ logger.debug("=== Model & Schema Validation Phase ===")
1666
+ try:
1667
+ output_model = create_dynamic_model(
1668
+ schema,
1669
+ show_schema=args.get("show_model_schema", False),
1670
+ debug_validation=args.get("debug_validation", False),
1446
1671
  )
1447
- env = create_jinja_env()
1448
-
1449
- # Process system prompt and render task
1450
- system_prompt = process_system_prompt(
1451
- task_template,
1452
- args.system_prompt,
1453
- args.system_prompt_file,
1454
- template_context,
1455
- env,
1456
- args.ignore_task_sysprompt,
1672
+ logger.debug("Successfully created output model")
1673
+ except (
1674
+ SchemaFileError,
1675
+ InvalidJSONError,
1676
+ SchemaValidationError,
1677
+ ModelCreationError,
1678
+ ) as e:
1679
+ logger.error("Schema error: %s", str(e))
1680
+ raise
1681
+
1682
+ if not supports_structured_output(args["model"]):
1683
+ msg = f"Model {args['model']} does not support structured output"
1684
+ logger.error(msg)
1685
+ raise ModelNotSupportedError(msg)
1686
+
1687
+ messages = [
1688
+ {"role": "system", "content": system_prompt},
1689
+ {"role": "user", "content": user_prompt},
1690
+ ]
1691
+
1692
+ total_tokens = estimate_tokens_with_encoding(messages, args["model"])
1693
+ registry = ModelRegistry()
1694
+ capabilities = registry.get_capabilities(args["model"])
1695
+ context_limit = capabilities.context_window
1696
+
1697
+ if total_tokens > context_limit:
1698
+ msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
1699
+ logger.error(msg)
1700
+ raise CLIError(
1701
+ msg,
1702
+ context={
1703
+ "total_tokens": total_tokens,
1704
+ "context_limit": context_limit,
1705
+ },
1457
1706
  )
1458
- rendered_task = render_template(task_template, template_context, env)
1459
- logger.info("Rendered task template: %s", rendered_task)
1460
1707
 
1461
- if args.dry_run:
1462
- logger.info("DRY RUN MODE")
1463
- return ExitCode.SUCCESS
1708
+ return output_model, messages, total_tokens, registry
1464
1709
 
1465
- # Create output model
1466
- logger.debug("Creating output model")
1467
- try:
1468
- output_model = create_dynamic_model(
1469
- schema,
1470
- base_name="OutputModel",
1471
- show_schema=args.show_model_schema,
1472
- debug_validation=args.debug_validation,
1473
- )
1474
- logger.debug("Successfully created output model")
1475
- except (
1476
- SchemaFileError,
1477
- InvalidJSONError,
1478
- SchemaValidationError,
1479
- ModelCreationError,
1480
- ) as e:
1481
- logger.error("Schema error: %s", str(e))
1482
- raise # Let the error propagate with its context
1483
-
1484
- # Validate model support and token usage
1485
- try:
1486
- supports_structured_output(args.model)
1487
- except (ModelNotSupportedError, ModelVersionError) as e:
1488
- logger.error("Model validation error: %s", str(e))
1489
- raise # Let the error propagate with its context
1490
-
1491
- messages = [
1492
- {"role": "system", "content": system_prompt},
1493
- {"role": "user", "content": rendered_task},
1494
- ]
1495
- total_tokens = estimate_tokens_for_chat(messages, args.model)
1496
- context_limit = get_context_window_limit(args.model)
1497
- if total_tokens > context_limit:
1498
- msg = f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
1499
- logger.error(msg)
1500
- raise CLIError(
1501
- msg,
1502
- context={
1503
- "total_tokens": total_tokens,
1504
- "context_limit": context_limit,
1505
- },
1506
- )
1507
1710
 
1508
- # Get API key and create client
1509
- api_key = args.api_key or os.getenv("OPENAI_API_KEY")
1510
- if not api_key:
1511
- msg = "No OpenAI API key provided (--api-key or OPENAI_API_KEY env var)"
1512
- logger.error(msg)
1513
- raise CLIError(msg)
1711
+ async def execute_model(
1712
+ args: CLIParams,
1713
+ params: Dict[str, Any],
1714
+ output_model: Type[BaseModel],
1715
+ system_prompt: str,
1716
+ user_prompt: str,
1717
+ ) -> ExitCode:
1718
+ """Execute the model and handle the response.
1514
1719
 
1515
- client = AsyncOpenAI(api_key=api_key, timeout=args.timeout)
1720
+ Args:
1721
+ args: Command line arguments
1722
+ params: Validated model parameters
1723
+ output_model: Generated Pydantic model
1724
+ system_prompt: Processed system prompt
1725
+ user_prompt: Processed user prompt
1516
1726
 
1517
- # Create detailed log callback
1518
- def log_callback(
1519
- level: int, message: str, extra: dict[str, Any]
1520
- ) -> None:
1521
- if args.debug_openai_stream:
1522
- if extra:
1523
- extra_str = json.dumps(extra, indent=2)
1524
- message = f"{message}\nDetails:\n{extra_str}"
1525
- logger.log(level, message, extra=extra)
1727
+ Returns:
1728
+ Exit code indicating success or failure
1526
1729
 
1527
- # Stream the output
1528
- try:
1529
- await stream_structured_output(
1530
- client=client,
1531
- model=args.model,
1532
- system_prompt=system_prompt,
1533
- user_prompt=rendered_task,
1534
- output_schema=output_model,
1535
- output_file=args.output_file,
1536
- temperature=args.temperature,
1537
- max_tokens=args.max_tokens,
1538
- top_p=args.top_p,
1539
- frequency_penalty=args.frequency_penalty,
1540
- presence_penalty=args.presence_penalty,
1541
- timeout=args.timeout,
1542
- on_log=log_callback,
1730
+ Raises:
1731
+ CLIError: For execution errors
1732
+ """
1733
+ logger.debug("=== Execution Phase ===")
1734
+ api_key = args.get("api_key") or os.getenv("OPENAI_API_KEY")
1735
+ if not api_key:
1736
+ msg = "No API key provided. Set OPENAI_API_KEY environment variable or use --api-key"
1737
+ logger.error(msg)
1738
+ raise CLIError(msg, exit_code=ExitCode.API_ERROR)
1739
+
1740
+ client = AsyncOpenAI(api_key=api_key, timeout=args.get("timeout", 60.0))
1741
+
1742
+ # Create detailed log callback
1743
+ def log_callback(level: int, message: str, extra: dict[str, Any]) -> None:
1744
+ if args.get("debug_openai_stream", False):
1745
+ if extra:
1746
+ extra_str = LogSerializer.serialize_log_extra(extra)
1747
+ if extra_str:
1748
+ logger.debug("%s\nExtra:\n%s", message, extra_str)
1749
+ else:
1750
+ logger.debug("%s\nExtra: Failed to serialize", message)
1751
+ else:
1752
+ logger.debug(message)
1753
+
1754
+ try:
1755
+ # Create output buffer
1756
+ output_buffer = []
1757
+
1758
+ # Stream the response
1759
+ async for response in stream_structured_output(
1760
+ client=client,
1761
+ model=args["model"],
1762
+ system_prompt=system_prompt,
1763
+ user_prompt=user_prompt,
1764
+ output_schema=output_model,
1765
+ output_file=args.get("output_file"),
1766
+ **params, # Only pass validated model parameters
1767
+ on_log=log_callback, # Pass logging callback separately
1768
+ ):
1769
+ output_buffer.append(response)
1770
+
1771
+ # Handle final output
1772
+ output_file = args.get("output_file")
1773
+ if output_file:
1774
+ with open(output_file, "w") as f:
1775
+ if len(output_buffer) == 1:
1776
+ f.write(output_buffer[0].model_dump_json(indent=2))
1777
+ else:
1778
+ # Build complete JSON array as a single string
1779
+ json_output = "[\n"
1780
+ for i, response in enumerate(output_buffer):
1781
+ if i > 0:
1782
+ json_output += ",\n"
1783
+ json_output += " " + response.model_dump_json(
1784
+ indent=2
1785
+ ).replace("\n", "\n ")
1786
+ json_output += "\n]"
1787
+ f.write(json_output)
1788
+ else:
1789
+ # Write to stdout when no output file is specified
1790
+ if len(output_buffer) == 1:
1791
+ print(output_buffer[0].model_dump_json(indent=2))
1792
+ else:
1793
+ # Build complete JSON array as a single string
1794
+ json_output = "[\n"
1795
+ for i, response in enumerate(output_buffer):
1796
+ if i > 0:
1797
+ json_output += ",\n"
1798
+ json_output += " " + response.model_dump_json(
1799
+ indent=2
1800
+ ).replace("\n", "\n ")
1801
+ json_output += "\n]"
1802
+ print(json_output)
1803
+
1804
+ return ExitCode.SUCCESS
1805
+
1806
+ except (
1807
+ StreamInterruptedError,
1808
+ StreamBufferError,
1809
+ StreamParseError,
1810
+ APIResponseError,
1811
+ EmptyResponseError,
1812
+ InvalidResponseFormatError,
1813
+ ) as e:
1814
+ logger.error("Stream error: %s", str(e))
1815
+ raise CLIError(str(e), exit_code=ExitCode.API_ERROR)
1816
+ except Exception as e:
1817
+ logger.exception("Unexpected error during streaming")
1818
+ raise CLIError(str(e), exit_code=ExitCode.UNKNOWN_ERROR)
1819
+ finally:
1820
+ await client.close()
1821
+
1822
+
1823
+ async def run_cli_async(args: CLIParams) -> ExitCode:
1824
+ """Async wrapper for CLI operations.
1825
+
1826
+ Returns:
1827
+ Exit code to return from the CLI
1828
+
1829
+ Raises:
1830
+ CLIError: For various error conditions
1831
+ KeyboardInterrupt: When operation is cancelled by user
1832
+ """
1833
+ try:
1834
+ # 0. Model Parameter Validation
1835
+ logger.debug("=== Model Parameter Validation ===")
1836
+ params = await validate_model_params(args)
1837
+
1838
+ # 1. Input Validation Phase
1839
+ security_manager, task_template, schema, template_context, env = (
1840
+ await validate_inputs(args)
1841
+ )
1842
+
1843
+ # 2. Template Processing Phase
1844
+ system_prompt, user_prompt = await process_templates(
1845
+ args, task_template, template_context, env
1846
+ )
1847
+
1848
+ # 3. Model & Schema Validation Phase
1849
+ output_model, messages, total_tokens, registry = (
1850
+ await validate_model_and_schema(
1851
+ args, schema, system_prompt, user_prompt
1852
+ )
1853
+ )
1854
+
1855
+ # 4. Dry Run Output Phase
1856
+ if args.get("dry_run", False):
1857
+ logger.info("\n=== Dry Run Summary ===")
1858
+ logger.info("✓ Template rendered successfully")
1859
+ logger.info("✓ Schema validation passed")
1860
+ logger.info("✓ Model compatibility validated")
1861
+ logger.info(
1862
+ f"✓ Token count: {total_tokens}/{registry.get_capabilities(args['model']).context_window}"
1543
1863
  )
1864
+
1865
+ if args.get("verbose", False):
1866
+ logger.info("\nSystem Prompt:")
1867
+ logger.info("-" * 40)
1868
+ logger.info(system_prompt)
1869
+ logger.info("\nRendered Template:")
1870
+ logger.info("-" * 40)
1871
+ logger.info(user_prompt)
1872
+
1544
1873
  return ExitCode.SUCCESS
1545
- except (
1546
- StreamInterruptedError,
1547
- StreamBufferError,
1548
- StreamParseError,
1549
- APIResponseError,
1550
- EmptyResponseError,
1551
- InvalidResponseFormatError,
1552
- ) as e:
1553
- logger.error("Stream error: %s", str(e))
1554
- raise # Let stream errors propagate
1555
- except (APIConnectionError, InternalServerError) as e:
1556
- logger.error("API connection error: %s", str(e))
1557
- raise APIResponseError(str(e)) # Convert to our error type
1558
- except RateLimitError as e:
1559
- logger.error("Rate limit exceeded: %s", str(e))
1560
- raise APIResponseError(str(e)) # Convert to our error type
1561
- except (BadRequestError, AuthenticationError, OpenAIClientError) as e:
1562
- logger.error("API client error: %s", str(e))
1563
- raise APIResponseError(str(e)) # Convert to our error type
1564
- finally:
1565
- await client.close()
1874
+
1875
+ # 5. Execution Phase
1876
+ return await execute_model(
1877
+ args, params, output_model, system_prompt, user_prompt
1878
+ )
1566
1879
 
1567
1880
  except KeyboardInterrupt:
1568
1881
  logger.info("Operation cancelled by user")
1569
- return ExitCode.INTERRUPTED
1882
+ raise
1570
1883
  except Exception as e:
1571
1884
  if isinstance(e, CLIError):
1572
1885
  raise # Let our custom errors propagate
@@ -1580,65 +1893,35 @@ def create_cli() -> click.Command:
1580
1893
  Returns:
1581
1894
  click.Command: The CLI command object
1582
1895
  """
1583
-
1584
- @create_click_command()
1585
- def cli(**kwargs: Any) -> None:
1586
- """CLI entry point for structured OpenAI API calls."""
1587
- try:
1588
- args = Namespace(**kwargs)
1589
-
1590
- # Validate required arguments first
1591
- if not args.task and not args.task_file:
1592
- raise click.UsageError(
1593
- "Must specify either --task or --task-file"
1594
- )
1595
- if not args.schema_file:
1596
- raise click.UsageError("Missing option '--schema-file'")
1597
- if args.task and args.task_file:
1598
- raise click.UsageError(
1599
- "Cannot specify both --task and --task-file"
1600
- )
1601
- if args.system_prompt and args.system_prompt_file:
1602
- raise click.UsageError(
1603
- "Cannot specify both --system-prompt and --system-prompt-file"
1604
- )
1605
-
1606
- # Run the async function synchronously
1607
- exit_code = asyncio.run(run_cli_async(args))
1608
-
1609
- if exit_code != ExitCode.SUCCESS:
1610
- error_msg = f"Command failed with exit code {exit_code}"
1611
- if hasattr(ExitCode, exit_code.name):
1612
- error_msg = f"{error_msg} ({exit_code.name})"
1613
- raise CLIError(error_msg, context={"exit_code": exit_code})
1614
-
1615
- except click.UsageError:
1616
- # Let Click handle usage errors directly
1617
- raise
1618
- except InvalidJSONError:
1619
- # Let InvalidJSONError propagate directly
1620
- raise
1621
- except CLIError:
1622
- # Let our custom errors propagate with their context
1623
- raise
1624
- except Exception as e:
1625
- # Convert other exceptions to CLIError
1626
- logger.exception("Unexpected error")
1627
- raise CLIError(str(e), context={"error_type": type(e).__name__})
1628
-
1629
- return cli
1896
+ return cli # The decorator already returns a Command
1630
1897
 
1631
1898
 
1632
1899
  def main() -> None:
1633
1900
  """Main entry point for the CLI."""
1634
- cli = create_cli()
1635
- cli(standalone_mode=False)
1901
+ try:
1902
+ cli(standalone_mode=False)
1903
+ except (
1904
+ CLIError,
1905
+ InvalidJSONError,
1906
+ SchemaFileError,
1907
+ SchemaValidationError,
1908
+ ) as e:
1909
+ handle_error(e)
1910
+ sys.exit(
1911
+ e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
1912
+ )
1913
+ except click.UsageError as e:
1914
+ handle_error(e)
1915
+ sys.exit(ExitCode.USAGE_ERROR)
1916
+ except Exception as e:
1917
+ handle_error(e)
1918
+ sys.exit(ExitCode.INTERNAL_ERROR)
1636
1919
 
1637
1920
 
1638
1921
  # Export public API
1639
1922
  __all__ = [
1640
1923
  "ExitCode",
1641
- "estimate_tokens_for_chat",
1924
+ "estimate_tokens_with_encoding",
1642
1925
  "parse_json_var",
1643
1926
  "create_dynamic_model",
1644
1927
  "validate_path_mapping",
@@ -1656,17 +1939,16 @@ def create_dynamic_model(
1656
1939
  """Create a Pydantic model from a JSON schema.
1657
1940
 
1658
1941
  Args:
1659
- schema: JSON schema dict, can be wrapped in {"schema": ...} format
1660
- base_name: Base name for the model
1661
- show_schema: Whether to show the generated schema
1662
- debug_validation: Whether to enable validation debugging
1942
+ schema: JSON schema to create model from
1943
+ base_name: Name for the model class
1944
+ show_schema: Whether to show the generated model schema
1945
+ debug_validation: Whether to show detailed validation errors
1663
1946
 
1664
1947
  Returns:
1665
- Generated Pydantic model class
1948
+ Type[BaseModel]: The generated Pydantic model class
1666
1949
 
1667
1950
  Raises:
1668
- ModelCreationError: When model creation fails
1669
- SchemaValidationError: When schema is invalid
1951
+ ModelValidationError: If the schema is invalid
1670
1952
  """
1671
1953
  if debug_validation:
1672
1954
  logger.info("Creating dynamic model from schema:")
@@ -1758,18 +2040,17 @@ def create_dynamic_model(
1758
2040
  " JSON Schema Extra: %s", config.get("json_schema_extra")
1759
2041
  )
1760
2042
 
1761
- # Create field definitions
1762
- field_definitions: Dict[str, FieldDefinition] = {}
2043
+ # Process schema properties into fields
1763
2044
  properties = schema.get("properties", {})
2045
+ required = schema.get("required", [])
1764
2046
 
2047
+ field_definitions: Dict[str, Tuple[Type[Any], FieldInfoType]] = {}
1765
2048
  for field_name, field_schema in properties.items():
1766
- try:
1767
- if debug_validation:
1768
- logger.info("Processing field %s:", field_name)
1769
- logger.info(
1770
- " Schema: %s", json.dumps(field_schema, indent=2)
1771
- )
2049
+ if debug_validation:
2050
+ logger.info("Processing field %s:", field_name)
2051
+ logger.info(" Schema: %s", json.dumps(field_schema, indent=2))
1772
2052
 
2053
+ try:
1773
2054
  python_type, field = _get_type_with_constraints(
1774
2055
  field_schema, field_name, base_name
1775
2056
  )
@@ -1804,22 +2085,24 @@ def create_dynamic_model(
1804
2085
  raise ModelValidationError(base_name, [str(e)])
1805
2086
 
1806
2087
  # Create the model with the fields
1807
- model = create_model(
1808
- base_name,
1809
- __config__=config,
1810
- **{
1811
- name: (
1812
- (
1813
- cast(Type[Any], field_type)
1814
- if is_container_type(field_type)
1815
- else field_type
1816
- ),
1817
- field,
1818
- )
1819
- for name, (field_type, field) in field_definitions.items()
1820
- },
2088
+ field_defs: Dict[str, Any] = {
2089
+ name: (
2090
+ (
2091
+ cast(Type[Any], field_type)
2092
+ if is_container_type(field_type)
2093
+ else field_type
2094
+ ),
2095
+ field,
2096
+ )
2097
+ for name, (field_type, field) in field_definitions.items()
2098
+ }
2099
+ model: Type[BaseModel] = create_model(
2100
+ base_name, __config__=config, **field_defs
1821
2101
  )
1822
2102
 
2103
+ # Set the model config after creation
2104
+ model.model_config = config
2105
+
1823
2106
  if debug_validation:
1824
2107
  logger.info("Successfully created model: %s", model.__name__)
1825
2108
  logger.info("Model config: %s", dict(model.model_config))
@@ -1836,10 +2119,6 @@ def create_dynamic_model(
1836
2119
  logger.error("Schema validation failed:")
1837
2120
  logger.error(" Error type: %s", type(e).__name__)
1838
2121
  logger.error(" Error message: %s", str(e))
1839
- if hasattr(e, "errors"):
1840
- logger.error(" Validation errors:")
1841
- for error in e.errors():
1842
- logger.error(" - %s", error)
1843
2122
  validation_errors = (
1844
2123
  [str(err) for err in e.errors()]
1845
2124
  if hasattr(e, "errors")
@@ -1847,7 +2126,7 @@ def create_dynamic_model(
1847
2126
  )
1848
2127
  raise ModelValidationError(base_name, validation_errors)
1849
2128
 
1850
- return cast(Type[BaseModel], model)
2129
+ return model
1851
2130
 
1852
2131
  except Exception as e:
1853
2132
  if debug_validation: