ostruct-cli 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ostruct/cli/base_errors.py +183 -0
- ostruct/cli/cli.py +830 -585
- ostruct/cli/click_options.py +338 -211
- ostruct/cli/errors.py +214 -227
- ostruct/cli/exit_codes.py +18 -0
- ostruct/cli/file_info.py +126 -69
- ostruct/cli/file_list.py +191 -72
- ostruct/cli/file_utils.py +132 -97
- ostruct/cli/path_utils.py +86 -77
- ostruct/cli/security/__init__.py +32 -0
- ostruct/cli/security/allowed_checker.py +55 -0
- ostruct/cli/security/base.py +46 -0
- ostruct/cli/security/case_manager.py +75 -0
- ostruct/cli/security/errors.py +164 -0
- ostruct/cli/security/normalization.py +161 -0
- ostruct/cli/security/safe_joiner.py +211 -0
- ostruct/cli/security/security_manager.py +366 -0
- ostruct/cli/security/symlink_resolver.py +483 -0
- ostruct/cli/security/types.py +108 -0
- ostruct/cli/security/windows_paths.py +404 -0
- ostruct/cli/serialization.py +25 -0
- ostruct/cli/template_filters.py +13 -8
- ostruct/cli/template_rendering.py +46 -22
- ostruct/cli/template_utils.py +12 -4
- ostruct/cli/template_validation.py +26 -8
- ostruct/cli/token_utils.py +43 -0
- ostruct/cli/validators.py +109 -0
- {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/METADATA +64 -24
- ostruct_cli-0.5.0.dist-info/RECORD +42 -0
- {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/WHEEL +1 -1
- ostruct/cli/security.py +0 -964
- ostruct/cli/security_types.py +0 -46
- ostruct_cli-0.3.0.dist-info/RECORD +0 -28
- {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/LICENSE +0 -0
- {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,404 @@
|
|
1
|
+
"""Windows path handling and validation.
|
2
|
+
|
3
|
+
This module provides functions for handling Windows-specific path features:
|
4
|
+
- Device paths (r"\\\\?\\", r"\\\\.")
|
5
|
+
- Drive-relative paths (C:folder)
|
6
|
+
- Reserved names (CON, PRN, etc.)
|
7
|
+
- UNC paths (r"\\\\server\\share")
|
8
|
+
- Alternate Data Streams (file.txt:stream)
|
9
|
+
|
10
|
+
Security Design Choices:
|
11
|
+
1. Device Paths:
|
12
|
+
- Explicitly blocked for security
|
13
|
+
- No support for extended-length paths
|
14
|
+
- No direct device access allowed
|
15
|
+
|
16
|
+
2. Drive Paths:
|
17
|
+
- Drive-relative paths must include separator
|
18
|
+
- Drive absolute paths are allowed
|
19
|
+
- Drive letters must be A-Z (case insensitive)
|
20
|
+
|
21
|
+
3. Reserved Names:
|
22
|
+
- All Windows reserved names blocked
|
23
|
+
- Case-insensitive matching
|
24
|
+
- Blocked with or without extensions
|
25
|
+
|
26
|
+
4. UNC Paths:
|
27
|
+
- Must be complete (server and share)
|
28
|
+
- No device paths in UNC format
|
29
|
+
- Normalized to forward slashes
|
30
|
+
|
31
|
+
5. Alternate Data Streams:
|
32
|
+
- All ADS access is blocked
|
33
|
+
- No exceptions for Zone.Identifier
|
34
|
+
- Blocks both read and write
|
35
|
+
|
36
|
+
Known Limitations:
|
37
|
+
1. Path Length:
|
38
|
+
- No extended-length path support
|
39
|
+
- Standard Windows MAX_PATH limits
|
40
|
+
- No workarounds for long paths
|
41
|
+
|
42
|
+
2. Network:
|
43
|
+
- No special handling for DFS
|
44
|
+
- No support for administrative shares
|
45
|
+
- Basic UNC validation only
|
46
|
+
|
47
|
+
3. Security:
|
48
|
+
- Some rare path formats may bypass checks
|
49
|
+
- Complex NTFS features not handled
|
50
|
+
- Limited reparse point support
|
51
|
+
"""
|
52
|
+
|
53
|
+
import logging
|
54
|
+
import os
|
55
|
+
import re
|
56
|
+
from pathlib import Path, WindowsPath
|
57
|
+
from typing import Optional, Union
|
58
|
+
|
59
|
+
from .errors import PathSecurityError, SecurityErrorReasons
|
60
|
+
|
61
|
+
logger = logging.getLogger(__name__)
|
62
|
+
|
63
|
+
# Windows path length limits
|
64
|
+
MAX_PATH = 260
|
65
|
+
EXTENDED_MAX_PATH = 32767
|
66
|
+
|
67
|
+
# Regex patterns for Windows path features
|
68
|
+
_WINDOWS_DEVICE_PATH = re.compile(
|
69
|
+
r"^(?:\\\\|//)[?.](?:\\|/)(?!UNC(?:\\|/))", # Match device paths but exclude UNC
|
70
|
+
flags=re.IGNORECASE,
|
71
|
+
)
|
72
|
+
|
73
|
+
_WINDOWS_DRIVE_RELATIVE = re.compile(
|
74
|
+
r"(?:^|[/\\])[A-Za-z]:(?![/\\])|" # C:folder or \C:folder but not C:\folder
|
75
|
+
r"^/[A-Za-z]:(?![/\\])" # /C:folder variants
|
76
|
+
)
|
77
|
+
|
78
|
+
_WINDOWS_RESERVED_NAMES = re.compile(
|
79
|
+
r"^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])" # Base names
|
80
|
+
r"(\.[^\\/:*?\"<>|]*)?$", # Optional extension
|
81
|
+
re.IGNORECASE,
|
82
|
+
)
|
83
|
+
|
84
|
+
_WINDOWS_UNC = re.compile(
|
85
|
+
r"^\\\\[^?.\\/][^\\/]*\\[^\\/]+(?:\\.*)?|" # \\server\share[\anything]
|
86
|
+
r"^//[^?./][^/]*/[^/]+(?:/.*)?$" # //server/share[/anything]
|
87
|
+
)
|
88
|
+
|
89
|
+
_WINDOWS_INCOMPLETE_UNC = re.compile(
|
90
|
+
r"^\\\\[^?.\\/][^\\/]*(?:\\[^\\/]+)?$|" # \\server or \\server\incomplete
|
91
|
+
r"^//[^?./][^/]*(?:/[^/]+)?$" # //server or //server/incomplete variants
|
92
|
+
)
|
93
|
+
|
94
|
+
_WINDOWS_ADS = re.compile(
|
95
|
+
r":[^/\\<>:\"|?*]+$|" # Basic ADS
|
96
|
+
r":Zone\.Identifier$|" # Zone.Identifier
|
97
|
+
r":[^/\\<>:\"|?*]+:[^/\\]+$" # Multiple stream segments
|
98
|
+
)
|
99
|
+
|
100
|
+
_WINDOWS_INVALID_CHARS = re.compile(
|
101
|
+
r'[<>"|?*]|' # Standard invalid chars except colon
|
102
|
+
r"(?<!^[A-Za-z]):|" # Colon except after drive letter at start
|
103
|
+
r"[\x00-\x1F]" # Control chars
|
104
|
+
)
|
105
|
+
|
106
|
+
_WINDOWS_TRAILING = re.compile(r"[. ]+$") # Trailing dots/spaces
|
107
|
+
|
108
|
+
|
109
|
+
def is_windows_path(path: Union[str, Path]) -> bool:
|
110
|
+
"""Check if path uses Windows-specific features.
|
111
|
+
|
112
|
+
Security Note:
|
113
|
+
- Detects device paths (r"\\?\\" and r"\\.\\") in both slash formats
|
114
|
+
- Case insensitive to handle drive letters
|
115
|
+
"""
|
116
|
+
path_str = str(path)
|
117
|
+
|
118
|
+
# Normalize slashes for consistent matching
|
119
|
+
normalized_path = path_str.replace("\\", "/")
|
120
|
+
|
121
|
+
# Check for device paths first before any processing
|
122
|
+
if _WINDOWS_DEVICE_PATH.match(path_str) or _WINDOWS_DEVICE_PATH.match(
|
123
|
+
normalized_path
|
124
|
+
):
|
125
|
+
logger.debug("Windows device path detected: %r", path_str)
|
126
|
+
return True
|
127
|
+
|
128
|
+
# Rest of the function remains unchanged
|
129
|
+
basename = os.path.basename(path_str)
|
130
|
+
is_drive_relative = bool(_WINDOWS_DRIVE_RELATIVE.search(path_str))
|
131
|
+
is_unc = bool(_WINDOWS_UNC.search(path_str))
|
132
|
+
is_ads = bool(_WINDOWS_ADS.search(path_str))
|
133
|
+
is_reserved = bool(_WINDOWS_RESERVED_NAMES.match(basename))
|
134
|
+
|
135
|
+
logger.debug(
|
136
|
+
"Windows path check for %r: drive_relative=%s, unc=%s, ads=%s, reserved=%s",
|
137
|
+
path_str,
|
138
|
+
is_drive_relative,
|
139
|
+
is_unc,
|
140
|
+
is_ads,
|
141
|
+
is_reserved,
|
142
|
+
)
|
143
|
+
|
144
|
+
return bool(is_drive_relative or is_unc or is_ads or is_reserved)
|
145
|
+
|
146
|
+
|
147
|
+
def normalize_windows_path(path: Union[str, Path]) -> Path:
|
148
|
+
"""Normalize a path using Windows-specific rules.
|
149
|
+
|
150
|
+
This function:
|
151
|
+
1. Converts to Path with Windows semantics
|
152
|
+
2. Resolves to absolute path
|
153
|
+
3. Normalizes separators and case
|
154
|
+
4. Removes redundant separators and dots
|
155
|
+
|
156
|
+
Args:
|
157
|
+
path: Path to normalize
|
158
|
+
|
159
|
+
Returns:
|
160
|
+
Normalized Path
|
161
|
+
|
162
|
+
Raises:
|
163
|
+
PathSecurityError: If path cannot be normalized
|
164
|
+
"""
|
165
|
+
try:
|
166
|
+
logger.debug("Normalizing Windows path: %r", path)
|
167
|
+
|
168
|
+
# Convert to string and normalize all slashes to forward slashes first
|
169
|
+
path_str = str(path)
|
170
|
+
# Replace all backslashes with forward slashes for consistent handling
|
171
|
+
path_str = path_str.replace("\\", "/")
|
172
|
+
# Collapse multiple slashes to single slash, except for UNC prefixes
|
173
|
+
path_str = re.sub(r"(?<!^)//+", "/", path_str)
|
174
|
+
logger.debug("Normalized slashes: %r", path_str)
|
175
|
+
|
176
|
+
# Use regular Path on non-Windows systems
|
177
|
+
path_cls = WindowsPath if os.name == "nt" else Path
|
178
|
+
|
179
|
+
# Convert back to backslashes for Windows path handling
|
180
|
+
if os.name == "nt":
|
181
|
+
path_str = path_str.replace("/", "\\")
|
182
|
+
# Preserve UNC path double backslashes
|
183
|
+
if path_str.startswith("\\") and not path_str.startswith("\\\\"):
|
184
|
+
path_str = "\\" + path_str
|
185
|
+
|
186
|
+
normalized = path_cls(path_str)
|
187
|
+
logger.debug(
|
188
|
+
"Created path object: %r (class=%s)", normalized, path_cls.__name__
|
189
|
+
)
|
190
|
+
|
191
|
+
if os.name == "nt":
|
192
|
+
# Check if resolve() would exceed MAX_PATH
|
193
|
+
resolved = normalized.resolve()
|
194
|
+
resolved_str = str(resolved)
|
195
|
+
# If resolve() added \\?\ prefix or path is too long, reject it
|
196
|
+
if (
|
197
|
+
resolved_str.startswith("\\\\?\\")
|
198
|
+
or len(resolved_str) > MAX_PATH
|
199
|
+
):
|
200
|
+
raise PathSecurityError(
|
201
|
+
f"Path would exceed maximum length of {MAX_PATH} characters after resolution",
|
202
|
+
path=str(path),
|
203
|
+
context={
|
204
|
+
"reason": SecurityErrorReasons.NORMALIZATION_ERROR
|
205
|
+
},
|
206
|
+
)
|
207
|
+
normalized = resolved
|
208
|
+
logger.debug("Resolved on Windows: %r", normalized)
|
209
|
+
else:
|
210
|
+
# On non-Windows, just normalize the path
|
211
|
+
normalized = Path(os.path.normpath(path_str))
|
212
|
+
logger.debug("Normalized on non-Windows: %r", normalized)
|
213
|
+
|
214
|
+
return normalized
|
215
|
+
except PathSecurityError:
|
216
|
+
raise
|
217
|
+
except Exception as e:
|
218
|
+
logger.error(
|
219
|
+
"Failed to normalize Windows path %r: %s", path, e, exc_info=True
|
220
|
+
)
|
221
|
+
raise PathSecurityError(
|
222
|
+
f"Failed to normalize Windows path: {e}",
|
223
|
+
path=str(path),
|
224
|
+
context={"reason": SecurityErrorReasons.NORMALIZATION_ERROR},
|
225
|
+
)
|
226
|
+
|
227
|
+
|
228
|
+
def validate_windows_path(path: Union[str, Path]) -> Optional[str]:
|
229
|
+
"""Validate a path for Windows-specific security issues.
|
230
|
+
|
231
|
+
Performs checks in order:
|
232
|
+
1. Device paths (blocked)
|
233
|
+
2. Path normalization
|
234
|
+
3. Other Windows-specific checks
|
235
|
+
|
236
|
+
Returns an error message if the path:
|
237
|
+
- Uses device paths (r"\\\\?\\", r"\\\\.")
|
238
|
+
- Uses drive-relative paths (C:folder)
|
239
|
+
- Contains reserved names (CON, PRN, etc.)
|
240
|
+
- Uses UNC paths (r"\\\\server\\share")
|
241
|
+
- Contains Alternate Data Streams (file.txt:stream)
|
242
|
+
- Exceeds maximum path length
|
243
|
+
- Contains invalid characters
|
244
|
+
- Has trailing dots or spaces
|
245
|
+
|
246
|
+
Returns None if the path is valid.
|
247
|
+
"""
|
248
|
+
logger.debug("Validating Windows path: %r", path)
|
249
|
+
|
250
|
+
# Initial checks on raw path string
|
251
|
+
path_str = str(path)
|
252
|
+
|
253
|
+
# Normalize slashes for consistent matching
|
254
|
+
normalized_path = path_str.replace("\\", "/")
|
255
|
+
|
256
|
+
# Check for device paths before any processing
|
257
|
+
if _WINDOWS_DEVICE_PATH.match(path_str) or _WINDOWS_DEVICE_PATH.match(
|
258
|
+
normalized_path
|
259
|
+
):
|
260
|
+
logger.debug("Device path detected in original path: %r", path_str)
|
261
|
+
return "Device paths not allowed"
|
262
|
+
|
263
|
+
# Check for incomplete UNC paths before normalization
|
264
|
+
if _WINDOWS_INCOMPLETE_UNC.search(path_str):
|
265
|
+
logger.debug("Incomplete UNC path detected: %r", path_str)
|
266
|
+
return "Incomplete UNC path"
|
267
|
+
|
268
|
+
# Then normalize the path for other checks
|
269
|
+
try:
|
270
|
+
normalized_str = str(normalize_windows_path(path))
|
271
|
+
logger.debug("Normalized path: %r", normalized_str)
|
272
|
+
|
273
|
+
# Check for device paths again after normalization
|
274
|
+
if _WINDOWS_DEVICE_PATH.match(normalized_str):
|
275
|
+
logger.debug(
|
276
|
+
"Device path detected after normalization: %r", normalized_str
|
277
|
+
)
|
278
|
+
return "Device paths not allowed"
|
279
|
+
|
280
|
+
except PathSecurityError as e:
|
281
|
+
logger.debug("Path normalization failed: %s", e)
|
282
|
+
return str(e)
|
283
|
+
|
284
|
+
# Check path length
|
285
|
+
if len(normalized_str) > MAX_PATH:
|
286
|
+
msg = f"Path exceeds maximum length of {MAX_PATH} characters"
|
287
|
+
logger.debug("Path too long: %s", msg)
|
288
|
+
return msg
|
289
|
+
|
290
|
+
if _WINDOWS_DRIVE_RELATIVE.search(normalized_str):
|
291
|
+
logger.debug("Drive-relative path detected: %r", normalized_str)
|
292
|
+
return "Drive-relative paths must include separator"
|
293
|
+
|
294
|
+
# Check for complete UNC paths
|
295
|
+
if _WINDOWS_UNC.search(normalized_str):
|
296
|
+
logger.debug("UNC path detected: %r", normalized_str)
|
297
|
+
return "UNC paths not allowed"
|
298
|
+
|
299
|
+
if _WINDOWS_ADS.search(normalized_str):
|
300
|
+
logger.debug("Alternate Data Stream detected: %r", normalized_str)
|
301
|
+
return "Alternate Data Streams not allowed"
|
302
|
+
|
303
|
+
# Check each path component
|
304
|
+
try:
|
305
|
+
parts = (
|
306
|
+
Path(normalized_str).parts
|
307
|
+
if os.name != "nt"
|
308
|
+
else WindowsPath(normalized_str).parts
|
309
|
+
)
|
310
|
+
logger.debug("Path components: %r", parts)
|
311
|
+
|
312
|
+
for part in parts:
|
313
|
+
# Check for reserved names
|
314
|
+
if _WINDOWS_RESERVED_NAMES.match(part):
|
315
|
+
logger.debug("Reserved name detected: %r", part)
|
316
|
+
return "Windows reserved names not allowed"
|
317
|
+
|
318
|
+
# Check for invalid characters
|
319
|
+
if _WINDOWS_INVALID_CHARS.search(part):
|
320
|
+
msg = f"Invalid characters in path component '{part}'"
|
321
|
+
logger.debug("Invalid characters: %s", msg)
|
322
|
+
return msg
|
323
|
+
|
324
|
+
# Check for trailing dots/spaces
|
325
|
+
if _WINDOWS_TRAILING.search(part):
|
326
|
+
msg = f"Trailing dots or spaces not allowed in '{part}'"
|
327
|
+
logger.debug("Trailing dots/spaces: %s", msg)
|
328
|
+
return msg
|
329
|
+
except Exception as e:
|
330
|
+
logger.error("Failed to check path components: %s", e, exc_info=True)
|
331
|
+
return f"Failed to validate path components: {e}"
|
332
|
+
|
333
|
+
logger.debug("Path validation successful: %r", normalized_str)
|
334
|
+
return None
|
335
|
+
|
336
|
+
|
337
|
+
def resolve_windows_symlink(path: Path) -> Optional[Path]:
|
338
|
+
"""Resolve a Windows symlink or reparse point.
|
339
|
+
|
340
|
+
This is a Windows-specific helper for symlink resolution that handles:
|
341
|
+
- NTFS symbolic links
|
342
|
+
- NTFS junction points
|
343
|
+
- NTFS mount points
|
344
|
+
- Other reparse points
|
345
|
+
|
346
|
+
Args:
|
347
|
+
path: The path to resolve.
|
348
|
+
|
349
|
+
Returns:
|
350
|
+
Resolved Path if successful, None if not a Windows symlink.
|
351
|
+
|
352
|
+
Note:
|
353
|
+
This function requires Windows and elevated privileges for some
|
354
|
+
reparse point operations.
|
355
|
+
|
356
|
+
Security Note:
|
357
|
+
By default, this function only handles regular symlinks.
|
358
|
+
For security reasons, other reparse points (junctions, mount points)
|
359
|
+
are not resolved by default as they can bypass directory restrictions.
|
360
|
+
If you need to handle these, implement proper security checks in the
|
361
|
+
calling code.
|
362
|
+
"""
|
363
|
+
if os.name != "nt":
|
364
|
+
return None
|
365
|
+
|
366
|
+
try:
|
367
|
+
# Try to resolve as a regular symlink first
|
368
|
+
if path.is_symlink():
|
369
|
+
target = Path(os.readlink(path))
|
370
|
+
logger.debug("Resolved symlink %r to %r", path, target)
|
371
|
+
return target
|
372
|
+
|
373
|
+
# Check if it's a reparse point but not a symlink
|
374
|
+
# This requires using Windows APIs, so we just warn about it
|
375
|
+
if hasattr(path, "is_mount") and path.is_mount():
|
376
|
+
logger.warning(
|
377
|
+
"Path %r is a mount point/junction - not resolving for security",
|
378
|
+
path,
|
379
|
+
)
|
380
|
+
return None
|
381
|
+
|
382
|
+
# For any other reparse points, log a warning
|
383
|
+
try:
|
384
|
+
import ctypes
|
385
|
+
|
386
|
+
attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) # type: ignore[attr-defined]
|
387
|
+
is_reparse = bool(
|
388
|
+
attrs != -1 and attrs & 0x400
|
389
|
+
) # FILE_ATTRIBUTE_REPARSE_POINT
|
390
|
+
if is_reparse:
|
391
|
+
logger.warning(
|
392
|
+
"Path %r is a reparse point - not resolving for security",
|
393
|
+
path,
|
394
|
+
)
|
395
|
+
return None
|
396
|
+
except Exception:
|
397
|
+
# If we can't check reparse attributes, assume it's not a reparse point
|
398
|
+
pass
|
399
|
+
|
400
|
+
return None
|
401
|
+
|
402
|
+
except OSError as e:
|
403
|
+
logger.debug("Failed to resolve Windows symlink %r: %s", path, e)
|
404
|
+
return None
|
@@ -0,0 +1,25 @@
|
|
1
|
+
"""Serialization utilities for CLI logging."""
|
2
|
+
|
3
|
+
import json
|
4
|
+
from typing import Any, Dict
|
5
|
+
|
6
|
+
|
7
|
+
class LogSerializer:
|
8
|
+
"""Utility class for serializing log data."""
|
9
|
+
|
10
|
+
@staticmethod
|
11
|
+
def serialize_log_extra(extra: Dict[str, Any]) -> str:
|
12
|
+
"""Serialize extra log data to a formatted string.
|
13
|
+
|
14
|
+
Args:
|
15
|
+
extra: Dictionary of extra log data
|
16
|
+
|
17
|
+
Returns:
|
18
|
+
Formatted string representation of the extra data
|
19
|
+
"""
|
20
|
+
try:
|
21
|
+
# Try to serialize with nice formatting
|
22
|
+
return json.dumps(extra, indent=2, default=str)
|
23
|
+
except Exception:
|
24
|
+
# Fall back to basic string representation if JSON fails
|
25
|
+
return str(extra)
|
ostruct/cli/template_filters.py
CHANGED
@@ -10,7 +10,7 @@ from collections import Counter
|
|
10
10
|
from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
|
11
11
|
|
12
12
|
import tiktoken
|
13
|
-
from jinja2 import Environment
|
13
|
+
from jinja2 import Environment, pass_context
|
14
14
|
from pygments import highlight
|
15
15
|
from pygments.formatters import HtmlFormatter, NullFormatter, TerminalFormatter
|
16
16
|
from pygments.lexers import TextLexer, get_lexer_by_name, guess_lexer
|
@@ -178,10 +178,12 @@ def format_error(e: Exception) -> str:
|
|
178
178
|
return f"{type(e).__name__}: {str(e)}"
|
179
179
|
|
180
180
|
|
181
|
-
|
181
|
+
@pass_context
|
182
|
+
def estimate_tokens(context: Any, text: str) -> int:
|
182
183
|
"""Estimate number of tokens in text."""
|
183
184
|
try:
|
184
|
-
encoding
|
185
|
+
# Use o200k_base encoding for token estimation
|
186
|
+
encoding = tiktoken.get_encoding("o200k_base")
|
185
187
|
return len(encoding.encode(str(text)))
|
186
188
|
except Exception as e:
|
187
189
|
logger.warning(f"Failed to estimate tokens: {e}")
|
@@ -544,9 +546,9 @@ def format_code(
|
|
544
546
|
"""Format code with syntax highlighting.
|
545
547
|
|
546
548
|
Args:
|
547
|
-
text
|
548
|
-
output_format
|
549
|
-
language
|
549
|
+
text: The code text to format
|
550
|
+
output_format: The output format ('terminal', 'html', or 'plain')
|
551
|
+
language: The programming language for syntax highlighting
|
550
552
|
|
551
553
|
Returns:
|
552
554
|
str: Formatted code string
|
@@ -580,10 +582,13 @@ def format_code(
|
|
580
582
|
else: # plain
|
581
583
|
formatter = NullFormatter[str]()
|
582
584
|
|
583
|
-
|
585
|
+
result = highlight(text, lexer, formatter)
|
586
|
+
if isinstance(result, bytes):
|
587
|
+
return result.decode("utf-8")
|
588
|
+
return str(result)
|
584
589
|
except Exception as e:
|
585
590
|
logger.error(f"Error formatting code: {e}")
|
586
|
-
return text
|
591
|
+
return str(text)
|
587
592
|
|
588
593
|
|
589
594
|
def register_template_filters(env: Environment) -> None:
|
@@ -61,8 +61,9 @@ from typing import Any, Dict, List, Optional, Union
|
|
61
61
|
import jinja2
|
62
62
|
from jinja2 import Environment
|
63
63
|
|
64
|
-
from .errors import TemplateValidationError
|
64
|
+
from .errors import TaskTemplateVariableError, TemplateValidationError
|
65
65
|
from .file_utils import FileInfo
|
66
|
+
from .progress import ProgressContext
|
66
67
|
from .template_env import create_jinja_env
|
67
68
|
from .template_schema import DotDict, StdinProxy
|
68
69
|
|
@@ -92,23 +93,23 @@ TemplateContextValue = Union[
|
|
92
93
|
def render_template(
|
93
94
|
template_str: str,
|
94
95
|
context: Dict[str, Any],
|
95
|
-
|
96
|
-
|
96
|
+
env: Optional[Environment] = None,
|
97
|
+
progress: Optional[ProgressContext] = None,
|
97
98
|
) -> str:
|
98
|
-
"""Render a
|
99
|
+
"""Render a template with the given context.
|
99
100
|
|
100
101
|
Args:
|
101
|
-
template_str:
|
102
|
-
context:
|
103
|
-
|
104
|
-
|
102
|
+
template_str: Template string to render
|
103
|
+
context: Context dictionary for template variables
|
104
|
+
env: Optional Jinja2 environment to use
|
105
|
+
progress: Optional progress bar to update
|
105
106
|
|
106
107
|
Returns:
|
107
|
-
|
108
|
+
str: The rendered template string
|
108
109
|
|
109
110
|
Raises:
|
110
|
-
|
111
|
-
|
111
|
+
TaskTemplateVariableError: If template variables are undefined
|
112
|
+
TemplateValidationError: If template rendering fails for other reasons
|
112
113
|
"""
|
113
114
|
from .progress import ( # Import here to avoid circular dependency
|
114
115
|
ProgressContext,
|
@@ -116,16 +117,14 @@ def render_template(
|
|
116
117
|
|
117
118
|
with ProgressContext(
|
118
119
|
description="Rendering task template",
|
119
|
-
level="basic" if
|
120
|
+
level="basic" if progress else "none",
|
120
121
|
) as progress:
|
121
122
|
try:
|
122
123
|
if progress:
|
123
124
|
progress.update(1) # Update progress for setup
|
124
125
|
|
125
|
-
if
|
126
|
-
|
127
|
-
loader=jinja2.FileSystemLoader(".")
|
128
|
-
)
|
126
|
+
if env is None:
|
127
|
+
env = create_jinja_env(loader=jinja2.FileSystemLoader("."))
|
129
128
|
|
130
129
|
logger.debug("=== Raw Input ===")
|
131
130
|
logger.debug(
|
@@ -154,7 +153,7 @@ def render_template(
|
|
154
153
|
# Wrap JSON variables in DotDict and handle special cases
|
155
154
|
wrapped_context: Dict[str, TemplateContextValue] = {}
|
156
155
|
for key, value in context.items():
|
157
|
-
if isinstance(value, dict):
|
156
|
+
if isinstance(value, dict) and not isinstance(value, DotDict):
|
158
157
|
wrapped_context[key] = DotDict(value)
|
159
158
|
else:
|
160
159
|
wrapped_context[key] = value
|
@@ -188,7 +187,7 @@ def render_template(
|
|
188
187
|
f"Task template file not found: {template_str}"
|
189
188
|
)
|
190
189
|
try:
|
191
|
-
template =
|
190
|
+
template = env.get_template(template_str)
|
192
191
|
except jinja2.TemplateNotFound as e:
|
193
192
|
raise TemplateValidationError(
|
194
193
|
f"Task template file not found: {e.name}"
|
@@ -199,10 +198,10 @@ def render_template(
|
|
199
198
|
template_str,
|
200
199
|
)
|
201
200
|
try:
|
202
|
-
template =
|
201
|
+
template = env.from_string(template_str)
|
203
202
|
|
204
203
|
# Add debug log for loop rendering
|
205
|
-
def debug_file_render(f: FileInfo) ->
|
204
|
+
def debug_file_render(f: FileInfo) -> Any:
|
206
205
|
logger.info("Rendering file: %s", f.path)
|
207
206
|
return ""
|
208
207
|
|
@@ -292,6 +291,10 @@ def render_template(
|
|
292
291
|
" %s: %s (%r)", key, type(value).__name__, value
|
293
292
|
)
|
294
293
|
result = template.render(**wrapped_context)
|
294
|
+
if not isinstance(result, str):
|
295
|
+
raise TemplateValidationError(
|
296
|
+
f"Template rendered to non-string type: {type(result)}"
|
297
|
+
)
|
295
298
|
logger.info(
|
296
299
|
"Template render result (first 100 chars): %r",
|
297
300
|
result[:100],
|
@@ -302,7 +305,17 @@ def render_template(
|
|
302
305
|
)
|
303
306
|
if progress:
|
304
307
|
progress.update(1)
|
305
|
-
return result
|
308
|
+
return result
|
309
|
+
except jinja2.UndefinedError as e:
|
310
|
+
# Extract variable name from error message
|
311
|
+
var_name = str(e).split("'")[1]
|
312
|
+
error_msg = (
|
313
|
+
f"Missing required template variable: {var_name}\n"
|
314
|
+
f"Available variables: {', '.join(sorted(context.keys()))}\n"
|
315
|
+
"To fix this, please provide the variable using:\n"
|
316
|
+
f" -V {var_name}='value'"
|
317
|
+
)
|
318
|
+
raise TaskTemplateVariableError(error_msg) from e
|
306
319
|
except (jinja2.TemplateError, Exception) as e:
|
307
320
|
logger.error("Template rendering failed: %s", str(e))
|
308
321
|
raise TemplateValidationError(
|
@@ -336,4 +349,15 @@ def render_template_file(
|
|
336
349
|
"""
|
337
350
|
with open(template_path, "r", encoding="utf-8") as f:
|
338
351
|
template_str = f.read()
|
339
|
-
|
352
|
+
|
353
|
+
# Create a progress context if enabled
|
354
|
+
progress = (
|
355
|
+
ProgressContext(
|
356
|
+
description="Rendering template file",
|
357
|
+
level="basic" if progress_enabled else "none",
|
358
|
+
)
|
359
|
+
if progress_enabled
|
360
|
+
else None
|
361
|
+
)
|
362
|
+
|
363
|
+
return render_template(template_str, context, jinja_env, progress)
|
ostruct/cli/template_utils.py
CHANGED
@@ -9,10 +9,11 @@ It re-exports the public APIs from specialized modules:
|
|
9
9
|
- template_io: File I/O operations and metadata extraction
|
10
10
|
"""
|
11
11
|
|
12
|
+
import logging
|
12
13
|
from typing import Any, Dict, List, Optional, Set
|
13
14
|
|
14
15
|
import jsonschema
|
15
|
-
from jinja2 import Environment, meta
|
16
|
+
from jinja2 import Environment, TemplateSyntaxError, meta
|
16
17
|
from jinja2.nodes import Node
|
17
18
|
|
18
19
|
from .errors import (
|
@@ -35,6 +36,8 @@ from .template_schema import (
|
|
35
36
|
)
|
36
37
|
from .template_validation import SafeUndefined, validate_template_placeholders
|
37
38
|
|
39
|
+
logger = logging.getLogger(__name__)
|
40
|
+
|
38
41
|
|
39
42
|
# Custom error classes
|
40
43
|
class TemplateMetadataError(TaskTemplateError):
|
@@ -118,8 +121,13 @@ def find_all_template_variables(
|
|
118
121
|
if env is None:
|
119
122
|
env = Environment(undefined=SafeUndefined)
|
120
123
|
|
121
|
-
|
122
|
-
|
124
|
+
try:
|
125
|
+
ast = env.parse(template)
|
126
|
+
except TemplateSyntaxError as e:
|
127
|
+
logger.error("Failed to parse template: %s", str(e))
|
128
|
+
return set() # Return empty set on parse error
|
129
|
+
|
130
|
+
variables: Set[str] = set()
|
123
131
|
|
124
132
|
# Find all variables in this template
|
125
133
|
variables = meta.find_undeclared_variables(ast)
|
@@ -252,7 +260,7 @@ def find_all_template_variables(
|
|
252
260
|
visit_nodes(node.else_)
|
253
261
|
|
254
262
|
visit_nodes(ast.body)
|
255
|
-
return variables
|
263
|
+
return variables
|
256
264
|
|
257
265
|
|
258
266
|
__all__ = [
|