pg-sui 0.2.3__py3-none-any.whl → 1.6.16a3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (128) hide show
  1. pg_sui-1.6.16a3.dist-info/METADATA +292 -0
  2. pg_sui-1.6.16a3.dist-info/RECORD +81 -0
  3. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info}/WHEEL +1 -1
  4. pg_sui-1.6.16a3.dist-info/entry_points.txt +4 -0
  5. {pg_sui-0.2.3.dist-info → pg_sui-1.6.16a3.dist-info/licenses}/LICENSE +0 -0
  6. pg_sui-1.6.16a3.dist-info/top_level.txt +1 -0
  7. pgsui/__init__.py +35 -54
  8. pgsui/_version.py +34 -0
  9. pgsui/cli.py +922 -0
  10. pgsui/data_processing/__init__.py +0 -0
  11. pgsui/data_processing/config.py +565 -0
  12. pgsui/data_processing/containers.py +1436 -0
  13. pgsui/data_processing/transformers.py +557 -907
  14. pgsui/{example_data/trees → electron/app}/__init__.py +0 -0
  15. pgsui/electron/app/__main__.py +5 -0
  16. pgsui/electron/app/extra-resources/.gitkeep +1 -0
  17. pgsui/electron/app/icons/icons/1024x1024.png +0 -0
  18. pgsui/electron/app/icons/icons/128x128.png +0 -0
  19. pgsui/electron/app/icons/icons/16x16.png +0 -0
  20. pgsui/electron/app/icons/icons/24x24.png +0 -0
  21. pgsui/electron/app/icons/icons/256x256.png +0 -0
  22. pgsui/electron/app/icons/icons/32x32.png +0 -0
  23. pgsui/electron/app/icons/icons/48x48.png +0 -0
  24. pgsui/electron/app/icons/icons/512x512.png +0 -0
  25. pgsui/electron/app/icons/icons/64x64.png +0 -0
  26. pgsui/electron/app/icons/icons/icon.icns +0 -0
  27. pgsui/electron/app/icons/icons/icon.ico +0 -0
  28. pgsui/electron/app/main.js +227 -0
  29. pgsui/electron/app/package-lock.json +6894 -0
  30. pgsui/electron/app/package.json +51 -0
  31. pgsui/electron/app/preload.js +15 -0
  32. pgsui/electron/app/server.py +157 -0
  33. pgsui/electron/app/ui/logo.png +0 -0
  34. pgsui/electron/app/ui/renderer.js +131 -0
  35. pgsui/electron/app/ui/styles.css +59 -0
  36. pgsui/electron/app/ui/ui_shim.js +72 -0
  37. pgsui/electron/bootstrap.py +43 -0
  38. pgsui/electron/launch.py +57 -0
  39. pgsui/electron/package.json +14 -0
  40. pgsui/example_data/__init__.py +0 -0
  41. pgsui/example_data/phylip_files/__init__.py +0 -0
  42. pgsui/example_data/phylip_files/test.phy +0 -0
  43. pgsui/example_data/popmaps/__init__.py +0 -0
  44. pgsui/example_data/popmaps/{test.popmap → phylogen_nomx.popmap} +185 -99
  45. pgsui/example_data/structure_files/__init__.py +0 -0
  46. pgsui/example_data/structure_files/test.pops.2row.allsites.str +0 -0
  47. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz +0 -0
  48. pgsui/example_data/vcf_files/phylogen_subset14K.vcf.gz.tbi +0 -0
  49. pgsui/impute/__init__.py +0 -0
  50. pgsui/impute/deterministic/imputers/allele_freq.py +725 -0
  51. pgsui/impute/deterministic/imputers/mode.py +844 -0
  52. pgsui/impute/deterministic/imputers/nmf.py +221 -0
  53. pgsui/impute/deterministic/imputers/phylo.py +973 -0
  54. pgsui/impute/deterministic/imputers/ref_allele.py +669 -0
  55. pgsui/impute/supervised/__init__.py +0 -0
  56. pgsui/impute/supervised/base.py +343 -0
  57. pgsui/impute/{unsupervised/models/in_development → supervised/imputers}/__init__.py +0 -0
  58. pgsui/impute/supervised/imputers/hist_gradient_boosting.py +317 -0
  59. pgsui/impute/supervised/imputers/random_forest.py +291 -0
  60. pgsui/impute/unsupervised/__init__.py +0 -0
  61. pgsui/impute/unsupervised/base.py +1121 -0
  62. pgsui/impute/unsupervised/callbacks.py +92 -262
  63. {simulation → pgsui/impute/unsupervised/imputers}/__init__.py +0 -0
  64. pgsui/impute/unsupervised/imputers/autoencoder.py +1361 -0
  65. pgsui/impute/unsupervised/imputers/nlpca.py +1666 -0
  66. pgsui/impute/unsupervised/imputers/ubp.py +1660 -0
  67. pgsui/impute/unsupervised/imputers/vae.py +1316 -0
  68. pgsui/impute/unsupervised/loss_functions.py +261 -0
  69. pgsui/impute/unsupervised/models/__init__.py +0 -0
  70. pgsui/impute/unsupervised/models/autoencoder_model.py +215 -567
  71. pgsui/impute/unsupervised/models/nlpca_model.py +155 -394
  72. pgsui/impute/unsupervised/models/ubp_model.py +180 -1106
  73. pgsui/impute/unsupervised/models/vae_model.py +269 -630
  74. pgsui/impute/unsupervised/nn_scorers.py +255 -0
  75. pgsui/utils/__init__.py +0 -0
  76. pgsui/utils/classification_viz.py +608 -0
  77. pgsui/utils/logging_utils.py +22 -0
  78. pgsui/utils/misc.py +35 -480
  79. pgsui/utils/plotting.py +996 -829
  80. pgsui/utils/pretty_metrics.py +290 -0
  81. pgsui/utils/scorers.py +213 -666
  82. pg_sui-0.2.3.dist-info/METADATA +0 -322
  83. pg_sui-0.2.3.dist-info/RECORD +0 -75
  84. pg_sui-0.2.3.dist-info/top_level.txt +0 -3
  85. pgsui/example_data/phylip_files/test_n10.phy +0 -118
  86. pgsui/example_data/phylip_files/test_n100.phy +0 -118
  87. pgsui/example_data/phylip_files/test_n2.phy +0 -118
  88. pgsui/example_data/phylip_files/test_n500.phy +0 -118
  89. pgsui/example_data/structure_files/test.nopops.1row.10sites.str +0 -117
  90. pgsui/example_data/structure_files/test.nopops.2row.100sites.str +0 -234
  91. pgsui/example_data/structure_files/test.nopops.2row.10sites.str +0 -234
  92. pgsui/example_data/structure_files/test.nopops.2row.30sites.str +0 -234
  93. pgsui/example_data/structure_files/test.nopops.2row.allsites.str +0 -234
  94. pgsui/example_data/structure_files/test.pops.1row.10sites.str +0 -117
  95. pgsui/example_data/structure_files/test.pops.2row.10sites.str +0 -234
  96. pgsui/example_data/trees/test.iqtree +0 -376
  97. pgsui/example_data/trees/test.qmat +0 -5
  98. pgsui/example_data/trees/test.rate +0 -2033
  99. pgsui/example_data/trees/test.tre +0 -1
  100. pgsui/example_data/trees/test_n10.rate +0 -19
  101. pgsui/example_data/trees/test_n100.rate +0 -109
  102. pgsui/example_data/trees/test_n500.rate +0 -509
  103. pgsui/example_data/trees/test_siterates.txt +0 -2024
  104. pgsui/example_data/trees/test_siterates_n10.txt +0 -10
  105. pgsui/example_data/trees/test_siterates_n100.txt +0 -100
  106. pgsui/example_data/trees/test_siterates_n500.txt +0 -500
  107. pgsui/example_data/vcf_files/test.vcf +0 -244
  108. pgsui/example_data/vcf_files/test.vcf.gz +0 -0
  109. pgsui/example_data/vcf_files/test.vcf.gz.tbi +0 -0
  110. pgsui/impute/estimators.py +0 -1268
  111. pgsui/impute/impute.py +0 -1463
  112. pgsui/impute/simple_imputers.py +0 -1431
  113. pgsui/impute/supervised/iterative_imputer_fixedparams.py +0 -782
  114. pgsui/impute/supervised/iterative_imputer_gridsearch.py +0 -1024
  115. pgsui/impute/unsupervised/keras_classifiers.py +0 -697
  116. pgsui/impute/unsupervised/models/in_development/cnn_model.py +0 -486
  117. pgsui/impute/unsupervised/neural_network_imputers.py +0 -1440
  118. pgsui/impute/unsupervised/neural_network_methods.py +0 -1395
  119. pgsui/pg_sui.py +0 -261
  120. pgsui/utils/sequence_tools.py +0 -407
  121. simulation/sim_benchmarks.py +0 -333
  122. simulation/sim_treeparams.py +0 -475
  123. test/__init__.py +0 -0
  124. test/pg_sui_simtest.py +0 -215
  125. test/pg_sui_testing.py +0 -523
  126. test/test.py +0 -151
  127. test/test_pgsui.py +0 -374
  128. test/test_tkc.py +0 -185
