pg-sui 1.0.2.1__py3-none-any.whl → 1.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of pg-sui might be problematic. Click here for more details.

Files changed (112) hide show
  1. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/METADATA +51 -70
  2. pg_sui-1.6.8.dist-info/RECORD +78 -0
  3. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.8.dist-info/entry_points.txt +4 -0
  5. pg_sui-1.6.8.dist-info/top_level.txt +1 -0
  6. pgsui/__init__.py +35 -54
  7. pgsui/_version.py +34 -0
  8. pgsui/cli.py +635 -0
  9. pgsui/data_processing/config.py +576 -0
  10. pgsui/data_processing/containers.py +1782 -0
  11. pgsui/data_processing/transformers.py +121 -1103
  12. pgsui/electron/app/__main__.py +5 -0
  13. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  14. pgsui/electron/app/icons/icons/128x128.png +0 -0
  15. pgsui/electron/app/icons/icons/16x16.png +0 -0
  16. pgsui/electron/app/icons/icons/24x24.png +0 -0
  17. pgsui/electron/app/icons/icons/256x256.png +0 -0
  18. pgsui/electron/app/icons/icons/32x32.png +0 -0
  19. pgsui/electron/app/icons/icons/48x48.png +0 -0
  20. pgsui/electron/app/icons/icons/512x512.png +0 -0
  21. pgsui/electron/app/icons/icons/64x64.png +0 -0
  22. pgsui/electron/app/icons/icons/icon.icns +0 -0
  23. pgsui/electron/app/icons/icons/icon.ico +0 -0
  24. pgsui/electron/app/main.js +189 -0
  25. pgsui/electron/app/package-lock.json +6893 -0
  26. pgsui/electron/app/package.json +50 -0
  27. pgsui/electron/app/preload.js +15 -0
  28. pgsui/electron/app/server.py +146 -0
  29. pgsui/electron/app/ui/logo.png +0 -0
  30. pgsui/electron/app/ui/renderer.js +130 -0
  31. pgsui/electron/app/ui/styles.css +59 -0
  32. pgsui/electron/app/ui/ui_shim.js +72 -0
  33. pgsui/electron/bootstrap.py +43 -0
  34. pgsui/electron/launch.py +59 -0
  35. pgsui/electron/package.json +14 -0
  36. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  37. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  38. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  39. pgsui/impute/deterministic/imputers/allele_freq.py +691 -0
  40. pgsui/impute/deterministic/imputers/mode.py +679 -0
  41. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  42. pgsui/impute/deterministic/imputers/phylo.py +971 -0
  43. pgsui/impute/deterministic/imputers/ref_allele.py +530 -0
  44. pgsui/impute/supervised/base.py +339 -0
  45. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +293 -0
  46. pgsui/impute/supervised/imputers/random_forest.py +287 -0
  47. pgsui/impute/unsupervised/base.py +924 -0
  48. pgsui/impute/unsupervised/callbacks.py +89 -263
  49. pgsui/impute/unsupervised/imputers/autoencoder.py +972 -0
  50. pgsui/impute/unsupervised/imputers/nlpca.py +1264 -0
  51. pgsui/impute/unsupervised/imputers/ubp.py +1288 -0
  52. pgsui/impute/unsupervised/imputers/vae.py +957 -0
  53. pgsui/impute/unsupervised/loss_functions.py +158 -0
  54. pgsui/impute/unsupervised/models/autoencoder_model.py +208 -558
  55. pgsui/impute/unsupervised/models/nlpca_model.py +149 -468
  56. pgsui/impute/unsupervised/models/ubp_model.py +198 -1317
  57. pgsui/impute/unsupervised/models/vae_model.py +259 -618
  58. pgsui/impute/unsupervised/nn_scorers.py +215 -0
  59. pgsui/utils/classification_viz.py +591 -0
  60. pgsui/utils/misc.py +35 -480
  61. pgsui/utils/plotting.py +514 -824
  62. pgsui/utils/scorers.py +212 -438
  63. pg_sui-1.0.2.1.dist-info/RECORD +0 -75
  64. pg_sui-1.0.2.1.dist-info/top_level.txt +0 -3
  65. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  66. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  67. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  68. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  69. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  70. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  71. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  72. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  73. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  74. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  75. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  76. pgsui/example_data/trees/test.iqtree +0 -376
  77. pgsui/example_data/trees/test.qmat +0 -5
  78. pgsui/example_data/trees/test.rate +0 -2033
  79. pgsui/example_data/trees/test.tre +0 -1
  80. pgsui/example_data/trees/test_n10.rate +0 -19
  81. pgsui/example_data/trees/test_n100.rate +0 -109
  82. pgsui/example_data/trees/test_n500.rate +0 -509
  83. pgsui/example_data/trees/test_siterates.txt +0 -2024
  84. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  85. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  86. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  87. pgsui/example_data/vcf_files/test.vcf +0 -244
  88. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  89. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  90. pgsui/impute/estimators.py +0 -735
  91. pgsui/impute/impute.py +0 -1486
  92. pgsui/impute/simple_imputers.py +0 -1439
  93. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -785
  94. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1027
  95. pgsui/impute/unsupervised/keras_classifiers.py +0 -702
  96. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  97. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1424
  98. pgsui/impute/unsupervised/neural_network_methods.py +0 -1549
  99. pgsui/pg_sui.py +0 -261
  100. pgsui/utils/sequence_tools.py +0 -407
  101. simulation/sim_benchmarks.py +0 -333
  102. simulation/sim_treeparams.py +0 -475
  103. test/__init__.py +0 -0
  104. test/pg_sui_simtest.py +0 -215
  105. test/pg_sui_testing.py +0 -523
  106. test/test.py +0 -297
  107. test/test_pgsui.py +0 -374
  108. test/test_tkc.py +0 -214
  109. {pg_sui-1.0.2.1.dist-info → pg_sui-1.6.8.dist-info/licenses}/LICENSE +0 -0
  110. /pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  111. /pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  112. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
