ostruct-cli 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. ostruct/cli/base_errors.py +183 -0
  2. ostruct/cli/cli.py +830 -585
  3. ostruct/cli/click_options.py +338 -211
  4. ostruct/cli/errors.py +214 -227
  5. ostruct/cli/exit_codes.py +18 -0
  6. ostruct/cli/file_info.py +126 -69
  7. ostruct/cli/file_list.py +191 -72
  8. ostruct/cli/file_utils.py +132 -97
  9. ostruct/cli/path_utils.py +86 -77
  10. ostruct/cli/security/__init__.py +32 -0
  11. ostruct/cli/security/allowed_checker.py +55 -0
  12. ostruct/cli/security/base.py +46 -0
  13. ostruct/cli/security/case_manager.py +75 -0
  14. ostruct/cli/security/errors.py +164 -0
  15. ostruct/cli/security/normalization.py +161 -0
  16. ostruct/cli/security/safe_joiner.py +211 -0
  17. ostruct/cli/security/security_manager.py +366 -0
  18. ostruct/cli/security/symlink_resolver.py +483 -0
  19. ostruct/cli/security/types.py +108 -0
  20. ostruct/cli/security/windows_paths.py +404 -0
  21. ostruct/cli/serialization.py +25 -0
  22. ostruct/cli/template_filters.py +13 -8
  23. ostruct/cli/template_rendering.py +46 -22
  24. ostruct/cli/template_utils.py +12 -4
  25. ostruct/cli/template_validation.py +26 -8
  26. ostruct/cli/token_utils.py +43 -0
  27. ostruct/cli/validators.py +109 -0
  28. {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/METADATA +64 -24
  29. ostruct_cli-0.5.0.dist-info/RECORD +42 -0
  30. {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/WHEEL +1 -1
  31. ostruct/cli/security.py +0 -964
  32. ostruct/cli/security_types.py +0 -46
  33. ostruct_cli-0.3.0.dist-info/RECORD +0 -28
  34. {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/LICENSE +0 -0
  35. {ostruct_cli-0.3.0.dist-info → ostruct_cli-0.5.0.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,404 @@
1
+ """Windows path handling and validation.
2
+
3
+ This module provides functions for handling Windows-specific path features:
4
+ - Device paths (r"\\\\?\\", r"\\\\.")
5
+ - Drive-relative paths (C:folder)
6
+ - Reserved names (CON, PRN, etc.)
7
+ - UNC paths (r"\\\\server\\share")
8
+ - Alternate Data Streams (file.txt:stream)
9
+
10
+ Security Design Choices:
11
+ 1. Device Paths:
12
+ - Explicitly blocked for security
13
+ - No support for extended-length paths
14
+ - No direct device access allowed
15
+
16
+ 2. Drive Paths:
17
+ - Drive-relative paths must include separator
18
+ - Drive absolute paths are allowed
19
+ - Drive letters must be A-Z (case insensitive)
20
+
21
+ 3. Reserved Names:
22
+ - All Windows reserved names blocked
23
+ - Case-insensitive matching
24
+ - Blocked with or without extensions
25
+
26
+ 4. UNC Paths:
27
+ - Must be complete (server and share)
28
+ - No device paths in UNC format
29
+ - Normalized to forward slashes
30
+
31
+ 5. Alternate Data Streams:
32
+ - All ADS access is blocked
33
+ - No exceptions for Zone.Identifier
34
+ - Blocks both read and write
35
+
36
+ Known Limitations:
37
+ 1. Path Length:
38
+ - No extended-length path support
39
+ - Standard Windows MAX_PATH limits
40
+ - No workarounds for long paths
41
+
42
+ 2. Network:
43
+ - No special handling for DFS
44
+ - No support for administrative shares
45
+ - Basic UNC validation only
46
+
47
+ 3. Security:
48
+ - Some rare path formats may bypass checks
49
+ - Complex NTFS features not handled
50
+ - Limited reparse point support
51
+ """
52
+
53
+ import logging
54
+ import os
55
+ import re
56
+ from pathlib import Path, WindowsPath
57
+ from typing import Optional, Union
58
+
59
+ from .errors import PathSecurityError, SecurityErrorReasons
60
+
61
+ logger = logging.getLogger(__name__)
62
+
63
+ # Windows path length limits
64
+ MAX_PATH = 260
65
+ EXTENDED_MAX_PATH = 32767
66
+
67
+ # Regex patterns for Windows path features
68
+ _WINDOWS_DEVICE_PATH = re.compile(
69
+ r"^(?:\\\\|//)[?.](?:\\|/)(?!UNC(?:\\|/))", # Match device paths but exclude UNC
70
+ flags=re.IGNORECASE,
71
+ )
72
+
73
+ _WINDOWS_DRIVE_RELATIVE = re.compile(
74
+ r"(?:^|[/\\])[A-Za-z]:(?![/\\])|" # C:folder or \C:folder but not C:\folder
75
+ r"^/[A-Za-z]:(?![/\\])" # /C:folder variants
76
+ )
77
+
78
+ _WINDOWS_RESERVED_NAMES = re.compile(
79
+ r"^(CON|PRN|AUX|NUL|COM[1-9]|LPT[1-9])" # Base names
80
+ r"(\.[^\\/:*?\"<>|]*)?$", # Optional extension
81
+ re.IGNORECASE,
82
+ )
83
+
84
+ _WINDOWS_UNC = re.compile(
85
+ r"^\\\\[^?.\\/][^\\/]*\\[^\\/]+(?:\\.*)?|" # \\server\share[\anything]
86
+ r"^//[^?./][^/]*/[^/]+(?:/.*)?$" # //server/share[/anything]
87
+ )
88
+
89
+ _WINDOWS_INCOMPLETE_UNC = re.compile(
90
+ r"^\\\\[^?.\\/][^\\/]*(?:\\[^\\/]+)?$|" # \\server or \\server\incomplete
91
+ r"^//[^?./][^/]*(?:/[^/]+)?$" # //server or //server/incomplete variants
92
+ )
93
+
94
+ _WINDOWS_ADS = re.compile(
95
+ r":[^/\\<>:\"|?*]+$|" # Basic ADS
96
+ r":Zone\.Identifier$|" # Zone.Identifier
97
+ r":[^/\\<>:\"|?*]+:[^/\\]+$" # Multiple stream segments
98
+ )
99
+
100
+ _WINDOWS_INVALID_CHARS = re.compile(
101
+ r'[<>"|?*]|' # Standard invalid chars except colon
102
+ r"(?<!^[A-Za-z]):|" # Colon except after drive letter at start
103
+ r"[\x00-\x1F]" # Control chars
104
+ )
105
+
106
+ _WINDOWS_TRAILING = re.compile(r"[. ]+$") # Trailing dots/spaces
107
+
108
+
109
+ def is_windows_path(path: Union[str, Path]) -> bool:
110
+ """Check if path uses Windows-specific features.
111
+
112
+ Security Note:
113
+ - Detects device paths (r"\\?\\" and r"\\.\\") in both slash formats
114
+ - Case insensitive to handle drive letters
115
+ """
116
+ path_str = str(path)
117
+
118
+ # Normalize slashes for consistent matching
119
+ normalized_path = path_str.replace("\\", "/")
120
+
121
+ # Check for device paths first before any processing
122
+ if _WINDOWS_DEVICE_PATH.match(path_str) or _WINDOWS_DEVICE_PATH.match(
123
+ normalized_path
124
+ ):
125
+ logger.debug("Windows device path detected: %r", path_str)
126
+ return True
127
+
128
+ # Rest of the function remains unchanged
129
+ basename = os.path.basename(path_str)
130
+ is_drive_relative = bool(_WINDOWS_DRIVE_RELATIVE.search(path_str))
131
+ is_unc = bool(_WINDOWS_UNC.search(path_str))
132
+ is_ads = bool(_WINDOWS_ADS.search(path_str))
133
+ is_reserved = bool(_WINDOWS_RESERVED_NAMES.match(basename))
134
+
135
+ logger.debug(
136
+ "Windows path check for %r: drive_relative=%s, unc=%s, ads=%s, reserved=%s",
137
+ path_str,
138
+ is_drive_relative,
139
+ is_unc,
140
+ is_ads,
141
+ is_reserved,
142
+ )
143
+
144
+ return bool(is_drive_relative or is_unc or is_ads or is_reserved)
145
+
146
+
147
+ def normalize_windows_path(path: Union[str, Path]) -> Path:
148
+ """Normalize a path using Windows-specific rules.
149
+
150
+ This function:
151
+ 1. Converts to Path with Windows semantics
152
+ 2. Resolves to absolute path
153
+ 3. Normalizes separators and case
154
+ 4. Removes redundant separators and dots
155
+
156
+ Args:
157
+ path: Path to normalize
158
+
159
+ Returns:
160
+ Normalized Path
161
+
162
+ Raises:
163
+ PathSecurityError: If path cannot be normalized
164
+ """
165
+ try:
166
+ logger.debug("Normalizing Windows path: %r", path)
167
+
168
+ # Convert to string and normalize all slashes to forward slashes first
169
+ path_str = str(path)
170
+ # Replace all backslashes with forward slashes for consistent handling
171
+ path_str = path_str.replace("\\", "/")
172
+ # Collapse multiple slashes to single slash, except for UNC prefixes
173
+ path_str = re.sub(r"(?<!^)//+", "/", path_str)
174
+ logger.debug("Normalized slashes: %r", path_str)
175
+
176
+ # Use regular Path on non-Windows systems
177
+ path_cls = WindowsPath if os.name == "nt" else Path
178
+
179
+ # Convert back to backslashes for Windows path handling
180
+ if os.name == "nt":
181
+ path_str = path_str.replace("/", "\\")
182
+ # Preserve UNC path double backslashes
183
+ if path_str.startswith("\\") and not path_str.startswith("\\\\"):
184
+ path_str = "\\" + path_str
185
+
186
+ normalized = path_cls(path_str)
187
+ logger.debug(
188
+ "Created path object: %r (class=%s)", normalized, path_cls.__name__
189
+ )
190
+
191
+ if os.name == "nt":
192
+ # Check if resolve() would exceed MAX_PATH
193
+ resolved = normalized.resolve()
194
+ resolved_str = str(resolved)
195
+ # If resolve() added \\?\ prefix or path is too long, reject it
196
+ if (
197
+ resolved_str.startswith("\\\\?\\")
198
+ or len(resolved_str) > MAX_PATH
199
+ ):
200
+ raise PathSecurityError(
201
+ f"Path would exceed maximum length of {MAX_PATH} characters after resolution",
202
+ path=str(path),
203
+ context={
204
+ "reason": SecurityErrorReasons.NORMALIZATION_ERROR
205
+ },
206
+ )
207
+ normalized = resolved
208
+ logger.debug("Resolved on Windows: %r", normalized)
209
+ else:
210
+ # On non-Windows, just normalize the path
211
+ normalized = Path(os.path.normpath(path_str))
212
+ logger.debug("Normalized on non-Windows: %r", normalized)
213
+
214
+ return normalized
215
+ except PathSecurityError:
216
+ raise
217
+ except Exception as e:
218
+ logger.error(
219
+ "Failed to normalize Windows path %r: %s", path, e, exc_info=True
220
+ )
221
+ raise PathSecurityError(
222
+ f"Failed to normalize Windows path: {e}",
223
+ path=str(path),
224
+ context={"reason": SecurityErrorReasons.NORMALIZATION_ERROR},
225
+ )
226
+
227
+
228
+ def validate_windows_path(path: Union[str, Path]) -> Optional[str]:
229
+ """Validate a path for Windows-specific security issues.
230
+
231
+ Performs checks in order:
232
+ 1. Device paths (blocked)
233
+ 2. Path normalization
234
+ 3. Other Windows-specific checks
235
+
236
+ Returns an error message if the path:
237
+ - Uses device paths (r"\\\\?\\", r"\\\\.")
238
+ - Uses drive-relative paths (C:folder)
239
+ - Contains reserved names (CON, PRN, etc.)
240
+ - Uses UNC paths (r"\\\\server\\share")
241
+ - Contains Alternate Data Streams (file.txt:stream)
242
+ - Exceeds maximum path length
243
+ - Contains invalid characters
244
+ - Has trailing dots or spaces
245
+
246
+ Returns None if the path is valid.
247
+ """
248
+ logger.debug("Validating Windows path: %r", path)
249
+
250
+ # Initial checks on raw path string
251
+ path_str = str(path)
252
+
253
+ # Normalize slashes for consistent matching
254
+ normalized_path = path_str.replace("\\", "/")
255
+
256
+ # Check for device paths before any processing
257
+ if _WINDOWS_DEVICE_PATH.match(path_str) or _WINDOWS_DEVICE_PATH.match(
258
+ normalized_path
259
+ ):
260
+ logger.debug("Device path detected in original path: %r", path_str)
261
+ return "Device paths not allowed"
262
+
263
+ # Check for incomplete UNC paths before normalization
264
+ if _WINDOWS_INCOMPLETE_UNC.search(path_str):
265
+ logger.debug("Incomplete UNC path detected: %r", path_str)
266
+ return "Incomplete UNC path"
267
+
268
+ # Then normalize the path for other checks
269
+ try:
270
+ normalized_str = str(normalize_windows_path(path))
271
+ logger.debug("Normalized path: %r", normalized_str)
272
+
273
+ # Check for device paths again after normalization
274
+ if _WINDOWS_DEVICE_PATH.match(normalized_str):
275
+ logger.debug(
276
+ "Device path detected after normalization: %r", normalized_str
277
+ )
278
+ return "Device paths not allowed"
279
+
280
+ except PathSecurityError as e:
281
+ logger.debug("Path normalization failed: %s", e)
282
+ return str(e)
283
+
284
+ # Check path length
285
+ if len(normalized_str) > MAX_PATH:
286
+ msg = f"Path exceeds maximum length of {MAX_PATH} characters"
287
+ logger.debug("Path too long: %s", msg)
288
+ return msg
289
+
290
+ if _WINDOWS_DRIVE_RELATIVE.search(normalized_str):
291
+ logger.debug("Drive-relative path detected: %r", normalized_str)
292
+ return "Drive-relative paths must include separator"
293
+
294
+ # Check for complete UNC paths
295
+ if _WINDOWS_UNC.search(normalized_str):
296
+ logger.debug("UNC path detected: %r", normalized_str)
297
+ return "UNC paths not allowed"
298
+
299
+ if _WINDOWS_ADS.search(normalized_str):
300
+ logger.debug("Alternate Data Stream detected: %r", normalized_str)
301
+ return "Alternate Data Streams not allowed"
302
+
303
+ # Check each path component
304
+ try:
305
+ parts = (
306
+ Path(normalized_str).parts
307
+ if os.name != "nt"
308
+ else WindowsPath(normalized_str).parts
309
+ )
310
+ logger.debug("Path components: %r", parts)
311
+
312
+ for part in parts:
313
+ # Check for reserved names
314
+ if _WINDOWS_RESERVED_NAMES.match(part):
315
+ logger.debug("Reserved name detected: %r", part)
316
+ return "Windows reserved names not allowed"
317
+
318
+ # Check for invalid characters
319
+ if _WINDOWS_INVALID_CHARS.search(part):
320
+ msg = f"Invalid characters in path component '{part}'"
321
+ logger.debug("Invalid characters: %s", msg)
322
+ return msg
323
+
324
+ # Check for trailing dots/spaces
325
+ if _WINDOWS_TRAILING.search(part):
326
+ msg = f"Trailing dots or spaces not allowed in '{part}'"
327
+ logger.debug("Trailing dots/spaces: %s", msg)
328
+ return msg
329
+ except Exception as e:
330
+ logger.error("Failed to check path components: %s", e, exc_info=True)
331
+ return f"Failed to validate path components: {e}"
332
+
333
+ logger.debug("Path validation successful: %r", normalized_str)
334
+ return None
335
+
336
+
337
+ def resolve_windows_symlink(path: Path) -> Optional[Path]:
338
+ """Resolve a Windows symlink or reparse point.
339
+
340
+ This is a Windows-specific helper for symlink resolution that handles:
341
+ - NTFS symbolic links
342
+ - NTFS junction points
343
+ - NTFS mount points
344
+ - Other reparse points
345
+
346
+ Args:
347
+ path: The path to resolve.
348
+
349
+ Returns:
350
+ Resolved Path if successful, None if not a Windows symlink.
351
+
352
+ Note:
353
+ This function requires Windows and elevated privileges for some
354
+ reparse point operations.
355
+
356
+ Security Note:
357
+ By default, this function only handles regular symlinks.
358
+ For security reasons, other reparse points (junctions, mount points)
359
+ are not resolved by default as they can bypass directory restrictions.
360
+ If you need to handle these, implement proper security checks in the
361
+ calling code.
362
+ """
363
+ if os.name != "nt":
364
+ return None
365
+
366
+ try:
367
+ # Try to resolve as a regular symlink first
368
+ if path.is_symlink():
369
+ target = Path(os.readlink(path))
370
+ logger.debug("Resolved symlink %r to %r", path, target)
371
+ return target
372
+
373
+ # Check if it's a reparse point but not a symlink
374
+ # This requires using Windows APIs, so we just warn about it
375
+ if hasattr(path, "is_mount") and path.is_mount():
376
+ logger.warning(
377
+ "Path %r is a mount point/junction - not resolving for security",
378
+ path,
379
+ )
380
+ return None
381
+
382
+ # For any other reparse points, log a warning
383
+ try:
384
+ import ctypes
385
+
386
+ attrs = ctypes.windll.kernel32.GetFileAttributesW(str(path)) # type: ignore[attr-defined]
387
+ is_reparse = bool(
388
+ attrs != -1 and attrs & 0x400
389
+ ) # FILE_ATTRIBUTE_REPARSE_POINT
390
+ if is_reparse:
391
+ logger.warning(
392
+ "Path %r is a reparse point - not resolving for security",
393
+ path,
394
+ )
395
+ return None
396
+ except Exception:
397
+ # If we can't check reparse attributes, assume it's not a reparse point
398
+ pass
399
+
400
+ return None
401
+
402
+ except OSError as e:
403
+ logger.debug("Failed to resolve Windows symlink %r: %s", path, e)
404
+ return None
@@ -0,0 +1,25 @@
1
+ """Serialization utilities for CLI logging."""
2
+
3
+ import json
4
+ from typing import Any, Dict
5
+
6
+
7
+ class LogSerializer:
8
+ """Utility class for serializing log data."""
9
+
10
+ @staticmethod
11
+ def serialize_log_extra(extra: Dict[str, Any]) -> str:
12
+ """Serialize extra log data to a formatted string.
13
+
14
+ Args:
15
+ extra: Dictionary of extra log data
16
+
17
+ Returns:
18
+ Formatted string representation of the extra data
19
+ """
20
+ try:
21
+ # Try to serialize with nice formatting
22
+ return json.dumps(extra, indent=2, default=str)
23
+ except Exception:
24
+ # Fall back to basic string representation if JSON fails
25
+ return str(extra)
@@ -10,7 +10,7 @@ from collections import Counter
10
10
  from typing import Any, Dict, List, Optional, Sequence, TypeVar, Union
11
11
 
12
12
  import tiktoken
13
- from jinja2 import Environment
13
+ from jinja2 import Environment, pass_context
14
14
  from pygments import highlight
15
15
  from pygments.formatters import HtmlFormatter, NullFormatter, TerminalFormatter
16
16
  from pygments.lexers import TextLexer, get_lexer_by_name, guess_lexer
@@ -178,10 +178,12 @@ def format_error(e: Exception) -> str:
178
178
  return f"{type(e).__name__}: {str(e)}"
179
179
 
180
180
 
181
- def estimate_tokens(text: str) -> int:
181
+ @pass_context
182
+ def estimate_tokens(context: Any, text: str) -> int:
182
183
  """Estimate number of tokens in text."""
183
184
  try:
184
- encoding = tiktoken.encoding_for_model("gpt-4")
185
+ # Use o200k_base encoding for token estimation
186
+ encoding = tiktoken.get_encoding("o200k_base")
185
187
  return len(encoding.encode(str(text)))
186
188
  except Exception as e:
187
189
  logger.warning(f"Failed to estimate tokens: {e}")
@@ -544,9 +546,9 @@ def format_code(
544
546
  """Format code with syntax highlighting.
545
547
 
546
548
  Args:
547
- text (str): The code text to format
548
- output_format (str): The output format ('terminal', 'html', or 'plain')
549
- language (str): The programming language for syntax highlighting
549
+ text: The code text to format
550
+ output_format: The output format ('terminal', 'html', or 'plain')
551
+ language: The programming language for syntax highlighting
550
552
 
551
553
  Returns:
552
554
  str: Formatted code string
@@ -580,10 +582,13 @@ def format_code(
580
582
  else: # plain
581
583
  formatter = NullFormatter[str]()
582
584
 
583
- return highlight(text, lexer, formatter)
585
+ result = highlight(text, lexer, formatter)
586
+ if isinstance(result, bytes):
587
+ return result.decode("utf-8")
588
+ return str(result)
584
589
  except Exception as e:
585
590
  logger.error(f"Error formatting code: {e}")
586
- return text
591
+ return str(text)
587
592
 
588
593
 
589
594
  def register_template_filters(env: Environment) -> None:
@@ -61,8 +61,9 @@ from typing import Any, Dict, List, Optional, Union
61
61
  import jinja2
62
62
  from jinja2 import Environment
63
63
 
64
- from .errors import TemplateValidationError
64
+ from .errors import TaskTemplateVariableError, TemplateValidationError
65
65
  from .file_utils import FileInfo
66
+ from .progress import ProgressContext
66
67
  from .template_env import create_jinja_env
67
68
  from .template_schema import DotDict, StdinProxy
68
69
 
@@ -92,23 +93,23 @@ TemplateContextValue = Union[
92
93
  def render_template(
93
94
  template_str: str,
94
95
  context: Dict[str, Any],
95
- jinja_env: Optional[Environment] = None,
96
- progress_enabled: bool = True,
96
+ env: Optional[Environment] = None,
97
+ progress: Optional[ProgressContext] = None,
97
98
  ) -> str:
98
- """Render a task template with the given context.
99
+ """Render a template with the given context.
99
100
 
100
101
  Args:
101
- template_str: Task template string or path to task template file
102
- context: Task template variables
103
- jinja_env: Optional Jinja2 environment to use
104
- progress_enabled: Whether to show progress indicators
102
+ template_str: Template string to render
103
+ context: Context dictionary for template variables
104
+ env: Optional Jinja2 environment to use
105
+ progress: Optional progress bar to update
105
106
 
106
107
  Returns:
107
- Rendered task template string
108
+ str: The rendered template string
108
109
 
109
110
  Raises:
110
- TemplateValidationError: If task template cannot be loaded or rendered. The original error
111
- will be chained using `from` for proper error context.
111
+ TaskTemplateVariableError: If template variables are undefined
112
+ TemplateValidationError: If template rendering fails for other reasons
112
113
  """
113
114
  from .progress import ( # Import here to avoid circular dependency
114
115
  ProgressContext,
@@ -116,16 +117,14 @@ def render_template(
116
117
 
117
118
  with ProgressContext(
118
119
  description="Rendering task template",
119
- level="basic" if progress_enabled else "none",
120
+ level="basic" if progress else "none",
120
121
  ) as progress:
121
122
  try:
122
123
  if progress:
123
124
  progress.update(1) # Update progress for setup
124
125
 
125
- if jinja_env is None:
126
- jinja_env = create_jinja_env(
127
- loader=jinja2.FileSystemLoader(".")
128
- )
126
+ if env is None:
127
+ env = create_jinja_env(loader=jinja2.FileSystemLoader("."))
129
128
 
130
129
  logger.debug("=== Raw Input ===")
131
130
  logger.debug(
@@ -154,7 +153,7 @@ def render_template(
154
153
  # Wrap JSON variables in DotDict and handle special cases
155
154
  wrapped_context: Dict[str, TemplateContextValue] = {}
156
155
  for key, value in context.items():
157
- if isinstance(value, dict):
156
+ if isinstance(value, dict) and not isinstance(value, DotDict):
158
157
  wrapped_context[key] = DotDict(value)
159
158
  else:
160
159
  wrapped_context[key] = value
@@ -188,7 +187,7 @@ def render_template(
188
187
  f"Task template file not found: {template_str}"
189
188
  )
190
189
  try:
191
- template = jinja_env.get_template(template_str)
190
+ template = env.get_template(template_str)
192
191
  except jinja2.TemplateNotFound as e:
193
192
  raise TemplateValidationError(
194
193
  f"Task template file not found: {e.name}"
@@ -199,10 +198,10 @@ def render_template(
199
198
  template_str,
200
199
  )
201
200
  try:
202
- template = jinja_env.from_string(template_str)
201
+ template = env.from_string(template_str)
203
202
 
204
203
  # Add debug log for loop rendering
205
- def debug_file_render(f: FileInfo) -> str:
204
+ def debug_file_render(f: FileInfo) -> Any:
206
205
  logger.info("Rendering file: %s", f.path)
207
206
  return ""
208
207
 
@@ -292,6 +291,10 @@ def render_template(
292
291
  " %s: %s (%r)", key, type(value).__name__, value
293
292
  )
294
293
  result = template.render(**wrapped_context)
294
+ if not isinstance(result, str):
295
+ raise TemplateValidationError(
296
+ f"Template rendered to non-string type: {type(result)}"
297
+ )
295
298
  logger.info(
296
299
  "Template render result (first 100 chars): %r",
297
300
  result[:100],
@@ -302,7 +305,17 @@ def render_template(
302
305
  )
303
306
  if progress:
304
307
  progress.update(1)
305
- return result # type: ignore[no-any-return]
308
+ return result
309
+ except jinja2.UndefinedError as e:
310
+ # Extract variable name from error message
311
+ var_name = str(e).split("'")[1]
312
+ error_msg = (
313
+ f"Missing required template variable: {var_name}\n"
314
+ f"Available variables: {', '.join(sorted(context.keys()))}\n"
315
+ "To fix this, please provide the variable using:\n"
316
+ f" -V {var_name}='value'"
317
+ )
318
+ raise TaskTemplateVariableError(error_msg) from e
306
319
  except (jinja2.TemplateError, Exception) as e:
307
320
  logger.error("Template rendering failed: %s", str(e))
308
321
  raise TemplateValidationError(
@@ -336,4 +349,15 @@ def render_template_file(
336
349
  """
337
350
  with open(template_path, "r", encoding="utf-8") as f:
338
351
  template_str = f.read()
339
- return render_template(template_str, context, jinja_env, progress_enabled)
352
+
353
+ # Create a progress context if enabled
354
+ progress = (
355
+ ProgressContext(
356
+ description="Rendering template file",
357
+ level="basic" if progress_enabled else "none",
358
+ )
359
+ if progress_enabled
360
+ else None
361
+ )
362
+
363
+ return render_template(template_str, context, jinja_env, progress)
@@ -9,10 +9,11 @@ It re-exports the public APIs from specialized modules:
9
9
  - template_io: File I/O operations and metadata extraction
10
10
  """
11
11
 
12
+ import logging
12
13
  from typing import Any, Dict, List, Optional, Set
13
14
 
14
15
  import jsonschema
15
- from jinja2 import Environment, meta
16
+ from jinja2 import Environment, TemplateSyntaxError, meta
16
17
  from jinja2.nodes import Node
17
18
 
18
19
  from .errors import (
@@ -35,6 +36,8 @@ from .template_schema import (
35
36
  )
36
37
  from .template_validation import SafeUndefined, validate_template_placeholders
37
38
 
39
+ logger = logging.getLogger(__name__)
40
+
38
41
 
39
42
  # Custom error classes
40
43
  class TemplateMetadataError(TaskTemplateError):
@@ -118,8 +121,13 @@ def find_all_template_variables(
118
121
  if env is None:
119
122
  env = Environment(undefined=SafeUndefined)
120
123
 
121
- # Parse template
122
- ast = env.parse(template)
124
+ try:
125
+ ast = env.parse(template)
126
+ except TemplateSyntaxError as e:
127
+ logger.error("Failed to parse template: %s", str(e))
128
+ return set() # Return empty set on parse error
129
+
130
+ variables: Set[str] = set()
123
131
 
124
132
  # Find all variables in this template
125
133
  variables = meta.find_undeclared_variables(ast)
@@ -252,7 +260,7 @@ def find_all_template_variables(
252
260
  visit_nodes(node.else_)
253
261
 
254
262
  visit_nodes(ast.body)
255
- return variables # type: ignore[no-any-return]
263
+ return variables
256
264
 
257
265
 
258
266
  __all__ = [