confit 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
confit/__init__.py ADDED
@@ -0,0 +1,15 @@
1
+ from .cli import Cli
2
+ from .config import Config
3
+ from .registry import (
4
+ validate_arguments,
5
+ Registry,
6
+ get_default_registry,
7
+ set_default_registry,
8
+ RegistryCollection,
9
+ VisibleDeprecationWarning,
10
+ )
11
+ from .autoreload import autoreload_plugin
12
+
13
+ __version__ = "0.5.6"
14
+
15
+ autoreload_plugin()
confit/autoreload.py ADDED
@@ -0,0 +1,40 @@
1
+ """
2
+ Plugin to help IPython's autoreload magic reload functions wrapped with confit.
3
+ """
4
+
5
+ import types
6
+
7
+
8
+ def check_wrapped(a, b):
9
+ return (
10
+ isinstance(b, types.FunctionType)
11
+ and hasattr(a, "__wrapped__")
12
+ and isinstance(a.__wrapped__, types.FunctionType)
13
+ )
14
+
15
+
16
+ def update_wrapped_function(old, new):
17
+ from IPython.extensions.autoreload import func_attrs
18
+
19
+ """Upgrade the code object of a function"""
20
+ new_func = new if not hasattr(new, "__wrapped__") else new.__wrapped__
21
+ for name in func_attrs:
22
+ try:
23
+ setattr(old.__wrapped__, name, getattr(new_func, name))
24
+ except (AttributeError, TypeError):
25
+ pass
26
+
27
+
28
+ def autoreload_plugin():
29
+ try:
30
+ from IPython.extensions.autoreload import UPDATE_RULES
31
+ except ImportError:
32
+ return
33
+
34
+ UPDATE_RULES.insert(
35
+ 0,
36
+ (
37
+ lambda a, b: check_wrapped(a, b),
38
+ update_wrapped_function,
39
+ ),
40
+ )
confit/cli.py ADDED
@@ -0,0 +1,176 @@
1
+ import inspect
2
+ import sys
3
+ from pathlib import Path
4
+ from typing import Any, Callable, Dict, List, Optional, Type, Union
5
+
6
+ from typer import Context, Typer, colors, secho
7
+ from typer.core import TyperCommand
8
+ from typer.models import CommandFunctionType, Default
9
+
10
+ from .config import Config, merge_from_disk
11
+ from .errors import ConfitValidationError, LegacyValidationError, patch_errors
12
+ from .registry import validate_arguments
13
+ from .utils.random import set_seed
14
+ from .utils.settings import is_debug
15
+ from .utils.xjson import loads
16
+
17
+
18
+ def parse_overrides(args: List[str]) -> Dict[str, Any]:
19
+ """
20
+ Parse the overrides from the command line into a dictionary
21
+ of key-value pairs.
22
+
23
+ Parameters
24
+ ----------
25
+ args: List[str]
26
+ The arguments to parse
27
+
28
+ Returns
29
+ -------
30
+ Dict[str, Any]
31
+ The parsed overrides as a dictionary
32
+ """
33
+ result = {}
34
+ while args:
35
+ opt = args.pop(0)
36
+ err = f"Invalid config override '{opt}'"
37
+ if opt.startswith("--"): # new argument
38
+ opt = opt.replace("--", "")
39
+ if "=" in opt: # we have --opt=value
40
+ opt, value = opt.split("=", 1)
41
+ else:
42
+ if not args or args[0].startswith("--"): # flag with no value
43
+ value = "true"
44
+ else:
45
+ value = args.pop(0)
46
+ opt = opt.replace("-", "_")
47
+ result[opt] = loads(value)
48
+ else:
49
+ secho(f"{err}: doesn't support shorthands", fg=colors.RED)
50
+ exit(1)
51
+ return result
52
+
53
+
54
+ class Cli(Typer):
55
+ """
56
+ Custom Typer object that:
57
+
58
+ - validates a command parameters before executing it
59
+ - accepts a configuration file describing the parameters
60
+ - automatically instantiates parameters given a dictionary when type hinted
61
+ """
62
+
63
+ def command( # noqa
64
+ self,
65
+ name,
66
+ *,
67
+ cls: Optional[Type[TyperCommand]] = None,
68
+ context_settings: Optional[Dict[Any, Any]] = None,
69
+ help: Optional[str] = None,
70
+ epilog: Optional[str] = None,
71
+ short_help: Optional[str] = None,
72
+ options_metavar: str = "[OPTIONS]",
73
+ add_help_option: bool = True,
74
+ no_args_is_help: bool = False,
75
+ hidden: bool = False,
76
+ deprecated: bool = False,
77
+ # Rich settings
78
+ rich_help_panel: Union[str, None] = Default(None),
79
+ registry: Any = None,
80
+ ) -> Callable[[CommandFunctionType], CommandFunctionType]:
81
+ typer_command = super().command(
82
+ name=name,
83
+ cls=cls,
84
+ help=help,
85
+ epilog=epilog,
86
+ short_help=short_help,
87
+ options_metavar=options_metavar,
88
+ add_help_option=add_help_option,
89
+ no_args_is_help=no_args_is_help,
90
+ hidden=hidden,
91
+ deprecated=deprecated,
92
+ rich_help_panel=rich_help_panel,
93
+ context_settings={
94
+ **(context_settings or {}),
95
+ "ignore_unknown_options": True,
96
+ "allow_extra_args": True,
97
+ },
98
+ )
99
+
100
+ def wrapper(fn):
101
+ validated = validate_arguments(fn)
102
+
103
+ @typer_command
104
+ def command(ctx: Context, config: Optional[List[Path]] = None):
105
+ config_path = config
106
+
107
+ has_meta = _fn_has_meta(fn)
108
+ if config_path:
109
+ config, name_from_file = merge_from_disk(config_path)
110
+ else:
111
+ config = Config({name: {}})
112
+ model_fields = (
113
+ validated.model.model_fields
114
+ if hasattr(validated.model, "model_fields")
115
+ else validated.model.__fields__
116
+ )
117
+ for k, v in parse_overrides(ctx.args).items():
118
+ if "." not in k:
119
+ parts = (name, k)
120
+ else:
121
+ parts = k.split(".")
122
+ if parts[0] in model_fields and parts[0] not in config:
123
+ parts = (name, *parts)
124
+ current = config
125
+ if parts[0] not in current:
126
+ raise Exception(
127
+ f"{k} does not match any existing section in config"
128
+ )
129
+ for part in parts[:-1]:
130
+ current = current.setdefault(part, Config())
131
+ current[parts[-1]] = v
132
+ try:
133
+ resolved_config = config.resolve(registry=registry)
134
+ default_seed = model_fields.get("seed")
135
+ if default_seed is not None:
136
+ default_seed = default_seed.get_default()
137
+ seed = config.get(name, {}).get("seed", default_seed)
138
+ if seed is not None:
139
+ set_seed(seed)
140
+ if has_meta:
141
+ config_meta = dict(
142
+ config_path=config_path,
143
+ resolved_config=resolved_config,
144
+ unresolved_config=config,
145
+ )
146
+ return validated(
147
+ **resolved_config.get(name, {}),
148
+ config_meta=config_meta,
149
+ )
150
+ else:
151
+ return validated(**resolved_config.get(name, {}))
152
+ except (LegacyValidationError, ConfitValidationError) as e:
153
+ e.raw_errors = patch_errors(
154
+ e.raw_errors,
155
+ (name,),
156
+ )
157
+ if is_debug() or e.__cause__ is not None:
158
+ raise e
159
+ try:
160
+ import rich
161
+
162
+ console = rich.console.Console(stderr=True)
163
+ console.print("Validation error:", style="red", end=" ")
164
+ console.print(str(e))
165
+ except ImportError: # pragma: no cover
166
+ print("Validation error:", file=sys.stderr, end=" ")
167
+ print(str(e), file=sys.stderr)
168
+ sys.exit(1)
169
+
170
+ return validated
171
+
172
+ return wrapper
173
+
174
+
175
+ def _fn_has_meta(fn):
176
+ return "config_meta" in inspect.signature(fn).parameters