ostruct-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
ostruct/cli/cli.py ADDED
@@ -0,0 +1,2033 @@
1
+ """Command-line interface for making structured OpenAI API calls."""
2
+
3
+ import argparse
4
+ import asyncio
5
+ import json
6
+ import logging
7
+ import os
8
+ import sys
9
+ from enum import Enum, IntEnum
10
+
11
+ if sys.version_info >= (3, 11):
12
+ from enum import StrEnum
13
+
14
+ from datetime import date, datetime, time
15
+ from importlib.metadata import version
16
+ from pathlib import Path
17
+ from typing import (
18
+ Any,
19
+ Dict,
20
+ List,
21
+ Literal,
22
+ Optional,
23
+ Set,
24
+ Tuple,
25
+ Type,
26
+ TypeVar,
27
+ Union,
28
+ cast,
29
+ get_origin,
30
+ overload,
31
+ )
32
+
33
+ import jinja2
34
+ import tiktoken
35
+ import yaml
36
+ from openai import (
37
+ APIConnectionError,
38
+ AsyncOpenAI,
39
+ AuthenticationError,
40
+ BadRequestError,
41
+ InternalServerError,
42
+ RateLimitError,
43
+ )
44
+ from openai_structured.client import (
45
+ async_openai_structured_stream,
46
+ supports_structured_output,
47
+ )
48
+ from openai_structured.errors import (
49
+ APIResponseError,
50
+ EmptyResponseError,
51
+ InvalidResponseFormatError,
52
+ ModelNotSupportedError,
53
+ ModelVersionError,
54
+ OpenAIClientError,
55
+ SchemaFileError,
56
+ SchemaValidationError,
57
+ StreamBufferError,
58
+ StreamInterruptedError,
59
+ StreamParseError,
60
+ )
61
+ from pydantic import (
62
+ AnyUrl,
63
+ BaseModel,
64
+ ConfigDict,
65
+ EmailStr,
66
+ Field,
67
+ ValidationError,
68
+ create_model,
69
+ )
70
+ from pydantic.fields import FieldInfo as FieldInfoType
71
+ from pydantic.functional_validators import BeforeValidator
72
+ from pydantic.types import constr
73
+ from typing_extensions import TypeAlias
74
+
75
+ from .errors import (
76
+ DirectoryNotFoundError,
77
+ FieldDefinitionError,
78
+ FileNotFoundError,
79
+ InvalidJSONError,
80
+ ModelCreationError,
81
+ ModelValidationError,
82
+ NestedModelError,
83
+ PathSecurityError,
84
+ TaskTemplateSyntaxError,
85
+ TaskTemplateVariableError,
86
+ VariableError,
87
+ VariableNameError,
88
+ VariableValueError,
89
+ )
90
+ from .file_utils import FileInfoList, TemplateValue, collect_files
91
+ from .path_utils import validate_path_mapping
92
+ from .progress import ProgressContext
93
+ from .security import SecurityManager
94
+ from .template_env import create_jinja_env
95
+ from .template_utils import SystemPromptError, render_template
96
+
97
+ # Set up logging
98
+ logger = logging.getLogger(__name__)
99
+
100
+ # Configure openai_structured logging based on debug flag
101
+ openai_logger = logging.getLogger("openai_structured")
102
+ openai_logger.setLevel(logging.DEBUG) # Allow all messages through to handlers
103
+ openai_logger.propagate = False # Prevent propagation to root logger
104
+
105
+ # Remove any existing handlers
106
+ for handler in openai_logger.handlers:
107
+ openai_logger.removeHandler(handler)
108
+
109
+ # Create a file handler for openai_structured logger that captures all levels
110
+ log_dir = os.path.expanduser("~/.ostruct/logs")
111
+ os.makedirs(log_dir, exist_ok=True)
112
+ openai_file_handler = logging.FileHandler(
113
+ os.path.join(log_dir, "openai_stream.log")
114
+ )
115
+ openai_file_handler.setLevel(logging.DEBUG) # Always capture debug in file
116
+ openai_file_handler.setFormatter(
117
+ logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
118
+ )
119
+ openai_logger.addHandler(openai_file_handler)
120
+
121
+ # Create a file handler for the main logger that captures all levels
122
+ ostruct_file_handler = logging.FileHandler(
123
+ os.path.join(log_dir, "ostruct.log")
124
+ )
125
+ ostruct_file_handler.setLevel(logging.DEBUG) # Always capture debug in file
126
+ ostruct_file_handler.setFormatter(
127
+ logging.Formatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
128
+ )
129
+ logger.addHandler(ostruct_file_handler)
130
+
131
+ # Constants
132
+ DEFAULT_SYSTEM_PROMPT = "You are a helpful assistant."
133
+
134
+ # Get package version
135
+ try:
136
+ __version__ = version("openai-structured")
137
+ except Exception:
138
+ __version__ = "unknown"
139
+
140
+
141
+ class ExitCode(IntEnum):
142
+ """Exit codes for the CLI following standard Unix conventions.
143
+
144
+ Categories:
145
+ - Success (0-1)
146
+ - User Interruption (2-3)
147
+ - Input/Validation (64-69)
148
+ - I/O and File Access (70-79)
149
+ - API and External Services (80-89)
150
+ - Internal Errors (90-99)
151
+ """
152
+
153
+ # Success codes
154
+ SUCCESS = 0
155
+
156
+ # User interruption
157
+ INTERRUPTED = 2
158
+
159
+ # Input/Validation errors (64-69)
160
+ USAGE_ERROR = 64
161
+ DATA_ERROR = 65
162
+ SCHEMA_ERROR = 66
163
+ VALIDATION_ERROR = 67
164
+
165
+ # I/O and File Access errors (70-79)
166
+ IO_ERROR = 70
167
+ FILE_NOT_FOUND = 71
168
+ PERMISSION_ERROR = 72
169
+ SECURITY_ERROR = 73
170
+
171
+ # API and External Service errors (80-89)
172
+ API_ERROR = 80
173
+ API_TIMEOUT = 81
174
+
175
+ # Internal errors (90-99)
176
+ INTERNAL_ERROR = 90
177
+ UNKNOWN_ERROR = 91
178
+
179
+
180
+ # Type aliases
181
+ FieldType = (
182
+ Any # Changed from Type[Any] to allow both concrete types and generics
183
+ )
184
+ FieldDefinition = Tuple[FieldType, FieldInfoType]
185
+ ModelType = TypeVar("ModelType", bound=BaseModel)
186
+ ItemType: TypeAlias = Type[BaseModel]
187
+ ValueType: TypeAlias = Type[Any]
188
+
189
+
190
+ def is_container_type(tp: Type[Any]) -> bool:
191
+ """Check if a type is a container type (list, dict, etc.)."""
192
+ origin = get_origin(tp)
193
+ return origin in (list, dict)
194
+
195
+
196
+ def _create_field(**kwargs: Any) -> FieldInfoType:
197
+ """Create a Pydantic Field with the given kwargs."""
198
+ field: FieldInfoType = Field(**kwargs)
199
+ return field
200
+
201
+
202
+ def _get_type_with_constraints(
203
+ field_schema: Dict[str, Any], field_name: str, base_name: str
204
+ ) -> FieldDefinition:
205
+ """Get type with constraints from field schema.
206
+
207
+ Args:
208
+ field_schema: Field schema dict
209
+ field_name: Name of the field
210
+ base_name: Base name for nested models
211
+
212
+ Returns:
213
+ Tuple of (type, field)
214
+ """
215
+ field_type = field_schema.get("type")
216
+ field_kwargs: Dict[str, Any] = {}
217
+
218
+ # Add common field metadata
219
+ if "title" in field_schema:
220
+ field_kwargs["title"] = field_schema["title"]
221
+ if "description" in field_schema:
222
+ field_kwargs["description"] = field_schema["description"]
223
+ if "default" in field_schema:
224
+ field_kwargs["default"] = field_schema["default"]
225
+ if "readOnly" in field_schema:
226
+ field_kwargs["frozen"] = field_schema["readOnly"]
227
+
228
+ # Handle array type
229
+ if field_type == "array":
230
+ items_schema = field_schema.get("items", {})
231
+ if not items_schema:
232
+ return (List[Any], Field(**field_kwargs))
233
+
234
+ # Create nested model with explicit type annotation
235
+ array_item_model = create_dynamic_model(
236
+ items_schema,
237
+ base_name=f"{base_name}_{field_name}_Item",
238
+ show_schema=False,
239
+ debug_validation=False,
240
+ )
241
+ array_type: Type[List[Any]] = List[array_item_model] # type: ignore[valid-type]
242
+ return (array_type, Field(**field_kwargs))
243
+
244
+ # Handle object type
245
+ if field_type == "object":
246
+ # Create nested model with explicit type annotation
247
+ object_model = create_dynamic_model(
248
+ field_schema,
249
+ base_name=f"{base_name}_{field_name}",
250
+ show_schema=False,
251
+ debug_validation=False,
252
+ )
253
+ return (object_model, Field(**field_kwargs))
254
+
255
+ # Handle additionalProperties
256
+ if "additionalProperties" in field_schema and isinstance(
257
+ field_schema["additionalProperties"], dict
258
+ ):
259
+ # Create nested model with explicit type annotation
260
+ dict_value_model = create_dynamic_model(
261
+ field_schema["additionalProperties"],
262
+ base_name=f"{base_name}_{field_name}_Value",
263
+ show_schema=False,
264
+ debug_validation=False,
265
+ )
266
+ dict_type: Type[Dict[str, Any]] = Dict[str, dict_value_model] # type: ignore[valid-type]
267
+ return (dict_type, Field(**field_kwargs))
268
+
269
+ # Handle other types
270
+ if field_type == "string":
271
+ field_type_cls: Type[Any] = str
272
+
273
+ # Add string-specific constraints to field_kwargs
274
+ if "pattern" in field_schema:
275
+ field_kwargs["pattern"] = field_schema["pattern"]
276
+ if "minLength" in field_schema:
277
+ field_kwargs["min_length"] = field_schema["minLength"]
278
+ if "maxLength" in field_schema:
279
+ field_kwargs["max_length"] = field_schema["maxLength"]
280
+
281
+ # Handle special string formats
282
+ if "format" in field_schema:
283
+ if field_schema["format"] == "date-time":
284
+ field_type_cls = datetime
285
+ elif field_schema["format"] == "date":
286
+ field_type_cls = date
287
+ elif field_schema["format"] == "time":
288
+ field_type_cls = time
289
+ elif field_schema["format"] == "email":
290
+ field_type_cls = EmailStr
291
+ elif field_schema["format"] == "uri":
292
+ field_type_cls = AnyUrl
293
+
294
+ return (field_type_cls, Field(**field_kwargs))
295
+
296
+ if field_type == "number":
297
+ field_type_cls = float
298
+
299
+ # Add number-specific constraints to field_kwargs
300
+ if "minimum" in field_schema:
301
+ field_kwargs["ge"] = field_schema["minimum"]
302
+ if "maximum" in field_schema:
303
+ field_kwargs["le"] = field_schema["maximum"]
304
+ if "exclusiveMinimum" in field_schema:
305
+ field_kwargs["gt"] = field_schema["exclusiveMinimum"]
306
+ if "exclusiveMaximum" in field_schema:
307
+ field_kwargs["lt"] = field_schema["exclusiveMaximum"]
308
+ if "multipleOf" in field_schema:
309
+ field_kwargs["multiple_of"] = field_schema["multipleOf"]
310
+
311
+ return (field_type_cls, Field(**field_kwargs))
312
+
313
+ if field_type == "integer":
314
+ field_type_cls = int
315
+
316
+ # Add integer-specific constraints to field_kwargs
317
+ if "minimum" in field_schema:
318
+ field_kwargs["ge"] = field_schema["minimum"]
319
+ if "maximum" in field_schema:
320
+ field_kwargs["le"] = field_schema["maximum"]
321
+ if "exclusiveMinimum" in field_schema:
322
+ field_kwargs["gt"] = field_schema["exclusiveMinimum"]
323
+ if "exclusiveMaximum" in field_schema:
324
+ field_kwargs["lt"] = field_schema["exclusiveMaximum"]
325
+ if "multipleOf" in field_schema:
326
+ field_kwargs["multiple_of"] = field_schema["multipleOf"]
327
+
328
+ return (field_type_cls, Field(**field_kwargs))
329
+
330
+ if field_type == "boolean":
331
+ return (bool, Field(**field_kwargs))
332
+
333
+ if field_type == "null":
334
+ return (type(None), Field(**field_kwargs))
335
+
336
+ # Handle enum
337
+ if "enum" in field_schema:
338
+ enum_type = _create_enum_type(field_schema["enum"], field_name)
339
+ return (cast(Type[Any], enum_type), Field(**field_kwargs))
340
+
341
+ # Default to Any for unknown types
342
+ return (Any, Field(**field_kwargs))
343
+
344
+
345
+ T = TypeVar("T")
346
+ K = TypeVar("K")
347
+ V = TypeVar("V")
348
+
349
+
350
+ def estimate_tokens_for_chat(
351
+ messages: List[Dict[str, str]], model: str
352
+ ) -> int:
353
+ """Estimate the number of tokens in a chat completion."""
354
+ try:
355
+ encoding = tiktoken.encoding_for_model(model)
356
+ except KeyError:
357
+ # Fall back to cl100k_base for unknown models
358
+ encoding = tiktoken.get_encoding("cl100k_base")
359
+
360
+ num_tokens = 0
361
+ for message in messages:
362
+ # Add message overhead
363
+ num_tokens += 4 # every message follows <im_start>{role/name}\n{content}<im_end>\n
364
+ for key, value in message.items():
365
+ num_tokens += len(encoding.encode(str(value)))
366
+ if key == "name": # if there's a name, the role is omitted
367
+ num_tokens += -1 # role is always required and always 1 token
368
+ num_tokens += 2 # every reply is primed with <im_start>assistant
369
+ return num_tokens
370
+
371
+
372
+ def get_default_token_limit(model: str) -> int:
373
+ """Get the default token limit for a given model.
374
+
375
+ Note: These limits are based on current OpenAI model specifications as of 2024 and may
376
+ need to be updated if OpenAI changes the models' capabilities.
377
+
378
+ Args:
379
+ model: The model name (e.g., 'gpt-4o', 'gpt-4o-mini', 'o1')
380
+
381
+ Returns:
382
+ The default token limit for the model
383
+ """
384
+ if "o1" in model:
385
+ return 100_000 # o1 supports up to 100K output tokens
386
+ elif "gpt-4o" in model:
387
+ return 16_384 # gpt-4o and gpt-4o-mini support up to 16K output tokens
388
+ else:
389
+ return 4_096 # default fallback
390
+
391
+
392
+ def get_context_window_limit(model: str) -> int:
393
+ """Get the total context window limit for a given model.
394
+
395
+ Note: These limits are based on current OpenAI model specifications as of 2024 and may
396
+ need to be updated if OpenAI changes the models' capabilities.
397
+
398
+ Args:
399
+ model: The model name (e.g., 'gpt-4o', 'gpt-4o-mini', 'o1')
400
+
401
+ Returns:
402
+ The context window limit for the model
403
+ """
404
+ if "o1" in model:
405
+ return 200_000 # o1 supports 200K total context window
406
+ elif "gpt-4o" in model:
407
+ return 128_000 # gpt-4o and gpt-4o-mini support 128K context window
408
+ else:
409
+ return 8_192 # default fallback
410
+
411
+
412
+ def validate_token_limits(
413
+ model: str, total_tokens: int, max_token_limit: Optional[int] = None
414
+ ) -> None:
415
+ """Validate token counts against model limits.
416
+
417
+ Args:
418
+ model: The model name
419
+ total_tokens: Total number of tokens in the prompt
420
+ max_token_limit: Optional user-specified token limit
421
+
422
+ Raises:
423
+ ValueError: If token limits are exceeded
424
+ """
425
+ context_limit = get_context_window_limit(model)
426
+ output_limit = (
427
+ max_token_limit
428
+ if max_token_limit is not None
429
+ else get_default_token_limit(model)
430
+ )
431
+
432
+ # Check if total tokens exceed context window
433
+ if total_tokens >= context_limit:
434
+ raise ValueError(
435
+ f"Total tokens ({total_tokens:,}) exceed model's context window limit "
436
+ f"of {context_limit:,} tokens"
437
+ )
438
+
439
+ # Check if there's enough room for output tokens
440
+ remaining_tokens = context_limit - total_tokens
441
+ if remaining_tokens < output_limit:
442
+ raise ValueError(
443
+ f"Only {remaining_tokens:,} tokens remaining in context window, but "
444
+ f"output may require up to {output_limit:,} tokens"
445
+ )
446
+
447
+
448
+ def process_system_prompt(
449
+ task_template: str,
450
+ system_prompt: Optional[str],
451
+ template_context: Dict[str, Any],
452
+ env: jinja2.Environment,
453
+ ignore_task_sysprompt: bool = False,
454
+ ) -> str:
455
+ """Process system prompt from various sources.
456
+
457
+ Args:
458
+ task_template: The task template string
459
+ system_prompt: Optional system prompt string or file path (with @ prefix)
460
+ template_context: Template context for rendering
461
+ env: Jinja2 environment
462
+ ignore_task_sysprompt: Whether to ignore system prompt in task template
463
+
464
+ Returns:
465
+ The final system prompt string
466
+
467
+ Raises:
468
+ SystemPromptError: If the system prompt cannot be loaded or rendered
469
+ FileNotFoundError: If a prompt file does not exist
470
+ PathSecurityError: If a prompt file path violates security constraints
471
+ """
472
+ # Default system prompt
473
+ default_prompt = "You are a helpful assistant."
474
+
475
+ # Try to get system prompt from CLI argument first
476
+ if system_prompt:
477
+ if system_prompt.startswith("@"):
478
+ # Load from file
479
+ path = system_prompt[1:]
480
+ try:
481
+ name, path = validate_path_mapping(f"system_prompt={path}")
482
+ with open(path, "r", encoding="utf-8") as f:
483
+ system_prompt = f.read().strip()
484
+ except (FileNotFoundError, PathSecurityError) as e:
485
+ raise SystemPromptError(f"Invalid system prompt file: {e}")
486
+
487
+ # Render system prompt with template context
488
+ try:
489
+ template = env.from_string(system_prompt)
490
+ return cast(str, template.render(**template_context).strip())
491
+ except jinja2.TemplateError as e:
492
+ raise SystemPromptError(f"Error rendering system prompt: {e}")
493
+
494
+ # If not ignoring task template system prompt, try to extract it
495
+ if not ignore_task_sysprompt:
496
+ try:
497
+ # Extract YAML frontmatter
498
+ if task_template.startswith("---\n"):
499
+ end = task_template.find("\n---\n", 4)
500
+ if end != -1:
501
+ frontmatter = task_template[4:end]
502
+ try:
503
+ metadata = yaml.safe_load(frontmatter)
504
+ if (
505
+ isinstance(metadata, dict)
506
+ and "system_prompt" in metadata
507
+ ):
508
+ system_prompt = str(metadata["system_prompt"])
509
+ # Render system prompt with template context
510
+ try:
511
+ template = env.from_string(system_prompt)
512
+ return cast(
513
+ str,
514
+ template.render(
515
+ **template_context
516
+ ).strip(),
517
+ )
518
+ except jinja2.TemplateError as e:
519
+ raise SystemPromptError(
520
+ f"Error rendering system prompt: {e}"
521
+ )
522
+ except yaml.YAMLError as e:
523
+ raise SystemPromptError(
524
+ f"Invalid YAML frontmatter: {e}"
525
+ )
526
+
527
+ except Exception as e:
528
+ raise SystemPromptError(
529
+ f"Error extracting system prompt from template: {e}"
530
+ )
531
+
532
+ # Fall back to default
533
+ return default_prompt
534
+
535
+
536
+ def validate_variable_mapping(
537
+ mapping: str, is_json: bool = False
538
+ ) -> tuple[str, Any]:
539
+ """Validate a variable mapping in name=value format."""
540
+ try:
541
+ name, value = mapping.split("=", 1)
542
+ if not name:
543
+ raise VariableNameError(
544
+ f"Empty name in {'JSON ' if is_json else ''}variable mapping"
545
+ )
546
+
547
+ if is_json:
548
+ try:
549
+ value = json.loads(value)
550
+ except json.JSONDecodeError as e:
551
+ raise InvalidJSONError(
552
+ f"Invalid JSON value for variable {name!r}: {value!r}"
553
+ ) from e
554
+
555
+ return name, value
556
+
557
+ except ValueError as e:
558
+ if "not enough values to unpack" in str(e):
559
+ raise VariableValueError(
560
+ f"Invalid {'JSON ' if is_json else ''}variable mapping "
561
+ f"(expected name=value format): {mapping!r}"
562
+ )
563
+ raise
564
+
565
+
566
+ @overload
567
+ def _validate_path_mapping_internal(
568
+ mapping: str,
569
+ is_dir: Literal[True],
570
+ base_dir: Optional[str] = None,
571
+ security_manager: Optional[SecurityManager] = None,
572
+ ) -> Tuple[str, str]: ...
573
+
574
+
575
+ @overload
576
+ def _validate_path_mapping_internal(
577
+ mapping: str,
578
+ is_dir: Literal[False] = False,
579
+ base_dir: Optional[str] = None,
580
+ security_manager: Optional[SecurityManager] = None,
581
+ ) -> Tuple[str, str]: ...
582
+
583
+
584
+ def _validate_path_mapping_internal(
585
+ mapping: str,
586
+ is_dir: bool = False,
587
+ base_dir: Optional[str] = None,
588
+ security_manager: Optional[SecurityManager] = None,
589
+ ) -> Tuple[str, str]:
590
+ """Validate a path mapping in the format "name=path".
591
+
592
+ Args:
593
+ mapping: The path mapping string (e.g., "myvar=/path/to/file").
594
+ is_dir: Whether the path is expected to be a directory (True) or file (False).
595
+ base_dir: Optional base directory to resolve relative paths against.
596
+ security_manager: Optional security manager to validate paths.
597
+
598
+ Returns:
599
+ A (name, path) tuple.
600
+
601
+ Raises:
602
+ VariableNameError: If the variable name portion is empty or invalid.
603
+ DirectoryNotFoundError: If is_dir=True and the path is not a directory or doesn't exist.
604
+ FileNotFoundError: If is_dir=False and the path is not a file or doesn't exist.
605
+ PathSecurityError: If the path is inaccessible or outside the allowed directory.
606
+ ValueError: If the format is invalid (missing "=").
607
+ OSError: If there is an underlying OS error (permissions, etc.).
608
+ """
609
+ try:
610
+ if not mapping or "=" not in mapping:
611
+ raise ValueError(
612
+ "Invalid path mapping format. Expected format: name=path"
613
+ )
614
+
615
+ name, path = mapping.split("=", 1)
616
+ if not name:
617
+ raise VariableNameError(
618
+ f"Empty name in {'directory' if is_dir else 'file'} mapping"
619
+ )
620
+
621
+ if not path:
622
+ raise VariableValueError("Path cannot be empty")
623
+
624
+ # Convert to Path object and resolve against base_dir if provided
625
+ path_obj = Path(path)
626
+ if base_dir:
627
+ path_obj = Path(base_dir) / path_obj
628
+
629
+ # Resolve the path to catch directory traversal attempts
630
+ try:
631
+ resolved_path = path_obj.resolve()
632
+ except OSError as e:
633
+ raise OSError(f"Failed to resolve path: {e}")
634
+
635
+ # Check for directory traversal
636
+ try:
637
+ base_path = (
638
+ Path.cwd() if base_dir is None else Path(base_dir).resolve()
639
+ )
640
+ if not str(resolved_path).startswith(str(base_path)):
641
+ raise PathSecurityError(
642
+ f"Path {str(path)!r} resolves to {str(resolved_path)!r} which is outside "
643
+ f"base directory {str(base_path)!r}"
644
+ )
645
+ except OSError as e:
646
+ raise OSError(f"Failed to resolve base path: {e}")
647
+
648
+ # Check if path exists
649
+ if not resolved_path.exists():
650
+ if is_dir:
651
+ raise DirectoryNotFoundError(f"Directory not found: {path!r}")
652
+ else:
653
+ raise FileNotFoundError(f"File not found: {path!r}")
654
+
655
+ # Check if path is correct type
656
+ if is_dir and not resolved_path.is_dir():
657
+ raise DirectoryNotFoundError(f"Path is not a directory: {path!r}")
658
+ elif not is_dir and not resolved_path.is_file():
659
+ raise FileNotFoundError(f"Path is not a file: {path!r}")
660
+
661
+ # Check if path is accessible
662
+ try:
663
+ if is_dir:
664
+ os.listdir(str(resolved_path))
665
+ else:
666
+ with open(str(resolved_path), "r", encoding="utf-8") as f:
667
+ f.read(1)
668
+ except OSError as e:
669
+ if e.errno == 13: # Permission denied
670
+ raise PathSecurityError(
671
+ f"Permission denied accessing path: {path!r}",
672
+ error_logged=True,
673
+ )
674
+ raise
675
+
676
+ if security_manager:
677
+ if not security_manager.is_allowed_file(str(resolved_path)):
678
+ raise PathSecurityError.from_expanded_paths(
679
+ original_path=str(path),
680
+ expanded_path=str(resolved_path),
681
+ base_dir=str(security_manager.base_dir),
682
+ allowed_dirs=[
683
+ str(d) for d in security_manager.allowed_dirs
684
+ ],
685
+ error_logged=True,
686
+ )
687
+
688
+ # Return the original path to maintain relative paths in the output
689
+ return name, path
690
+
691
+ except ValueError as e:
692
+ if "not enough values to unpack" in str(e):
693
+ raise VariableValueError(
694
+ f"Invalid {'directory' if is_dir else 'file'} mapping "
695
+ f"(expected name=path format): {mapping!r}"
696
+ )
697
+ raise
698
+
699
+
700
+ def validate_task_template(task: str) -> str:
701
+ """Validate and load a task template.
702
+
703
+ Args:
704
+ task: The task template string or path to task template file (with @ prefix)
705
+
706
+ Returns:
707
+ The task template string
708
+
709
+ Raises:
710
+ TaskTemplateVariableError: If the template file cannot be read or is invalid
711
+ TaskTemplateSyntaxError: If the template has invalid syntax
712
+ FileNotFoundError: If the template file does not exist
713
+ PathSecurityError: If the template file path violates security constraints
714
+ """
715
+ template_content = task
716
+
717
+ # Check if task is a file path
718
+ if task.startswith("@"):
719
+ path = task[1:]
720
+ try:
721
+ name, path = validate_path_mapping(f"task={path}")
722
+ with open(path, "r", encoding="utf-8") as f:
723
+ template_content = f.read()
724
+ except (FileNotFoundError, PathSecurityError) as e:
725
+ raise TaskTemplateVariableError(f"Invalid task template file: {e}")
726
+
727
+ # Validate template syntax
728
+ try:
729
+ env = jinja2.Environment(undefined=jinja2.StrictUndefined)
730
+ env.parse(template_content)
731
+ return template_content
732
+ except jinja2.TemplateSyntaxError as e:
733
+ raise TaskTemplateSyntaxError(
734
+ f"Invalid task template syntax at line {e.lineno}: {e.message}"
735
+ )
736
+
737
+
738
+ def validate_schema_file(
739
+ path: str,
740
+ verbose: bool = False,
741
+ ) -> Dict[str, Any]:
742
+ """Validate a JSON schema file.
743
+
744
+ Args:
745
+ path: Path to the schema file
746
+ verbose: Whether to enable verbose logging
747
+
748
+ Returns:
749
+ The validated schema
750
+
751
+ Raises:
752
+ SchemaFileError: When file cannot be read
753
+ InvalidJSONError: When file contains invalid JSON
754
+ SchemaValidationError: When schema is invalid
755
+ """
756
+ if verbose:
757
+ logger.info("Validating schema file: %s", path)
758
+
759
+ try:
760
+ with open(path) as f:
761
+ schema = json.load(f)
762
+ except FileNotFoundError:
763
+ raise SchemaFileError(f"Schema file not found: {path}")
764
+ except json.JSONDecodeError as e:
765
+ raise InvalidJSONError(f"Invalid JSON in schema file: {e}")
766
+ except Exception as e:
767
+ raise SchemaFileError(f"Failed to read schema file: {e}")
768
+
769
+ # Pre-validation structure checks
770
+ if verbose:
771
+ logger.info("Performing pre-validation structure checks")
772
+ logger.debug("Loaded schema: %s", json.dumps(schema, indent=2))
773
+
774
+ if not isinstance(schema, dict):
775
+ if verbose:
776
+ logger.error(
777
+ "Schema is not a dictionary: %s", type(schema).__name__
778
+ )
779
+ raise SchemaValidationError("Schema must be a JSON object")
780
+
781
+ # Validate schema structure
782
+ if "schema" in schema:
783
+ if verbose:
784
+ logger.debug("Found schema wrapper, validating inner schema")
785
+ inner_schema = schema["schema"]
786
+ if not isinstance(inner_schema, dict):
787
+ if verbose:
788
+ logger.error(
789
+ "Inner schema is not a dictionary: %s",
790
+ type(inner_schema).__name__,
791
+ )
792
+ raise SchemaValidationError("Inner schema must be a JSON object")
793
+ if verbose:
794
+ logger.debug("Inner schema validated successfully")
795
+ else:
796
+ if verbose:
797
+ logger.debug("No schema wrapper found, using schema as-is")
798
+
799
+ # Return the full schema including wrapper
800
+ return schema
801
+
802
+
803
+ def collect_template_files(
804
+ args: argparse.Namespace,
805
+ security_manager: SecurityManager,
806
+ ) -> Dict[str, TemplateValue]:
807
+ """Collect files from command line arguments.
808
+
809
+ Args:
810
+ args: Parsed command line arguments
811
+ security_manager: Security manager for path validation
812
+
813
+ Returns:
814
+ Dictionary mapping variable names to file info objects
815
+
816
+ Raises:
817
+ PathSecurityError: If any file paths violate security constraints
818
+ ValueError: If file mappings are invalid or files cannot be accessed
819
+ """
820
+ try:
821
+ result = collect_files(
822
+ file_mappings=args.file,
823
+ pattern_mappings=args.files,
824
+ dir_mappings=args.dir,
825
+ dir_recursive=args.dir_recursive,
826
+ dir_extensions=args.dir_ext.split(",") if args.dir_ext else None,
827
+ security_manager=security_manager,
828
+ )
829
+ return cast(Dict[str, TemplateValue], result)
830
+ except PathSecurityError:
831
+ # Let PathSecurityError propagate without wrapping
832
+ raise
833
+ except (FileNotFoundError, DirectoryNotFoundError) as e:
834
+ # Wrap file-related errors
835
+ raise ValueError(f"File access error: {e}")
836
+ except Exception as e:
837
+ # Check if this is a wrapped security error
838
+ if isinstance(e.__cause__, PathSecurityError):
839
+ raise e.__cause__
840
+ # Wrap unexpected errors
841
+ raise ValueError(f"Error collecting files: {e}")
842
+
843
+
844
+ def collect_simple_variables(args: argparse.Namespace) -> Dict[str, str]:
845
+ """Collect simple string variables from --var arguments.
846
+
847
+ Args:
848
+ args: Parsed command line arguments
849
+
850
+ Returns:
851
+ Dictionary mapping variable names to string values
852
+
853
+ Raises:
854
+ VariableNameError: If a variable name is invalid or duplicate
855
+ """
856
+ variables: Dict[str, str] = {}
857
+ all_names: Set[str] = set()
858
+
859
+ if args.var:
860
+ for mapping in args.var:
861
+ try:
862
+ name, value = mapping.split("=", 1)
863
+ if not name.isidentifier():
864
+ raise VariableNameError(f"Invalid variable name: {name}")
865
+ if name in all_names:
866
+ raise VariableNameError(f"Duplicate variable name: {name}")
867
+ variables[name] = value
868
+ all_names.add(name)
869
+ except ValueError:
870
+ raise VariableNameError(
871
+ f"Invalid variable mapping (expected name=value format): {mapping!r}"
872
+ )
873
+
874
+ return variables
875
+
876
+
877
+ def collect_json_variables(args: argparse.Namespace) -> Dict[str, Any]:
878
+ """Collect JSON variables from --json-var arguments.
879
+
880
+ Args:
881
+ args: Parsed command line arguments
882
+
883
+ Returns:
884
+ Dictionary mapping variable names to parsed JSON values
885
+
886
+ Raises:
887
+ VariableNameError: If a variable name is invalid or duplicate
888
+ InvalidJSONError: If a JSON value is invalid
889
+ """
890
+ variables: Dict[str, Any] = {}
891
+ all_names: Set[str] = set()
892
+
893
+ if args.json_var:
894
+ for mapping in args.json_var:
895
+ try:
896
+ name, json_str = mapping.split("=", 1)
897
+ if not name.isidentifier():
898
+ raise VariableNameError(f"Invalid variable name: {name}")
899
+ if name in all_names:
900
+ raise VariableNameError(f"Duplicate variable name: {name}")
901
+ try:
902
+ value = json.loads(json_str)
903
+ variables[name] = value
904
+ all_names.add(name)
905
+ except json.JSONDecodeError as e:
906
+ raise InvalidJSONError(
907
+ f"Invalid JSON value for {name}: {str(e)}"
908
+ )
909
+ except ValueError:
910
+ raise VariableNameError(
911
+ f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
912
+ )
913
+
914
+ return variables
915
+
916
+
917
+ def create_template_context(
918
+ files: Optional[Dict[str, FileInfoList]] = None,
919
+ variables: Optional[Dict[str, str]] = None,
920
+ json_variables: Optional[Dict[str, Any]] = None,
921
+ security_manager: Optional[SecurityManager] = None,
922
+ stdin_content: Optional[str] = None,
923
+ ) -> Dict[str, Any]:
924
+ """Create template context from direct inputs.
925
+
926
+ Args:
927
+ files: Optional dictionary mapping names to FileInfoList objects
928
+ variables: Optional dictionary of simple string variables
929
+ json_variables: Optional dictionary of JSON variables
930
+ security_manager: Optional security manager for path validation
931
+ stdin_content: Optional content to use for stdin
932
+
933
+ Returns:
934
+ Template context dictionary
935
+
936
+ Raises:
937
+ PathSecurityError: If any file paths violate security constraints
938
+ VariableError: If variable mappings are invalid
939
+ """
940
+ context: Dict[str, Any] = {}
941
+
942
+ # Add file variables
943
+ if files:
944
+ for name, file_list in files.items():
945
+ # For single files, extract the first FileInfo object
946
+ if len(file_list) == 1:
947
+ context[name] = file_list[0]
948
+ else:
949
+ context[name] = file_list
950
+
951
+ # Add simple variables
952
+ if variables:
953
+ context.update(variables)
954
+
955
+ # Add JSON variables
956
+ if json_variables:
957
+ context.update(json_variables)
958
+
959
+ # Add stdin if provided
960
+ if stdin_content is not None:
961
+ context["stdin"] = stdin_content
962
+
963
+ return context
964
+
965
+
966
+ def create_template_context_from_args(
967
+ args: argparse.Namespace,
968
+ security_manager: SecurityManager,
969
+ ) -> Dict[str, Any]:
970
+ """Create template context from command line arguments.
971
+
972
+ Args:
973
+ args: Parsed command line arguments
974
+ security_manager: Security manager for path validation
975
+
976
+ Returns:
977
+ Template context dictionary
978
+
979
+ Raises:
980
+ PathSecurityError: If any file paths violate security constraints
981
+ VariableError: If variable mappings are invalid
982
+ ValueError: If file mappings are invalid or files cannot be accessed
983
+ """
984
+ try:
985
+ # Collect files from arguments
986
+ files = None
987
+ if any([args.file, args.files, args.dir]):
988
+ files = collect_files(
989
+ file_mappings=args.file,
990
+ pattern_mappings=args.files,
991
+ dir_mappings=args.dir,
992
+ dir_recursive=args.dir_recursive,
993
+ dir_extensions=(
994
+ args.dir_ext.split(",") if args.dir_ext else None
995
+ ),
996
+ security_manager=security_manager,
997
+ )
998
+
999
+ # Collect simple variables
1000
+ try:
1001
+ variables = collect_simple_variables(args)
1002
+ except VariableNameError as e:
1003
+ raise VariableError(str(e))
1004
+
1005
+ # Collect JSON variables
1006
+ json_variables = {}
1007
+ if args.json_var:
1008
+ for mapping in args.json_var:
1009
+ try:
1010
+ name, value = mapping.split("=", 1)
1011
+ if not name.isidentifier():
1012
+ raise VariableNameError(
1013
+ f"Invalid variable name: {name}"
1014
+ )
1015
+ try:
1016
+ json_value = json.loads(value)
1017
+ except json.JSONDecodeError as e:
1018
+ raise InvalidJSONError(
1019
+ f"Invalid JSON value for {name} ({value!r}): {str(e)}"
1020
+ )
1021
+ if name in json_variables:
1022
+ raise VariableNameError(
1023
+ f"Duplicate variable name: {name}"
1024
+ )
1025
+ json_variables[name] = json_value
1026
+ except ValueError:
1027
+ raise VariableNameError(
1028
+ f"Invalid JSON variable mapping format: {mapping}. Expected name=json"
1029
+ )
1030
+
1031
+ # Get stdin content if available
1032
+ stdin_content = None
1033
+ try:
1034
+ if not sys.stdin.isatty():
1035
+ stdin_content = sys.stdin.read()
1036
+ except (OSError, IOError):
1037
+ # Skip stdin if it can't be read
1038
+ pass
1039
+
1040
+ return create_template_context(
1041
+ files=files,
1042
+ variables=variables,
1043
+ json_variables=json_variables,
1044
+ security_manager=security_manager,
1045
+ stdin_content=stdin_content,
1046
+ )
1047
+
1048
+ except PathSecurityError:
1049
+ # Let PathSecurityError propagate without wrapping
1050
+ raise
1051
+ except (FileNotFoundError, DirectoryNotFoundError) as e:
1052
+ # Wrap file-related errors
1053
+ raise ValueError(f"File access error: {e}")
1054
+ except Exception as e:
1055
+ # Check if this is a wrapped security error
1056
+ if isinstance(e.__cause__, PathSecurityError):
1057
+ raise e.__cause__
1058
+ # Wrap unexpected errors
1059
+ raise ValueError(f"Error collecting files: {e}")
1060
+
1061
+
1062
+ def validate_security_manager(
1063
+ base_dir: Optional[str] = None,
1064
+ allowed_dirs: Optional[List[str]] = None,
1065
+ allowed_dirs_file: Optional[str] = None,
1066
+ ) -> SecurityManager:
1067
+ """Create and validate a security manager.
1068
+
1069
+ Args:
1070
+ base_dir: Optional base directory to resolve paths against
1071
+ allowed_dirs: Optional list of allowed directory paths
1072
+ allowed_dirs_file: Optional path to file containing allowed directories
1073
+
1074
+ Returns:
1075
+ Configured SecurityManager instance
1076
+
1077
+ Raises:
1078
+ FileNotFoundError: If allowed_dirs_file does not exist
1079
+ PathSecurityError: If any paths are outside base directory
1080
+ """
1081
+ # Convert base_dir to string if it's a Path
1082
+ base_dir_str = str(base_dir) if base_dir else None
1083
+ security_manager = SecurityManager(base_dir_str)
1084
+
1085
+ if allowed_dirs_file:
1086
+ security_manager.add_allowed_dirs_from_file(str(allowed_dirs_file))
1087
+
1088
+ if allowed_dirs:
1089
+ for allowed_dir in allowed_dirs:
1090
+ security_manager.add_allowed_dir(str(allowed_dir))
1091
+
1092
+ return security_manager
1093
+
1094
+
1095
+ def parse_var(var_str: str) -> Tuple[str, str]:
1096
+ """Parse a simple variable string in the format 'name=value'.
1097
+
1098
+ Args:
1099
+ var_str: Variable string in format 'name=value'
1100
+
1101
+ Returns:
1102
+ Tuple of (name, value)
1103
+
1104
+ Raises:
1105
+ VariableNameError: If variable name is empty or invalid
1106
+ VariableValueError: If variable format is invalid
1107
+ """
1108
+ try:
1109
+ name, value = var_str.split("=", 1)
1110
+ if not name:
1111
+ raise VariableNameError("Empty name in variable mapping")
1112
+ if not name.isidentifier():
1113
+ raise VariableNameError(
1114
+ f"Invalid variable name: {name}. Must be a valid Python identifier"
1115
+ )
1116
+ return name, value
1117
+ except ValueError as e:
1118
+ if "not enough values to unpack" in str(e):
1119
+ raise VariableValueError(
1120
+ f"Invalid variable mapping (expected name=value format): {var_str!r}"
1121
+ )
1122
+ raise
1123
+
1124
+
1125
+ def parse_json_var(var_str: str) -> Tuple[str, Any]:
1126
+ """Parse a JSON variable string in the format 'name=json_value'.
1127
+
1128
+ Args:
1129
+ var_str: Variable string in format 'name=json_value'
1130
+
1131
+ Returns:
1132
+ Tuple of (name, parsed_value)
1133
+
1134
+ Raises:
1135
+ VariableNameError: If variable name is empty or invalid
1136
+ VariableValueError: If variable format is invalid
1137
+ InvalidJSONError: If JSON value is invalid
1138
+ """
1139
+ try:
1140
+ name, json_str = var_str.split("=", 1)
1141
+ if not name:
1142
+ raise VariableNameError("Empty name in JSON variable mapping")
1143
+ if not name.isidentifier():
1144
+ raise VariableNameError(
1145
+ f"Invalid variable name: {name}. Must be a valid Python identifier"
1146
+ )
1147
+
1148
+ try:
1149
+ value = json.loads(json_str)
1150
+ except json.JSONDecodeError as e:
1151
+ raise InvalidJSONError(
1152
+ f"Invalid JSON value for variable {name!r}: {json_str!r}"
1153
+ ) from e
1154
+
1155
+ return name, value
1156
+
1157
+ except ValueError as e:
1158
+ if "not enough values to unpack" in str(e):
1159
+ raise VariableValueError(
1160
+ f"Invalid JSON variable mapping (expected name=json format): {var_str!r}"
1161
+ )
1162
+ raise
1163
+
1164
+
1165
+ def _create_enum_type(values: List[Any], field_name: str) -> Type[Enum]:
1166
+ """Create an enum type from a list of values.
1167
+
1168
+ Args:
1169
+ values: List of enum values
1170
+ field_name: Name of the field for enum type name
1171
+
1172
+ Returns:
1173
+ Created enum type
1174
+ """
1175
+ # Determine the value type
1176
+ value_types = {type(v) for v in values}
1177
+
1178
+ if len(value_types) > 1:
1179
+ # Mixed types, use string representation
1180
+ enum_dict = {f"VALUE_{i}": str(v) for i, v in enumerate(values)}
1181
+ return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
1182
+ elif value_types == {int}:
1183
+ # All integer values
1184
+ enum_dict = {f"VALUE_{v}": v for v in values}
1185
+ return type(f"{field_name.title()}Enum", (IntEnum,), enum_dict)
1186
+ elif value_types == {str}:
1187
+ # All string values
1188
+ enum_dict = {v.upper().replace(" ", "_"): v for v in values}
1189
+ if sys.version_info >= (3, 11):
1190
+ return type(f"{field_name.title()}Enum", (StrEnum,), enum_dict)
1191
+ else:
1192
+ # Other types, use string representation
1193
+ return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
1194
+
1195
+ # Default case: treat as string enum
1196
+ enum_dict = {f"VALUE_{i}": str(v) for i, v in enumerate(values)}
1197
+ return type(f"{field_name.title()}Enum", (str, Enum), enum_dict)
1198
+
1199
+
1200
+ def create_argument_parser() -> argparse.ArgumentParser:
1201
+ """Create argument parser for CLI."""
1202
+ parser = argparse.ArgumentParser(
1203
+ description="Make structured OpenAI API calls.",
1204
+ formatter_class=argparse.RawDescriptionHelpFormatter,
1205
+ )
1206
+
1207
+ # Debug output options
1208
+ debug_group = parser.add_argument_group("Debug Output Options")
1209
+ debug_group.add_argument(
1210
+ "--show-model-schema",
1211
+ action="store_true",
1212
+ help="Display the generated Pydantic model schema",
1213
+ )
1214
+ debug_group.add_argument(
1215
+ "--debug-validation",
1216
+ action="store_true",
1217
+ help="Show detailed schema validation debugging information",
1218
+ )
1219
+ debug_group.add_argument(
1220
+ "--verbose-schema",
1221
+ action="store_true",
1222
+ help="Enable verbose schema debugging output",
1223
+ )
1224
+ debug_group.add_argument(
1225
+ "--progress-level",
1226
+ choices=["none", "basic", "detailed"],
1227
+ default="basic",
1228
+ help="Set the level of progress reporting (default: basic)",
1229
+ )
1230
+
1231
+ # Required arguments
1232
+ parser.add_argument(
1233
+ "--task",
1234
+ required=True,
1235
+ help="Task template string or @file",
1236
+ )
1237
+
1238
+ # File access arguments
1239
+ parser.add_argument(
1240
+ "--file",
1241
+ action="append",
1242
+ default=[],
1243
+ help="Map file to variable (name=path)",
1244
+ metavar="NAME=PATH",
1245
+ )
1246
+ parser.add_argument(
1247
+ "--files",
1248
+ action="append",
1249
+ default=[],
1250
+ help="Map file pattern to variable (name=pattern)",
1251
+ metavar="NAME=PATTERN",
1252
+ )
1253
+ parser.add_argument(
1254
+ "--dir",
1255
+ action="append",
1256
+ default=[],
1257
+ help="Map directory to variable (name=path)",
1258
+ metavar="NAME=PATH",
1259
+ )
1260
+ parser.add_argument(
1261
+ "--allowed-dir",
1262
+ action="append",
1263
+ default=[],
1264
+ help="Additional allowed directory or @file",
1265
+ metavar="PATH",
1266
+ )
1267
+ parser.add_argument(
1268
+ "--base-dir",
1269
+ help="Base directory for file access (defaults to current directory)",
1270
+ default=os.getcwd(),
1271
+ )
1272
+ parser.add_argument(
1273
+ "--allowed-dirs-file",
1274
+ help="File containing list of allowed directories",
1275
+ )
1276
+ parser.add_argument(
1277
+ "--dir-recursive",
1278
+ action="store_true",
1279
+ help="Process directories recursively",
1280
+ )
1281
+ parser.add_argument(
1282
+ "--dir-ext",
1283
+ help="Comma-separated list of file extensions to include in directory processing",
1284
+ )
1285
+
1286
+ # Variable arguments
1287
+ parser.add_argument(
1288
+ "--var",
1289
+ action="append",
1290
+ default=[],
1291
+ help="Pass simple variables (name=value)",
1292
+ metavar="NAME=VALUE",
1293
+ )
1294
+ parser.add_argument(
1295
+ "--json-var",
1296
+ action="append",
1297
+ default=[],
1298
+ help="Pass JSON variables (name=json)",
1299
+ metavar="NAME=JSON",
1300
+ )
1301
+
1302
+ # System prompt options
1303
+ parser.add_argument(
1304
+ "--system-prompt",
1305
+ help=(
1306
+ "System prompt for the model (use @file to load from file, "
1307
+ "can also be specified in task template YAML frontmatter)"
1308
+ ),
1309
+ default=DEFAULT_SYSTEM_PROMPT,
1310
+ )
1311
+ parser.add_argument(
1312
+ "--ignore-task-sysprompt",
1313
+ action="store_true",
1314
+ help="Ignore system prompt from task template YAML frontmatter",
1315
+ )
1316
+
1317
+ # Schema validation
1318
+ parser.add_argument(
1319
+ "--schema",
1320
+ dest="schema_file",
1321
+ required=True,
1322
+ help="JSON schema file for response validation",
1323
+ )
1324
+ parser.add_argument(
1325
+ "--validate-schema",
1326
+ action="store_true",
1327
+ help="Validate schema and response",
1328
+ )
1329
+
1330
+ # Model configuration
1331
+ parser.add_argument(
1332
+ "--model",
1333
+ default="gpt-4o-2024-08-06",
1334
+ help="Model to use",
1335
+ )
1336
+ parser.add_argument(
1337
+ "--temperature",
1338
+ type=float,
1339
+ default=0.0,
1340
+ help="Temperature (0.0-2.0)",
1341
+ )
1342
+ parser.add_argument(
1343
+ "--max-tokens",
1344
+ type=int,
1345
+ help="Maximum tokens to generate",
1346
+ )
1347
+ parser.add_argument(
1348
+ "--top-p",
1349
+ type=float,
1350
+ default=1.0,
1351
+ help="Top-p sampling (0.0-1.0)",
1352
+ )
1353
+ parser.add_argument(
1354
+ "--frequency-penalty",
1355
+ type=float,
1356
+ default=0.0,
1357
+ help="Frequency penalty (-2.0-2.0)",
1358
+ )
1359
+ parser.add_argument(
1360
+ "--presence-penalty",
1361
+ type=float,
1362
+ default=0.0,
1363
+ help="Presence penalty (-2.0-2.0)",
1364
+ )
1365
+ parser.add_argument(
1366
+ "--timeout",
1367
+ type=float,
1368
+ default=60.0,
1369
+ help="API timeout in seconds",
1370
+ )
1371
+
1372
+ # Output options
1373
+ parser.add_argument(
1374
+ "--output-file",
1375
+ help="Write JSON output to file",
1376
+ )
1377
+ parser.add_argument(
1378
+ "--dry-run",
1379
+ action="store_true",
1380
+ help="Simulate API call without making request",
1381
+ )
1382
+ parser.add_argument(
1383
+ "--no-progress",
1384
+ action="store_true",
1385
+ help="Disable progress indicators",
1386
+ )
1387
+
1388
+ # Other options
1389
+ parser.add_argument(
1390
+ "--api-key",
1391
+ help="OpenAI API key (overrides env var)",
1392
+ )
1393
+ parser.add_argument(
1394
+ "--verbose",
1395
+ action="store_true",
1396
+ help="Enable verbose output",
1397
+ )
1398
+ parser.add_argument(
1399
+ "--debug-openai-stream",
1400
+ action="store_true",
1401
+ help="Enable low-level debug output for OpenAI streaming (very verbose)",
1402
+ )
1403
+ parser.add_argument(
1404
+ "--version",
1405
+ action="version",
1406
+ version=f"%(prog)s {__version__}",
1407
+ )
1408
+
1409
+ return parser
1410
+
1411
+
1412
+ async def _main() -> ExitCode:
1413
+ """Main CLI function.
1414
+
1415
+ Returns:
1416
+ ExitCode: Exit code indicating success or failure
1417
+ """
1418
+ try:
1419
+ parser = create_argument_parser()
1420
+ args = parser.parse_args()
1421
+
1422
+ # Configure logging
1423
+ log_level = logging.DEBUG if args.verbose else logging.INFO
1424
+ logger.setLevel(log_level)
1425
+
1426
+ # Create security manager
1427
+ security_manager = validate_security_manager(
1428
+ base_dir=args.base_dir,
1429
+ allowed_dirs=args.allowed_dir,
1430
+ allowed_dirs_file=args.allowed_dirs_file,
1431
+ )
1432
+
1433
+ # Validate task template
1434
+ task_template = validate_task_template(args.task)
1435
+
1436
+ # Validate schema file
1437
+ schema = validate_schema_file(args.schema_file, args.verbose)
1438
+
1439
+ # Create template context
1440
+ template_context = create_template_context_from_args(
1441
+ args, security_manager
1442
+ )
1443
+
1444
+ # Create Jinja environment
1445
+ env = create_jinja_env()
1446
+
1447
+ # Process system prompt
1448
+ args.system_prompt = process_system_prompt(
1449
+ task_template,
1450
+ args.system_prompt,
1451
+ template_context,
1452
+ env,
1453
+ args.ignore_task_sysprompt,
1454
+ )
1455
+
1456
+ # Render task template
1457
+ rendered_task = render_template(task_template, template_context, env)
1458
+ logger.info(rendered_task) # Log the rendered template
1459
+
1460
+ # If dry run, exit here
1461
+ if args.dry_run:
1462
+ logger.info("DRY RUN MODE")
1463
+ return ExitCode.SUCCESS
1464
+
1465
+ # Load and validate schema
1466
+ try:
1467
+ logger.debug("[_main] Loading schema from %s", args.schema_file)
1468
+ schema = validate_schema_file(
1469
+ args.schema_file, verbose=args.verbose_schema
1470
+ )
1471
+ logger.debug("[_main] Creating output model")
1472
+ output_model = create_dynamic_model(
1473
+ schema,
1474
+ base_name="OutputModel",
1475
+ show_schema=args.show_model_schema,
1476
+ debug_validation=args.debug_validation,
1477
+ )
1478
+ logger.debug("[_main] Successfully created output model")
1479
+ except (SchemaFileError, InvalidJSONError, SchemaValidationError) as e:
1480
+ logger.error(str(e))
1481
+ return ExitCode.SCHEMA_ERROR
1482
+ except ModelCreationError as e:
1483
+ logger.error(f"Model creation error: {e}")
1484
+ return ExitCode.SCHEMA_ERROR
1485
+ except Exception as e:
1486
+ logger.error(f"Unexpected error creating model: {e}")
1487
+ return ExitCode.SCHEMA_ERROR
1488
+
1489
+ # Validate model support
1490
+ try:
1491
+ supports_structured_output(args.model)
1492
+ except ModelNotSupportedError as e:
1493
+ logger.error(str(e))
1494
+ return ExitCode.DATA_ERROR
1495
+ except ModelVersionError as e:
1496
+ logger.error(str(e))
1497
+ return ExitCode.DATA_ERROR
1498
+
1499
+ # Estimate token usage
1500
+ messages = [
1501
+ {"role": "system", "content": args.system_prompt},
1502
+ {"role": "user", "content": rendered_task},
1503
+ ]
1504
+ total_tokens = estimate_tokens_for_chat(messages, args.model)
1505
+ context_limit = get_context_window_limit(args.model)
1506
+
1507
+ if total_tokens > context_limit:
1508
+ logger.error(
1509
+ f"Total tokens ({total_tokens}) exceeds model context limit ({context_limit})"
1510
+ )
1511
+ return ExitCode.DATA_ERROR
1512
+
1513
+ # Get API key
1514
+ api_key = args.api_key or os.getenv("OPENAI_API_KEY")
1515
+ if not api_key:
1516
+ logger.error(
1517
+ "No OpenAI API key provided (--api-key or OPENAI_API_KEY env var)"
1518
+ )
1519
+ return ExitCode.USAGE_ERROR
1520
+
1521
+ # Create OpenAI client
1522
+ client = AsyncOpenAI(api_key=api_key, timeout=args.timeout)
1523
+
1524
+ # Create log callback that matches expected signature
1525
+ def log_callback(
1526
+ level: int, message: str, extra: dict[str, Any]
1527
+ ) -> None:
1528
+ # Only log if debug_openai_stream is enabled
1529
+ if args.debug_openai_stream:
1530
+ # Include extra dictionary in the message for both DEBUG and ERROR
1531
+ if extra: # Only add if there's actually extra data
1532
+ extra_str = json.dumps(extra, indent=2)
1533
+ message = f"{message}\nDetails:\n{extra_str}"
1534
+ openai_logger.log(level, message, extra=extra)
1535
+
1536
+ # Make API request
1537
+ try:
1538
+ logger.debug("Creating ProgressContext for API response handling")
1539
+ with ProgressContext(
1540
+ description="Processing API response",
1541
+ level=args.progress_level,
1542
+ ) as progress:
1543
+ logger.debug("Starting API response stream processing")
1544
+ logger.debug("Debug flag status: %s", args.debug_openai_stream)
1545
+ logger.debug("OpenAI logger level: %s", openai_logger.level)
1546
+ for handler in openai_logger.handlers:
1547
+ logger.debug("Handler level: %s", handler.level)
1548
+ async for chunk in async_openai_structured_stream(
1549
+ client=client,
1550
+ model=args.model,
1551
+ temperature=args.temperature,
1552
+ max_tokens=args.max_tokens,
1553
+ top_p=args.top_p,
1554
+ frequency_penalty=args.frequency_penalty,
1555
+ presence_penalty=args.presence_penalty,
1556
+ system_prompt=args.system_prompt,
1557
+ user_prompt=rendered_task,
1558
+ output_schema=output_model,
1559
+ timeout=args.timeout,
1560
+ on_log=log_callback,
1561
+ ):
1562
+ logger.debug("Received API response chunk")
1563
+ if not chunk:
1564
+ logger.debug("Empty chunk received, skipping")
1565
+ continue
1566
+
1567
+ # Write output
1568
+ try:
1569
+ logger.debug("Starting to process output chunk")
1570
+ dumped = chunk.model_dump(mode="json")
1571
+ logger.debug("Successfully dumped chunk to JSON")
1572
+ logger.debug("Dumped chunk: %s", dumped)
1573
+ logger.debug(
1574
+ "Chunk type: %s, length: %d",
1575
+ type(dumped),
1576
+ len(json.dumps(dumped)),
1577
+ )
1578
+
1579
+ if args.output_file:
1580
+ logger.debug(
1581
+ "Writing to output file: %s", args.output_file
1582
+ )
1583
+ try:
1584
+ with open(
1585
+ args.output_file, "a", encoding="utf-8"
1586
+ ) as f:
1587
+ json_str = json.dumps(dumped, indent=2)
1588
+ logger.debug(
1589
+ "Writing JSON string of length %d",
1590
+ len(json_str),
1591
+ )
1592
+ f.write(json_str)
1593
+ f.write("\n")
1594
+ logger.debug("Successfully wrote to file")
1595
+ except Exception as e:
1596
+ logger.error(
1597
+ "Failed to write to output file: %s", e
1598
+ )
1599
+ else:
1600
+ logger.debug(
1601
+ "About to call progress.print_output with JSON string"
1602
+ )
1603
+ json_str = json.dumps(dumped, indent=2)
1604
+ logger.debug(
1605
+ "JSON string length before print_output: %d",
1606
+ len(json_str),
1607
+ )
1608
+ logger.debug(
1609
+ "First 100 chars of JSON string: %s",
1610
+ json_str[:100] if json_str else "",
1611
+ )
1612
+ progress.print_output(json_str)
1613
+ logger.debug(
1614
+ "Completed print_output call for JSON string"
1615
+ )
1616
+
1617
+ logger.debug("Starting progress update")
1618
+ progress.update()
1619
+ logger.debug("Completed progress update")
1620
+ except Exception as e:
1621
+ logger.error("Failed to process chunk: %s", e)
1622
+ logger.error("Chunk: %s", chunk)
1623
+ continue
1624
+
1625
+ logger.debug("Finished processing API response stream")
1626
+
1627
+ except StreamInterruptedError as e:
1628
+ logger.error(f"Stream interrupted: {e}")
1629
+ return ExitCode.API_ERROR
1630
+ except StreamBufferError as e:
1631
+ logger.error(f"Stream buffer error: {e}")
1632
+ return ExitCode.API_ERROR
1633
+ except StreamParseError as e:
1634
+ logger.error(f"Stream parse error: {e}")
1635
+ return ExitCode.API_ERROR
1636
+ except APIResponseError as e:
1637
+ logger.error(f"API response error: {e}")
1638
+ return ExitCode.API_ERROR
1639
+ except EmptyResponseError as e:
1640
+ logger.error(f"Empty response error: {e}")
1641
+ return ExitCode.API_ERROR
1642
+ except InvalidResponseFormatError as e:
1643
+ logger.error(f"Invalid response format: {e}")
1644
+ return ExitCode.API_ERROR
1645
+ except (APIConnectionError, InternalServerError) as e:
1646
+ logger.error(f"API connection error: {e}")
1647
+ return ExitCode.API_ERROR
1648
+ except RateLimitError as e:
1649
+ logger.error(f"Rate limit exceeded: {e}")
1650
+ return ExitCode.API_ERROR
1651
+ except BadRequestError as e:
1652
+ logger.error(f"Bad request: {e}")
1653
+ return ExitCode.API_ERROR
1654
+ except AuthenticationError as e:
1655
+ logger.error(f"Authentication failed: {e}")
1656
+ return ExitCode.API_ERROR
1657
+ except OpenAIClientError as e:
1658
+ logger.error(f"OpenAI client error: {e}")
1659
+ return ExitCode.API_ERROR
1660
+ except Exception as e:
1661
+ logger.error(f"Unexpected error: {e}")
1662
+ return ExitCode.INTERNAL_ERROR
1663
+
1664
+ return ExitCode.SUCCESS
1665
+
1666
+ except KeyboardInterrupt:
1667
+ logger.error("Operation cancelled by user")
1668
+ return ExitCode.INTERRUPTED
1669
+ except PathSecurityError as e:
1670
+ # Only log security errors if they haven't been logged already
1671
+ logger.debug(
1672
+ "[_main] Caught PathSecurityError: %s (logged=%s)",
1673
+ str(e),
1674
+ getattr(e, "has_been_logged", False),
1675
+ )
1676
+ if not getattr(e, "has_been_logged", False):
1677
+ logger.error(str(e))
1678
+ return ExitCode.SECURITY_ERROR
1679
+ except ValueError as e:
1680
+ # Get the original cause of the error
1681
+ cause = e.__cause__ or e.__context__
1682
+ if isinstance(cause, PathSecurityError):
1683
+ logger.debug(
1684
+ "[_main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
1685
+ str(cause),
1686
+ getattr(cause, "has_been_logged", False),
1687
+ )
1688
+ # Only log security errors if they haven't been logged already
1689
+ if not getattr(cause, "has_been_logged", False):
1690
+ logger.error(str(cause))
1691
+ return ExitCode.SECURITY_ERROR
1692
+ else:
1693
+ logger.debug("[_main] Caught ValueError: %s", str(e))
1694
+ logger.error(f"Invalid input: {e}")
1695
+ return ExitCode.DATA_ERROR
1696
+ except Exception as e:
1697
+ # Check if this is a wrapped security error
1698
+ if isinstance(e.__cause__, PathSecurityError):
1699
+ logger.debug(
1700
+ "[_main] Caught wrapped PathSecurityError in Exception: %s (logged=%s)",
1701
+ str(e.__cause__),
1702
+ getattr(e.__cause__, "has_been_logged", False),
1703
+ )
1704
+ # Only log security errors if they haven't been logged already
1705
+ if not getattr(e.__cause__, "has_been_logged", False):
1706
+ logger.error(str(e.__cause__))
1707
+ return ExitCode.SECURITY_ERROR
1708
+ logger.debug("[_main] Caught unexpected error: %s", str(e))
1709
+ logger.error(f"Unexpected error: {e}")
1710
+ return ExitCode.INTERNAL_ERROR
1711
+
1712
+
1713
+ def main() -> None:
1714
+ """CLI entry point that handles all errors."""
1715
+ try:
1716
+ logger.debug("[main] Starting main execution")
1717
+ exit_code = asyncio.run(_main())
1718
+ sys.exit(exit_code.value)
1719
+ except KeyboardInterrupt:
1720
+ logger.error("Operation cancelled by user")
1721
+ sys.exit(ExitCode.INTERRUPTED.value)
1722
+ except PathSecurityError as e:
1723
+ # Only log security errors if they haven't been logged already
1724
+ logger.debug(
1725
+ "[main] Caught PathSecurityError: %s (logged=%s)",
1726
+ str(e),
1727
+ getattr(e, "has_been_logged", False),
1728
+ )
1729
+ if not getattr(e, "has_been_logged", False):
1730
+ logger.error(str(e))
1731
+ sys.exit(ExitCode.SECURITY_ERROR.value)
1732
+ except ValueError as e:
1733
+ # Get the original cause of the error
1734
+ cause = e.__cause__ or e.__context__
1735
+ if isinstance(cause, PathSecurityError):
1736
+ logger.debug(
1737
+ "[main] Caught wrapped PathSecurityError in ValueError: %s (logged=%s)",
1738
+ str(cause),
1739
+ getattr(cause, "has_been_logged", False),
1740
+ )
1741
+ # Only log security errors if they haven't been logged already
1742
+ if not getattr(cause, "has_been_logged", False):
1743
+ logger.error(str(cause))
1744
+ sys.exit(ExitCode.SECURITY_ERROR.value)
1745
+ else:
1746
+ logger.debug("[main] Caught ValueError: %s", str(e))
1747
+ logger.error(f"Invalid input: {e}")
1748
+ sys.exit(ExitCode.DATA_ERROR.value)
1749
+ except Exception as e:
1750
+ # Check if this is a wrapped security error
1751
+ if isinstance(e.__cause__, PathSecurityError):
1752
+ logger.debug(
1753
+ "[main] Caught wrapped PathSecurityError in Exception: %s (logged=%s)",
1754
+ str(e.__cause__),
1755
+ getattr(e.__cause__, "has_been_logged", False),
1756
+ )
1757
+ # Only log security errors if they haven't been logged already
1758
+ if not getattr(e.__cause__, "has_been_logged", False):
1759
+ logger.error(str(e.__cause__))
1760
+ sys.exit(ExitCode.SECURITY_ERROR.value)
1761
+ logger.debug("[main] Caught unexpected error: %s", str(e))
1762
+ logger.error(f"Unexpected error: {e}")
1763
+ sys.exit(ExitCode.INTERNAL_ERROR.value)
1764
+
1765
+
1766
+ # Export public API
1767
+ __all__ = [
1768
+ "ExitCode",
1769
+ "estimate_tokens_for_chat",
1770
+ "get_context_window_limit",
1771
+ "get_default_token_limit",
1772
+ "parse_json_var",
1773
+ "create_dynamic_model",
1774
+ "validate_path_mapping",
1775
+ "create_argument_parser",
1776
+ "main",
1777
+ ]
1778
+
1779
+
1780
+ def create_dynamic_model(
1781
+ schema: Dict[str, Any],
1782
+ base_name: str = "DynamicModel",
1783
+ show_schema: bool = False,
1784
+ debug_validation: bool = False,
1785
+ ) -> Type[BaseModel]:
1786
+ """Create a Pydantic model from a JSON schema.
1787
+
1788
+ Args:
1789
+ schema: JSON schema dict, can be wrapped in {"schema": ...} format
1790
+ base_name: Base name for the model
1791
+ show_schema: Whether to show the generated schema
1792
+ debug_validation: Whether to enable validation debugging
1793
+
1794
+ Returns:
1795
+ Generated Pydantic model class
1796
+
1797
+ Raises:
1798
+ ModelCreationError: When model creation fails
1799
+ SchemaValidationError: When schema is invalid
1800
+ """
1801
+ if debug_validation:
1802
+ logger.info("Creating dynamic model from schema:")
1803
+ logger.info(json.dumps(schema, indent=2))
1804
+
1805
+ try:
1806
+ # Extract required fields
1807
+ required: Set[str] = set(schema.get("required", []))
1808
+
1809
+ # Handle our wrapper format if present
1810
+ if "schema" in schema:
1811
+ if debug_validation:
1812
+ logger.info("Found schema wrapper, extracting inner schema")
1813
+ logger.info(
1814
+ "Original schema: %s", json.dumps(schema, indent=2)
1815
+ )
1816
+ inner_schema = schema["schema"]
1817
+ if not isinstance(inner_schema, dict):
1818
+ if debug_validation:
1819
+ logger.info(
1820
+ "Inner schema must be a dictionary, got %s",
1821
+ type(inner_schema),
1822
+ )
1823
+ raise SchemaValidationError(
1824
+ "Inner schema must be a dictionary"
1825
+ )
1826
+ if debug_validation:
1827
+ logger.info("Using inner schema:")
1828
+ logger.info(json.dumps(inner_schema, indent=2))
1829
+ schema = inner_schema
1830
+
1831
+ # Ensure schema has type field
1832
+ if "type" not in schema:
1833
+ if debug_validation:
1834
+ logger.info("Schema missing type field, assuming object type")
1835
+ schema["type"] = "object"
1836
+
1837
+ # Validate root schema is object type
1838
+ if schema["type"] != "object":
1839
+ if debug_validation:
1840
+ logger.error(
1841
+ "Schema type must be 'object', got %s", schema["type"]
1842
+ )
1843
+ raise SchemaValidationError("Root schema must be of type 'object'")
1844
+
1845
+ # Create model configuration
1846
+ config = ConfigDict(
1847
+ title=schema.get("title", base_name),
1848
+ extra=(
1849
+ "forbid"
1850
+ if schema.get("additionalProperties") is False
1851
+ else "allow"
1852
+ ),
1853
+ validate_default=True,
1854
+ use_enum_values=True,
1855
+ arbitrary_types_allowed=True,
1856
+ json_schema_extra={
1857
+ k: v
1858
+ for k, v in schema.items()
1859
+ if k
1860
+ not in {
1861
+ "type",
1862
+ "properties",
1863
+ "required",
1864
+ "title",
1865
+ "description",
1866
+ "additionalProperties",
1867
+ "readOnly",
1868
+ }
1869
+ },
1870
+ )
1871
+
1872
+ if debug_validation:
1873
+ logger.info("Created model configuration:")
1874
+ logger.info(" Title: %s", config.get("title"))
1875
+ logger.info(" Extra: %s", config.get("extra"))
1876
+ logger.info(
1877
+ " Validate Default: %s", config.get("validate_default")
1878
+ )
1879
+ logger.info(" Use Enum Values: %s", config.get("use_enum_values"))
1880
+ logger.info(
1881
+ " Arbitrary Types: %s", config.get("arbitrary_types_allowed")
1882
+ )
1883
+ logger.info(
1884
+ " JSON Schema Extra: %s", config.get("json_schema_extra")
1885
+ )
1886
+
1887
+ # Create field definitions
1888
+ field_definitions: Dict[str, FieldDefinition] = {}
1889
+ properties = schema.get("properties", {})
1890
+
1891
+ for field_name, field_schema in properties.items():
1892
+ try:
1893
+ if debug_validation:
1894
+ logger.info("Processing field %s:", field_name)
1895
+ logger.info(
1896
+ " Schema: %s", json.dumps(field_schema, indent=2)
1897
+ )
1898
+
1899
+ python_type, field = _get_type_with_constraints(
1900
+ field_schema, field_name, base_name
1901
+ )
1902
+
1903
+ # Handle optional fields
1904
+ if field_name not in required:
1905
+ if debug_validation:
1906
+ logger.info(
1907
+ "Field %s is optional, wrapping in Optional",
1908
+ field_name,
1909
+ )
1910
+ field_type = cast(Type[Any], Optional[python_type])
1911
+ else:
1912
+ field_type = python_type
1913
+ if debug_validation:
1914
+ logger.info("Field %s is required", field_name)
1915
+
1916
+ # Create field definition
1917
+ field_definitions[field_name] = (field_type, field)
1918
+
1919
+ if debug_validation:
1920
+ logger.info("Successfully created field definition:")
1921
+ logger.info(" Name: %s", field_name)
1922
+ logger.info(" Type: %s", str(field_type))
1923
+ logger.info(" Required: %s", field_name in required)
1924
+
1925
+ except (FieldDefinitionError, NestedModelError) as e:
1926
+ if debug_validation:
1927
+ logger.error("Error creating field %s:", field_name)
1928
+ logger.error(" Error type: %s", type(e).__name__)
1929
+ logger.error(" Error message: %s", str(e))
1930
+ raise ModelValidationError(base_name, [str(e)])
1931
+
1932
+ # Create the model with the fields
1933
+ model = create_model(
1934
+ base_name,
1935
+ __config__=config,
1936
+ **{
1937
+ name: (
1938
+ (
1939
+ cast(Type[Any], field_type)
1940
+ if is_container_type(field_type)
1941
+ else field_type
1942
+ ),
1943
+ field,
1944
+ )
1945
+ for name, (field_type, field) in field_definitions.items()
1946
+ },
1947
+ )
1948
+
1949
+ if debug_validation:
1950
+ logger.info("Successfully created model: %s", model.__name__)
1951
+ logger.info("Model config: %s", dict(model.model_config))
1952
+ logger.info(
1953
+ "Model schema: %s",
1954
+ json.dumps(model.model_json_schema(), indent=2),
1955
+ )
1956
+
1957
+ # Validate the model's JSON schema
1958
+ try:
1959
+ model.model_json_schema()
1960
+ except ValidationError as e:
1961
+ if debug_validation:
1962
+ logger.error("Schema validation failed:")
1963
+ logger.error(" Error type: %s", type(e).__name__)
1964
+ logger.error(" Error message: %s", str(e))
1965
+ if hasattr(e, "errors"):
1966
+ logger.error(" Validation errors:")
1967
+ for error in e.errors():
1968
+ logger.error(" - %s", error)
1969
+ validation_errors = (
1970
+ [str(err) for err in e.errors()]
1971
+ if hasattr(e, "errors")
1972
+ else [str(e)]
1973
+ )
1974
+ raise ModelValidationError(base_name, validation_errors)
1975
+
1976
+ return cast(Type[BaseModel], model)
1977
+
1978
+ except Exception as e:
1979
+ if debug_validation:
1980
+ logger.error("Failed to create model:")
1981
+ logger.error(" Error type: %s", type(e).__name__)
1982
+ logger.error(" Error message: %s", str(e))
1983
+ if hasattr(e, "__cause__"):
1984
+ logger.error(" Caused by: %s", str(e.__cause__))
1985
+ if hasattr(e, "__context__"):
1986
+ logger.error(" Context: %s", str(e.__context__))
1987
+ if hasattr(e, "__traceback__"):
1988
+ import traceback
1989
+
1990
+ logger.error(
1991
+ " Traceback:\n%s",
1992
+ "".join(traceback.format_tb(e.__traceback__)),
1993
+ )
1994
+ raise ModelCreationError(
1995
+ f"Failed to create model '{base_name}': {str(e)}"
1996
+ )
1997
+
1998
+
1999
+ # Validation functions
2000
+ def pattern(regex: str) -> Any:
2001
+ return constr(pattern=regex)
2002
+
2003
+
2004
+ def min_length(length: int) -> Any:
2005
+ return BeforeValidator(lambda v: v if len(str(v)) >= length else None)
2006
+
2007
+
2008
+ def max_length(length: int) -> Any:
2009
+ return BeforeValidator(lambda v: v if len(str(v)) <= length else None)
2010
+
2011
+
2012
+ def ge(value: Union[int, float]) -> Any:
2013
+ return BeforeValidator(lambda v: v if float(v) >= value else None)
2014
+
2015
+
2016
+ def le(value: Union[int, float]) -> Any:
2017
+ return BeforeValidator(lambda v: v if float(v) <= value else None)
2018
+
2019
+
2020
+ def gt(value: Union[int, float]) -> Any:
2021
+ return BeforeValidator(lambda v: v if float(v) > value else None)
2022
+
2023
+
2024
+ def lt(value: Union[int, float]) -> Any:
2025
+ return BeforeValidator(lambda v: v if float(v) < value else None)
2026
+
2027
+
2028
+ def multiple_of(value: Union[int, float]) -> Any:
2029
+ return BeforeValidator(lambda v: v if float(v) % value == 0 else None)
2030
+
2031
+
2032
+ if __name__ == "__main__":
2033
+ main()