confarg 0.0.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,566 @@
1
+ # This Source Code Form is subject to the terms of the Mozilla Public
2
+ # License, v. 2.0. If a copy of the MPL was not distributed with this
3
+ # file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
+
5
+ """Expression resolution for ${...} field references and computations."""
6
+
7
+ from __future__ import annotations
8
+
9
+ import ast
10
+ import copy
11
+ import math
12
+ import operator
13
+ import re
14
+ from collections import deque
15
+ from typing import Any
16
+
17
+ from confarg._errors import (
18
+ CircularReferenceError,
19
+ ExpressionEvalError,
20
+ MissingReferenceError,
21
+ UnsafeExpressionError,
22
+ )
23
+
24
+ # Regex: matches escaped $${...} (no capture) and real ${...} (content in group 1)
25
+ _EXPR_RE = re.compile(r"\$\$\{[^}]*\}|\$\{([^}]+)\}")
26
+
27
+ # Whitelisted free functions
28
+ _SAFE_FUNCTIONS: dict[str, Any] = {
29
+ "abs": abs,
30
+ "min": min,
31
+ "max": max,
32
+ "round": round,
33
+ "ceil": math.ceil,
34
+ "floor": math.floor,
35
+ "str": str,
36
+ "int": int,
37
+ "float": float,
38
+ "bool": bool,
39
+ "len": len,
40
+ }
41
+
42
+ # Whitelisted string methods
43
+ _SAFE_METHODS: set[str] = {
44
+ "upper",
45
+ "lower",
46
+ "strip",
47
+ "split",
48
+ "replace",
49
+ "startswith",
50
+ "endswith",
51
+ "join",
52
+ }
53
+
54
+ # Allowed AST node types
55
+ _ALLOWED_NODES: set[type] = {
56
+ ast.Expression,
57
+ ast.Constant,
58
+ ast.Name,
59
+ ast.Attribute,
60
+ ast.Subscript,
61
+ ast.BinOp,
62
+ ast.UnaryOp,
63
+ ast.Compare,
64
+ ast.BoolOp,
65
+ ast.Call,
66
+ ast.IfExp,
67
+ ast.Load,
68
+ # Operator nodes
69
+ ast.Add,
70
+ ast.Sub,
71
+ ast.Mult,
72
+ ast.Div,
73
+ ast.FloorDiv,
74
+ ast.Mod,
75
+ ast.Pow,
76
+ ast.UAdd,
77
+ ast.USub,
78
+ ast.Not,
79
+ ast.Eq,
80
+ ast.NotEq,
81
+ ast.Lt,
82
+ ast.LtE,
83
+ ast.Gt,
84
+ ast.GtE,
85
+ ast.And,
86
+ ast.Or,
87
+ }
88
+
89
+ # Binary operator dispatch
90
+ _BINOP_MAP: dict[type, Any] = {
91
+ ast.Add: operator.add,
92
+ ast.Sub: operator.sub,
93
+ ast.Mult: operator.mul,
94
+ ast.Div: operator.truediv,
95
+ ast.FloorDiv: operator.floordiv,
96
+ ast.Mod: operator.mod,
97
+ ast.Pow: operator.pow,
98
+ }
99
+
100
+ # Unary operator dispatch
101
+ _UNARYOP_MAP: dict[type, Any] = {
102
+ ast.UAdd: operator.pos,
103
+ ast.USub: operator.neg,
104
+ ast.Not: operator.not_,
105
+ }
106
+
107
+ # Comparison operator dispatch
108
+ _CMPOP_MAP: dict[type, Any] = {
109
+ ast.Eq: operator.eq,
110
+ ast.NotEq: operator.ne,
111
+ ast.Lt: operator.lt,
112
+ ast.LtE: operator.le,
113
+ ast.Gt: operator.gt,
114
+ ast.GtE: operator.ge,
115
+ }
116
+
117
+
118
+ def resolve_expressions(
119
+ data: dict[str, Any],
120
+ ) -> dict[str, Any]:
121
+ """Resolve ${...} expressions in a merged config dict.
122
+
123
+ Args:
124
+ data: The merged configuration dict.
125
+
126
+ Returns:
127
+ A new dict with all ${...} expression strings replaced by their values.
128
+ Returns data unchanged if no expressions are found.
129
+
130
+ Raises:
131
+ CircularReferenceError: If expression references form a cycle.
132
+ UnsafeExpressionError: If an expression contains disallowed constructs.
133
+ MissingReferenceError: If an expression references a field that does not exist.
134
+ ExpressionEvalError: If an expression fails at runtime.
135
+ """
136
+ # 1. Scan for expressions
137
+ expr_fields = _scan_expressions(data)
138
+ if not expr_fields:
139
+ return data
140
+
141
+ data = copy.deepcopy(data)
142
+
143
+ # 2. Extract references and build dependency graph
144
+ deps: dict[str, set[str]] = {}
145
+ for path, raw_str in expr_fields.items():
146
+ refs = _extract_references(raw_str)
147
+ # Filter refs to only those that are themselves expressions
148
+ # Non-expression refs are "free" (already resolved)
149
+ deps[path] = refs & set(expr_fields.keys())
150
+
151
+ # 3. Topological sort
152
+ order = _topological_sort(deps)
153
+
154
+ # 4. Validate AST for all expressions
155
+ for path in order:
156
+ raw_str = expr_fields[path]
157
+ for m in _EXPR_RE.finditer(raw_str):
158
+ expr_content = m.group(1)
159
+ if expr_content is not None: # not escaped
160
+ _validate_ast(expr_content)
161
+
162
+ # 5. Resolve in order, building namespace incrementally
163
+ for path in order:
164
+ raw_str = expr_fields[path]
165
+ result = _resolve_single(raw_str, data)
166
+ _set_nested_by_path(data, path, result)
167
+
168
+ return data
169
+
170
+
171
+ def _scan_expressions(
172
+ data: dict[str, Any],
173
+ prefix: str = "",
174
+ ) -> dict[str, str]:
175
+ """Walk merged dict, find string values containing ${...}.
176
+
177
+ Returns:
178
+ Dict mapping dotted paths to raw expression strings.
179
+ """
180
+ result: dict[str, str] = {}
181
+ for key, value in data.items():
182
+ full_path = f"{prefix}.{key}" if prefix else key
183
+ _collect_expressions(value, full_path, result)
184
+ return result
185
+
186
+
187
+ def _collect_expressions(value: Any, path: str, out: dict[str, str]) -> None:
188
+ """Recursively collect expression strings from a value into *out*."""
189
+ if isinstance(value, dict):
190
+ for k, v in value.items():
191
+ _collect_expressions(v, f"{path}.{k}", out)
192
+ elif isinstance(value, list):
193
+ for i, item in enumerate(value):
194
+ _collect_expressions(item, f"{path}.{i}", out)
195
+ elif isinstance(value, str) and _EXPR_RE.search(value):
196
+ out[path] = value
197
+
198
+
199
+ def _extract_references(expr_str: str) -> set[str]:
200
+ """Extract dotted field paths referenced in expression string.
201
+
202
+ Returns:
203
+ Set of dotted paths (e.g. {"db.host", "db.port"}).
204
+ """
205
+ refs: set[str] = set()
206
+ for m in _EXPR_RE.finditer(expr_str):
207
+ expr_content = m.group(1)
208
+ if expr_content is None:
209
+ continue # escaped $${...}
210
+ try:
211
+ tree = ast.parse(expr_content, mode="eval")
212
+ except SyntaxError:
213
+ continue
214
+ _collect_names(tree, refs)
215
+ return refs
216
+
217
+
218
+ def _collect_names(node: ast.AST, refs: set[str]) -> None:
219
+ """Collect Name nodes and dotted Attribute chains from AST as field references."""
220
+ if isinstance(node, ast.Name):
221
+ if node.id not in _SAFE_FUNCTIONS:
222
+ refs.add(node.id)
223
+ return
224
+ if isinstance(node, ast.Attribute):
225
+ parts = _attribute_chain(node)
226
+ if parts is not None:
227
+ # Check if the first part is a safe function (not a ref)
228
+ if parts[0] not in _SAFE_FUNCTIONS:
229
+ # Check if this is a method call — the attribute itself might be a method
230
+ # We add the full dotted path as a potential reference
231
+ refs.add(".".join(parts))
232
+ else:
233
+ # Non-name base, recurse into children
234
+ for child in ast.iter_child_nodes(node):
235
+ _collect_names(child, refs)
236
+ return
237
+ if isinstance(node, ast.Call):
238
+ # For method calls like name.upper(), don't add "name.upper" as ref
239
+ # Instead, check if it's a method call and only add the object
240
+ if isinstance(node.func, ast.Attribute):
241
+ parts = _attribute_chain(node.func)
242
+ if parts is not None and len(parts) >= 2:
243
+ method_name = parts[-1]
244
+ if method_name in _SAFE_METHODS:
245
+ # The object part is the reference
246
+ obj_path = ".".join(parts[:-1])
247
+ if obj_path not in _SAFE_FUNCTIONS:
248
+ refs.add(obj_path)
249
+ # Also collect refs from arguments
250
+ for arg in node.args:
251
+ _collect_names(arg, refs)
252
+ for kw in node.keywords:
253
+ _collect_names(kw.value, refs)
254
+ return
255
+ # Fall through to collect from func base
256
+ _collect_names(node.func, refs)
257
+ elif isinstance(node.func, ast.Name):
258
+ # Free function call — don't add function name, but add args
259
+ pass
260
+ else:
261
+ _collect_names(node.func, refs)
262
+ for arg in node.args:
263
+ _collect_names(arg, refs)
264
+ for kw in node.keywords:
265
+ _collect_names(kw.value, refs)
266
+ return
267
+ # Recurse into all child nodes
268
+ for child in ast.iter_child_nodes(node):
269
+ _collect_names(child, refs)
270
+
271
+
272
+ def _attribute_chain(node: ast.Attribute | ast.Subscript) -> list[str] | None:
273
+ """Extract dotted name chain from nested Attribute/Subscript nodes.
274
+
275
+ Returns e.g. ["db", "host"] for ``db.host``, ["servers", "0", "host"] for
276
+ ``servers[0].host``, or None if base is not a Name.
277
+ """
278
+ parts: list[str] = []
279
+ if isinstance(node, ast.Attribute):
280
+ parts.append(node.attr)
281
+ current: ast.AST = node.value
282
+ else:
283
+ # Subscript at top (shouldn't be called directly, but handle it)
284
+ if isinstance(node.slice, ast.Constant) and isinstance(node.slice.value, int):
285
+ parts.append(str(node.slice.value))
286
+ else:
287
+ return None
288
+ current = node.value
289
+ while True:
290
+ if isinstance(current, ast.Attribute):
291
+ parts.append(current.attr)
292
+ current = current.value
293
+ elif isinstance(current, ast.Subscript):
294
+ if isinstance(current.slice, ast.Constant) and isinstance(current.slice.value, int):
295
+ parts.append(str(current.slice.value))
296
+ current = current.value
297
+ else:
298
+ return None
299
+ else:
300
+ break
301
+ if isinstance(current, ast.Name):
302
+ parts.append(current.id)
303
+ parts.reverse()
304
+ return parts
305
+ return None
306
+
307
+
308
+ def _topological_sort(deps: dict[str, set[str]]) -> list[str]:
309
+ """Kahn's algorithm. Raises CircularReferenceError on cycles."""
310
+ if not deps:
311
+ return []
312
+
313
+ in_degree: dict[str, int] = dict.fromkeys(deps, 0)
314
+ for node, node_deps in deps.items():
315
+ for dep in node_deps:
316
+ if dep in deps:
317
+ in_degree[node] += 1
318
+
319
+ queue: deque[str] = deque()
320
+ for node, degree in in_degree.items():
321
+ if degree == 0:
322
+ queue.append(node)
323
+
324
+ order: list[str] = []
325
+ while queue:
326
+ node = queue.popleft()
327
+ order.append(node)
328
+ # Find nodes that depend on this one
329
+ for other, other_deps in deps.items():
330
+ if node in other_deps and other not in order:
331
+ in_degree[other] -= 1
332
+ if in_degree[other] == 0:
333
+ queue.append(other)
334
+
335
+ if len(order) != len(deps):
336
+ remaining = set(deps.keys()) - set(order)
337
+ raise CircularReferenceError(f"Circular reference detected among: {', '.join(sorted(remaining))}")
338
+
339
+ return order
340
+
341
+
342
+ def _validate_ast(expr_str: str) -> None:
343
+ """Parse expression and validate AST contains only allowed nodes.
344
+
345
+ Raises UnsafeExpressionError for disallowed constructs.
346
+ """
347
+ try:
348
+ tree = ast.parse(expr_str, mode="eval")
349
+ except SyntaxError as exc:
350
+ raise UnsafeExpressionError(f"Invalid expression syntax: {expr_str!r}") from exc
351
+
352
+ for node in ast.walk(tree):
353
+ if type(node) not in _ALLOWED_NODES:
354
+ raise UnsafeExpressionError(f"Disallowed construct in expression: {type(node).__name__}")
355
+ # Check for dunder attribute access
356
+ if isinstance(node, ast.Attribute) and node.attr.startswith("__"):
357
+ raise UnsafeExpressionError(f"Access to dunder attribute '{node.attr}' is not allowed")
358
+ # Check function calls are whitelisted
359
+ if isinstance(node, ast.Call):
360
+ _validate_call(node)
361
+
362
+
363
+ def _validate_call(node: ast.Call) -> None:
364
+ """Validate that a Call node targets a whitelisted function/method."""
365
+ if isinstance(node.func, ast.Name):
366
+ if node.func.id not in _SAFE_FUNCTIONS:
367
+ raise UnsafeExpressionError(f"Function '{node.func.id}' is not allowed")
368
+ elif isinstance(node.func, ast.Attribute):
369
+ if node.func.attr not in _SAFE_METHODS and node.func.attr not in _SAFE_FUNCTIONS:
370
+ raise UnsafeExpressionError(f"Method '{node.func.attr}' is not allowed")
371
+ else:
372
+ raise UnsafeExpressionError("Indirect function calls are not allowed")
373
+
374
+
375
+ def _evaluate_ast(node: ast.AST, namespace: dict[str, Any]) -> Any:
376
+ """Recursively evaluate AST node against namespace."""
377
+ if isinstance(node, ast.Expression):
378
+ return _evaluate_ast(node.body, namespace)
379
+
380
+ if isinstance(node, ast.Constant):
381
+ return node.value
382
+
383
+ if isinstance(node, ast.Name):
384
+ if node.id in _SAFE_FUNCTIONS:
385
+ return _SAFE_FUNCTIONS[node.id]
386
+ return _get_nested(namespace, node.id)
387
+
388
+ if isinstance(node, ast.Attribute):
389
+ parts = _attribute_chain(node)
390
+ if parts is not None:
391
+ # Try as dotted path first
392
+ full_path = ".".join(parts)
393
+ try:
394
+ return _get_nested(namespace, full_path)
395
+ except MissingReferenceError:
396
+ # Fall back to attribute access on evaluated value
397
+ pass
398
+ # Evaluate value then access attribute
399
+ value = _evaluate_ast(node.value, namespace)
400
+ return getattr(value, node.attr)
401
+
402
+ if isinstance(node, ast.Subscript):
403
+ value = _evaluate_ast(node.value, namespace)
404
+ index = _evaluate_ast(node.slice, namespace)
405
+ return value[index]
406
+
407
+ if isinstance(node, ast.BinOp):
408
+ left = _evaluate_ast(node.left, namespace)
409
+ right = _evaluate_ast(node.right, namespace)
410
+ op_func = _BINOP_MAP.get(type(node.op))
411
+ if op_func is None:
412
+ raise ExpressionEvalError(f"Unsupported binary operator: {type(node.op).__name__}")
413
+ try:
414
+ return op_func(left, right)
415
+ except Exception as exc:
416
+ raise ExpressionEvalError(str(exc)) from exc
417
+
418
+ if isinstance(node, ast.UnaryOp):
419
+ operand = _evaluate_ast(node.operand, namespace)
420
+ op_func = _UNARYOP_MAP.get(type(node.op))
421
+ if op_func is None:
422
+ raise ExpressionEvalError(f"Unsupported unary operator: {type(node.op).__name__}")
423
+ return op_func(operand)
424
+
425
+ if isinstance(node, ast.Compare):
426
+ left = _evaluate_ast(node.left, namespace)
427
+ for op, comparator in zip(node.ops, node.comparators, strict=False):
428
+ right = _evaluate_ast(comparator, namespace)
429
+ op_func = _CMPOP_MAP.get(type(op))
430
+ if op_func is None:
431
+ raise ExpressionEvalError(f"Unsupported comparison: {type(op).__name__}")
432
+ if not op_func(left, right):
433
+ return False
434
+ left = right
435
+ return True
436
+
437
+ if isinstance(node, ast.BoolOp):
438
+ if isinstance(node.op, ast.And):
439
+ result: Any = True
440
+ for value in node.values:
441
+ result = _evaluate_ast(value, namespace)
442
+ if not result:
443
+ return result
444
+ return result
445
+ else: # ast.Or
446
+ result = False
447
+ for value in node.values:
448
+ result = _evaluate_ast(value, namespace)
449
+ if result:
450
+ return result
451
+ return result
452
+
453
+ if isinstance(node, ast.IfExp):
454
+ test = _evaluate_ast(node.test, namespace)
455
+ if test:
456
+ return _evaluate_ast(node.body, namespace)
457
+ return _evaluate_ast(node.orelse, namespace)
458
+
459
+ if isinstance(node, ast.Call):
460
+ func = _evaluate_ast(node.func, namespace)
461
+ args = [_evaluate_ast(a, namespace) for a in node.args]
462
+ kwargs = {kw.arg: _evaluate_ast(kw.value, namespace) for kw in node.keywords}
463
+ try:
464
+ return func(*args, **kwargs)
465
+ except Exception as exc:
466
+ raise ExpressionEvalError(str(exc)) from exc
467
+
468
+ raise ExpressionEvalError(f"Cannot evaluate node type: {type(node).__name__}") # pragma: no cover
469
+
470
+
471
+ def _resolve_single(expr_str: str, namespace: dict[str, Any]) -> Any:
472
+ """Resolve a single expression string.
473
+
474
+ Handles three cases:
475
+ 1. Pure ${expr} — typed result
476
+ 2. Interpolation (text around ${...}) — string result
477
+ 3. Escaped $${...} — literal ${...}
478
+ """
479
+ # Check if the entire string is a single ${expr}
480
+ stripped = expr_str.strip()
481
+ m = re.fullmatch(r"\$\{([^}]+)\}", stripped)
482
+ if m and stripped == expr_str:
483
+ # Pure expression — return typed result
484
+ tree = ast.parse(m.group(1), mode="eval")
485
+ try:
486
+ return _evaluate_ast(tree, namespace)
487
+ except (MissingReferenceError, UnsafeExpressionError):
488
+ raise
489
+ except ExpressionEvalError:
490
+ raise
491
+ except Exception as exc:
492
+ raise ExpressionEvalError(f"Error in expression {expr_str!r}: {exc}") from exc
493
+
494
+ # Interpolation or escape mode: build string from parts
495
+ result_parts: list[str] = []
496
+ last_end = 0
497
+ for m in _EXPR_RE.finditer(expr_str):
498
+ # Add literal text before this match
499
+ start = m.start()
500
+ result_parts.append(expr_str[last_end:start])
501
+
502
+ if m.group(1) is None:
503
+ # Escaped $${...} — produce literal ${...}
504
+ escaped_text = m.group(0) # e.g. "$${foo}"
505
+ result_parts.append(escaped_text[1:]) # strip one $, producing "${foo}"
506
+ else:
507
+ # Real expression — evaluate and stringify
508
+ tree = ast.parse(m.group(1), mode="eval")
509
+ try:
510
+ value = _evaluate_ast(tree, namespace)
511
+ except (MissingReferenceError, UnsafeExpressionError):
512
+ raise
513
+ except ExpressionEvalError:
514
+ raise
515
+ except Exception as exc:
516
+ raise ExpressionEvalError(f"Error in expression {m.group(0)!r}: {exc}") from exc
517
+ result_parts.append(str(value))
518
+
519
+ last_end = m.end()
520
+
521
+ # Add any trailing literal text
522
+ result_parts.append(expr_str[last_end:])
523
+ return "".join(result_parts)
524
+
525
+
526
+ def _get_nested(data: dict[str, Any], path: str) -> Any:
527
+ """Retrieve value from nested dict/list by dotted path."""
528
+ parts = path.split(".")
529
+ current: Any = data
530
+ for part in parts:
531
+ if isinstance(current, dict):
532
+ if part not in current:
533
+ raise MissingReferenceError.field_not_found(path)
534
+ current = current[part]
535
+ elif isinstance(current, list | tuple):
536
+ try:
537
+ idx = int(part)
538
+ except ValueError:
539
+ raise MissingReferenceError.field_not_found(path, f"'{part}' is not a valid index") from None
540
+ try:
541
+ current = current[idx]
542
+ except IndexError:
543
+ raise MissingReferenceError.field_not_found(path, f"index {idx} out of range") from None
544
+ else:
545
+ raise MissingReferenceError.field_not_found(path, f"cannot traverse into {type(current).__name__}")
546
+ return current
547
+
548
+
549
+ def _set_nested_by_path(data: dict[str, Any], path: str, value: Any) -> None:
550
+ """Set value in nested dict by dotted path."""
551
+ parts = path.split(".")
552
+ current: Any = data
553
+ for part in parts[:-1]:
554
+ if isinstance(current, dict):
555
+ current = current[part]
556
+ elif isinstance(current, list | tuple):
557
+ current = current[int(part)]
558
+ else:
559
+ raise MissingReferenceError(f"Cannot set path '{path}': cannot traverse into {type(current).__name__}")
560
+ last = parts[-1]
561
+ if isinstance(current, dict):
562
+ current[last] = value
563
+ elif isinstance(current, list):
564
+ current[int(last)] = value
565
+ else:
566
+ raise MissingReferenceError(f"Cannot set path '{path}'")
@@ -0,0 +1,44 @@
1
+ # This Source Code Form is subject to the terms of the Mozilla Public
2
+ # License, v. 2.0. If a copy of the MPL was not distributed with this
3
+ # file, You can obtain one at https://mozilla.org/MPL/2.0/.
4
+
5
+ """confarg.typedload — type-aware construction of Python dataclasses from raw dicts.
6
+
7
+ Builds typed instances from plain dicts, with union disambiguation (tag-based,
8
+ structural, or leaf-coercion), nested dataclass support, and collection handling.
9
+ Also exposes leaf-value coercion for scalar types.
10
+
11
+ Typical use::
12
+
13
+ from dataclasses import dataclass
14
+ from confarg.typedload import construct, coerce
15
+
16
+
17
+ @dataclass
18
+ class Server:
19
+ host: str
20
+ port: int
21
+
22
+
23
+ srv = construct(Server, {"host": "localhost", "port": "8080"})
24
+ # srv == Server(host="localhost", port=8080)
25
+
26
+ val = coerce(int, "42")
27
+ # val == 42
28
+ """
29
+
30
+ from confarg._errors import (
31
+ AmbiguousUnionError,
32
+ MissingFieldError,
33
+ TypeCoercionError,
34
+ )
35
+ from confarg.typedload._coerce import _coerce_leaf as coerce
36
+ from confarg.typedload._construct import construct
37
+
38
+ __all__ = [
39
+ "AmbiguousUnionError",
40
+ "MissingFieldError",
41
+ "TypeCoercionError",
42
+ "coerce",
43
+ "construct",
44
+ ]