@@ -0,0 +1,576 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import dataclasses
5
+ import logging
6
+ import os
7
+ import typing as t
8
+ from dataclasses import MISSING, asdict, fields
9
+ from dataclasses import is_dataclass
10
+ from dataclasses import is_dataclass as _is_dc
11
+ from typing import Any, Callable, Dict, Literal, Type, TypeVar
12
+
13
+ import yaml
14
+
15
+ # type variable for dataclass types
16
+ T = t.TypeVar("T")
17
+ T = TypeVar("T")
18
+
19
+ """
20
+ Config utilities for PG-SUI.
21
+
22
+ We keep nested configs as dataclasses at all times.
23
+
24
+ Public API:
25
+ - load_yaml_to_dataclass
26
+ - apply_dot_overrides
27
+ - dataclass_to_yaml
28
+ - save_dataclass_yaml
29
+ """
30
+
31
+
32
+ # ---------------- Env var interpolation ----------------
33
+ def _interpolate_env(s: str) -> str:
34
+ """Interpolate env vars in a string.
35
+
36
+ Syntax: ${VAR} or ${VAR:default}
37
+
38
+ Args:
39
+ s (str): Input string possibly containing env var patterns.
40
+
41
+ Returns:
42
+ str: The string with env vars interpolated.
43
+ """
44
+ out, i = [], 0
45
+ while i < len(s):
46
+ if s[i : i + 2] == "${":
47
+ j = s.find("}", i + 2)
48
+ if j == -1:
49
+ out.append(s[i:])
50
+ break
51
+ token = s[i + 2 : j]
52
+ if ":" in token:
53
+ var, default = token.split(":", 1)
54
+ out.append(os.getenv(var, default))
55
+ else:
56
+ out.append(os.getenv(token, ""))
57
+ i = j + 1
58
+ else:
59
+ out.append(s[i])
60
+ i += 1
61
+ return "".join(out)
62
+
63
+
64
+ def _walk_env(obj: Any) -> Any:
65
+ """Recursively interpolate env vars in strings within a nested structure.
66
+
67
+ This function traverses the input object and applies environment variable interpolation to any strings it encounters.
68
+
69
+ Args:
70
+ obj (Any): The input object, which can be a string, dict, list, or other types.
71
+
72
+ Returns:
73
+ Any: The object with environment variables interpolated in strings.
74
+ """
75
+ if isinstance(obj, str):
76
+ return _interpolate_env(obj)
77
+ if isinstance(obj, dict):
78
+ return {k: _walk_env(v) for k, v in obj.items()}
79
+ if isinstance(obj, list):
80
+ return [_walk_env(v) for v in obj]
81
+ return obj
82
+
83
+
84
+ # ---------------- YAML helpers ----------------
85
+ def dataclass_to_yaml(dc: T) -> str:
86
+ """Convert a dataclass instance to a YAML string.
87
+
88
+ This function uses the `asdict` function from the `dataclasses` module to convert the dataclass instance into a dictionary, which is then serialized to a YAML string using the `yaml` module.
89
+
90
+ Args:
91
+ dc (T): A dataclass instance.
92
+
93
+ Returns:
94
+ str: The YAML representation of the dataclass.
95
+
96
+ Raises:
97
+ TypeError: If `dc` is not a dataclass instance.
98
+ """
99
+ if not is_dataclass(dc):
100
+ raise TypeError("dataclass_to_yaml expects a dataclass instance.")
101
+ return yaml.safe_dump(asdict(dc), sort_keys=False)
102
+
103
+
104
+ def save_dataclass_yaml(dc: T, path: str) -> None:
105
+ """Save a dataclass instance as a YAML file.
106
+
107
+ This function uses the `dataclass_to_yaml` function to convert the dataclass instance into a YAML string, which is then written to a file.
108
+
109
+ Args:
110
+ dc (T): A dataclass instance.
111
+ path (str): Path to save the YAML file.
112
+
113
+ Raises:
114
+ TypeError: If `dc` is not a dataclass instance.
115
+ """
116
+ if not is_dataclass(dc):
117
+ raise TypeError(
118
+ "save_dataclass_yaml expects a dataclass or dataclass instance."
119
+ )
120
+
121
+ with open(path, "w", encoding="utf-8") as f:
122
+ f.write(dataclass_to_yaml(dc))
123
+
124
+
125
+ def _merge_into_dataclass(inst: Any, payload: Dict[str, Any], path: str = "") -> Any:
126
+ """Recursively merge a nested dict into a dataclass instance in place.
127
+
128
+ This function updates the fields of the dataclass instance with values from the nested mapping. It raises errors for unknown keys and ensures that nested dataclasses are merged recursively.
129
+
130
+ Args:
131
+ inst (Any): A dataclass instance to update.
132
+ payload (Dict[str, Any]): A nested mapping to merge into `inst`.
133
+ path (str): Internal use only; tracks the current path for error messages.
134
+
135
+ Returns:
136
+ Any: The updated dataclass instance (same as `inst`).
137
+
138
+ Raises:
139
+ TypeError: If `inst` is not a dataclass.
140
+ KeyError: If `payload` contains keys not present in `inst`.
141
+ """
142
+ if not _is_dc(inst):
143
+ raise TypeError(
144
+ f"_merge_into_dataclass expects a dataclass at '{path or '<root>'}'"
145
+ )
146
+
147
+ fld_map = {f.name: f for f in fields(inst)}
148
+ for k, v in payload.items():
149
+ if k not in fld_map:
150
+ full = f"{path + '.' if path else ''}{k}"
151
+ raise KeyError(f"Unknown key '{full}'")
152
+ cur = getattr(inst, k)
153
+ if _is_dc(cur) and isinstance(v, dict):
154
+ _merge_into_dataclass(cur, v, path=(f"{path}.{k}" if path else k))
155
+ else:
156
+ setattr(inst, k, v)
157
+ return inst
158
+
159
+
160
+ def load_yaml_to_dataclass(
161
+ path: str,
162
+ dc_type: Type[T],
163
+ *,
164
+ base: T | None = None,
165
+ overlays: Dict[str, Any] | None = None,
166
+ preset_builder: Callable[[str], T] | None = None,
167
+ yaml_preset_behavior: Literal["ignore", "error"] = "ignore",
168
+ ) -> T:
169
+ """Load a YAML file and merge into a dataclass instance with strict precedence.
170
+
171
+ This function is designed for the new argument hierarchy: defaults < CLI preset (build `base` from it) < YAML file < CLI args/--set
172
+
173
+ Notes:
174
+ - `preset` is **CLI-only**. If the YAML contains `preset`, it will be ignored (default) or cause an error depending on `yaml_preset_behavior`.
175
+ - Pass a `base` instance that is already constructed from the CLI-selected preset (e.g., `NLPCAConfig.from_preset(args.preset)`), and this function will overlay the YAML on top of it. Any additional `overlays` (a nested dict) are applied last.
176
+
177
+ Args:
178
+ path (str): Path to the YAML file.
179
+ dc_type (Type[T]): Dataclass type to construct if `base` is not provided.
180
+ base (T | None): A preconstructed dataclass instance to start from
181
+ (typically built from the CLI preset). If provided, it takes precedence
182
+ over any other starting point.
183
+ overlays (Dict[str, Any] | None): A nested mapping to apply **after** the
184
+ YAML (e.g., derived CLI flags). These win over YAML values.
185
+ preset_builder (Callable[[str], T] | None): Retained for backward
186
+ compatibility. Not used when enforcing CLI-only presets.
187
+ yaml_preset_behavior (Literal["ignore","error"]): What to do if the YAML
188
+ contains a `preset` key. Default: "ignore".
189
+
190
+ Returns:
191
+ T: The merged dataclass instance.
192
+
193
+ Raises:
194
+ TypeError: If `base` is not a dataclass, or YAML root isn't a mapping,
195
+ or `overlays` isn't a mapping when provided.
196
+ ValueError: If `yaml_preset_behavior="error"` and YAML contains `preset`.
197
+ KeyError: If any override path is invalid.
198
+ """
199
+ with open(path, "r", encoding="utf-8") as f:
200
+ raw = yaml.safe_load(f) or {}
201
+ raw = _walk_env(raw)
202
+
203
+ # Enforce: 'preset' is CLI-only.
204
+ if isinstance(raw, dict) and "preset" in raw:
205
+ preset_in_yaml = raw.get("preset")
206
+ if yaml_preset_behavior == "error":
207
+ raise ValueError(
208
+ f"YAML contains 'preset: {preset_in_yaml}'. "
209
+ "The preset must be selected via the command line only."
210
+ )
211
+ # ignore (default): drop it and continue
212
+ logging.warning(
213
+ "Ignoring 'preset' in YAML (%r). Preset selection is CLI-only.",
214
+ preset_in_yaml,
215
+ )
216
+ raw.pop("preset", None)
217
+
218
+ # Start from `base` if given; else construct a fresh instance of dc_type.
219
+ if base is not None:
220
+ if not is_dataclass(base):
221
+ raise TypeError("`base` must be a dataclass instance.")
222
+ cfg = copy.deepcopy(base)
223
+ else:
224
+ # Do NOT call preset_builder here; presets are CLI-only.
225
+ cfg = dc_type() # defaults
226
+
227
+ if not isinstance(raw, dict):
228
+ raise TypeError(f"{path} did not parse as a mapping.")
229
+
230
+ # YAML overlays the starting config
231
+ _merge_into_dataclass(cfg, raw)
232
+
233
+ # Optional final overlays (e.g., mapped CLI flags / --set already parsed)
234
+ if overlays:
235
+ if not isinstance(overlays, dict):
236
+ raise TypeError("`overlays` must be a nested dict.")
237
+ _merge_into_dataclass(cfg, overlays)
238
+
239
+ return cfg
240
+
241
+
242
+ def _is_dataclass_type(tp: t.Any) -> bool:
243
+ """Return True if tp is a dataclass type (not instance).
244
+
245
+ This function checks if the given type is a dataclass type by verifying its properties and using the `is_dataclass` function from the `dataclasses` module.
246
+
247
+ Args:
248
+ tp (t.Any): A type to check.
249
+
250
+ Returns:
251
+ bool: True if `tp` is a dataclass type, False otherwise.
252
+ """
253
+ try:
254
+ return isinstance(tp, type) and dataclasses.is_dataclass(tp)
255
+ except Exception:
256
+ return False
257
+
258
+
259
+ def _unwrap_optional(tp: t.Any) -> t.Any:
260
+ """If Optional[T] or Union[T, None], return T; else tp.
261
+
262
+ This function checks if the given type is an Optional or a Union that includes None, and if so, it returns the non-None type. Otherwise, it returns the original type.
263
+
264
+ Args:
265
+ tp (t.Any): A type annotation.
266
+
267
+ Returns:
268
+ t.Any: The unwrapped type, or the original type if not applicable.
269
+ """
270
+ origin = t.get_origin(tp)
271
+ if origin is t.Union:
272
+ args = [a for a in t.get_args(tp) if a is not type(None)]
273
+ return args[0] if len(args) == 1 else tp
274
+ return tp
275
+
276
+
277
+ def _expected_field_type(dc_type: type, name: str) -> t.Any:
278
+ """Fetch the annotated type of field `name` on dataclass type `dc_type`.
279
+
280
+ This function retrieves the type annotation for a specific field in a dataclass. If the field is not found, it raises a KeyError.
281
+
282
+ Args:
283
+ dc_type (type): A dataclass type.
284
+ name (str): The field name to look up.
285
+
286
+ Returns:
287
+ t.Any: The annotated type of the field.
288
+
289
+ Raises:
290
+ KeyError: If the field is unknown.
291
+ """
292
+ for f in fields(dc_type):
293
+ if f.name == name:
294
+ hint = f.type
295
+ if isinstance(hint, str):
296
+ try:
297
+ resolved = t.get_type_hints(dc_type).get(name, hint)
298
+ hint = resolved
299
+ except Exception:
300
+ pass
301
+ return hint
302
+ raise KeyError(f"Unknown config key: '{name}' on {dc_type.__name__}")
303
+
304
+
305
+ def _instantiate_field(dc_type: type, name: str):
306
+ """Create a default instance for nested dataclass field `name`.
307
+
308
+ Attempts to use default_factory, then default, then type constructor. If none are available, raises KeyError.
309
+
310
+ Args:
311
+ dc_type (type): A dataclass type.
312
+ name (str): The field name to instantiate.
313
+
314
+ Returns:
315
+ Any: An instance of the field's type.
316
+
317
+ Raises:
318
+ KeyError: If the field is unknown or cannot be instantiated.
319
+ """
320
+ for f in fields(dc_type):
321
+ if f.name == name:
322
+ # Prefer default_factory → default → type()
323
+ if f.default_factory is not MISSING: # type: ignore[attr-defined]
324
+ return f.default_factory()
325
+ if f.default is not MISSING:
326
+ val = f.default
327
+ # If default is None but type is dataclass, construct it:
328
+ tp = _unwrap_optional(f.type)
329
+ if val is None and _is_dataclass_type(tp):
330
+ return tp()
331
+ return val
332
+ # No default supplied; if it's a dataclass type, construct it.
333
+ tp = _unwrap_optional(f.type)
334
+ if _is_dataclass_type(tp):
335
+ return tp()
336
+ # Otherwise we cannot guess safely:
337
+ raise KeyError(
338
+ f"Cannot create default for '{name}' on {dc_type.__name__}; "
339
+ "no default/default_factory and not a dataclass field."
340
+ )
341
+ raise KeyError(f"Unknown config key: '{name}' on {dc_type.__name__}'")
342
+
343
+
344
+ def _merge_mapping_into_dataclass(
345
+ instance: T, payload: dict, *, path: str = "<root>"
346
+ ) -> T:
347
+ """Recursively merge a dict into a dataclass instance (strict on keys).
348
+
349
+ This function updates the fields of a dataclass instance with values from a nested mapping (dict). It ensures that all keys in the mapping correspond to fields in the dataclass, and it handles nested dataclass fields as well.
350
+
351
+ Args:
352
+ instance (T): A dataclass instance to update.
353
+ payload (dict): A nested mapping to merge into `instance`.
354
+ path (str): Internal use only; tracks the current path for error messages.
355
+
356
+ Returns:
357
+ T: The updated dataclass instance (same as `instance`).
358
+
359
+ Raises:
360
+ TypeError: If `instance` is not a dataclass.
361
+ KeyError: If `payload` contains keys not present in `instance`.
362
+ """
363
+ if not is_dataclass(instance):
364
+ raise TypeError(f"Expected dataclass at {path}, got {type(instance)}")
365
+
366
+ dc_type = type(instance)
367
+ for k, v in payload.items():
368
+ # Ensure field exists
369
+ exp_type = _expected_field_type(dc_type, k)
370
+ exp_core = _unwrap_optional(exp_type)
371
+
372
+ cur = getattr(instance, k, MISSING)
373
+ if cur is MISSING:
374
+ raise KeyError(f"Unknown config key: '{path}.{k}'")
375
+
376
+ if _is_dataclass_type(exp_core) and isinstance(v, dict):
377
+ # Ensure we have a dataclass instance to merge into
378
+ if cur is None or not is_dataclass(cur):
379
+ cur = _instantiate_field(dc_type, k)
380
+ setattr(instance, k, cur)
381
+ merged = _merge_mapping_into_dataclass(cur, v, path=f"{path}.{k}")
382
+ setattr(instance, k, merged)
383
+ else:
384
+ setattr(
385
+ instance,
386
+ k,
387
+ _coerce_value(v, exp_core, f"{path}.{k}", current=cur),
388
+ )
389
+ return instance
390
+
391
+
392
+ def _coerce_value(
393
+ value: t.Any, tp: t.Any, where: str, *, current: t.Any = MISSING
394
+ ):
395
+ """Lightweight coercion for common primitives and Literals.
396
+
397
+ This function attempts to coerce a value into a target type, handling common cases like basic primitives (int, float, bool, str) and Literal types. If coercion is not applicable or fails, it returns the original value.
398
+
399
+ Args:
400
+ value (t.Any): The input value to coerce.
401
+ tp (t.Any): The target type annotation.
402
+ where (str): Context string for error messages.
403
+
404
+ Returns:
405
+ t.Any: The coerced value, or the original if no coercion was applied.
406
+
407
+ Raises:
408
+ ValueError: If the value is not valid for a Literal type.
409
+ TypeError: If the value cannot be coerced to the target type.
410
+ """
411
+ origin = t.get_origin(tp)
412
+ args = t.get_args(tp)
413
+
414
+ if tp in {t.Any, object, None}:
415
+ if current is not MISSING and current is not None:
416
+ infer_type = type(current)
417
+ if isinstance(current, bool):
418
+ tp = bool
419
+ elif isinstance(current, int) and not isinstance(current, bool):
420
+ tp = int
421
+ elif isinstance(current, float):
422
+ tp = float
423
+ elif isinstance(current, str):
424
+ tp = str
425
+ else:
426
+ tp = infer_type
427
+
428
+ # Literal[...] → restrict values
429
+ if origin is t.Literal:
430
+ allowed = set(args)
431
+ if value not in allowed:
432
+ raise ValueError(
433
+ f"Invalid value for {where}. Expected one of {sorted(allowed)}, got {value!r}."
434
+ )
435
+ return value
436
+
437
+ # Basic primitives coercion
438
+ if tp in (int, float, bool, str):
439
+ if tp is bool:
440
+ if isinstance(value, str):
441
+ v = value.strip().lower()
442
+ truthy = {"true", "1", "yes", "on"}
443
+ falsy = {"false", "0", "no", "off"}
444
+ if v in truthy:
445
+ return True
446
+ if v in falsy:
447
+ return False
448
+ if v == "" and current is not MISSING:
449
+ return bool(current)
450
+ return bool(value)
451
+
452
+ if isinstance(value, str):
453
+ stripped = value.strip()
454
+ if stripped == "":
455
+ return current if current is not MISSING else value
456
+ try:
457
+ return tp(stripped)
458
+ except Exception:
459
+ return value
460
+
461
+ try:
462
+ return tp(value)
463
+ except Exception:
464
+ return value
465
+
466
+ # Dataclasses or other complex types → trust caller
467
+ return value
468
+
469
+
470
+ def apply_dot_overrides(
471
+ dc: t.Any,
472
+ overrides: dict[str, t.Any] | None,
473
+ *,
474
+ root_cls: type | None = None,
475
+ create_missing: bool = False,
476
+ registry: dict[str, type] | None = None,
477
+ ) -> t.Any:
478
+ """Apply overrides like {'io.prefix': '...', 'train.batch_size': 64} to any \*Config dataclass.
479
+
480
+ This function updates the fields of a dataclass instance with values from a nested mapping (dict). It ensures that all keys in the mapping correspond to fields in the dataclass, and it handles nested dataclass fields as well.
481
+
482
+ Args:
483
+ dc (t.Any): A dataclass instance (or a dict that can be up-cast).
484
+ overrides (dict[str, t.Any] | None): Mapping of dot-key paths to values.
485
+ root_cls (type | None): Optional dataclass type to up-cast a root dict into (if `dc` is a dict).
486
+ create_missing (bool): If True, instantiate missing intermediate dataclass nodes when the schema defines them.
487
+ registry (dict[str, type] | None): Optional mapping from top-level segment → dataclass type to assist up-casting.
488
+
489
+ Returns:
490
+ t.Any: The updated dataclass instance (same object identity is not guaranteed; a deep copy is made).
491
+
492
+ Notes:
493
+ - No hard-coding of NLPCAConfig. Pass `root_cls=NLPCAConfig` (or UBPConfig, etc.) when starting from a dict.
494
+ - Dict payloads encountered at intermediate nodes are merged into the expected dataclass type using schema introspection.
495
+ - Enforces unknown-key errors to keep configs honest.
496
+
497
+ Raises:
498
+ TypeError: If `dc` is not a dataclass or dict (for up-cast).
499
+ KeyError: If any override path is invalid.
500
+ """
501
+ if not overrides:
502
+ return dc
503
+
504
+ # Root up-cast if needed
505
+ if not is_dataclass(dc):
506
+ if isinstance(dc, dict):
507
+ if root_cls is None:
508
+ raise TypeError(
509
+ "Root payload is a dict. Provide `root_cls` to up-cast it into the desired *Config dataclass."
510
+ )
511
+ base = root_cls()
512
+ dc = _merge_mapping_into_dataclass(base, dc)
513
+ else:
514
+ raise TypeError(
515
+ "apply_dot_overrides expects a dataclass instance or a dict for up-cast."
516
+ )
517
+
518
+ updated = copy.deepcopy(dc)
519
+
520
+ for dotkey, value in overrides.items():
521
+ parts = dotkey.split(".")
522
+ node = updated
523
+ node_type = type(node)
524
+
525
+ # Descend to parent
526
+ for idx, seg in enumerate(parts[:-1]):
527
+ if not is_dataclass(node):
528
+ parent_path = ".".join(parts[:idx]) or "<root>"
529
+ raise KeyError(
530
+ f"Target '{parent_path}' is not a dataclass in the override path; cannot descend into non-dataclass objects."
531
+ )
532
+
533
+ # Validate field existence and fetch expected type
534
+ exp_type = _expected_field_type(node_type, seg)
535
+ exp_core = _unwrap_optional(exp_type)
536
+
537
+ # Materialize or up-cast if needed
538
+ child = getattr(node, seg, MISSING)
539
+ if child is MISSING:
540
+ raise KeyError(f"Unknown config key: '{'.'.join(parts[:idx+1])}'")
541
+
542
+ if isinstance(child, dict) and _is_dataclass_type(exp_core):
543
+ # Up-cast dict → dataclass of the expected type
544
+ child = _merge_mapping_into_dataclass(
545
+ exp_core(), child, path=".".join(parts[: idx + 1])
546
+ )
547
+ setattr(node, seg, child)
548
+
549
+ if child is None and create_missing and _is_dataclass_type(exp_core):
550
+ child = exp_core()
551
+ setattr(node, seg, child)
552
+
553
+ node = getattr(node, seg)
554
+ node_type = type(node)
555
+
556
+ # Assign leaf with light coercion
557
+ if not is_dataclass(node):
558
+ parent_path = ".".join(parts[:-1]) or "<root>"
559
+ raise KeyError(
560
+ f"Target '{parent_path}' is not a dataclass in the override path; cannot set '{parts[-1]}'."
561
+ )
562
+
563
+ leaf = parts[-1]
564
+
565
+ # Check field exists and coerce to its annotated type
566
+ exp_type = _expected_field_type(type(node), leaf)
567
+ exp_core = _unwrap_optional(exp_type)
568
+
569
+ if not hasattr(node, leaf):
570
+ raise KeyError(f"Unknown config key: '{dotkey}'")
571
+
572
+ current = getattr(node, leaf, MISSING)
573
+ coerced = _coerce_value(value, exp_core, dotkey, current=current)
574
+ setattr(node, leaf, coerced)
575
+
576
+ return updated