ostruct-cli 0.6.0__py3-none-any.whl → 0.6.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ostruct/cli/cli.py CHANGED
@@ -5,7 +5,6 @@ import json
5
5
  import logging
6
6
  import os
7
7
  import sys
8
- from enum import Enum, IntEnum
9
8
  from typing import (
10
9
  Any,
11
10
  AsyncGenerator,
@@ -20,12 +19,11 @@ from typing import (
20
19
  TypeVar,
21
20
  Union,
22
21
  cast,
23
- get_origin,
24
22
  overload,
25
23
  )
26
24
 
27
25
  if sys.version_info >= (3, 11):
28
- from enum import StrEnum
26
+ pass
29
27
 
30
28
  from datetime import date, datetime, time
31
29
  from pathlib import Path
@@ -48,15 +46,7 @@ from openai_structured.errors import (
48
46
  StreamBufferError,
49
47
  )
50
48
  from openai_structured.model_registry import ModelRegistry
51
- from pydantic import (
52
- AnyUrl,
53
- BaseModel,
54
- ConfigDict,
55
- EmailStr,
56
- Field,
57
- ValidationError,
58
- create_model,
59
- )
49
+ from pydantic import AnyUrl, BaseModel, EmailStr, Field
60
50
  from pydantic.fields import FieldInfo as FieldInfoType
61
51
  from pydantic.functional_validators import BeforeValidator
62
52
  from pydantic.types import constr
@@ -69,11 +59,8 @@ from .. import __version__ # noqa: F401 - Used in package metadata
69
59
  from .errors import (
70
60
  CLIError,
71
61
  DirectoryNotFoundError,
72
- FieldDefinitionError,
73
62
  InvalidJSONError,
74
63
  ModelCreationError,
75
- ModelValidationError,
76
- NestedModelError,
77
64
  OstructFileNotFoundError,
78
65
  PathSecurityError,
79
66
  SchemaFileError,
@@ -86,17 +73,87 @@ from .errors import (
86
73
  VariableValueError,
87
74
  )
88
75
  from .file_utils import FileInfoList, collect_files
76
+ from .model_creation import _create_enum_type, create_dynamic_model
89
77
  from .path_utils import validate_path_mapping
90
78
  from .security import SecurityManager
91
79
  from .serialization import LogSerializer
92
80
  from .template_env import create_jinja_env
93
- from .template_utils import SystemPromptError, render_template
81
+ from .template_utils import (
82
+ SystemPromptError,
83
+ render_template,
84
+ validate_json_schema,
85
+ )
94
86
  from .token_utils import estimate_tokens_with_encoding
95
87
 
96
88
  # Constants
97
89
  DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
98
90
 
99
91
 
92
+ # Validation functions
93
+ def pattern(regex: str) -> Any:
94
+ return constr(pattern=regex)
95
+
96
+
97
+ def min_length(length: int) -> Any:
98
+ return BeforeValidator(lambda v: v if len(str(v)) >= length else None)
99
+
100
+
101
+ def max_length(length: int) -> Any:
102
+ return BeforeValidator(lambda v: v if len(str(v)) <= length else None)
103
+
104
+
105
+ def ge(value: Union[int, float]) -> Any:
106
+ return BeforeValidator(lambda v: v if float(v) >= value else None)
107
+
108
+
109
+ def le(value: Union[int, float]) -> Any:
110
+ return BeforeValidator(lambda v: v if float(v) <= value else None)
111
+
112
+
113
+ def gt(value: Union[int, float]) -> Any:
114
+ return BeforeValidator(lambda v: v if float(v) > value else None)
115
+
116
+
117
+ def lt(value: Union[int, float]) -> Any:
118
+ return BeforeValidator(lambda v: v if float(v) < value else None)
119
+
120
+
121
+ def multiple_of(value: Union[int, float]) -> Any:
122
+ return BeforeValidator(lambda v: v if float(v) % value == 0 else None)
123
+
124
+
125
+ def create_template_context(
126
+ files: Optional[
127
+ Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]
128
+ ] = None,
129
+ variables: Optional[Dict[str, str]] = None,
130
+ json_variables: Optional[Dict[str, Any]] = None,
131
+ security_manager: Optional[SecurityManager] = None,
132
+ stdin_content: Optional[str] = None,
133
+ ) -> Dict[str, Any]:
134
+ """Create template context from files and variables."""
135
+ context: Dict[str, Any] = {}
136
+
137
+ # Add file variables
138
+ if files:
139
+ for name, file_list in files.items():
140
+ context[name] = file_list # Always keep FileInfoList wrapper
141
+
142
+ # Add simple variables
143
+ if variables:
144
+ context.update(variables)
145
+
146
+ # Add JSON variables
147
+ if json_variables:
148
+ context.update(json_variables)
149
+
150
+ # Add stdin if provided
151
+ if stdin_content is not None:
152
+ context["stdin"] = stdin_content
153
+
154
+ return context
155
+
156
+
100
157
  class CLIParams(TypedDict, total=False):
101
158
  """Type-safe CLI parameters."""
102
159
 
@@ -185,12 +242,6 @@ ItemType: TypeAlias = Type[BaseModel]
185
242
  ValueType: TypeAlias = Type[Any]
186
243
 
187
244
 
188
- def is_container_type(tp: Type[Any]) -> bool:
189
- """Check if a type is a container type (list, dict, etc.)."""
190
- origin = get_origin(tp)
191
- return origin in (list, dict)
192
-
193
-
194
245
  def _create_field(**kwargs: Any) -> FieldInfoType:
195
246
  """Create a Pydantic Field with the given kwargs."""
196
247
  field: FieldInfoType = Field(**kwargs)
@@ -784,7 +835,7 @@ def validate_schema_file(
784
835
  logger.error(msg)
785
836
  raise SchemaFileError(msg, schema_path=path)
786
837
  except Exception as e:
787
- if isinstance(e, InvalidJSONError):
838
+ if isinstance(e, (InvalidJSONError, SchemaValidationError)):
788
839
  raise
789
840
  msg = f"Failed to read schema file {path}: {e}"
790
841
  logger.error(msg)
@@ -799,7 +850,13 @@ def validate_schema_file(
799
850
  if not isinstance(schema, dict):
800
851
  msg = f"Schema in {path} must be a JSON object"
801
852
  logger.error(msg)
802
- raise SchemaValidationError(msg, context={"path": path})
853
+ raise SchemaValidationError(
854
+ msg,
855
+ context={
856
+ "validation_type": "schema",
857
+ "schema_path": path,
858
+ },
859
+ )
803
860
 
804
861
  # Validate schema structure
805
862
  if "schema" in schema:
@@ -809,7 +866,13 @@ def validate_schema_file(
809
866
  if not isinstance(inner_schema, dict):
810
867
  msg = f"Inner schema in {path} must be a JSON object"
811
868
  logger.error(msg)
812
- raise SchemaValidationError(msg, context={"path": path})
869
+ raise SchemaValidationError(
870
+ msg,
871
+ context={
872
+ "validation_type": "schema",
873
+ "schema_path": path,
874
+ },
875
+ )
813
876
  if verbose:
814
877
  logger.debug("Inner schema validated successfully")
815
878
  logger.debug(
@@ -824,7 +887,20 @@ def validate_schema_file(
824
887
  if "type" not in schema.get("schema", schema):
825
888
  msg = f"Schema in {path} must specify a type"
826
889
  logger.error(msg)
827
- raise SchemaValidationError(msg, context={"path": path})
890
+ raise SchemaValidationError(
891
+ msg,
892
+ context={
893
+ "validation_type": "schema",
894
+ "schema_path": path,
895
+ },
896
+ )
897
+
898
+ # Validate schema against JSON Schema spec
899
+ try:
900
+ validate_json_schema(schema)
901
+ except SchemaValidationError as e:
902
+ logger.error("Schema validation error: %s", str(e))
903
+ raise # Re-raise to preserve error chain
828
904
 
829
905
  # Return the full schema including wrapper
830
906
  return schema
@@ -877,8 +953,11 @@ def collect_template_files(
877
953
  # Let PathSecurityError propagate without wrapping
878
954
  raise
879
955
  except (FileNotFoundError, DirectoryNotFoundError) as e:
880
- # Wrap file-related errors
881
- raise ValueError(f"File access error: {e}")
956
+ # Convert FileNotFoundError to OstructFileNotFoundError
957
+ if isinstance(e, FileNotFoundError):
958
+ raise OstructFileNotFoundError(str(e))
959
+ # Let DirectoryNotFoundError propagate
960
+ raise
882
961
  except Exception as e:
883
962
  # Don't wrap InvalidJSONError
884
963
  if isinstance(e, InvalidJSONError):
@@ -980,38 +1059,6 @@ def collect_json_variables(args: CLIParams) -> Dict[str, Any]:
980
1059
  return variables
981
1060
 
982
1061
 
983
- def create_template_context(
984
- files: Optional[
985
- Dict[str, Union[FileInfoList, str, List[str], Dict[str, str]]]
986
- ] = None,
987
- variables: Optional[Dict[str, str]] = None,
988
- json_variables: Optional[Dict[str, Any]] = None,
989
- security_manager: Optional[SecurityManager] = None,
990
- stdin_content: Optional[str] = None,
991
- ) -> Dict[str, Any]:
992
- """Create template context from files and variables."""
993
- context: Dict[str, Any] = {}
994
-
995
- # Add file variables
996
- if files:
997
- for name, file_list in files.items():
998
- context[name] = file_list # Always keep FileInfoList wrapper
999
-
1000
- # Add simple variables
1001
- if variables:
1002
- context.update(variables)
1003
-
1004
- # Add JSON variables
1005
- if json_variables:
1006
- context.update(json_variables)
1007
-
1008
- # Add stdin if provided
1009
- if stdin_content is not None:
1010
- context["stdin"] = stdin_content
1011
-
1012
- return context
1013
-
1014
-
1015
1062
  async def create_template_context_from_args(
1016
1063
  args: CLIParams,
1017
1064
  security_manager: SecurityManager,
@@ -1066,8 +1113,11 @@ async def create_template_context_from_args(
1066
1113
  # Let PathSecurityError propagate without wrapping
1067
1114
  raise
1068
1115
  except (FileNotFoundError, DirectoryNotFoundError) as e:
1069
- # Wrap file-related errors
1070
- raise ValueError(f"File access error: {e}")
1116
+ # Convert FileNotFoundError to OstructFileNotFoundError
1117
+ if isinstance(e, FileNotFoundError):
1118
+ raise OstructFileNotFoundError(str(e))
1119
+ # Let DirectoryNotFoundError propagate
1120
+ raise
1071
1121
  except Exception as e:
1072
1122
  # Don't wrap InvalidJSONError
1073
1123
  if isinstance(e, InvalidJSONError):
@@ -1197,41 +1247,6 @@ def parse_json_var(var_str: str) -> Tuple[str, Any]:
1197
1247
  raise
1198
1248
 
1199
1249
 
1200
- def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
1201
- """Create an enum type from a list of values.
1202
-
1203
- Args:
1204
- values: List of enum values
1205
- field_name: Name of the field for enum type name
1206
-
1207
- Returns:
1208
- Created enum type
1209
- """
1210
- # Determine the value type
1211
- value_types = {type(v) for v in values}
1212
-
1213
- if len(value_types) > 1:
1214
- # Mixed types, use string representation
1215
- enum_dict = {f"VALUE_{i}": str(v) for i, v in enumerate(values)}
1216
- return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
1217
- elif value_types == {int}:
1218
- # All integer values
1219
- enum_dict = {f"VALUE_{v}": v for v in values}
1220
- return type(f"{field_name.title()}Enum", (IntEnum,), enum_dict)
1221
- elif value_types == {str}:
1222
- # All string values
1223
- enum_dict = {v.upper().replace(" ", "_"): v for v in values}
1224
- if sys.version_info >= (3, 11):
1225
- return type(f"{field_name.title()}Enum", (StrEnum,), enum_dict)
1226
- else:
1227
- # Other types, use string representation
1228
- return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
1229
-
1230
- # Default case: treat as string enum
1231
- enum_dict = {f"VALUE_{i}": str(v) for i, v in enumerate(values)}
1232
- return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
1233
-
1234
-
1235
1250
  def handle_error(e: Exception) -> None:
1236
1251
  """Handle CLI errors and display appropriate messages.
1237
1252
 
@@ -1239,19 +1254,24 @@ def handle_error(e: Exception) -> None:
1239
1254
  Provides enhanced debug logging for CLI errors.
1240
1255
  """
1241
1256
  # 1. Determine error type and message
1242
- if isinstance(e, click.UsageError):
1257
+ if isinstance(e, SchemaValidationError):
1258
+ msg = str(e) # Already formatted in SchemaValidationError
1259
+ exit_code = e.exit_code
1260
+ elif isinstance(e, ModelCreationError):
1261
+ # Unwrap ModelCreationError that might wrap SchemaValidationError
1262
+ if isinstance(e.__cause__, SchemaValidationError):
1263
+ return handle_error(e.__cause__)
1264
+ msg = f"Model creation error: {str(e)}"
1265
+ exit_code = ExitCode.SCHEMA_ERROR
1266
+ elif isinstance(e, click.UsageError):
1243
1267
  msg = f"Usage error: {str(e)}"
1244
1268
  exit_code = ExitCode.USAGE_ERROR
1245
1269
  elif isinstance(e, SchemaFileError):
1246
- # Preserve specific schema error handling
1247
1270
  msg = str(e) # Use existing __str__ formatting
1248
1271
  exit_code = ExitCode.SCHEMA_ERROR
1249
1272
  elif isinstance(e, (InvalidJSONError, json.JSONDecodeError)):
1250
1273
  msg = f"Invalid JSON error: {str(e)}"
1251
1274
  exit_code = ExitCode.DATA_ERROR
1252
- elif isinstance(e, SchemaValidationError):
1253
- msg = f"Schema validation error: {str(e)}"
1254
- exit_code = ExitCode.VALIDATION_ERROR
1255
1275
  elif isinstance(e, CLIError):
1256
1276
  msg = str(e) # Use existing __str__ formatting
1257
1277
  exit_code = ExitCode(e.exit_code) # Convert int to ExitCode
@@ -1263,7 +1283,7 @@ def handle_error(e: Exception) -> None:
1263
1283
  if isinstance(e, CLIError) and logger.isEnabledFor(logging.DEBUG):
1264
1284
  # Format context fields with lowercase keys and simple values
1265
1285
  context_str = ""
1266
- if hasattr(e, "context"):
1286
+ if hasattr(e, "context") and e.context:
1267
1287
  for key, value in sorted(e.context.items()):
1268
1288
  if key not in {
1269
1289
  "timestamp",
@@ -1271,13 +1291,18 @@ def handle_error(e: Exception) -> None:
1271
1291
  "version",
1272
1292
  "python_version",
1273
1293
  }:
1274
- context_str += f"{key.lower()}: {value}\n"
1294
+ if isinstance(value, dict):
1295
+ context_str += (
1296
+ f"{key.lower()}:\n{json.dumps(value, indent=2)}\n"
1297
+ )
1298
+ else:
1299
+ context_str += f"{key.lower()}: {value}\n"
1275
1300
 
1276
- logger.debug(
1277
- "Error details:\n"
1278
- f"Type: {type(e).__name__}\n"
1279
- f"{context_str.rstrip()}"
1280
- )
1301
+ logger.debug(
1302
+ "Error details:\n"
1303
+ f"Type: {type(e).__name__}\n"
1304
+ f"{context_str.rstrip()}"
1305
+ )
1281
1306
  elif not isinstance(e, click.UsageError):
1282
1307
  logger.error(msg, exc_info=True)
1283
1308
  else:
@@ -1433,7 +1458,7 @@ async def stream_structured_output(
1433
1458
  EmptyResponseError,
1434
1459
  InvalidResponseFormatError,
1435
1460
  ) as e:
1436
- logger.error(f"Stream error: {e}")
1461
+ logger.error("Stream error: %s", str(e))
1437
1462
  raise
1438
1463
  finally:
1439
1464
  # Always ensure client is properly closed
@@ -1481,30 +1506,11 @@ def run(
1481
1506
  ) -> None:
1482
1507
  """Run a structured task with template and schema.
1483
1508
 
1484
- TASK_TEMPLATE is the path to your Jinja2 template file that defines the task.
1485
- SCHEMA_FILE is the path to your JSON schema file that defines the expected output structure.
1486
-
1487
- The command supports various options for file handling, variable definition,
1488
- model configuration, and output control. Use --help to see all available options.
1489
-
1490
- Examples:
1491
- # Basic usage
1492
- ostruct run task.j2 schema.json
1493
-
1494
- # Process multiple files
1495
- ostruct run task.j2 schema.json -f code main.py -f test tests/test_main.py
1496
-
1497
- # Scan directories recursively
1498
- ostruct run task.j2 schema.json -d src ./src -R
1499
-
1500
- # Define variables
1501
- ostruct run task.j2 schema.json -V debug=true -J config='{"env":"prod"}'
1502
-
1503
- # Configure model
1504
- ostruct run task.j2 schema.json -m gpt-4 --temperature 0.7 --max-output-tokens 1000
1505
-
1506
- # Control output
1507
- ostruct run task.j2 schema.json --output-file result.json --verbose
1509
+ Args:
1510
+ ctx: Click context
1511
+ task_template: Path to task template file
1512
+ schema_file: Path to schema file
1513
+ **kwargs: Additional CLI options
1508
1514
  """
1509
1515
  try:
1510
1516
  # Convert Click parameters to typed dict
@@ -1525,25 +1531,33 @@ def run(
1525
1531
  try:
1526
1532
  exit_code = loop.run_until_complete(run_cli_async(params))
1527
1533
  sys.exit(int(exit_code))
1534
+ except SchemaValidationError as e:
1535
+ # Log the error with full context
1536
+ logger.error("Schema validation error: %s", str(e))
1537
+ if e.context:
1538
+ logger.debug(
1539
+ "Error context: %s", json.dumps(e.context, indent=2)
1540
+ )
1541
+ # Re-raise to preserve error chain and exit code
1542
+ raise
1543
+ except (CLIError, InvalidJSONError, SchemaFileError) as e:
1544
+ handle_error(e)
1545
+ sys.exit(
1546
+ e.exit_code
1547
+ if hasattr(e, "exit_code")
1548
+ else ExitCode.INTERNAL_ERROR
1549
+ )
1550
+ except click.UsageError as e:
1551
+ handle_error(e)
1552
+ sys.exit(ExitCode.USAGE_ERROR)
1553
+ except Exception as e:
1554
+ handle_error(e)
1555
+ sys.exit(ExitCode.INTERNAL_ERROR)
1528
1556
  finally:
1529
1557
  loop.close()
1530
-
1531
- except (
1532
- CLIError,
1533
- InvalidJSONError,
1534
- SchemaFileError,
1535
- SchemaValidationError,
1536
- ) as e:
1537
- handle_error(e)
1538
- sys.exit(
1539
- e.exit_code if hasattr(e, "exit_code") else ExitCode.INTERNAL_ERROR
1540
- )
1541
- except click.UsageError as e:
1542
- handle_error(e)
1543
- sys.exit(ExitCode.USAGE_ERROR)
1544
- except Exception as e:
1545
- handle_error(e)
1546
- sys.exit(ExitCode.INTERNAL_ERROR)
1558
+ except KeyboardInterrupt:
1559
+ logger.info("Operation cancelled by user")
1560
+ raise
1547
1561
 
1548
1562
 
1549
1563
  # Remove the old @create_click_command() decorator and cli function definition
@@ -1596,6 +1610,7 @@ async def validate_inputs(
1596
1610
 
1597
1611
  Raises:
1598
1612
  CLIError: For various validation errors
1613
+ SchemaValidationError: When schema is invalid
1599
1614
  """
1600
1615
  logger.debug("=== Input Validation Phase ===")
1601
1616
  security_manager = validate_security_manager(
@@ -1607,10 +1622,22 @@ async def validate_inputs(
1607
1622
  task_template = validate_task_template(
1608
1623
  args.get("task"), args.get("task_file")
1609
1624
  )
1625
+
1626
+ # Load and validate schema
1610
1627
  logger.debug("Validating schema from %s", args["schema_file"])
1611
- schema = validate_schema_file(
1612
- args["schema_file"], args.get("verbose", False)
1613
- )
1628
+ try:
1629
+ schema = validate_schema_file(
1630
+ args["schema_file"], args.get("verbose", False)
1631
+ )
1632
+
1633
+ # Validate schema structure before any model creation
1634
+ validate_json_schema(
1635
+ schema
1636
+ ) # This will raise SchemaValidationError if invalid
1637
+ except SchemaValidationError as e:
1638
+ logger.error("Schema validation error: %s", str(e))
1639
+ raise # Re-raise the SchemaValidationError to preserve the error chain
1640
+
1614
1641
  template_context = await create_template_context_from_args(
1615
1642
  args, security_manager
1616
1643
  )
@@ -1689,6 +1716,7 @@ async def validate_model_and_schema(
1689
1716
  ModelCreationError,
1690
1717
  ) as e:
1691
1718
  logger.error("Schema error: %s", str(e))
1719
+ # Pass through the error without additional wrapping
1692
1720
  raise
1693
1721
 
1694
1722
  if not supports_structured_output(args["model"]):
@@ -1834,19 +1862,21 @@ async def execute_model(
1834
1862
  async def run_cli_async(args: CLIParams) -> ExitCode:
1835
1863
  """Async wrapper for CLI operations.
1836
1864
 
1865
+ Args:
1866
+ args: CLI parameters.
1867
+
1837
1868
  Returns:
1838
- Exit code to return from the CLI
1869
+ Exit code.
1839
1870
 
1840
1871
  Raises:
1841
- CLIError: For various error conditions
1842
- KeyboardInterrupt: When operation is cancelled by user
1872
+ CLIError: For errors during CLI operations.
1843
1873
  """
1844
1874
  try:
1845
1875
  # 0. Model Parameter Validation
1846
1876
  logger.debug("=== Model Parameter Validation ===")
1847
1877
  params = await validate_model_params(args)
1848
1878
 
1849
- # 1. Input Validation Phase
1879
+ # 1. Input Validation Phase (includes schema validation)
1850
1880
  security_manager, task_template, schema, template_context, env = (
1851
1881
  await validate_inputs(args)
1852
1882
  )
@@ -1863,15 +1893,12 @@ async def run_cli_async(args: CLIParams) -> ExitCode:
1863
1893
  )
1864
1894
  )
1865
1895
 
1866
- # 4. Dry Run Output Phase
1896
+ # 4. Dry Run Output Phase - Moved after all validations
1867
1897
  if args.get("dry_run", False):
1868
1898
  logger.info("\n=== Dry Run Summary ===")
1899
+ # Only log success if we got this far (no validation errors)
1869
1900
  logger.info("✓ Template rendered successfully")
1870
1901
  logger.info("✓ Schema validation passed")
1871
- logger.info("✓ Model compatibility validated")
1872
- logger.info(
1873
- f"✓ Token count: {total_tokens}/{registry.get_capabilities(args['model']).context_window}"
1874
- )
1875
1902
 
1876
1903
  if args.get("verbose", False):
1877
1904
  logger.info("\nSystem Prompt:")
@@ -1881,6 +1908,7 @@ async def run_cli_async(args: CLIParams) -> ExitCode:
1881
1908
  logger.info("-" * 40)
1882
1909
  logger.info(user_prompt)
1883
1910
 
1911
+ # Return success only if we got here (no validation errors)
1884
1912
  return ExitCode.SUCCESS
1885
1913
 
1886
1914
  # 5. Execution Phase
@@ -1891,6 +1919,10 @@ async def run_cli_async(args: CLIParams) -> ExitCode:
1891
1919
  except KeyboardInterrupt:
1892
1920
  logger.info("Operation cancelled by user")
1893
1921
  raise
1922
+ except SchemaValidationError as e:
1923
+ # Ensure schema validation errors are properly propagated with the correct exit code
1924
+ logger.error("Schema validation error: %s", str(e))
1925
+ raise # Re-raise the SchemaValidationError to preserve the error chain
1894
1926
  except Exception as e:
1895
1927
  if isinstance(e, CLIError):
1896
1928
  raise # Let our custom errors propagate
@@ -1941,254 +1973,5 @@ __all__ = [
1941
1973
  ]
1942
1974
 
1943
1975
 
1944
- def create_dynamic_model(
1945
- schema: Dict[str, Any],
1946
- base_name: str = "DynamicModel",
1947
- show_schema: bool = False,
1948
- debug_validation: bool = False,
1949
- ) -> Type[BaseModel]:
1950
- """Create a Pydantic model from a JSON schema.
1951
-
1952
- Args:
1953
- schema: JSON schema to create model from
1954
- base_name: Name for the model class
1955
- show_schema: Whether to show the generated model schema
1956
- debug_validation: Whether to show detailed validation errors
1957
-
1958
- Returns:
1959
- Type[BaseModel]: The generated Pydantic model class
1960
-
1961
- Raises:
1962
- ModelValidationError: If the schema is invalid
1963
- SchemaValidationError: If the schema violates OpenAI requirements
1964
- """
1965
- if debug_validation:
1966
- logger.info("Creating dynamic model from schema:")
1967
- logger.info(json.dumps(schema, indent=2))
1968
-
1969
- try:
1970
- # Handle our wrapper format if present
1971
- if "schema" in schema:
1972
- if debug_validation:
1973
- logger.info("Found schema wrapper, extracting inner schema")
1974
- logger.info(
1975
- "Original schema: %s", json.dumps(schema, indent=2)
1976
- )
1977
- inner_schema = schema["schema"]
1978
- if not isinstance(inner_schema, dict):
1979
- if debug_validation:
1980
- logger.info(
1981
- "Inner schema must be a dictionary, got %s",
1982
- type(inner_schema),
1983
- )
1984
- raise SchemaValidationError(
1985
- "Inner schema must be a dictionary"
1986
- )
1987
- if debug_validation:
1988
- logger.info("Using inner schema:")
1989
- logger.info(json.dumps(inner_schema, indent=2))
1990
- schema = inner_schema
1991
-
1992
- # Validate against OpenAI requirements
1993
- from .schema_validation import validate_openai_schema
1994
-
1995
- validate_openai_schema(schema)
1996
-
1997
- # Create model configuration
1998
- config = ConfigDict(
1999
- title=schema.get("title", base_name),
2000
- extra="forbid", # OpenAI requires additionalProperties: false
2001
- validate_default=True,
2002
- use_enum_values=True,
2003
- arbitrary_types_allowed=True,
2004
- json_schema_extra={
2005
- k: v
2006
- for k, v in schema.items()
2007
- if k
2008
- not in {
2009
- "type",
2010
- "properties",
2011
- "required",
2012
- "title",
2013
- "description",
2014
- "additionalProperties",
2015
- "readOnly",
2016
- }
2017
- },
2018
- )
2019
-
2020
- if debug_validation:
2021
- logger.info("Created model configuration:")
2022
- logger.info(" Title: %s", config.get("title"))
2023
- logger.info(" Extra: %s", config.get("extra"))
2024
- logger.info(
2025
- " Validate Default: %s", config.get("validate_default")
2026
- )
2027
- logger.info(" Use Enum Values: %s", config.get("use_enum_values"))
2028
- logger.info(
2029
- " Arbitrary Types: %s", config.get("arbitrary_types_allowed")
2030
- )
2031
- logger.info(
2032
- " JSON Schema Extra: %s", config.get("json_schema_extra")
2033
- )
2034
-
2035
- # Process schema properties into fields
2036
- properties = schema.get("properties", {})
2037
- required = schema.get("required", [])
2038
-
2039
- field_definitions: Dict[str, Tuple[Type[Any], FieldInfoType]] = {}
2040
- for field_name, field_schema in properties.items():
2041
- if debug_validation:
2042
- logger.info("Processing field %s:", field_name)
2043
- logger.info(" Schema: %s", json.dumps(field_schema, indent=2))
2044
-
2045
- try:
2046
- python_type, field = _get_type_with_constraints(
2047
- field_schema, field_name, base_name
2048
- )
2049
-
2050
- # Handle optional fields
2051
- if field_name not in required:
2052
- if debug_validation:
2053
- logger.info(
2054
- "Field %s is optional, wrapping in Optional",
2055
- field_name,
2056
- )
2057
- field_type = cast(Type[Any], Optional[python_type])
2058
- else:
2059
- field_type = python_type
2060
- if debug_validation:
2061
- logger.info("Field %s is required", field_name)
2062
-
2063
- # Create field definition
2064
- field_definitions[field_name] = (field_type, field)
2065
-
2066
- if debug_validation:
2067
- logger.info("Successfully created field definition:")
2068
- logger.info(" Name: %s", field_name)
2069
- logger.info(" Type: %s", str(field_type))
2070
- logger.info(" Required: %s", field_name in required)
2071
-
2072
- except (FieldDefinitionError, NestedModelError) as e:
2073
- if debug_validation:
2074
- logger.error("Error creating field %s:", field_name)
2075
- logger.error(" Error type: %s", type(e).__name__)
2076
- logger.error(" Error message: %s", str(e))
2077
- raise ModelValidationError(base_name, [str(e)])
2078
-
2079
- # Create the model with the fields
2080
- field_defs: Dict[str, Any] = {
2081
- name: (
2082
- (
2083
- cast(Type[Any], field_type)
2084
- if is_container_type(field_type)
2085
- else field_type
2086
- ),
2087
- field,
2088
- )
2089
- for name, (field_type, field) in field_definitions.items()
2090
- }
2091
- model: Type[BaseModel] = create_model(
2092
- base_name, __config__=config, **field_defs
2093
- )
2094
-
2095
- # Set the model config after creation
2096
- model.model_config = config
2097
-
2098
- if debug_validation:
2099
- logger.info("Successfully created model: %s", model.__name__)
2100
- logger.info("Model config: %s", dict(model.model_config))
2101
- logger.info(
2102
- "Model schema: %s",
2103
- json.dumps(model.model_json_schema(), indent=2),
2104
- )
2105
-
2106
- # Validate the model's JSON schema
2107
- try:
2108
- model.model_json_schema()
2109
- except ValidationError as e:
2110
- validation_errors = (
2111
- [str(err) for err in e.errors()]
2112
- if hasattr(e, "errors")
2113
- else [str(e)]
2114
- )
2115
- if debug_validation:
2116
- logger.error("Schema validation failed:")
2117
- logger.error(" Error type: %s", type(e).__name__)
2118
- logger.error(" Error message: %s", str(e))
2119
- raise ModelValidationError(base_name, validation_errors)
2120
-
2121
- return model
2122
-
2123
- except SchemaValidationError as e:
2124
- # Always log basic error info
2125
- logger.error("Schema validation error: %s", str(e))
2126
-
2127
- # Log additional debug info if requested
2128
- if debug_validation:
2129
- logger.error(" Error type: %s", type(e).__name__)
2130
- logger.error(" Error details: %s", str(e))
2131
- # Always raise schema validation errors directly
2132
- raise
2133
-
2134
- except Exception as e:
2135
- # Always log basic error info
2136
- logger.error("Model creation error: %s", str(e))
2137
-
2138
- # Log additional debug info if requested
2139
- if debug_validation:
2140
- logger.error(" Error type: %s", type(e).__name__)
2141
- logger.error(" Error details: %s", str(e))
2142
- if hasattr(e, "__cause__"):
2143
- logger.error(" Caused by: %s", str(e.__cause__))
2144
- if hasattr(e, "__context__"):
2145
- logger.error(" Context: %s", str(e.__context__))
2146
- if hasattr(e, "__traceback__"):
2147
- import traceback
2148
-
2149
- logger.error(
2150
- " Traceback:\n%s",
2151
- "".join(traceback.format_tb(e.__traceback__)),
2152
- )
2153
- # Always wrap other errors as ModelCreationError
2154
- raise ModelCreationError(
2155
- f"Failed to create model {base_name}",
2156
- context={"error": str(e)},
2157
- ) from e
2158
-
2159
-
2160
- # Validation functions
2161
- def pattern(regex: str) -> Any:
2162
- return constr(pattern=regex)
2163
-
2164
-
2165
- def min_length(length: int) -> Any:
2166
- return BeforeValidator(lambda v: v if len(str(v)) >= length else None)
2167
-
2168
-
2169
- def max_length(length: int) -> Any:
2170
- return BeforeValidator(lambda v: v if len(str(v)) <= length else None)
2171
-
2172
-
2173
- def ge(value: Union[int, float]) -> Any:
2174
- return BeforeValidator(lambda v: v if float(v) >= value else None)
2175
-
2176
-
2177
- def le(value: Union[int, float]) -> Any:
2178
- return BeforeValidator(lambda v: v if float(v) <= value else None)
2179
-
2180
-
2181
- def gt(value: Union[int, float]) -> Any:
2182
- return BeforeValidator(lambda v: v if float(v) > value else None)
2183
-
2184
-
2185
- def lt(value: Union[int, float]) -> Any:
2186
- return BeforeValidator(lambda v: v if float(v) < value else None)
2187
-
2188
-
2189
- def multiple_of(value: Union[int, float]) -> Any:
2190
- return BeforeValidator(lambda v: v if float(v) % value == 0 else None)
2191
-
2192
-
2193
1976
  if __name__ == "__main__":
2194
1977
  main()