File without changes
@@ -0,0 +1,565 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ import dataclasses
5
+ import logging
6
+ import os
7
+ import typing as t
8
+ from dataclasses import MISSING, asdict, fields, is_dataclass
9
+ from typing import Any, Dict, Literal, Type
10
+
11
+ import yaml
12
+
13
+ T = t.TypeVar("T")
14
+
15
+
16
+ """
17
+ Config utilities for PG-SUI.
18
+
19
+ We keep nested configs as dataclasses at all times.
20
+
21
+ Public API:
22
+ - load_yaml_to_dataclass
23
+ - apply_dot_overrides
24
+ - dataclass_to_yaml
25
+ - save_dataclass_yaml
26
+ """
27
+
28
+
29
+ # ---------------- Env var interpolation ----------------
30
+ def _interpolate_env(s: str) -> str:
31
+ """Interpolate env vars in a string.
32
+
33
+ Syntax: ${VAR} or ${VAR:default}
34
+
35
+ Args:
36
+ s (str): Input string possibly containing env var patterns.
37
+
38
+ Returns:
39
+ str: The string with env vars interpolated.
40
+ """
41
+ out, i = [], 0
42
+ while i < len(s):
43
+ if s[i : i + 2] == "${":
44
+ j = s.find("}", i + 2)
45
+ if j == -1:
46
+ out.append(s[i:])
47
+ break
48
+ token = s[i + 2 : j]
49
+ if ":" in token:
50
+ var, default = token.split(":", 1)
51
+ out.append(os.getenv(var, default))
52
+ else:
53
+ out.append(os.getenv(token, ""))
54
+ i = j + 1
55
+ else:
56
+ out.append(s[i])
57
+ i += 1
58
+ return "".join(out)
59
+
60
+
61
+ def _walk_env(obj: Any) -> Any:
62
+ """Recursively interpolate env vars in strings within a nested structure.
63
+
64
+ This function traverses the input object and applies environment variable interpolation to any strings it encounters.
65
+
66
+ Args:
67
+ obj (Any): The input object, which can be a string, dict, list, or other types.
68
+
69
+ Returns:
70
+ Any: The object with environment variables interpolated in strings.
71
+ """
72
+ if isinstance(obj, str):
73
+ return _interpolate_env(obj)
74
+ if isinstance(obj, dict):
75
+ return {k: _walk_env(v) for k, v in obj.items()}
76
+ if isinstance(obj, list):
77
+ return [_walk_env(v) for v in obj]
78
+ return obj
79
+
80
+
81
+ # ---------------- YAML helpers ----------------
82
+ def dataclass_to_yaml(dc: t.Any) -> str:
83
+ """Convert a dataclass instance to a YAML string.
84
+
85
+ This function uses the `asdict` function from the `dataclasses` module to convert the dataclass instance into a dictionary, which is then serialized to a YAML string using the `yaml` module.
86
+
87
+ Args:
88
+ dc (t.Any): A dataclass instance.
89
+
90
+ Returns:
91
+ str: The YAML representation of the dataclass.
92
+
93
+ Raises:
94
+ TypeError: If `dc` is not a dataclass instance.
95
+ """
96
+ if not is_dataclass(dc) or isinstance(dc, type):
97
+ raise TypeError("dataclass_to_yaml expects a dataclass instance.")
98
+ return yaml.safe_dump(asdict(dc), sort_keys=False)
99
+
100
+
101
+ def save_dataclass_yaml(dc: t.Any, path: str) -> None:
102
+ """Save a dataclass instance as a YAML file.
103
+
104
+ This function uses the `dataclass_to_yaml` function to convert the dataclass instance into a YAML string, which is then written to a file.
105
+
106
+ Args:
107
+ dc (T): A dataclass instance.
108
+ path (str): Path to save the YAML file.
109
+
110
+ Raises:
111
+ TypeError: If `dc` is not a dataclass instance.
112
+ """
113
+ if not is_dataclass(dc) or isinstance(dc, type):
114
+ raise TypeError("save_dataclass_yaml expects a dataclass instance.")
115
+
116
+ with open(path, "w", encoding="utf-8") as f:
117
+ f.write(dataclass_to_yaml(dc))
118
+
119
+
120
+ def _merge_into_dataclass(inst: Any, payload: Dict[str, Any], path: str = "") -> Any:
121
+ """Recursively merge a nested dict into a dataclass instance in place.
122
+
123
+ This function updates the fields of the dataclass instance with values from the nested mapping. It raises errors for unknown keys and ensures that nested dataclasses are merged recursively.
124
+
125
+ Args:
126
+ inst (Any): A dataclass instance to update.
127
+ payload (Dict[str, Any]): A nested mapping to merge into `inst`.
128
+ path (str): Internal use only; tracks the current path for error messages.
129
+
130
+ Returns:
131
+ Any: The updated dataclass instance (same as `inst`).
132
+
133
+ Raises:
134
+ TypeError: If `inst` is not a dataclass.
135
+ KeyError: If `payload` contains keys not present in `inst`.
136
+ """
137
+ if not is_dataclass(inst):
138
+ raise TypeError(
139
+ f"_merge_into_dataclass expects a dataclass at '{path or '<root>'}'"
140
+ )
141
+
142
+ fld_map = {f.name: f for f in fields(inst)}
143
+ for k, v in payload.items():
144
+ if k not in fld_map:
145
+ full = f"{path + '.' if path else ''}{k}"
146
+ raise KeyError(f"Unknown key '{full}'")
147
+ cur = getattr(inst, k)
148
+ if is_dataclass(cur) and isinstance(v, dict):
149
+ _merge_into_dataclass(cur, v, path=(f"{path}.{k}" if path else k))
150
+ else:
151
+ setattr(inst, k, v)
152
+ return inst
153
+
154
+
155
+ def load_yaml_to_dataclass(
156
+ path: str,
157
+ dc_type: Type[T],
158
+ *,
159
+ base: T | None = None,
160
+ overlays: Dict[str, Any] | None = None,
161
+ yaml_preset_behavior: Literal["ignore", "error"] = "ignore",
162
+ ) -> T:
163
+ """Load a YAML file and merge into a dataclass instance with strict precedence.
164
+
165
+ This function is designed for the new argument hierarchy: defaults < CLI preset (build `base` from it) < YAML file < CLI args/--set
166
+
167
+ Notes:
168
+ - `preset` is **CLI-only**. If the YAML contains `preset`, it will be ignored (default) or cause an error depending on `yaml_preset_behavior`.
169
+ - Pass a `base` instance that is already constructed from the CLI-selected preset (e.g., `NLPCAConfig.from_preset(args.preset)`), and this function will overlay the YAML on top of it. Any additional `overlays` (a nested dict) are applied last.
170
+
171
+ Args:
172
+ path (str): Path to the YAML file.
173
+ dc_type (Type[T]): Dataclass type to construct if `base` is not provided.
174
+ base (T | None): A preconstructed dataclass instance to start from
175
+ (typically built from the CLI preset). If provided, it takes precedence
176
+ over any other starting point.
177
+ overlays (Dict[str, Any] | None): A nested mapping to apply **after** the
178
+ YAML (e.g., derived CLI flags). These win over YAML values.
179
+ yaml_preset_behavior (Literal["ignore","error"]): What to do if the YAML
180
+ contains a `preset` key. Default: "ignore".
181
+
182
+ Returns:
183
+ Type[T]: The merged dataclass instance.
184
+
185
+ Raises:
186
+ TypeError: If `base` is not a dataclass, or YAML root isn't a mapping,
187
+ or `overlays` isn't a mapping when provided.
188
+ ValueError: If `yaml_preset_behavior="error"` and YAML contains `preset`.
189
+ KeyError: If any override path is invalid.
190
+ """
191
+ with open(path, "r", encoding="utf-8") as f:
192
+ raw = yaml.safe_load(f) or {}
193
+ raw = _walk_env(raw)
194
+
195
+ # Enforce: 'preset' is CLI-only.
196
+ if isinstance(raw, dict) and "preset" in raw:
197
+ preset_in_yaml = raw.get("preset")
198
+ if yaml_preset_behavior == "error":
199
+ raise ValueError(
200
+ f"YAML contains 'preset: {preset_in_yaml}'. "
201
+ "The preset must be selected via the command line only."
202
+ )
203
+ # ignore (default): drop it and continue
204
+ logging.warning(
205
+ "Ignoring 'preset' in YAML (%r). Preset selection is CLI-only.",
206
+ preset_in_yaml,
207
+ )
208
+ raw.pop("preset", None)
209
+
210
+ # Start from `base` if given; else construct a fresh instance of dc_type.
211
+ if base is not None:
212
+ if not is_dataclass(base):
213
+ raise TypeError("`base` must be a dataclass instance.")
214
+ cfg = copy.deepcopy(base)
215
+ else:
216
+ cfg = dc_type() # defaults
217
+
218
+ if not isinstance(raw, dict):
219
+ raise TypeError(f"{path} did not parse as a mapping.")
220
+
221
+ # YAML overlays the starting config
222
+ _merge_into_dataclass(cfg, raw)
223
+
224
+ # Optional final overlays (e.g., mapped CLI flags / --set already parsed)
225
+ if overlays:
226
+ if not isinstance(overlays, dict):
227
+ raise TypeError("`overlays` must be a nested dict.")
228
+ _merge_into_dataclass(cfg, overlays)
229
+
230
+ return cfg
231
+
232
+
233
+ def _is_dataclass_type(tp: t.Any) -> bool:
234
+ """Return True if tp is a dataclass type (not instance).
235
+
236
+ This function checks if the given type is a dataclass type by verifying its properties and using the `is_dataclass` function from the `dataclasses` module.
237
+
238
+ Args:
239
+ tp (t.Any): A type to check.
240
+
241
+ Returns:
242
+ bool: True if `tp` is a dataclass type, False otherwise.
243
+ """
244
+ try:
245
+ return isinstance(tp, type) and dataclasses.is_dataclass(tp)
246
+ except Exception:
247
+ return False
248
+
249
+
250
+ def _unwrap_optional(tp: t.Any) -> t.Any:
251
+ """If Optional[T] or Union[T, None], return T; else tp.
252
+
253
+ This function checks if the given type is an Optional or a Union that includes None, and if so, it returns the non-None type. Otherwise, it returns the original type.
254
+
255
+ Args:
256
+ tp (t.Any): A type annotation.
257
+
258
+ Returns:
259
+ t.Any: The unwrapped type, or the original type if not applicable.
260
+ """
261
+ origin = t.get_origin(tp)
262
+ if origin is t.Union:
263
+ args = [a for a in t.get_args(tp) if a is not type(None)]
264
+ return args[0] if len(args) == 1 else tp
265
+ return tp
266
+
267
+
268
+ def _expected_field_type(dc_type: type, name: str) -> t.Any:
269
+ """Fetch the annotated type of field `name` on dataclass type `dc_type`.
270
+
271
+ This function retrieves the type annotation for a specific field in a dataclass. If the field is not found, it raises a KeyError.
272
+
273
+ Args:
274
+ dc_type (type): A dataclass type.
275
+ name (str): The field name to look up.
276
+
277
+ Returns:
278
+ t.Any: The annotated type of the field.
279
+
280
+ Raises:
281
+ KeyError: If the field is unknown.
282
+ """
283
+ for f in fields(dc_type):
284
+ if f.name == name:
285
+ hint = f.type
286
+ if isinstance(hint, str):
287
+ try:
288
+ resolved = t.get_type_hints(dc_type).get(name, hint)
289
+ hint = resolved
290
+ except Exception:
291
+ pass
292
+ return hint
293
+ raise KeyError(f"Unknown config key: '{name}' on {dc_type.__name__}")
294
+
295
+
296
+ def _instantiate_field(dc_type: t.Any, name: str):
297
+ """Create a default instance for nested dataclass field `name`.
298
+
299
+ Attempts to use default_factory, then default, then type constructor. If none are available, raises KeyError.
300
+
301
+ Args:
302
+ dc_type (Any): A dataclass type.
303
+ name (str): The field name to instantiate.
304
+
305
+ Returns:
306
+ Any: An instance of the field's type.
307
+
308
+ Raises:
309
+ KeyError: If the field is unknown or cannot be instantiated.
310
+ """
311
+ for f in fields(dc_type):
312
+ if f.name == name:
313
+ # Prefer default_factory → default → type()
314
+ if f.default_factory is not MISSING: # type: ignore[attr-defined]
315
+ return f.default_factory()
316
+ if f.default is not MISSING:
317
+ val = f.default
318
+ # If default is None but type is dataclass, construct it:
319
+ tp = _unwrap_optional(f.type)
320
+ if val is None and _is_dataclass_type(tp):
321
+ return tp()
322
+ return val
323
+ # No default supplied; if it's a dataclass type, construct it.
324
+ tp = _unwrap_optional(f.type)
325
+ if _is_dataclass_type(tp):
326
+ return tp()
327
+ # Otherwise we cannot guess safely:
328
+ raise KeyError(
329
+ f"Cannot create default for '{name}' on {dc_type.__name__}; "
330
+ "no default/default_factory and not a dataclass field."
331
+ )
332
+ raise KeyError(f"Unknown config key: '{name}' on {dc_type.__name__}'")
333
+
334
+
335
+ def _merge_mapping_into_dataclass(
336
+ instance: T, payload: dict, *, path: str = "<root>"
337
+ ) -> T:
338
+ """Recursively merge a dict into a dataclass instance (strict on keys).
339
+
340
+ This function updates the fields of a dataclass instance with values from a nested mapping (dict). It ensures that all keys in the mapping correspond to fields in the dataclass, and it handles nested dataclass fields as well.
341
+
342
+ Args:
343
+ instance (T): A dataclass instance to update.
344
+ payload (dict): A nested mapping to merge into `instance`.
345
+ path (str): Internal use only; tracks the current path for error messages.
346
+
347
+ Returns:
348
+ Type[T]: The updated dataclass instance (same as `instance`).
349
+
350
+ Raises:
351
+ TypeError: If `instance` is not a dataclass.
352
+ KeyError: If `payload` contains keys not present in `instance`.
353
+ """
354
+ if not is_dataclass(instance):
355
+ raise TypeError(f"Expected dataclass at {path}, got {type(instance)}")
356
+
357
+ dc_type = type(instance)
358
+ for k, v in payload.items():
359
+ # Ensure field exists
360
+ exp_type = _expected_field_type(dc_type, k)
361
+ exp_core = _unwrap_optional(exp_type)
362
+
363
+ cur = getattr(instance, k, MISSING)
364
+ if cur is MISSING:
365
+ raise KeyError(f"Unknown config key: '{path}.{k}'")
366
+
367
+ if _is_dataclass_type(exp_core) and isinstance(v, dict):
368
+ # Ensure we have a dataclass instance to merge into
369
+ if cur is None or not is_dataclass(cur):
370
+ cur = _instantiate_field(dc_type, k)
371
+ setattr(instance, k, cur)
372
+ merged = _merge_mapping_into_dataclass(cur, v, path=f"{path}.{k}")
373
+ setattr(instance, k, merged)
374
+ else:
375
+ setattr(
376
+ instance,
377
+ k,
378
+ _coerce_value(v, exp_core, f"{path}.{k}", current=cur),
379
+ )
380
+ return instance
381
+
382
+
383
+ def _coerce_value(value: t.Any, tp: t.Any, where: str, *, current: t.Any = MISSING):
384
+ """Lightweight coercion for common primitives and Literals.
385
+
386
+ This function attempts to coerce a value into a target type, handling common cases like basic primitives (int, float, bool, str) and Literal types. If coercion is not applicable or fails, it returns the original value.
387
+
388
+ Args:
389
+ value (t.Any): The input value to coerce.
390
+ tp (t.Any): The target type annotation.
391
+ where (str): Context string for error messages.
392
+
393
+ Returns:
394
+ t.Any: The coerced value, or the original if no coercion was applied.
395
+
396
+ Raises:
397
+ ValueError: If the value is not valid for a Literal type.
398
+ TypeError: If the value cannot be coerced to the target type.
399
+ """
400
+ origin = t.get_origin(tp)
401
+ args = t.get_args(tp)
402
+
403
+ if tp in {t.Any, object, None}:
404
+ if current is not MISSING and current is not None:
405
+ infer_type = type(current)
406
+ if isinstance(current, bool):
407
+ tp = bool
408
+ elif isinstance(current, int) and not isinstance(current, bool):
409
+ tp = int
410
+ elif isinstance(current, float):
411
+ tp = float
412
+ elif isinstance(current, str):
413
+ tp = str
414
+ else:
415
+ tp = infer_type
416
+
417
+ # Literal[...] → restrict values
418
+ if origin is t.Literal:
419
+ allowed = set(args)
420
+ if value not in allowed:
421
+ raise ValueError(
422
+ f"Invalid value for {where}. Expected one of {sorted(allowed)}, got {value!r}."
423
+ )
424
+ return value
425
+
426
+ # Basic primitives coercion
427
+ if tp in (int, float, bool, str):
428
+ if tp is bool:
429
+ if isinstance(value, str):
430
+ v = value.strip().lower()
431
+ truthy = {"true", "1", "yes", "on"}
432
+ falsy = {"false", "0", "no", "off"}
433
+ if v in truthy:
434
+ return True
435
+ if v in falsy:
436
+ return False
437
+ if v == "" and current is not MISSING:
438
+ return bool(current)
439
+ return bool(value)
440
+
441
+ if isinstance(value, str):
442
+ stripped = value.strip()
443
+ if stripped == "":
444
+ return current if current is not MISSING else value
445
+ try:
446
+ return tp(stripped)
447
+ except Exception:
448
+ return value
449
+
450
+ try:
451
+ return tp(value)
452
+ except Exception:
453
+ return value
454
+
455
+ # Dataclasses or other complex types → trust caller
456
+ return value
457
+
458
+
459
+ def apply_dot_overrides(
460
+ dc: t.Any,
461
+ overrides: dict[str, t.Any] | None,
462
+ *,
463
+ root_cls: type | None = None,
464
+ create_missing: bool = False,
465
+ registry: dict[str, type] | None = None,
466
+ ) -> t.Any:
467
+ """Apply overrides like {'io.prefix': '...', 'train.batch_size': 64} to any Config dataclass.
468
+
469
+ This function updates the fields of a dataclass instance with values from a nested mapping (dict). It ensures that all keys in the mapping correspond to fields in the dataclass, and it handles nested dataclass fields as well.
470
+
471
+ Args:
472
+ dc (t.Any): A dataclass instance (or a dict that can be up-cast).
473
+ overrides (dict[str, t.Any] | None): Mapping of dot-key paths to values.
474
+ root_cls (type | None): Optional dataclass type to up-cast a root dict into (if `dc` is a dict).
475
+ create_missing (bool): If True, instantiate missing intermediate dataclass nodes when the schema defines them.
476
+ registry (dict[str, type] | None): Optional mapping from top-level segment → dataclass type to assist up-casting.
477
+
478
+ Returns:
479
+ t.Any: The updated dataclass instance (same object identity is not guaranteed; a deep copy is made).
480
+
481
+ Notes:
482
+ - No hard-coding of NLPCAConfig. Pass `root_cls=NLPCAConfig` (or UBPConfig, etc.) when starting from a dict.
483
+ - Dict payloads encountered at intermediate nodes are merged into the expected dataclass type using schema introspection.
484
+ - Enforces unknown-key errors to keep configs honest.
485
+
486
+ Raises:
487
+ TypeError: If `dc` is not a dataclass or dict (for up-cast).
488
+ KeyError: If any override path is invalid.
489
+ """
490
+ if not overrides:
491
+ return dc
492
+
493
+ # Root up-cast if needed
494
+ if not is_dataclass(dc):
495
+ if isinstance(dc, dict):
496
+ if root_cls is None:
497
+ raise TypeError(
498
+ "Root payload is a dict. Provide `root_cls` to up-cast it into the desired *Config dataclass."
499
+ )
500
+ base = root_cls()
501
+ dc = _merge_mapping_into_dataclass(base, dc)
502
+ else:
503
+ raise TypeError(
504
+ "apply_dot_overrides expects a dataclass instance or a dict for up-cast."
505
+ )
506
+
507
+ updated = copy.deepcopy(dc)
508
+
509
+ for dotkey, value in overrides.items():
510
+ parts = dotkey.split(".")
511
+ node = updated
512
+ node_type = type(node)
513
+
514
+ # Descend to parent
515
+ for idx, seg in enumerate(parts[:-1]):
516
+ if not is_dataclass(node):
517
+ parent_path = ".".join(parts[:idx]) or "<root>"
518
+ raise KeyError(
519
+ f"Target '{parent_path}' is not a dataclass in the override path; cannot descend into non-dataclass objects."
520
+ )
521
+
522
+ # Validate field existence and fetch expected type
523
+ exp_type = _expected_field_type(node_type, seg)
524
+ exp_core = _unwrap_optional(exp_type)
525
+
526
+ # Materialize or up-cast if needed
527
+ child = getattr(node, seg, MISSING)
528
+ if child is MISSING:
529
+ raise KeyError(f"Unknown config key: '{'.'.join(parts[:idx+1])}'")
530
+
531
+ if isinstance(child, dict) and _is_dataclass_type(exp_core):
532
+ # Up-cast dict → dataclass of the expected type
533
+ child = _merge_mapping_into_dataclass(
534
+ exp_core(), child, path=".".join(parts[: idx + 1])
535
+ )
536
+ setattr(node, seg, child)
537
+
538
+ if child is None and create_missing and _is_dataclass_type(exp_core):
539
+ child = exp_core()
540
+ setattr(node, seg, child)
541
+
542
+ node = getattr(node, seg)
543
+ node_type = type(node)
544
+
545
+ # Assign leaf with light coercion
546
+ if not is_dataclass(node):
547
+ parent_path = ".".join(parts[:-1]) or "<root>"
548
+ raise KeyError(
549
+ f"Target '{parent_path}' is not a dataclass in the override path; cannot set '{parts[-1]}'."
550
+ )
551
+
552
+ leaf = parts[-1]
553
+
554
+ # Check field exists and coerce to its annotated type
555
+ exp_type = _expected_field_type(type(node), leaf)
556
+ exp_core = _unwrap_optional(exp_type)
557
+
558
+ if not hasattr(node, leaf):
559
+ raise KeyError(f"Unknown config key: '{dotkey}'")
560
+
561
+ current = getattr(node, leaf, MISSING)
562
+ coerced = _coerce_value(value, exp_core, dotkey, current=current)
563
+ setattr(node, leaf, coerced)
564
+
565
+ return updated