pg-sui 0.2.3__py3-none-any.whl → 1.6.14.dev9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/METADATA +99 -77
- pg_sui-1.6.14.dev9.dist-info/RECORD +81 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info}/WHEEL +1 -1
- pg_sui-1.6.14.dev9.dist-info/entry_points.txt +4 -0
- {pg_sui-0.2.3.dist-info → pg_sui-1.6.14.dev9.dist-info/licenses}/LICENSE +0 -0
- pg_sui-1.6.14.dev9.dist-info/top_level.txt +1 -0
- pgsui/__init__.py +35 -54
- pgsui/_version.py +34 -0
- pgsui/cli.py +909 -0
- pgsui/data_processing/__init__.py +0 -0
- pgsui/data_processing/config.py +565 -0
- pgsui/data_processing/containers.py +1424 -0
- pgsui/data_processing/transformers.py +557 -907
- pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
- pgsui/electron/app/__main__.py +5 -0
- pgsui/electron/app/extra-resources/.gitkeep +1 -0
- pgsui/electron/app/icons/icons/1024x1024.png +0 -0
- pgsui/electron/app/icons/icons/128x128.png +0 -0
- pgsui/electron/app/icons/icons/16x16.png +0 -0
- pgsui/electron/app/icons/icons/24x24.png +0 -0
- pgsui/electron/app/icons/icons/256x256.png +0 -0
- pgsui/electron/app/icons/icons/32x32.png +0 -0
- pgsui/electron/app/icons/icons/48x48.png +0 -0
- pgsui/electron/app/icons/icons/512x512.png +0 -0
- pgsui/electron/app/icons/icons/64x64.png +0 -0
- pgsui/electron/app/icons/icons/icon.icns +0 -0
- pgsui/electron/app/icons/icons/icon.ico +0 -0
- pgsui/electron/app/main.js +227 -0
- pgsui/electron/app/package-lock.json +6894 -0
- pgsui/electron/app/package.json +51 -0
- pgsui/electron/app/preload.js +15 -0
- pgsui/electron/app/server.py +157 -0
- pgsui/electron/app/ui/logo.png +0 -0
- pgsui/electron/app/ui/renderer.js +131 -0
- pgsui/electron/app/ui/styles.css +59 -0
- pgsui/electron/app/ui/ui_shim.js +72 -0
- pgsui/electron/bootstrap.py +43 -0
- pgsui/electron/launch.py +57 -0
- pgsui/electron/package.json +14 -0
- pgsui/example_data/__init__.py +0 -0
- pgsui/example_data/phylip_files/__init__.py +0 -0
- pgsui/example_data/phylip_files/test.phy +0 -0
- pgsui/example_data/popmaps/__init__.py +0 -0
- pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
- pgsui/example_data/structure_files/__init__.py +0 -0
- pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
- pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
- pgsui/impute/__init__.py +0 -0
- pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
- pgsui/impute/deterministic/imputers/mode.py +844 -0
- pgsui/impute/deterministic/imputers/nmf.py +221 -0
- pgsui/impute/deterministic/imputers/phylo.py +973 -0
- pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
- pgsui/impute/supervised/__init__.py +0 -0
- pgsui/impute/supervised/base.py +343 -0
- pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
- pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
- pgsui/impute/supervised/imputers/random_forest.py +291 -0
- pgsui/impute/unsupervised/__init__.py +0 -0
- pgsui/impute/unsupervised/base.py +1118 -0
- pgsui/impute/unsupervised/callbacks.py +92 -262
- {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
- pgsui/impute/unsupervised/imputers/autoencoder.py +1285 -0
- pgsui/impute/unsupervised/imputers/nlpca.py +1554 -0
- pgsui/impute/unsupervised/imputers/ubp.py +1575 -0
- pgsui/impute/unsupervised/imputers/vae.py +1228 -0
- pgsui/impute/unsupervised/loss_functions.py +261 -0
- pgsui/impute/unsupervised/models/__init__.py +0 -0
- pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
- pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
- pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
- pgsui/impute/unsupervised/models/vae_model.py +269 -630
- pgsui/impute/unsupervised/nn_scorers.py +255 -0
- pgsui/utils/__init__.py +0 -0
- pgsui/utils/classification_viz.py +608 -0
- pgsui/utils/logging_utils.py +22 -0
- pgsui/utils/misc.py +35 -480
- pgsui/utils/plotting.py +996 -829
- pgsui/utils/pretty_metrics.py +290 -0
- pgsui/utils/scorers.py +213 -666
- pg_sui-0.2.3.dist-info/RECORD +0 -75
- pg_sui-0.2.3.dist-info/top_level.txt +0 -3
- pgsui/example_data/phylip_files/test_n10.phy +0 -118
- pgsui/example_data/phylip_files/test_n100.phy +0 -118
- pgsui/example_data/phylip_files/test_n2.phy +0 -118
- pgsui/example_data/phylip_files/test_n500.phy +0 -118
- pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
- pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
- pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
- pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
- pgsui/example_data/trees/test.iqtree +0 -376
- pgsui/example_data/trees/test.qmat +0 -5
- pgsui/example_data/trees/test.rate +0 -2033
- pgsui/example_data/trees/test.tre +0 -1
- pgsui/example_data/trees/test_n10.rate +0 -19
- pgsui/example_data/trees/test_n100.rate +0 -109
- pgsui/example_data/trees/test_n500.rate +0 -509
- pgsui/example_data/trees/test_siterates.txt +0 -2024
- pgsui/example_data/trees/test_siterates_n10.txt +0 -10
- pgsui/example_data/trees/test_siterates_n100.txt +0 -100
- pgsui/example_data/trees/test_siterates_n500.txt +0 -500
- pgsui/example_data/vcf_files/test.vcf +0 -244
- pgsui/example_data/vcf_files/test.vcf.gz +0 -0
- pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
- pgsui/impute/estimators.py +0 -1268
- pgsui/impute/impute.py +0 -1463
- pgsui/impute/simple_imputers.py +0 -1431
- pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
- pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
- pgsui/impute/unsupervised/keras_classifiers.py +0 -697
- pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
- pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
- pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
- pgsui/pg_sui.py +0 -261
- pgsui/utils/sequence_tools.py +0 -407
- simulation/sim_benchmarks.py +0 -333
- simulation/sim_treeparams.py +0 -475
- test/__init__.py +0 -0
- test/pg_sui_simtest.py +0 -215
- test/pg_sui_testing.py +0 -523
- test/test.py +0 -151
- test/test_pgsui.py +0 -374
- test/test_tkc.py +0 -185
|
File without changes
|
|
@@ -0,0 +1,565 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import dataclasses
|
|
5
|
+
import logging
|
|
6
|
+
import os
|
|
7
|
+
import typing as t
|
|
8
|
+
from dataclasses import MISSING, asdict, fields, is_dataclass
|
|
9
|
+
from typing import Any, Dict, Literal, Type
|
|
10
|
+
|
|
11
|
+
import yaml
|
|
12
|
+
|
|
13
|
+
T = t.TypeVar("T")
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
"""
|
|
17
|
+
Config utilities for PG-SUI.
|
|
18
|
+
|
|
19
|
+
We keep nested configs as dataclasses at all times.
|
|
20
|
+
|
|
21
|
+
Public API:
|
|
22
|
+
- load_yaml_to_dataclass
|
|
23
|
+
- apply_dot_overrides
|
|
24
|
+
- dataclass_to_yaml
|
|
25
|
+
- save_dataclass_yaml
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
# ---------------- Env var interpolation ----------------
|
|
30
|
+
def _interpolate_env(s: str) -> str:
|
|
31
|
+
"""Interpolate env vars in a string.
|
|
32
|
+
|
|
33
|
+
Syntax: ${VAR} or ${VAR:default}
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
s (str): Input string possibly containing env var patterns.
|
|
37
|
+
|
|
38
|
+
Returns:
|
|
39
|
+
str: The string with env vars interpolated.
|
|
40
|
+
"""
|
|
41
|
+
out, i = [], 0
|
|
42
|
+
while i < len(s):
|
|
43
|
+
if s[i : i + 2] == "${":
|
|
44
|
+
j = s.find("}", i + 2)
|
|
45
|
+
if j == -1:
|
|
46
|
+
out.append(s[i:])
|
|
47
|
+
break
|
|
48
|
+
token = s[i + 2 : j]
|
|
49
|
+
if ":" in token:
|
|
50
|
+
var, default = token.split(":", 1)
|
|
51
|
+
out.append(os.getenv(var, default))
|
|
52
|
+
else:
|
|
53
|
+
out.append(os.getenv(token, ""))
|
|
54
|
+
i = j + 1
|
|
55
|
+
else:
|
|
56
|
+
out.append(s[i])
|
|
57
|
+
i += 1
|
|
58
|
+
return "".join(out)
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _walk_env(obj: Any) -> Any:
|
|
62
|
+
"""Recursively interpolate env vars in strings within a nested structure.
|
|
63
|
+
|
|
64
|
+
This function traverses the input object and applies environment variable interpolation to any strings it encounters.
|
|
65
|
+
|
|
66
|
+
Args:
|
|
67
|
+
obj (Any): The input object, which can be a string, dict, list, or other types.
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
Any: The object with environment variables interpolated in strings.
|
|
71
|
+
"""
|
|
72
|
+
if isinstance(obj, str):
|
|
73
|
+
return _interpolate_env(obj)
|
|
74
|
+
if isinstance(obj, dict):
|
|
75
|
+
return {k: _walk_env(v) for k, v in obj.items()}
|
|
76
|
+
if isinstance(obj, list):
|
|
77
|
+
return [_walk_env(v) for v in obj]
|
|
78
|
+
return obj
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
# ---------------- YAML helpers ----------------
|
|
82
|
+
def dataclass_to_yaml(dc: t.Any) -> str:
|
|
83
|
+
"""Convert a dataclass instance to a YAML string.
|
|
84
|
+
|
|
85
|
+
This function uses the `asdict` function from the `dataclasses` module to convert the dataclass instance into a dictionary, which is then serialized to a YAML string using the `yaml` module.
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
dc (t.Any): A dataclass instance.
|
|
89
|
+
|
|
90
|
+
Returns:
|
|
91
|
+
str: The YAML representation of the dataclass.
|
|
92
|
+
|
|
93
|
+
Raises:
|
|
94
|
+
TypeError: If `dc` is not a dataclass instance.
|
|
95
|
+
"""
|
|
96
|
+
if not is_dataclass(dc) or isinstance(dc, type):
|
|
97
|
+
raise TypeError("dataclass_to_yaml expects a dataclass instance.")
|
|
98
|
+
return yaml.safe_dump(asdict(dc), sort_keys=False)
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def save_dataclass_yaml(dc: t.Any, path: str) -> None:
|
|
102
|
+
"""Save a dataclass instance as a YAML file.
|
|
103
|
+
|
|
104
|
+
This function uses the `dataclass_to_yaml` function to convert the dataclass instance into a YAML string, which is then written to a file.
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
dc (T): A dataclass instance.
|
|
108
|
+
path (str): Path to save the YAML file.
|
|
109
|
+
|
|
110
|
+
Raises:
|
|
111
|
+
TypeError: If `dc` is not a dataclass instance.
|
|
112
|
+
"""
|
|
113
|
+
if not is_dataclass(dc) or isinstance(dc, type):
|
|
114
|
+
raise TypeError("save_dataclass_yaml expects a dataclass instance.")
|
|
115
|
+
|
|
116
|
+
with open(path, "w", encoding="utf-8") as f:
|
|
117
|
+
f.write(dataclass_to_yaml(dc))
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _merge_into_dataclass(inst: Any, payload: Dict[str, Any], path: str = "") -> Any:
|
|
121
|
+
"""Recursively merge a nested dict into a dataclass instance in place.
|
|
122
|
+
|
|
123
|
+
This function updates the fields of the dataclass instance with values from the nested mapping. It raises errors for unknown keys and ensures that nested dataclasses are merged recursively.
|
|
124
|
+
|
|
125
|
+
Args:
|
|
126
|
+
inst (Any): A dataclass instance to update.
|
|
127
|
+
payload (Dict[str, Any]): A nested mapping to merge into `inst`.
|
|
128
|
+
path (str): Internal use only; tracks the current path for error messages.
|
|
129
|
+
|
|
130
|
+
Returns:
|
|
131
|
+
Any: The updated dataclass instance (same as `inst`).
|
|
132
|
+
|
|
133
|
+
Raises:
|
|
134
|
+
TypeError: If `inst` is not a dataclass.
|
|
135
|
+
KeyError: If `payload` contains keys not present in `inst`.
|
|
136
|
+
"""
|
|
137
|
+
if not is_dataclass(inst):
|
|
138
|
+
raise TypeError(
|
|
139
|
+
f"_merge_into_dataclass expects a dataclass at '{path or '<root>'}'"
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
fld_map = {f.name: f for f in fields(inst)}
|
|
143
|
+
for k, v in payload.items():
|
|
144
|
+
if k not in fld_map:
|
|
145
|
+
full = f"{path + '.' if path else ''}{k}"
|
|
146
|
+
raise KeyError(f"Unknown key '{full}'")
|
|
147
|
+
cur = getattr(inst, k)
|
|
148
|
+
if is_dataclass(cur) and isinstance(v, dict):
|
|
149
|
+
_merge_into_dataclass(cur, v, path=(f"{path}.{k}" if path else k))
|
|
150
|
+
else:
|
|
151
|
+
setattr(inst, k, v)
|
|
152
|
+
return inst
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def load_yaml_to_dataclass(
|
|
156
|
+
path: str,
|
|
157
|
+
dc_type: Type[T],
|
|
158
|
+
*,
|
|
159
|
+
base: T | None = None,
|
|
160
|
+
overlays: Dict[str, Any] | None = None,
|
|
161
|
+
yaml_preset_behavior: Literal["ignore", "error"] = "ignore",
|
|
162
|
+
) -> T:
|
|
163
|
+
"""Load a YAML file and merge into a dataclass instance with strict precedence.
|
|
164
|
+
|
|
165
|
+
This function is designed for the new argument hierarchy: defaults < CLI preset (build `base` from it) < YAML file < CLI args/--set
|
|
166
|
+
|
|
167
|
+
Notes:
|
|
168
|
+
- `preset` is **CLI-only**. If the YAML contains `preset`, it will be ignored (default) or cause an error depending on `yaml_preset_behavior`.
|
|
169
|
+
- Pass a `base` instance that is already constructed from the CLI-selected preset (e.g., `NLPCAConfig.from_preset(args.preset)`), and this function will overlay the YAML on top of it. Any additional `overlays` (a nested dict) are applied last.
|
|
170
|
+
|
|
171
|
+
Args:
|
|
172
|
+
path (str): Path to the YAML file.
|
|
173
|
+
dc_type (Type[T]): Dataclass type to construct if `base` is not provided.
|
|
174
|
+
base (T | None): A preconstructed dataclass instance to start from
|
|
175
|
+
(typically built from the CLI preset). If provided, it takes precedence
|
|
176
|
+
over any other starting point.
|
|
177
|
+
overlays (Dict[str, Any] | None): A nested mapping to apply **after** the
|
|
178
|
+
YAML (e.g., derived CLI flags). These win over YAML values.
|
|
179
|
+
yaml_preset_behavior (Literal["ignore","error"]): What to do if the YAML
|
|
180
|
+
contains a `preset` key. Default: "ignore".
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
Type[T]: The merged dataclass instance.
|
|
184
|
+
|
|
185
|
+
Raises:
|
|
186
|
+
TypeError: If `base` is not a dataclass, or YAML root isn't a mapping,
|
|
187
|
+
or `overlays` isn't a mapping when provided.
|
|
188
|
+
ValueError: If `yaml_preset_behavior="error"` and YAML contains `preset`.
|
|
189
|
+
KeyError: If any override path is invalid.
|
|
190
|
+
"""
|
|
191
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
192
|
+
raw = yaml.safe_load(f) or {}
|
|
193
|
+
raw = _walk_env(raw)
|
|
194
|
+
|
|
195
|
+
# Enforce: 'preset' is CLI-only.
|
|
196
|
+
if isinstance(raw, dict) and "preset" in raw:
|
|
197
|
+
preset_in_yaml = raw.get("preset")
|
|
198
|
+
if yaml_preset_behavior == "error":
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"YAML contains 'preset: {preset_in_yaml}'. "
|
|
201
|
+
"The preset must be selected via the command line only."
|
|
202
|
+
)
|
|
203
|
+
# ignore (default): drop it and continue
|
|
204
|
+
logging.warning(
|
|
205
|
+
"Ignoring 'preset' in YAML (%r). Preset selection is CLI-only.",
|
|
206
|
+
preset_in_yaml,
|
|
207
|
+
)
|
|
208
|
+
raw.pop("preset", None)
|
|
209
|
+
|
|
210
|
+
# Start from `base` if given; else construct a fresh instance of dc_type.
|
|
211
|
+
if base is not None:
|
|
212
|
+
if not is_dataclass(base):
|
|
213
|
+
raise TypeError("`base` must be a dataclass instance.")
|
|
214
|
+
cfg = copy.deepcopy(base)
|
|
215
|
+
else:
|
|
216
|
+
cfg = dc_type() # defaults
|
|
217
|
+
|
|
218
|
+
if not isinstance(raw, dict):
|
|
219
|
+
raise TypeError(f"{path} did not parse as a mapping.")
|
|
220
|
+
|
|
221
|
+
# YAML overlays the starting config
|
|
222
|
+
_merge_into_dataclass(cfg, raw)
|
|
223
|
+
|
|
224
|
+
# Optional final overlays (e.g., mapped CLI flags / --set already parsed)
|
|
225
|
+
if overlays:
|
|
226
|
+
if not isinstance(overlays, dict):
|
|
227
|
+
raise TypeError("`overlays` must be a nested dict.")
|
|
228
|
+
_merge_into_dataclass(cfg, overlays)
|
|
229
|
+
|
|
230
|
+
return cfg
|
|
231
|
+
|
|
232
|
+
|
|
233
|
+
def _is_dataclass_type(tp: t.Any) -> bool:
|
|
234
|
+
"""Return True if tp is a dataclass type (not instance).
|
|
235
|
+
|
|
236
|
+
This function checks if the given type is a dataclass type by verifying its properties and using the `is_dataclass` function from the `dataclasses` module.
|
|
237
|
+
|
|
238
|
+
Args:
|
|
239
|
+
tp (t.Any): A type to check.
|
|
240
|
+
|
|
241
|
+
Returns:
|
|
242
|
+
bool: True if `tp` is a dataclass type, False otherwise.
|
|
243
|
+
"""
|
|
244
|
+
try:
|
|
245
|
+
return isinstance(tp, type) and dataclasses.is_dataclass(tp)
|
|
246
|
+
except Exception:
|
|
247
|
+
return False
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def _unwrap_optional(tp: t.Any) -> t.Any:
|
|
251
|
+
"""If Optional[T] or Union[T, None], return T; else tp.
|
|
252
|
+
|
|
253
|
+
This function checks if the given type is an Optional or a Union that includes None, and if so, it returns the non-None type. Otherwise, it returns the original type.
|
|
254
|
+
|
|
255
|
+
Args:
|
|
256
|
+
tp (t.Any): A type annotation.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
t.Any: The unwrapped type, or the original type if not applicable.
|
|
260
|
+
"""
|
|
261
|
+
origin = t.get_origin(tp)
|
|
262
|
+
if origin is t.Union:
|
|
263
|
+
args = [a for a in t.get_args(tp) if a is not type(None)]
|
|
264
|
+
return args[0] if len(args) == 1 else tp
|
|
265
|
+
return tp
|
|
266
|
+
|
|
267
|
+
|
|
268
|
+
def _expected_field_type(dc_type: type, name: str) -> t.Any:
|
|
269
|
+
"""Fetch the annotated type of field `name` on dataclass type `dc_type`.
|
|
270
|
+
|
|
271
|
+
This function retrieves the type annotation for a specific field in a dataclass. If the field is not found, it raises a KeyError.
|
|
272
|
+
|
|
273
|
+
Args:
|
|
274
|
+
dc_type (type): A dataclass type.
|
|
275
|
+
name (str): The field name to look up.
|
|
276
|
+
|
|
277
|
+
Returns:
|
|
278
|
+
t.Any: The annotated type of the field.
|
|
279
|
+
|
|
280
|
+
Raises:
|
|
281
|
+
KeyError: If the field is unknown.
|
|
282
|
+
"""
|
|
283
|
+
for f in fields(dc_type):
|
|
284
|
+
if f.name == name:
|
|
285
|
+
hint = f.type
|
|
286
|
+
if isinstance(hint, str):
|
|
287
|
+
try:
|
|
288
|
+
resolved = t.get_type_hints(dc_type).get(name, hint)
|
|
289
|
+
hint = resolved
|
|
290
|
+
except Exception:
|
|
291
|
+
pass
|
|
292
|
+
return hint
|
|
293
|
+
raise KeyError(f"Unknown config key: '{name}' on {dc_type.__name__}")
|
|
294
|
+
|
|
295
|
+
|
|
296
|
+
def _instantiate_field(dc_type: t.Any, name: str):
|
|
297
|
+
"""Create a default instance for nested dataclass field `name`.
|
|
298
|
+
|
|
299
|
+
Attempts to use default_factory, then default, then type constructor. If none are available, raises KeyError.
|
|
300
|
+
|
|
301
|
+
Args:
|
|
302
|
+
dc_type (Any): A dataclass type.
|
|
303
|
+
name (str): The field name to instantiate.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Any: An instance of the field's type.
|
|
307
|
+
|
|
308
|
+
Raises:
|
|
309
|
+
KeyError: If the field is unknown or cannot be instantiated.
|
|
310
|
+
"""
|
|
311
|
+
for f in fields(dc_type):
|
|
312
|
+
if f.name == name:
|
|
313
|
+
# Prefer default_factory → default → type()
|
|
314
|
+
if f.default_factory is not MISSING: # type: ignore[attr-defined]
|
|
315
|
+
return f.default_factory()
|
|
316
|
+
if f.default is not MISSING:
|
|
317
|
+
val = f.default
|
|
318
|
+
# If default is None but type is dataclass, construct it:
|
|
319
|
+
tp = _unwrap_optional(f.type)
|
|
320
|
+
if val is None and _is_dataclass_type(tp):
|
|
321
|
+
return tp()
|
|
322
|
+
return val
|
|
323
|
+
# No default supplied; if it's a dataclass type, construct it.
|
|
324
|
+
tp = _unwrap_optional(f.type)
|
|
325
|
+
if _is_dataclass_type(tp):
|
|
326
|
+
return tp()
|
|
327
|
+
# Otherwise we cannot guess safely:
|
|
328
|
+
raise KeyError(
|
|
329
|
+
f"Cannot create default for '{name}' on {dc_type.__name__}; "
|
|
330
|
+
"no default/default_factory and not a dataclass field."
|
|
331
|
+
)
|
|
332
|
+
raise KeyError(f"Unknown config key: '{name}' on {dc_type.__name__}'")
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def _merge_mapping_into_dataclass(
|
|
336
|
+
instance: T, payload: dict, *, path: str = "<root>"
|
|
337
|
+
) -> T:
|
|
338
|
+
"""Recursively merge a dict into a dataclass instance (strict on keys).
|
|
339
|
+
|
|
340
|
+
This function updates the fields of a dataclass instance with values from a nested mapping (dict). It ensures that all keys in the mapping correspond to fields in the dataclass, and it handles nested dataclass fields as well.
|
|
341
|
+
|
|
342
|
+
Args:
|
|
343
|
+
instance (T): A dataclass instance to update.
|
|
344
|
+
payload (dict): A nested mapping to merge into `instance`.
|
|
345
|
+
path (str): Internal use only; tracks the current path for error messages.
|
|
346
|
+
|
|
347
|
+
Returns:
|
|
348
|
+
Type[T]: The updated dataclass instance (same as `instance`).
|
|
349
|
+
|
|
350
|
+
Raises:
|
|
351
|
+
TypeError: If `instance` is not a dataclass.
|
|
352
|
+
KeyError: If `payload` contains keys not present in `instance`.
|
|
353
|
+
"""
|
|
354
|
+
if not is_dataclass(instance):
|
|
355
|
+
raise TypeError(f"Expected dataclass at {path}, got {type(instance)}")
|
|
356
|
+
|
|
357
|
+
dc_type = type(instance)
|
|
358
|
+
for k, v in payload.items():
|
|
359
|
+
# Ensure field exists
|
|
360
|
+
exp_type = _expected_field_type(dc_type, k)
|
|
361
|
+
exp_core = _unwrap_optional(exp_type)
|
|
362
|
+
|
|
363
|
+
cur = getattr(instance, k, MISSING)
|
|
364
|
+
if cur is MISSING:
|
|
365
|
+
raise KeyError(f"Unknown config key: '{path}.{k}'")
|
|
366
|
+
|
|
367
|
+
if _is_dataclass_type(exp_core) and isinstance(v, dict):
|
|
368
|
+
# Ensure we have a dataclass instance to merge into
|
|
369
|
+
if cur is None or not is_dataclass(cur):
|
|
370
|
+
cur = _instantiate_field(dc_type, k)
|
|
371
|
+
setattr(instance, k, cur)
|
|
372
|
+
merged = _merge_mapping_into_dataclass(cur, v, path=f"{path}.{k}")
|
|
373
|
+
setattr(instance, k, merged)
|
|
374
|
+
else:
|
|
375
|
+
setattr(
|
|
376
|
+
instance,
|
|
377
|
+
k,
|
|
378
|
+
_coerce_value(v, exp_core, f"{path}.{k}", current=cur),
|
|
379
|
+
)
|
|
380
|
+
return instance
|
|
381
|
+
|
|
382
|
+
|
|
383
|
+
def _coerce_value(value: t.Any, tp: t.Any, where: str, *, current: t.Any = MISSING):
|
|
384
|
+
"""Lightweight coercion for common primitives and Literals.
|
|
385
|
+
|
|
386
|
+
This function attempts to coerce a value into a target type, handling common cases like basic primitives (int, float, bool, str) and Literal types. If coercion is not applicable or fails, it returns the original value.
|
|
387
|
+
|
|
388
|
+
Args:
|
|
389
|
+
value (t.Any): The input value to coerce.
|
|
390
|
+
tp (t.Any): The target type annotation.
|
|
391
|
+
where (str): Context string for error messages.
|
|
392
|
+
|
|
393
|
+
Returns:
|
|
394
|
+
t.Any: The coerced value, or the original if no coercion was applied.
|
|
395
|
+
|
|
396
|
+
Raises:
|
|
397
|
+
ValueError: If the value is not valid for a Literal type.
|
|
398
|
+
TypeError: If the value cannot be coerced to the target type.
|
|
399
|
+
"""
|
|
400
|
+
origin = t.get_origin(tp)
|
|
401
|
+
args = t.get_args(tp)
|
|
402
|
+
|
|
403
|
+
if tp in {t.Any, object, None}:
|
|
404
|
+
if current is not MISSING and current is not None:
|
|
405
|
+
infer_type = type(current)
|
|
406
|
+
if isinstance(current, bool):
|
|
407
|
+
tp = bool
|
|
408
|
+
elif isinstance(current, int) and not isinstance(current, bool):
|
|
409
|
+
tp = int
|
|
410
|
+
elif isinstance(current, float):
|
|
411
|
+
tp = float
|
|
412
|
+
elif isinstance(current, str):
|
|
413
|
+
tp = str
|
|
414
|
+
else:
|
|
415
|
+
tp = infer_type
|
|
416
|
+
|
|
417
|
+
# Literal[...] → restrict values
|
|
418
|
+
if origin is t.Literal:
|
|
419
|
+
allowed = set(args)
|
|
420
|
+
if value not in allowed:
|
|
421
|
+
raise ValueError(
|
|
422
|
+
f"Invalid value for {where}. Expected one of {sorted(allowed)}, got {value!r}."
|
|
423
|
+
)
|
|
424
|
+
return value
|
|
425
|
+
|
|
426
|
+
# Basic primitives coercion
|
|
427
|
+
if tp in (int, float, bool, str):
|
|
428
|
+
if tp is bool:
|
|
429
|
+
if isinstance(value, str):
|
|
430
|
+
v = value.strip().lower()
|
|
431
|
+
truthy = {"true", "1", "yes", "on"}
|
|
432
|
+
falsy = {"false", "0", "no", "off"}
|
|
433
|
+
if v in truthy:
|
|
434
|
+
return True
|
|
435
|
+
if v in falsy:
|
|
436
|
+
return False
|
|
437
|
+
if v == "" and current is not MISSING:
|
|
438
|
+
return bool(current)
|
|
439
|
+
return bool(value)
|
|
440
|
+
|
|
441
|
+
if isinstance(value, str):
|
|
442
|
+
stripped = value.strip()
|
|
443
|
+
if stripped == "":
|
|
444
|
+
return current if current is not MISSING else value
|
|
445
|
+
try:
|
|
446
|
+
return tp(stripped)
|
|
447
|
+
except Exception:
|
|
448
|
+
return value
|
|
449
|
+
|
|
450
|
+
try:
|
|
451
|
+
return tp(value)
|
|
452
|
+
except Exception:
|
|
453
|
+
return value
|
|
454
|
+
|
|
455
|
+
# Dataclasses or other complex types → trust caller
|
|
456
|
+
return value
|
|
457
|
+
|
|
458
|
+
|
|
459
|
+
def apply_dot_overrides(
|
|
460
|
+
dc: t.Any,
|
|
461
|
+
overrides: dict[str, t.Any] | None,
|
|
462
|
+
*,
|
|
463
|
+
root_cls: type | None = None,
|
|
464
|
+
create_missing: bool = False,
|
|
465
|
+
registry: dict[str, type] | None = None,
|
|
466
|
+
) -> t.Any:
|
|
467
|
+
"""Apply overrides like {'io.prefix': '...', 'train.batch_size': 64} to any Config dataclass.
|
|
468
|
+
|
|
469
|
+
This function updates the fields of a dataclass instance with values from a nested mapping (dict). It ensures that all keys in the mapping correspond to fields in the dataclass, and it handles nested dataclass fields as well.
|
|
470
|
+
|
|
471
|
+
Args:
|
|
472
|
+
dc (t.Any): A dataclass instance (or a dict that can be up-cast).
|
|
473
|
+
overrides (dict[str, t.Any] | None): Mapping of dot-key paths to values.
|
|
474
|
+
root_cls (type | None): Optional dataclass type to up-cast a root dict into (if `dc` is a dict).
|
|
475
|
+
create_missing (bool): If True, instantiate missing intermediate dataclass nodes when the schema defines them.
|
|
476
|
+
registry (dict[str, type] | None): Optional mapping from top-level segment → dataclass type to assist up-casting.
|
|
477
|
+
|
|
478
|
+
Returns:
|
|
479
|
+
t.Any: The updated dataclass instance (same object identity is not guaranteed; a deep copy is made).
|
|
480
|
+
|
|
481
|
+
Notes:
|
|
482
|
+
- No hard-coding of NLPCAConfig. Pass `root_cls=NLPCAConfig` (or UBPConfig, etc.) when starting from a dict.
|
|
483
|
+
- Dict payloads encountered at intermediate nodes are merged into the expected dataclass type using schema introspection.
|
|
484
|
+
- Enforces unknown-key errors to keep configs honest.
|
|
485
|
+
|
|
486
|
+
Raises:
|
|
487
|
+
TypeError: If `dc` is not a dataclass or dict (for up-cast).
|
|
488
|
+
KeyError: If any override path is invalid.
|
|
489
|
+
"""
|
|
490
|
+
if not overrides:
|
|
491
|
+
return dc
|
|
492
|
+
|
|
493
|
+
# Root up-cast if needed
|
|
494
|
+
if not is_dataclass(dc):
|
|
495
|
+
if isinstance(dc, dict):
|
|
496
|
+
if root_cls is None:
|
|
497
|
+
raise TypeError(
|
|
498
|
+
"Root payload is a dict. Provide `root_cls` to up-cast it into the desired *Config dataclass."
|
|
499
|
+
)
|
|
500
|
+
base = root_cls()
|
|
501
|
+
dc = _merge_mapping_into_dataclass(base, dc)
|
|
502
|
+
else:
|
|
503
|
+
raise TypeError(
|
|
504
|
+
"apply_dot_overrides expects a dataclass instance or a dict for up-cast."
|
|
505
|
+
)
|
|
506
|
+
|
|
507
|
+
updated = copy.deepcopy(dc)
|
|
508
|
+
|
|
509
|
+
for dotkey, value in overrides.items():
|
|
510
|
+
parts = dotkey.split(".")
|
|
511
|
+
node = updated
|
|
512
|
+
node_type = type(node)
|
|
513
|
+
|
|
514
|
+
# Descend to parent
|
|
515
|
+
for idx, seg in enumerate(parts[:-1]):
|
|
516
|
+
if not is_dataclass(node):
|
|
517
|
+
parent_path = ".".join(parts[:idx]) or "<root>"
|
|
518
|
+
raise KeyError(
|
|
519
|
+
f"Target '{parent_path}' is not a dataclass in the override path; cannot descend into non-dataclass objects."
|
|
520
|
+
)
|
|
521
|
+
|
|
522
|
+
# Validate field existence and fetch expected type
|
|
523
|
+
exp_type = _expected_field_type(node_type, seg)
|
|
524
|
+
exp_core = _unwrap_optional(exp_type)
|
|
525
|
+
|
|
526
|
+
# Materialize or up-cast if needed
|
|
527
|
+
child = getattr(node, seg, MISSING)
|
|
528
|
+
if child is MISSING:
|
|
529
|
+
raise KeyError(f"Unknown config key: '{'.'.join(parts[:idx+1])}'")
|
|
530
|
+
|
|
531
|
+
if isinstance(child, dict) and _is_dataclass_type(exp_core):
|
|
532
|
+
# Up-cast dict → dataclass of the expected type
|
|
533
|
+
child = _merge_mapping_into_dataclass(
|
|
534
|
+
exp_core(), child, path=".".join(parts[: idx + 1])
|
|
535
|
+
)
|
|
536
|
+
setattr(node, seg, child)
|
|
537
|
+
|
|
538
|
+
if child is None and create_missing and _is_dataclass_type(exp_core):
|
|
539
|
+
child = exp_core()
|
|
540
|
+
setattr(node, seg, child)
|
|
541
|
+
|
|
542
|
+
node = getattr(node, seg)
|
|
543
|
+
node_type = type(node)
|
|
544
|
+
|
|
545
|
+
# Assign leaf with light coercion
|
|
546
|
+
if not is_dataclass(node):
|
|
547
|
+
parent_path = ".".join(parts[:-1]) or "<root>"
|
|
548
|
+
raise KeyError(
|
|
549
|
+
f"Target '{parent_path}' is not a dataclass in the override path; cannot set '{parts[-1]}'."
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
leaf = parts[-1]
|
|
553
|
+
|
|
554
|
+
# Check field exists and coerce to its annotated type
|
|
555
|
+
exp_type = _expected_field_type(type(node), leaf)
|
|
556
|
+
exp_core = _unwrap_optional(exp_type)
|
|
557
|
+
|
|
558
|
+
if not hasattr(node, leaf):
|
|
559
|
+
raise KeyError(f"Unknown config key: '{dotkey}'")
|
|
560
|
+
|
|
561
|
+
current = getattr(node, leaf, MISSING)
|
|
562
|
+
coerced = _coerce_value(value, exp_core, dotkey, current=current)
|
|
563
|
+
setattr(node, leaf, coerced)
|
|
564
|
+
|
|
565
|
+
return updated
|