smftools 0.1.7__py3-none-any.whl → 0.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (120) hide show
  1. smftools/__init__.py +9 -4
  2. smftools/_version.py +1 -1
  3. smftools/cli.py +184 -0
  4. smftools/config/__init__.py +1 -0
  5. smftools/config/conversion.yaml +33 -0
  6. smftools/config/deaminase.yaml +56 -0
  7. smftools/config/default.yaml +253 -0
  8. smftools/config/direct.yaml +17 -0
  9. smftools/config/experiment_config.py +1191 -0
  10. smftools/hmm/HMM.py +1576 -0
  11. smftools/hmm/__init__.py +20 -0
  12. smftools/{tools → hmm}/apply_hmm_batched.py +8 -7
  13. smftools/hmm/call_hmm_peaks.py +106 -0
  14. smftools/{tools → hmm}/display_hmm.py +3 -3
  15. smftools/{tools → hmm}/nucleosome_hmm_refinement.py +2 -2
  16. smftools/{tools → hmm}/train_hmm.py +1 -1
  17. smftools/informatics/__init__.py +0 -2
  18. smftools/informatics/archived/deaminase_smf.py +132 -0
  19. smftools/informatics/fast5_to_pod5.py +4 -1
  20. smftools/informatics/helpers/__init__.py +3 -4
  21. smftools/informatics/helpers/align_and_sort_BAM.py +34 -7
  22. smftools/informatics/helpers/aligned_BAM_to_bed.py +35 -24
  23. smftools/informatics/helpers/binarize_converted_base_identities.py +116 -23
  24. smftools/informatics/helpers/concatenate_fastqs_to_bam.py +365 -42
  25. smftools/informatics/helpers/converted_BAM_to_adata_II.py +165 -29
  26. smftools/informatics/helpers/discover_input_files.py +100 -0
  27. smftools/informatics/helpers/extract_base_identities.py +29 -3
  28. smftools/informatics/helpers/extract_read_features_from_bam.py +4 -2
  29. smftools/informatics/helpers/find_conversion_sites.py +5 -4
  30. smftools/informatics/helpers/modkit_extract_to_adata.py +6 -3
  31. smftools/informatics/helpers/plot_bed_histograms.py +269 -0
  32. smftools/informatics/helpers/separate_bam_by_bc.py +2 -2
  33. smftools/informatics/helpers/split_and_index_BAM.py +1 -5
  34. smftools/load_adata.py +1346 -0
  35. smftools/machine_learning/__init__.py +12 -0
  36. smftools/machine_learning/data/__init__.py +2 -0
  37. smftools/machine_learning/data/anndata_data_module.py +234 -0
  38. smftools/machine_learning/evaluation/__init__.py +2 -0
  39. smftools/machine_learning/evaluation/eval_utils.py +31 -0
  40. smftools/machine_learning/evaluation/evaluators.py +223 -0
  41. smftools/machine_learning/inference/__init__.py +3 -0
  42. smftools/machine_learning/inference/inference_utils.py +27 -0
  43. smftools/machine_learning/inference/lightning_inference.py +68 -0
  44. smftools/machine_learning/inference/sklearn_inference.py +55 -0
  45. smftools/machine_learning/inference/sliding_window_inference.py +114 -0
  46. smftools/machine_learning/models/base.py +295 -0
  47. smftools/machine_learning/models/cnn.py +138 -0
  48. smftools/machine_learning/models/lightning_base.py +345 -0
  49. smftools/machine_learning/models/mlp.py +26 -0
  50. smftools/{tools → machine_learning}/models/positional.py +3 -2
  51. smftools/{tools → machine_learning}/models/rnn.py +2 -1
  52. smftools/machine_learning/models/sklearn_models.py +273 -0
  53. smftools/machine_learning/models/transformer.py +303 -0
  54. smftools/machine_learning/training/__init__.py +2 -0
  55. smftools/machine_learning/training/train_lightning_model.py +135 -0
  56. smftools/machine_learning/training/train_sklearn_model.py +114 -0
  57. smftools/plotting/__init__.py +4 -1
  58. smftools/plotting/autocorrelation_plotting.py +611 -0
  59. smftools/plotting/general_plotting.py +566 -89
  60. smftools/plotting/hmm_plotting.py +260 -0
  61. smftools/plotting/qc_plotting.py +270 -0
  62. smftools/preprocessing/__init__.py +13 -8
  63. smftools/preprocessing/add_read_length_and_mapping_qc.py +129 -0
  64. smftools/preprocessing/append_base_context.py +122 -0
  65. smftools/preprocessing/append_binary_layer_by_base_context.py +143 -0
  66. smftools/preprocessing/calculate_complexity_II.py +248 -0
  67. smftools/preprocessing/calculate_coverage.py +10 -1
  68. smftools/preprocessing/calculate_read_modification_stats.py +101 -0
  69. smftools/preprocessing/clean_NaN.py +17 -1
  70. smftools/preprocessing/filter_reads_on_length_quality_mapping.py +158 -0
  71. smftools/preprocessing/filter_reads_on_modification_thresholds.py +352 -0
  72. smftools/preprocessing/flag_duplicate_reads.py +1326 -124
  73. smftools/preprocessing/invert_adata.py +12 -5
  74. smftools/preprocessing/load_sample_sheet.py +19 -4
  75. smftools/readwrite.py +849 -43
  76. smftools/tools/__init__.py +3 -32
  77. smftools/tools/calculate_umap.py +5 -5
  78. smftools/tools/general_tools.py +3 -3
  79. smftools/tools/position_stats.py +468 -106
  80. smftools/tools/read_stats.py +115 -1
  81. smftools/tools/spatial_autocorrelation.py +562 -0
  82. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/METADATA +5 -1
  83. smftools-0.2.1.dist-info/RECORD +161 -0
  84. smftools-0.2.1.dist-info/entry_points.txt +2 -0
  85. smftools/informatics/helpers/LoadExperimentConfig.py +0 -75
  86. smftools/informatics/helpers/plot_read_length_and_coverage_histograms.py +0 -53
  87. smftools/informatics/load_adata.py +0 -182
  88. smftools/preprocessing/append_C_context.py +0 -82
  89. smftools/preprocessing/calculate_converted_read_methylation_stats.py +0 -94
  90. smftools/preprocessing/filter_converted_reads_on_methylation.py +0 -44
  91. smftools/preprocessing/filter_reads_on_length.py +0 -51
  92. smftools/tools/call_hmm_peaks.py +0 -105
  93. smftools/tools/data/__init__.py +0 -2
  94. smftools/tools/data/anndata_data_module.py +0 -90
  95. smftools/tools/evaluation/__init__.py +0 -0
  96. smftools/tools/inference/__init__.py +0 -1
  97. smftools/tools/inference/lightning_inference.py +0 -41
  98. smftools/tools/models/base.py +0 -14
  99. smftools/tools/models/cnn.py +0 -34
  100. smftools/tools/models/lightning_base.py +0 -41
  101. smftools/tools/models/mlp.py +0 -17
  102. smftools/tools/models/sklearn_models.py +0 -40
  103. smftools/tools/models/transformer.py +0 -133
  104. smftools/tools/training/__init__.py +0 -1
  105. smftools/tools/training/train_lightning_model.py +0 -47
  106. smftools-0.1.7.dist-info/RECORD +0 -136
  107. /smftools/{tools → hmm}/calculate_distances.py +0 -0
  108. /smftools/{tools → hmm}/hmm_readwrite.py +0 -0
  109. /smftools/informatics/{conversion_smf.py → archived/conversion_smf.py} +0 -0
  110. /smftools/informatics/{direct_smf.py → archived/direct_smf.py} +0 -0
  111. /smftools/{tools → machine_learning}/data/preprocessing.py +0 -0
  112. /smftools/{tools → machine_learning}/models/__init__.py +0 -0
  113. /smftools/{tools → machine_learning}/models/wrappers.py +0 -0
  114. /smftools/{tools → machine_learning}/utils/__init__.py +0 -0
  115. /smftools/{tools → machine_learning}/utils/device.py +0 -0
  116. /smftools/{tools → machine_learning}/utils/grl.py +0 -0
  117. /smftools/tools/{apply_hmm.py → archived/apply_hmm.py} +0 -0
  118. /smftools/tools/{classifiers.py → archived/classifiers.py} +0 -0
  119. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/WHEEL +0 -0
  120. {smftools-0.1.7.dist-info → smftools-0.2.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,1191 @@
1
+ # experiment_config.py
2
+ from __future__ import annotations
3
+ import ast
4
+ import json
5
+ import warnings
6
+ from dataclasses import dataclass, field, asdict
7
+ from pathlib import Path
8
+ from typing import Any, Dict, List, Optional, Tuple, Union, IO, Sequence
9
+
10
+ # Optional dependency for YAML handling
11
+ try:
12
+ import yaml
13
+ except Exception:
14
+ yaml = None
15
+
16
+ import pandas as pd
17
+ import numpy as np
18
+
19
+
20
+ # -------------------------
21
+ # Utility parsing functions
22
+ # -------------------------
23
+ def _parse_bool(v: Any) -> bool:
24
+ if isinstance(v, bool):
25
+ return v
26
+ if v is None:
27
+ return False
28
+ s = str(v).strip().lower()
29
+ if s in ("1", "true", "t", "yes", "y", "on"):
30
+ return True
31
+ if s in ("0", "false", "f", "no", "n", "off", ""):
32
+ return False
33
+ try:
34
+ return float(s) != 0.0
35
+ except Exception:
36
+ return False
37
+
38
+
39
+ def _parse_list(v: Any) -> List:
40
+ if v is None:
41
+ return []
42
+ if isinstance(v, (list, tuple)):
43
+ return list(v)
44
+ s = str(v).strip()
45
+ if s == "" or s.lower() == "none":
46
+ return []
47
+ # try JSON
48
+ try:
49
+ parsed = json.loads(s)
50
+ if isinstance(parsed, list):
51
+ return parsed
52
+ except Exception:
53
+ pass
54
+ # try python literal eval
55
+ try:
56
+ lit = ast.literal_eval(s)
57
+ if isinstance(lit, (list, tuple)):
58
+ return list(lit)
59
+ except Exception:
60
+ pass
61
+ # fallback comma separated
62
+ s2 = s.strip("[]() ")
63
+ parts = [p.strip() for p in s2.split(",") if p.strip() != ""]
64
+ return parts
65
+
66
+
67
+ def _parse_numeric(v: Any, fallback: Any = None) -> Any:
68
+ if v is None:
69
+ return fallback
70
+ if isinstance(v, (int, float)):
71
+ return v
72
+ s = str(v).strip()
73
+ if s == "" or s.lower() == "none":
74
+ return fallback
75
+ try:
76
+ return int(s)
77
+ except Exception:
78
+ try:
79
+ return float(s)
80
+ except Exception:
81
+ return fallback
82
+
83
+ def _try_json_or_literal(s: Any) -> Any:
84
+ """Try parse JSON or python literal; otherwise return original string."""
85
+ if s is None:
86
+ return None
87
+ if not isinstance(s, str):
88
+ return s
89
+ s0 = s.strip()
90
+ if s0 == "":
91
+ return None
92
+ # try json
93
+ try:
94
+ return json.loads(s0)
95
+ except Exception:
96
+ pass
97
+ # try python literal
98
+ try:
99
+ return ast.literal_eval(s0)
100
+ except Exception:
101
+ pass
102
+ return s
103
+
104
+
105
+ def resolve_aligner_args(
106
+ merged: dict,
107
+ default_by_aligner: Optional[Dict[str, List[str]]] = None,
108
+ aligner_synonyms: Optional[Dict[str, str]] = None,
109
+ ) -> List[str]:
110
+ """
111
+ Resolve merged['aligner_args'] into a concrete list for the chosen aligner and sequencer.
112
+
113
+ Behavior (search order):
114
+ 1. If aligner_args is a dict, try keys in this order (case-insensitive):
115
+ a) "<aligner>@<sequencer>" (top-level combined key)
116
+ b) aligner -> (if dict) sequencer (nested) -> 'default' fallback
117
+ c) aligner -> (if list) use that list
118
+ d) top-level 'default' key in aligner_args dict
119
+ 2. If aligner_args is a list -> return it (applies to any aligner/sequencer).
120
+ 3. If aligner_args is a string -> try parse JSON/literal or return single-element list.
121
+ 4. Otherwise fall back to builtin defaults per aligner.
122
+ """
123
+ # builtin defaults (aligner -> args)
124
+ builtin_defaults = {
125
+ "minimap2": ['-a', '-x', 'map-ont', '--MD', '-Y', '-y', '-N', '5', '--secondary=no'],
126
+ "dorado": ['--mm2-opts', '-N', '5'],
127
+ }
128
+ if default_by_aligner is None:
129
+ default_by_aligner = builtin_defaults
130
+
131
+ # synonyms mapping
132
+ synonyms = {"mm2": "minimap2", "minimap": "minimap2", "minimap-2": "minimap2"}
133
+ if aligner_synonyms:
134
+ synonyms.update(aligner_synonyms)
135
+
136
+ # canonicalize requested aligner and sequencer
137
+ raw_aligner = merged.get("aligner", "minimap2") or "minimap2"
138
+ raw_sequencer = merged.get("sequencer", None) # e.g. 'ont', 'pacbio', 'illumina'
139
+ key_align = str(raw_aligner).strip().lower()
140
+ key_seq = None if raw_sequencer is None else str(raw_sequencer).strip().lower()
141
+ if key_align in synonyms:
142
+ key_align = synonyms[key_align]
143
+
144
+ raw = merged.get("aligner_args", None)
145
+
146
+ # helper to coerce a candidate to list[str]
147
+ def _coerce_to_list(val):
148
+ if isinstance(val, (list, tuple)):
149
+ return [str(x) for x in val]
150
+ if isinstance(val, str):
151
+ parsed = _try_json_or_literal(val)
152
+ if isinstance(parsed, (list, tuple)):
153
+ return [str(x) for x in parsed]
154
+ return [str(parsed)]
155
+ if val is None:
156
+ return None
157
+ return [str(val)]
158
+
159
+ # If dict, do layered lookups
160
+ if isinstance(raw, dict):
161
+ # case-insensitive dict
162
+ top_map = {str(k).lower(): v for k, v in raw.items()}
163
+
164
+ # 1) try combined top-level key "aligner@sequencer"
165
+ if key_seq:
166
+ combined_key = f"{key_align}@{key_seq}"
167
+ if combined_key in top_map:
168
+ res = _coerce_to_list(top_map[combined_key])
169
+ if res:
170
+ return res
171
+
172
+ # 2) try aligner key
173
+ if key_align in top_map:
174
+ val = top_map[key_align]
175
+ # if nested dict: try sequencer key then 'default'
176
+ if isinstance(val, dict):
177
+ submap = {str(k).lower(): v for k, v in val.items()}
178
+ if key_seq and key_seq in submap:
179
+ res = _coerce_to_list(submap[key_seq])
180
+ if res:
181
+ return res
182
+ if "default" in submap:
183
+ res = _coerce_to_list(submap["default"])
184
+ if res:
185
+ return res
186
+ # nothing matched inside aligner->dict; fall back to top-level aligner (no sequencer)
187
+ else:
188
+ # aligner maps to list/str: use it
189
+ res = _coerce_to_list(val)
190
+ if res:
191
+ return res
192
+
193
+ # 3) try top-level 'default' key inside aligner_args mapping
194
+ if "default" in top_map:
195
+ res = _coerce_to_list(top_map["default"])
196
+ if res:
197
+ return res
198
+
199
+ # 4) last top-level attempt: any key equal to aligner synonyms etc (already handled)
200
+ # fallthrough to builtin
201
+ # If user provided a concrete list -> use it
202
+ if isinstance(raw, (list, tuple)):
203
+ return [str(x) for x in raw]
204
+
205
+ # If scalar string, attempt to parse
206
+ if isinstance(raw, str):
207
+ parsed = _try_json_or_literal(raw)
208
+ if isinstance(parsed, (list, tuple)):
209
+ return [str(x) for x in parsed]
210
+ return [str(parsed)]
211
+
212
+ # Nothing found -> fallback builtin default
213
+ return list(default_by_aligner.get(key_align, []))
214
+
215
+
216
+ # HMM default params and hepler functions
217
+ def normalize_hmm_feature_sets(raw: Any) -> Dict[str, dict]:
218
+ """
219
+ Normalize user-provided `hmm_feature_sets` into canonical structure:
220
+ { group_name: {"features": {label: (lo, hi), ...}, "state": "<Modified|Non-Modified>"} }
221
+ Accepts dict, JSON/string, None. Returns {} for empty input.
222
+ """
223
+ if raw is None:
224
+ return {}
225
+ parsed = raw
226
+ if isinstance(raw, str):
227
+ parsed = _try_json_or_literal(raw)
228
+ if not isinstance(parsed, dict):
229
+ return {}
230
+
231
+ def _coerce_bound(x):
232
+ if x is None:
233
+ return None
234
+ if isinstance(x, (int, float)):
235
+ return float(x)
236
+ s = str(x).strip().lower()
237
+ if s in ("inf", "infty", "infinite"):
238
+ return np.inf
239
+ if s in ("none", ""):
240
+ return None
241
+ try:
242
+ return float(x)
243
+ except Exception:
244
+ return None
245
+
246
+ def _coerce_feature_map(feats):
247
+ out = {}
248
+ if not isinstance(feats, dict):
249
+ return out
250
+ for fname, rng in feats.items():
251
+ if rng is None:
252
+ out[fname] = (0.0, np.inf)
253
+ continue
254
+ if isinstance(rng, (list, tuple)) and len(rng) >= 2:
255
+ lo = _coerce_bound(rng[0]) or 0.0
256
+ hi = _coerce_bound(rng[1])
257
+ if hi is None:
258
+ hi = np.inf
259
+ out[fname] = (float(lo), float(hi) if not np.isinf(hi) else np.inf)
260
+ else:
261
+ # scalar -> treat as upper bound
262
+ val = _coerce_bound(rng)
263
+ out[fname] = (0.0, float(val) if val is not None else np.inf)
264
+ return out
265
+
266
+ canonical = {}
267
+ for grp, info in parsed.items():
268
+ if not isinstance(info, dict):
269
+ feats = _coerce_feature_map(info)
270
+ canonical[grp] = {"features": feats, "state": "Modified"}
271
+ continue
272
+ feats = _coerce_feature_map(info.get("features", info.get("ranges", {})))
273
+ state = info.get("state", info.get("label", "Modified"))
274
+ canonical[grp] = {"features": feats, "state": state}
275
+ return canonical
276
+
277
+
278
+ # -------------------------
279
+ # LoadExperimentConfig
280
+ # -------------------------
281
+ class LoadExperimentConfig:
282
+ """
283
+ Load an experiment CSV (or DataFrame / file-like) into a typed var_dict.
284
+
285
+ CSV expected columns: 'variable', 'value', optional 'type'.
286
+ If 'type' missing, the loader will infer type.
287
+
288
+ Example
289
+ -------
290
+ loader = LoadExperimentConfig("experiment_config.csv")
291
+ var_dict = loader.var_dict
292
+ """
293
+
294
+ def __init__(self, experiment_config: Union[str, Path, IO, pd.DataFrame]):
295
+ self.source = experiment_config
296
+ self.df = self._load_df(experiment_config)
297
+ self.var_dict = self._parse_df(self.df)
298
+
299
+ @staticmethod
300
+ def _load_df(source: Union[str, Path, IO, pd.DataFrame]) -> pd.DataFrame:
301
+ """Load a pandas DataFrame from path, file-like, or accept if already DataFrame."""
302
+ if isinstance(source, pd.DataFrame):
303
+ df = source.copy()
304
+ else:
305
+ if isinstance(source, (str, Path)):
306
+ p = Path(source)
307
+ if not p.exists():
308
+ raise FileNotFoundError(f"Config file not found: {source}")
309
+ df = pd.read_csv(p, dtype=str, keep_default_na=False, na_values=[""])
310
+ else:
311
+ # file-like
312
+ df = pd.read_csv(source, dtype=str, keep_default_na=False, na_values=[""])
313
+ # normalize column names
314
+ df.columns = [c.strip() for c in df.columns]
315
+ if 'variable' not in df.columns:
316
+ raise ValueError("Config CSV must contain a 'variable' column.")
317
+ if 'value' not in df.columns:
318
+ df['value'] = ''
319
+ if 'type' not in df.columns:
320
+ df['type'] = ''
321
+ return df
322
+
323
+ @staticmethod
324
+ def _parse_value_as_type(value_str: Optional[str], dtype_hint: Optional[str]) -> Any:
325
+ """
326
+ Parse a single value string into a Python object guided by dtype_hint (or infer).
327
+ Supports int, float, bool, list, JSON, Python literal, or string.
328
+ """
329
+ if value_str is None:
330
+ return None
331
+ v = str(value_str).strip()
332
+ if v == "" or v.lower() == "none":
333
+ return None
334
+
335
+ hint = "" if dtype_hint is None else str(dtype_hint).strip().lower()
336
+
337
+ def parse_bool(s: str):
338
+ s2 = s.strip().lower()
339
+ if s2 in ('1', 'true', 't', 'yes', 'y', 'on'):
340
+ return True
341
+ if s2 in ('0', 'false', 'f', 'no', 'n', 'off'):
342
+ return False
343
+ raise ValueError(f"Cannot parse boolean from '{s}'")
344
+
345
+ def parse_list_like(s: str):
346
+ # try JSON first
347
+ try:
348
+ val = json.loads(s)
349
+ if isinstance(val, list):
350
+ return val
351
+ except Exception:
352
+ pass
353
+ # try python literal
354
+ try:
355
+ val = ast.literal_eval(s)
356
+ if isinstance(val, (list, tuple)):
357
+ return list(val)
358
+ except Exception:
359
+ pass
360
+ # fallback split
361
+ parts = [p.strip() for p in s.strip("()[] ").split(',') if p.strip() != ""]
362
+ return parts
363
+
364
+ if hint in ('int', 'integer'):
365
+ return int(v)
366
+ if hint in ('float', 'double'):
367
+ return float(v)
368
+ if hint in ('bool', 'boolean'):
369
+ return parse_bool(v)
370
+ if hint in ('list', 'array'):
371
+ return parse_list_like(v)
372
+ if hint in ('string', 'str'):
373
+ return v
374
+
375
+ # infer
376
+ try:
377
+ return int(v)
378
+ except Exception:
379
+ pass
380
+ try:
381
+ return float(v)
382
+ except Exception:
383
+ pass
384
+ try:
385
+ return parse_bool(v)
386
+ except Exception:
387
+ pass
388
+ try:
389
+ j = json.loads(v)
390
+ return j
391
+ except Exception:
392
+ pass
393
+ try:
394
+ lit = ast.literal_eval(v)
395
+ return lit
396
+ except Exception:
397
+ pass
398
+ if (',' in v) and (not any(ch in v for ch in '{}[]()')):
399
+ return [p.strip() for p in v.split(',') if p.strip() != ""]
400
+ return v
401
+
402
+ def _parse_df(self, df: pd.DataFrame) -> Dict[str, Any]:
403
+ parsed: Dict[str, Any] = {}
404
+ for idx, row in df.iterrows():
405
+ name = str(row['variable']).strip()
406
+ if name == "":
407
+ continue
408
+ raw_val = row.get('value', "")
409
+ raw_type = row.get('type', "")
410
+ if pd.isna(raw_val) or str(raw_val).strip() == "":
411
+ raw_val = None
412
+ try:
413
+ parsed_val = self._parse_value_as_type(raw_val, raw_type)
414
+ except Exception as e:
415
+ warnings.warn(f"Failed to parse config variable '{name}' (row {idx}): {e}. Storing raw value.")
416
+ parsed_val = None if raw_val is None else raw_val
417
+ if name in parsed:
418
+ warnings.warn(f"Duplicate config variable '{name}' encountered (row {idx}). Overwriting previous value.")
419
+ parsed[name] = parsed_val
420
+ return parsed
421
+
422
+ def to_dataframe(self) -> pd.DataFrame:
423
+ """Return parsed config as a pandas DataFrame (variable, value)."""
424
+ rows = []
425
+ for k, v in self.var_dict.items():
426
+ rows.append({'variable': k, 'value': v})
427
+ return pd.DataFrame(rows)
428
+
429
+
430
+ # -------------------------
431
+ # deep merge & defaults loader (with inheritance)
432
+ # -------------------------
433
+ def deep_merge(a: Dict[str, Any], b: Dict[str, Any]) -> Dict[str, Any]:
434
+ """
435
+ Recursively merge two dicts: returns new dict = a merged with b, where b overrides.
436
+ If both values are dicts -> merge recursively; else b replaces a.
437
+ """
438
+ out = dict(a or {})
439
+ for k, v in (b or {}).items():
440
+ if k in out and isinstance(out[k], dict) and isinstance(v, dict):
441
+ out[k] = deep_merge(out[k], v)
442
+ else:
443
+ out[k] = v
444
+ return out
445
+
446
+
447
+ def _load_defaults_file(path: Path) -> Dict[str, Any]:
448
+ if not path.exists():
449
+ return {}
450
+ text = path.read_text(encoding="utf8")
451
+ suffix = path.suffix.lower()
452
+ if suffix in (".yaml", ".yml"):
453
+ if yaml is None:
454
+ raise RuntimeError("PyYAML required to load YAML defaults (pip install pyyaml).")
455
+ return yaml.safe_load(text) or {}
456
+ elif suffix == ".json":
457
+ return json.loads(text or "{}")
458
+ else:
459
+ # try json then yaml if available
460
+ try:
461
+ return json.loads(text)
462
+ except Exception:
463
+ if yaml is not None:
464
+ return yaml.safe_load(text) or {}
465
+ raise RuntimeError(f"Unknown defaults file type for {path}; provide JSON or YAML.")
466
+
467
+
468
+ def load_defaults_with_inheritance(
469
+ defaults_dir: Union[str, Path],
470
+ modality: Optional[str],
471
+ *,
472
+ default_basename: str = "default",
473
+ allowed_exts: Tuple[str, ...] = (".yaml", ".yml", ".json"),
474
+ debug: bool = False,
475
+ ) -> Tuple[Dict[str, Any], List[str]]:
476
+ """
477
+ Strict loader: only loads default + modality + any explicit 'extends' chain.
478
+
479
+ - defaults_dir: directory containing defaults files.
480
+ - modality: name of modality (e.g. "GpC"). We look for <modality>.<ext> in defaults_dir.
481
+ - default_basename: name of fallback default file (without extension).
482
+ - allowed_exts: allowed extensions to try.
483
+ - debug: if True, prints what was loaded.
484
+
485
+ Returns (merged_defaults_dict, load_order_list) where load_order_list are resolved file paths read.
486
+ """
487
+ pdir = Path(defaults_dir) if defaults_dir is not None else None
488
+ if pdir is None or not pdir.exists():
489
+ return {}, []
490
+
491
+ # Resolve a "name" to a file in defaults_dir.
492
+ # Only treat `name` as an explicit path if it contains a path separator or is absolute.
493
+ def resolve_name_to_path(name: str) -> Optional[Path]:
494
+ n = str(name).strip()
495
+ if n == "":
496
+ return None
497
+ cand = Path(n)
498
+ # If user provided a path-like string (contains slash/backslash or absolute), allow it
499
+ if cand.is_absolute() or ("/" in n) or ("\\" in n):
500
+ if cand.exists() and cand.suffix.lower() in allowed_exts:
501
+ return cand.resolve()
502
+ return None
503
+ # Otherwise only look inside defaults_dir for name + ext (do NOT treat bare name as arbitrary file)
504
+ for ext in allowed_exts:
505
+ p = pdir / f"{n}{ext}"
506
+ if p.exists():
507
+ return p.resolve()
508
+ return None
509
+
510
+ visited = set()
511
+ load_order: List[str] = []
512
+
513
+ def _rec_load(name_or_path: Union[str, Path]) -> Dict[str, Any]:
514
+ # Resolve to a file path (strict)
515
+ if isinstance(name_or_path, Path):
516
+ p = name_or_path
517
+ else:
518
+ p = resolve_name_to_path(str(name_or_path))
519
+ if p is None:
520
+ if debug:
521
+ print(f"[defaults loader] resolve failed for '{name_or_path}'")
522
+ return {}
523
+ p = Path(p).resolve()
524
+ p_str = str(p)
525
+ if p_str in visited:
526
+ if debug:
527
+ print(f"[defaults loader] already visited {p_str} (skipping to avoid cycle)")
528
+ return {}
529
+ visited.add(p_str)
530
+
531
+ data = _load_defaults_file(p) # reuse your existing helper
532
+ if not isinstance(data, dict):
533
+ if debug:
534
+ print(f"[defaults loader] file {p_str} did not produce a dict -> ignoring")
535
+ data = {}
536
+
537
+ # Extract any extends/inherits keys (string or list). They reference other named default files.
538
+ bases = []
539
+ for key in ("extends", "inherits", "base"):
540
+ if key in data:
541
+ b = data.pop(key)
542
+ if isinstance(b, (list, tuple)):
543
+ bases = list(b)
544
+ elif isinstance(b, str):
545
+ bases = [b]
546
+ break
547
+
548
+ merged = {}
549
+ # Load bases first (in order); bases are resolved relative to defaults_dir unless given as path
550
+ for base_name in bases:
551
+ base_defaults = _rec_load(base_name)
552
+ merged = deep_merge(merged, base_defaults)
553
+
554
+ # Then merge this file's data (this file overrides its bases)
555
+ merged = deep_merge(merged, data)
556
+ load_order.append(p_str)
557
+ if debug:
558
+ print(f"[defaults loader] loaded {p_str}")
559
+ return merged
560
+
561
+ merged_defaults = {}
562
+ # Load default.* first if present
563
+ def_path = resolve_name_to_path(default_basename)
564
+ if def_path is not None:
565
+ merged_defaults = deep_merge(merged_defaults, _rec_load(def_path))
566
+
567
+ # Load modality.* if present (modality overrides default)
568
+ if modality:
569
+ mod_path = resolve_name_to_path(modality)
570
+ if mod_path is not None:
571
+ merged_defaults = deep_merge(merged_defaults, _rec_load(mod_path))
572
+ else:
573
+ if debug:
574
+ print(f"[defaults loader] no modality file found for '{modality}' in {pdir}")
575
+
576
+ if debug:
577
+ print("[defaults loader] final load order:", load_order)
578
+ return merged_defaults, load_order
579
+
580
+
581
+ # -------------------------
582
+ # ExperimentConfig dataclass
583
+ # -------------------------
584
+ @dataclass
585
+ class ExperimentConfig:
586
+ # Compute
587
+ threads: Optional[int] = None
588
+ device: str = "auto"
589
+
590
+ # General I/O
591
+ input_data_path: Optional[str] = None
592
+ output_directory: Optional[str] = None
593
+ fasta: Optional[str] = None
594
+ bam_suffix: str = ".bam"
595
+ recursive_input_search: bool = True
596
+ split_dir: str = "demultiplexed_BAMs"
597
+ strands: List[str] = field(default_factory=lambda: ["bottom", "top"])
598
+ conversions: List[str] = field(default_factory=lambda: ["unconverted"])
599
+ fasta_regions_of_interest: Optional[str] = None
600
+ sample_sheet_path: Optional[str] = None
601
+ sample_sheet_mapping_column: Optional[str] = 'Barcode'
602
+ experiment_name: Optional[str] = None
603
+ input_already_demuxed: bool = False
604
+
605
+ # FASTQ input specific
606
+ fastq_barcode_map: Optional[Dict[str, str]] = None
607
+ fastq_auto_pairing: bool = True
608
+
609
+ # Conversion/Deamination file handling
610
+ delete_intermediate_hdfs: bool = True
611
+
612
+ # Direct SMF specific params for initial AnnData loading
613
+ batch_size: int = 4
614
+ skip_unclassified: bool = True
615
+ delete_batch_hdfs: bool = True
616
+
617
+ # Sequencing modality and general experiment params
618
+ smf_modality: Optional[str] = None
619
+ sequencer: Optional[str] = None
620
+
621
+ # Enzyme / mod targets
622
+ mod_target_bases: List[str] = field(default_factory=lambda: ["GpC", "CpG"])
623
+ enzyme_target_bases: List[str] = field(default_factory=lambda: ["GpC"])
624
+
625
+ # Conversion/deamination
626
+ conversion_types: List[str] = field(default_factory=lambda: ["5mC"])
627
+
628
+ # Nanopore specific for basecalling and demultiplexing
629
+ model_dir: Optional[str] = None
630
+ barcode_kit: Optional[str] = None
631
+ model: str = "hac"
632
+ barcode_both_ends: bool = False
633
+ trim: bool = False
634
+ # General basecalling params
635
+ filter_threshold: float = 0.8
636
+ # Modified basecalling specific params
637
+ m6A_threshold: float = 0.7
638
+ m5C_threshold: float = 0.7
639
+ hm5C_threshold: float = 0.7
640
+ thresholds: List[float] = field(default_factory=list)
641
+ mod_list: List[str] = field(default_factory=lambda: ["5mC_5hmC", "6mA"])
642
+
643
+ # Alignment params
644
+ mapping_threshold: float = 0.01 # Min threshold for fraction of reads in a sample mapping to a reference in order to include the reference in the anndata
645
+ aligner: str = "minimap2"
646
+ aligner_args: Optional[List[str]] = None
647
+ make_bigwigs: bool = False
648
+
649
+ # Anndata structure
650
+ reference_column: Optional[str] = 'Reference_strand'
651
+ sample_column: Optional[str] = 'Barcode'
652
+
653
+ # General Plotting
654
+ sample_name_col_for_plotting: Optional[str] = 'Barcode'
655
+ rows_per_qc_histogram_grid: int = 12
656
+
657
+ # Preprocessing - Read length and quality filter params
658
+ read_coord_filter: Optional[Sequence[float]] = field(default_factory=lambda: [None, None])
659
+ read_len_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [200, None])
660
+ read_len_to_ref_ratio_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [0.4, 1.1])
661
+ read_quality_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [20, None])
662
+ read_mapping_quality_filter_thresholds: Optional[Sequence[float]] = field(default_factory=lambda: [None, None])
663
+
664
+ # Preprocessing - Read modification filter params
665
+ read_mod_filtering_gpc_thresholds: List[float] = field(default_factory=lambda: [0.025, 0.975])
666
+ read_mod_filtering_cpg_thresholds: List[float] = field(default_factory=lambda: [0.00, 1])
667
+ read_mod_filtering_any_c_thresholds: List[float] = field(default_factory=lambda: [0.025, 0.975])
668
+ read_mod_filtering_a_thresholds: List[float] = field(default_factory=lambda: [0.025, 0.975])
669
+ read_mod_filtering_use_other_c_as_background: bool = True
670
+ min_valid_fraction_positions_in_read_vs_ref: float = 0.2
671
+
672
+ # Preprocessing - Duplicate detection params
673
+ duplicate_detection_site_types: List[str] = field(default_factory=lambda: ['GpC', 'CpG', 'ambiguous_GpC_CpG'])
674
+ duplicate_detection_distance_threshold: float = 0.07
675
+ hamming_vs_metric_keys: List[str] = field(default_factory=lambda: ['Fraction_any_C_site_modified'])
676
+ duplicate_detection_keep_best_metric: str ='read_quality'
677
+ duplicate_detection_window_size_for_hamming_neighbors: int = 50
678
+ duplicate_detection_min_overlapping_positions: int = 20
679
+ duplicate_detection_do_hierarchical: bool = True
680
+ duplicate_detection_hierarchical_linkage: str = "average"
681
+ duplicate_detection_do_pca: bool = False
682
+
683
+ # Preprocessing - Complexity analysis params
684
+
685
+ # Basic Analysis - Clustermap params
686
+ layer_for_clustermap_plotting: Optional[str] = 'nan0_0minus1'
687
+
688
+ # Basic Analysis - UMAP/Leiden params
689
+ layer_for_umap_plotting: Optional[str] = 'nan_half'
690
+ umap_layers_to_plot: List[str] = field(default_factory=lambda: ["mapped_length", "Raw_modification_signal"])
691
+
692
+ # Basic Analysis - Spatial Autocorrelation params
693
+ rows_per_qc_autocorr_grid: int = 12
694
+ autocorr_rolling_window_size: int = 25
695
+ autocorr_max_lag: int = 800
696
+ autocorr_site_types: List[str] = field(default_factory=lambda: ['GpC', 'CpG', 'any_C'])
697
+
698
+ # Basic Analysis - Correlation Matrix params
699
+ correlation_matrix_types: List[str] = field(default_factory=lambda: ["pearson", "binary_covariance"])
700
+ correlation_matrix_cmaps: List[str] = field(default_factory=lambda: ["seismic", "viridis"])
701
+ correlation_matrix_site_types: List[str] = field(default_factory=lambda: ["GpC_site"])
702
+
703
+ # HMM params
704
+ hmm_n_states: int = 2
705
+ hmm_init_emission_probs: List[list] = field(default_factory=lambda: [[0.8, 0.2], [0.2, 0.8]])
706
+ hmm_init_transition_probs: List[list] = field(default_factory=lambda: [[0.9, 0.1], [0.1, 0.9]])
707
+ hmm_init_start_probs: List[float] = field(default_factory=lambda: [0.5, 0.5])
708
+ hmm_eps: float = 1e-8
709
+ hmm_dtype: str = "float64"
710
+ hmm_annotation_threshold: float = 0.5
711
+ hmm_batch_size: int = 1024
712
+ hmm_use_viterbi: bool = False
713
+ hmm_device: Optional[str] = None
714
+ hmm_methbases: Optional[List[str]] = None # if None, HMM.annotate_adata will fall back to mod_target_bases
715
+ footprints: Optional[bool] = True
716
+ accessible_patches: Optional[bool] = True
717
+ cpg: Optional[bool] = False
718
+ hmm_feature_sets: Dict[str, Any] = field(default_factory=dict)
719
+ hmm_merge_layer_features: Optional[List[Tuple]] = field(default_factory=lambda: [(None, 80)])
720
+
721
+ # Pipeline control flow - preprocessing and QC
722
+ force_redo_preprocessing: bool = False
723
+ force_reload_sample_sheet: bool = True
724
+ bypass_add_read_length_and_mapping_qc: bool = False
725
+ force_redo_add_read_length_and_mapping_qc: bool = False
726
+ bypass_clean_nan: bool = False
727
+ force_redo_clean_nan: bool = False
728
+ bypass_append_base_context: bool = False
729
+ force_redo_append_base_context: bool = False
730
+ invert_adata: bool = False
731
+ bypass_append_binary_layer_by_base_context: bool = False
732
+ force_redo_append_binary_layer_by_base_context: bool = False
733
+ bypass_calculate_read_modification_stats: bool = False
734
+ force_redo_calculate_read_modification_stats: bool = False
735
+ bypass_filter_reads_on_modification_thresholds: bool = False
736
+ force_redo_filter_reads_on_modification_thresholds: bool = False
737
+ bypass_flag_duplicate_reads: bool = False
738
+ force_redo_flag_duplicate_reads: bool = False
739
+ bypass_complexity_analysis: bool = False
740
+ force_redo_complexity_analysis: bool = False
741
+
742
+ # Pipeline control flow - Basic Analyses
743
+ force_redo_basic_analyses: bool = False
744
+ bypass_basic_clustermaps: bool = False
745
+ force_redo_basic_clustermaps: bool = False
746
+ bypass_basic_umap: bool = False
747
+ force_redo_basic_umap: bool = False
748
+ bypass_spatial_autocorr_calculations: bool = False
749
+ force_redo_spatial_autocorr_calculations: bool = False
750
+ bypass_spatial_autocorr_plotting: bool = False
751
+ force_redo_spatial_autocorr_plotting: bool = False
752
+ bypass_matrix_corr_calculations: bool = False
753
+ force_redo_matrix_corr_calculations: bool = False
754
+ bypass_matrix_corr_plotting: bool = False
755
+ force_redo_matrix_corr_plotting: bool = False
756
+
757
+ # Pipeline control flow - HMM Analyses
758
+ bypass_hmm_fit: bool = False
759
+ force_redo_hmm_fit: bool = False
760
+ bypass_hmm_apply: bool = False
761
+ force_redo_hmm_apply: bool = False
762
+
763
+ # metadata
764
+ config_source: Optional[str] = None
765
+
766
+ # -------------------------
767
+ # Construction helpers
768
+ # -------------------------
769
+ @classmethod
770
+ def from_var_dict(
771
+ cls,
772
+ var_dict: Optional[Dict[str, Any]],
773
+ date_str: Optional[str] = None,
774
+ config_source: Optional[str] = None,
775
+ defaults_dir: Optional[Union[str, Path]] = None,
776
+ defaults_map: Optional[Dict[str, Dict[str, Any]]] = None,
777
+ merge_with_defaults: bool = True,
778
+ override_with_csv: bool = True,
779
+ allow_csv_extends: bool = True,
780
+ allow_null_override: bool = False,
781
+ ) -> Tuple["ExperimentConfig", Dict[str, Any]]:
782
+ """
783
+ Create ExperimentConfig from a raw var_dict (as produced by LoadExperimentConfig).
784
+ Returns (instance, report) where report contains modality/defaults/merged info.
785
+
786
+ merge_with_defaults: load defaults from defaults_dir or defaults_map.
787
+ override_with_csv: CSV values override defaults; if False defaults take precedence.
788
+ allow_csv_extends: allow the CSV to include 'extends' to pull in extra defaults files.
789
+ allow_null_override: if False, CSV keys with value None will NOT override defaults (keeps defaults).
790
+ """
791
+ var_dict = var_dict or {}
792
+
793
+ # 1) normalize incoming values
794
+ normalized: Dict[str, Any] = {}
795
+ for k, v in var_dict.items():
796
+ if v is None:
797
+ normalized[k] = None
798
+ continue
799
+ if isinstance(v, str):
800
+ s = v.strip()
801
+ if s == "" or s.lower() == "none":
802
+ normalized[k] = None
803
+ else:
804
+ normalized[k] = _try_json_or_literal(s)
805
+ else:
806
+ normalized[k] = v
807
+
808
+ modality = normalized.get("smf_modality")
809
+ if isinstance(modality, (list, tuple)) and len(modality) > 0:
810
+ modality = modality[0]
811
+
812
+ defaults_loaded = {}
813
+ defaults_source_chain: List[str] = []
814
+ if merge_with_defaults:
815
+ if defaults_map and modality in defaults_map:
816
+ defaults_loaded = dict(defaults_map[modality] or {})
817
+ defaults_source_chain = [f"defaults_map['{modality}']"]
818
+ elif defaults_dir is not None:
819
+ defaults_loaded, defaults_source_chain = load_defaults_with_inheritance(defaults_dir, modality)
820
+
821
+ # If CSV asks to extend defaults, load those and merge
822
+ merged = dict(defaults_loaded or {})
823
+
824
+ if allow_csv_extends:
825
+ extends = normalized.get("extends") or normalized.get("inherits")
826
+ if extends:
827
+ if isinstance(extends, str):
828
+ ext_list = [extends]
829
+ elif isinstance(extends, (list, tuple)):
830
+ ext_list = list(extends)
831
+ else:
832
+ ext_list = []
833
+ for ext in ext_list:
834
+ ext_defaults, ext_sources = (load_defaults_with_inheritance(defaults_dir, ext) if defaults_dir else ({}, []))
835
+ merged = deep_merge(merged, ext_defaults)
836
+ for s in ext_sources:
837
+ if s not in defaults_source_chain:
838
+ defaults_source_chain.append(s)
839
+
840
+ # Now overlay CSV values
841
+ # Prepare csv_effective depending on allow_null_override
842
+ csv_effective = {}
843
+ for k, v in normalized.items():
844
+ if k in ("extends", "inherits"):
845
+ continue
846
+ if v is None and not allow_null_override:
847
+ # skip: keep default
848
+ continue
849
+ csv_effective[k] = v
850
+
851
+ if override_with_csv:
852
+ merged = deep_merge(merged, csv_effective)
853
+ else:
854
+ # defaults take precedence: only set keys missing in merged
855
+ for k, v in csv_effective.items():
856
+ if k not in merged:
857
+ merged[k] = v
858
+
859
+ # experiment_name default
860
+ if merged.get("experiment_name") is None and date_str:
861
+ merged["experiment_name"] = f"{date_str}_SMF_experiment"
862
+
863
+ # final normalization
864
+ if "strands" in merged:
865
+ merged["strands"] = _parse_list(merged["strands"])
866
+ if "conversions" in merged:
867
+ merged["conversions"] = _parse_list(merged["conversions"])
868
+ if "mod_target_bases" in merged:
869
+ merged["mod_target_bases"] = _parse_list(merged["mod_target_bases"])
870
+ if "conversion_types" in merged:
871
+ merged["conversion_types"] = _parse_list(merged["conversion_types"])
872
+
873
+ merged["filter_threshold"] = float(_parse_numeric(merged.get("filter_threshold", 0.8), 0.8))
874
+ merged["m6A_threshold"] = float(_parse_numeric(merged.get("m6A_threshold", 0.7), 0.7))
875
+ merged["m5C_threshold"] = float(_parse_numeric(merged.get("m5C_threshold", 0.7), 0.7))
876
+ merged["hm5C_threshold"] = float(_parse_numeric(merged.get("hm5C_threshold", 0.7), 0.7))
877
+ merged["thresholds"] = [
878
+ merged["filter_threshold"],
879
+ merged["m6A_threshold"],
880
+ merged["m5C_threshold"],
881
+ merged["hm5C_threshold"],
882
+ ]
883
+
884
+ for bkey in ("barcode_both_ends", "trim", "input_already_demuxed", "make_bigwigs", "skip_unclassified", "delete_batch_hdfs"):
885
+ if bkey in merged:
886
+ merged[bkey] = _parse_bool(merged[bkey])
887
+
888
+ if "batch_size" in merged:
889
+ merged["batch_size"] = int(_parse_numeric(merged.get("batch_size", 4), 4))
890
+ if "threads" in merged:
891
+ tval = _parse_numeric(merged.get("threads", None), None)
892
+ merged["threads"] = None if tval is None else int(tval)
893
+
894
+ if "aligner_args" in merged and merged.get("aligner_args") is None:
895
+ merged.pop("aligner_args", None)
896
+
897
+ # --- Resolve aligner_args into concrete list for the chosen aligner ---
898
+ merged['aligner_args'] = resolve_aligner_args(merged)
899
+
900
+ if "mod_list" in merged:
901
+ merged["mod_list"] = _parse_list(merged.get("mod_list"))
902
+
903
+ # HMM feature set handling
904
+ if "hmm_feature_sets" in merged:
905
+ merged["hmm_feature_sets"] = normalize_hmm_feature_sets(merged["hmm_feature_sets"])
906
+ else:
907
+ # allow older names (footprint_ranges, accessible_ranges, cpg_ranges) — optional:
908
+ maybe_fs = {}
909
+ if "footprint_ranges" in merged or "hmm_footprint_ranges" in merged:
910
+ maybe_fs["footprint"] = {"features": merged.get("hmm_footprint_ranges", merged.get("footprint_ranges")), "state": merged.get("hmm_footprint_state", "Non-Modified")}
911
+ if "accessible_ranges" in merged or "hmm_accessible_ranges" in merged:
912
+ maybe_fs["accessible"] = {"features": merged.get("hmm_accessible_ranges", merged.get("accessible_ranges")), "state": merged.get("hmm_accessible_state", "Modified")}
913
+ if "cpg_ranges" in merged or "hmm_cpg_ranges" in merged:
914
+ maybe_fs["cpg"] = {"features": merged.get("hmm_cpg_ranges", merged.get("cpg_ranges")), "state": merged.get("hmm_cpg_state", "Modified")}
915
+ if maybe_fs:
916
+ merged.setdefault("hmm_feature_sets", {})
917
+ for k, v in maybe_fs.items():
918
+ merged["hmm_feature_sets"].setdefault(k, v)
919
+
920
+ # final normalization will be done below
921
+ # (do not set local hmm_feature_sets here — do it once below)
922
+ pass
923
+
924
+ # Final normalization of hmm_feature_sets and canonical local variables
925
+ merged["hmm_feature_sets"] = normalize_hmm_feature_sets(merged.get("hmm_feature_sets", {}))
926
+ hmm_feature_sets = merged.get("hmm_feature_sets", {})
927
+ hmm_annotation_threshold = merged.get("hmm_annotation_threshold", 0.5)
928
+ hmm_batch_size = int(merged.get("hmm_batch_size", 1024))
929
+ hmm_use_viterbi = bool(merged.get("hmm_use_viterbi", False))
930
+ hmm_device = merged.get("hmm_device", None)
931
+ hmm_methbases = _parse_list(merged.get("hmm_methbases", None))
932
+ if not hmm_methbases: # None or []
933
+ hmm_methbases = _parse_list(merged.get("mod_target_bases", None))
934
+ if not hmm_methbases:
935
+ hmm_methbases = ['C']
936
+ hmm_methbases = list(hmm_methbases)
937
+ hmm_merge_layer_features = _parse_list(merged.get("hmm_merge_layer_features", None))
938
+
939
+
940
+ # instantiate dataclass
941
+ instance = cls(
942
+ smf_modality = merged.get("smf_modality"),
943
+ input_data_path = merged.get("input_data_path"),
944
+ recursive_input_search = merged.get("recursive_input_search"),
945
+ output_directory = merged.get("output_directory"),
946
+ fasta = merged.get("fasta"),
947
+ sequencer = merged.get("sequencer"),
948
+ model_dir = merged.get("model_dir"),
949
+ barcode_kit = merged.get("barcode_kit"),
950
+ fastq_barcode_map = merged.get("fastq_barcode_map"),
951
+ fastq_auto_pairing = merged.get("fastq_auto_pairing"),
952
+ bam_suffix = merged.get("bam_suffix", ".bam"),
953
+ split_dir = merged.get("split_dir", "demultiplexed_BAMs"),
954
+ strands = merged.get("strands", ["bottom","top"]),
955
+ conversions = merged.get("conversions", ["unconverted"]),
956
+ fasta_regions_of_interest = merged.get("fasta_regions_of_interest"),
957
+ mapping_threshold = float(merged.get("mapping_threshold", 0.01)),
958
+ experiment_name = merged.get("experiment_name"),
959
+ model = merged.get("model", "hac"),
960
+ barcode_both_ends = merged.get("barcode_both_ends", False),
961
+ trim = merged.get("trim", False),
962
+ input_already_demuxed = merged.get("input_already_demuxed", False),
963
+ threads = merged.get("threads"),
964
+ sample_sheet_path = merged.get("sample_sheet_path"),
965
+ sample_sheet_mapping_column = merged.get("sample_sheet_mapping_column"),
966
+ aligner = merged.get("aligner", "minimap2"),
967
+ aligner_args = merged.get("aligner_args", None),
968
+ device = merged.get("device", "auto"),
969
+ make_bigwigs = merged.get("make_bigwigs", False),
970
+ delete_intermediate_hdfs = merged.get("delete_intermediate_hdfs", True),
971
+ mod_target_bases = merged.get("mod_target_bases", ["GpC","CpG"]),
972
+ enzyme_target_bases = merged.get("enzyme_target_bases", ["GpC"]),
973
+ conversion_types = merged.get("conversion_types", ["5mC"]),
974
+ filter_threshold = merged.get("filter_threshold", 0.8),
975
+ m6A_threshold = merged.get("m6A_threshold", 0.7),
976
+ m5C_threshold = merged.get("m5C_threshold", 0.7),
977
+ hm5C_threshold = merged.get("hm5C_threshold", 0.7),
978
+ thresholds = merged.get("thresholds", []),
979
+ mod_list = merged.get("mod_list", ["5mC_5hmC","6mA"]),
980
+ batch_size = merged.get("batch_size", 4),
981
+ skip_unclassified = merged.get("skip_unclassified", True),
982
+ delete_batch_hdfs = merged.get("delete_batch_hdfs", True),
983
+ reference_column = merged.get("reference_column", 'Reference_strand'),
984
+ sample_column = merged.get("sample_column", 'Barcode'),
985
+ sample_name_col_for_plotting = merged.get("sample_name_col_for_plotting", 'Barcode'),
986
+ layer_for_clustermap_plotting = merged.get("layer_for_clustermap_plotting", 'nan0_0minus1'),
987
+ layer_for_umap_plotting = merged.get("layer_for_umap_plotting", 'nan_half'),
988
+ umap_layers_to_plot = merged.get("umap_layers_to_plot",["mapped_length", 'Raw_modification_signal']),
989
+ rows_per_qc_histogram_grid = merged.get("rows_per_qc_histogram_grid", 12),
990
+ rows_per_qc_autocorr_grid = merged.get("rows_per_qc_autocorr_grid", 12),
991
+ autocorr_rolling_window_size = merged.get("autocorr_rolling_window_size", 25),
992
+ autocorr_max_lag = merged.get("autocorr_max_lag", 800),
993
+ autocorr_site_types = merged.get("autocorr_site_types", ['GpC', 'CpG', 'any_C']),
994
+ hmm_n_states = merged.get("hmm_n_states", 2),
995
+ hmm_init_emission_probs = merged.get("hmm_init_emission_probs",[[0.8, 0.2], [0.2, 0.8]]),
996
+ hmm_init_transition_probs = merged.get("hmm_init_transition_probs",[[0.9, 0.1], [0.1, 0.9]]),
997
+ hmm_init_start_probs = merged.get("hmm_init_start_probs",[0.5, 0.5]),
998
+ hmm_eps = merged.get("hmm_eps", 1e-8),
999
+ hmm_dtype = merged.get("hmm_dtype", "float64"),
1000
+ hmm_feature_sets = hmm_feature_sets,
1001
+ hmm_annotation_threshold = hmm_annotation_threshold,
1002
+ hmm_batch_size = hmm_batch_size,
1003
+ hmm_use_viterbi = hmm_use_viterbi,
1004
+ hmm_methbases = hmm_methbases,
1005
+ hmm_device = hmm_device,
1006
+ hmm_merge_layer_features = hmm_merge_layer_features,
1007
+ footprints = merged.get("footprints", None),
1008
+ accessible_patches = merged.get("accessible_patches", None),
1009
+ cpg = merged.get("cpg", None),
1010
+ read_coord_filter = merged.get("read_coord_filter", [None, None]),
1011
+ read_len_filter_thresholds = merged.get("read_len_filter_thresholds", [200, None]),
1012
+ read_len_to_ref_ratio_filter_thresholds = merged.get("read_len_to_ref_ratio_filter_thresholds", [0.4, 1.1]),
1013
+ read_quality_filter_thresholds = merged.get("read_quality_filter_thresholds", [20, None]),
1014
+ read_mapping_quality_filter_thresholds = merged.get("read_mapping_quality_filter_thresholds", [None, None]),
1015
+ read_mod_filtering_gpc_thresholds = merged.get("read_mod_filtering_gpc_thresholds", [0.025, 0.975]),
1016
+ read_mod_filtering_cpg_thresholds = merged.get("read_mod_filtering_cpg_thresholds", [0.0, 1.0]),
1017
+ read_mod_filtering_any_c_thresholds = merged.get("read_mod_filtering_any_c_thresholds", [0.025, 0.975]),
1018
+ read_mod_filtering_a_thresholds = merged.get("read_mod_filtering_a_thresholds", [0.025, 0.975]),
1019
+ read_mod_filtering_use_other_c_as_background = merged.get("read_mod_filtering_use_other_c_as_background", True),
1020
+ min_valid_fraction_positions_in_read_vs_ref = merged.get("min_valid_fraction_positions_in_read_vs_ref", 0.2),
1021
+ duplicate_detection_site_types = merged.get("duplicate_detection_site_types", ['GpC', 'CpG', 'ambiguous_GpC_CpG']),
1022
+ duplicate_detection_distance_threshold = merged.get("duplicate_detection_distance_threshold", 0.07),
1023
+ duplicate_detection_keep_best_metric = merged.get("duplicate_detection_keep_best_metric", "read_quality"),
1024
+ duplicate_detection_window_size_for_hamming_neighbors = merged.get("duplicate_detection_window_size_for_hamming_neighbors", 50),
1025
+ duplicate_detection_min_overlapping_positions = merged.get("duplicate_detection_min_overlapping_positions", 20),
1026
+ duplicate_detection_do_hierarchical = merged.get("duplicate_detection_do_hierarchical", True),
1027
+ duplicate_detection_hierarchical_linkage = merged.get("duplicate_detection_hierarchical_linkage", "average"),
1028
+ duplicate_detection_do_pca = merged.get("duplicate_detection_do_pca", False),
1029
+ correlation_matrix_types = merged.get("correlation_matrix_types", ["pearson", "binary_covariance"]),
1030
+ correlation_matrix_cmaps = merged.get("correlation_matrix_cmaps", ["seismic", "viridis"]),
1031
+ correlation_matrix_site_types = merged.get("correlation_matrix_site_types", ["GpC_site"]),
1032
+ hamming_vs_metric_keys = merged.get("hamming_vs_metric_keys", ['Fraction_any_C_site_modified']),
1033
+ force_redo_preprocessing = merged.get("force_redo_preprocessing", False),
1034
+ force_reload_sample_sheet = merged.get("force_reload_sample_sheet", True),
1035
+ bypass_add_read_length_and_mapping_qc = merged.get("bypass_add_read_length_and_mapping_qc", False),
1036
+ force_redo_add_read_length_and_mapping_qc = merged.get("force_redo_add_read_length_and_mapping_qc", False),
1037
+ bypass_clean_nan = merged.get("bypass_clean_nan", False),
1038
+ force_redo_clean_nan = merged.get("force_redo_clean_nan", False),
1039
+ bypass_append_base_context = merged.get("bypass_append_base_context", False),
1040
+ force_redo_append_base_context = merged.get("force_redo_append_base_context", False),
1041
+ invert_adata = merged.get("invert_adata", False),
1042
+ bypass_append_binary_layer_by_base_context = merged.get("bypass_append_binary_layer_by_base_context", False),
1043
+ force_redo_append_binary_layer_by_base_context = merged.get("force_redo_append_binary_layer_by_base_context", False),
1044
+ bypass_calculate_read_modification_stats = merged.get("bypass_calculate_read_modification_stats", False),
1045
+ force_redo_calculate_read_modification_stats = merged.get("force_redo_calculate_read_modification_stats", False),
1046
+ bypass_filter_reads_on_modification_thresholds = merged.get("bypass_filter_reads_on_modification_thresholds", False),
1047
+ force_redo_filter_reads_on_modification_thresholds = merged.get("force_redo_filter_reads_on_modification_thresholds", False),
1048
+ bypass_flag_duplicate_reads = merged.get("bypass_flag_duplicate_reads", False),
1049
+ force_redo_flag_duplicate_reads = merged.get("force_redo_flag_duplicate_reads", False),
1050
+ bypass_complexity_analysis = merged.get("bypass_complexity_analysis", False),
1051
+ force_redo_complexity_analysis = merged.get("force_redo_complexity_analysis", False),
1052
+ force_redo_basic_analyses = merged.get("force_redo_basic_analyses", False),
1053
+ bypass_basic_clustermaps = merged.get("bypass_basic_clustermaps", False),
1054
+ force_redo_basic_clustermaps = merged.get("force_redo_basic_clustermaps", False),
1055
+ bypass_basic_umap = merged.get("bypass_basic_umap", False),
1056
+ force_redo_basic_umap = merged.get("force_redo_basic_umap", False),
1057
+ bypass_spatial_autocorr_calculations = merged.get("bypass_spatial_autocorr_calculations", False),
1058
+ force_redo_spatial_autocorr_calculations = merged.get("force_redo_spatial_autocorr_calculations", False),
1059
+ bypass_spatial_autocorr_plotting = merged.get("bypass_spatial_autocorr_plotting", False),
1060
+ force_redo_spatial_autocorr_plotting = merged.get("force_redo_spatial_autocorr_plotting", False),
1061
+ bypass_matrix_corr_calculations = merged.get("bypass_matrix_corr_calculations", False),
1062
+ force_redo_matrix_corr_calculations = merged.get("force_redo_matrix_corr_calculations", False),
1063
+ bypass_matrix_corr_plotting = merged.get("bypass_matrix_corr_plotting", False),
1064
+ force_redo_matrix_corr_plotting = merged.get("force_redo_matrix_corr_plotting", False),
1065
+ bypass_hmm_fit = merged.get("bypass_hmm_fit", False),
1066
+ force_redo_hmm_fit = merged.get("force_redo_hmm_fit", False),
1067
+ bypass_hmm_apply = merged.get("bypass_hmm_apply", False),
1068
+ force_redo_hmm_apply = merged.get("force_redo_hmm_apply", False),
1069
+
1070
+ config_source = config_source or "<var_dict>",
1071
+ )
1072
+
1073
+ report = {
1074
+ "modality": modality,
1075
+ "defaults_source_chain": defaults_source_chain,
1076
+ "defaults_loaded": defaults_loaded,
1077
+ "csv_normalized": normalized,
1078
+ "merged": merged,
1079
+ }
1080
+ return instance, report
1081
+
1082
+ # convenience: load from CSV via LoadExperimentConfig
1083
+ @classmethod
1084
+ def from_csv(
1085
+ cls,
1086
+ csv_input: Union[str, Path, IO, pd.DataFrame],
1087
+ date_str: Optional[str] = None,
1088
+ config_source: Optional[str] = None,
1089
+ defaults_dir: Optional[Union[str, Path]] = None,
1090
+ defaults_map: Optional[Dict[str, Dict[str, Any]]] = None,
1091
+ **kwargs,
1092
+ ) -> Tuple["ExperimentConfig", Dict[str, Any]]:
1093
+ """
1094
+ Load CSV using LoadExperimentConfig (or accept DataFrame) and build ExperimentConfig.
1095
+ Additional kwargs passed to from_var_dict().
1096
+ """
1097
+ loader = LoadExperimentConfig(csv_input) if not isinstance(csv_input, pd.DataFrame) else LoadExperimentConfig(pd.DataFrame(csv_input))
1098
+ var_dict = loader.var_dict
1099
+ return cls.from_var_dict(var_dict, date_str=date_str, config_source=config_source, defaults_dir=defaults_dir, defaults_map=defaults_map, **kwargs)
1100
+
1101
+ # -------------------------
1102
+ # validation & serialization
1103
+ # -------------------------
1104
+ def _validate_hmm_features_structure(hfs: dict) -> List[str]:
1105
+ errs = []
1106
+ if not isinstance(hfs, dict):
1107
+ errs.append("hmm_feature_sets must be a mapping if provided.")
1108
+ return errs
1109
+ for g, info in hfs.items():
1110
+ if not isinstance(info, dict):
1111
+ errs.append(f"hmm_feature_sets['{g}'] must be a mapping with 'features' and 'state'.")
1112
+ continue
1113
+ feats = info.get("features")
1114
+ if not isinstance(feats, dict) or len(feats) == 0:
1115
+ errs.append(f"hmm_feature_sets['{g}'] must include non-empty 'features' mapping.")
1116
+ continue
1117
+ for fname, rng in feats.items():
1118
+ try:
1119
+ lo, hi = float(rng[0]), float(rng[1])
1120
+ if lo < 0 or hi <= lo:
1121
+ errs.append(f"Feature range for {g}:{fname} must satisfy 0 <= lo < hi; got {rng}.")
1122
+ except Exception:
1123
+ errs.append(f"Feature range for {g}:{fname} is invalid: {rng}")
1124
+ return errs
1125
+
1126
+ def validate(self, require_paths: bool = True, raise_on_error: bool = True) -> List[str]:
1127
+ """
1128
+ Validate the config. If require_paths True, check paths (input_data_path, fasta) exist;
1129
+ attempt to create output_directory if missing.
1130
+ Returns a list of error messages (empty if none). Raises ValueError if raise_on_error True.
1131
+ """
1132
+ errors: List[str] = []
1133
+ if not self.input_data_path:
1134
+ errors.append("input_data_path is required but missing.")
1135
+ if not self.output_directory:
1136
+ errors.append("output_directory is required but missing.")
1137
+ if not self.fasta:
1138
+ errors.append("fasta (reference FASTA) is required but missing.")
1139
+
1140
+ if require_paths:
1141
+ if self.input_data_path and not Path(self.input_data_path).exists():
1142
+ errors.append(f"input_data_path does not exist: {self.input_data_path}")
1143
+ if self.fasta and not Path(self.fasta).exists():
1144
+ errors.append(f"fasta does not exist: {self.fasta}")
1145
+ outp = Path(self.output_directory) if self.output_directory else None
1146
+ if outp and not outp.exists():
1147
+ try:
1148
+ outp.mkdir(parents=True, exist_ok=True)
1149
+ except Exception as e:
1150
+ errors.append(f"Could not create output_directory {self.output_directory}: {e}")
1151
+
1152
+ if not (0.0 <= float(self.mapping_threshold) <= 1.0):
1153
+ errors.append("mapping_threshold must be in [0,1].")
1154
+ for t in (self.filter_threshold, self.m6A_threshold, self.m5C_threshold, self.hm5C_threshold):
1155
+ if not (0.0 <= float(t) <= 1.0):
1156
+ errors.append(f"threshold value {t} must be in [0,1].")
1157
+
1158
+ if raise_on_error and errors:
1159
+ raise ValueError("ExperimentConfig validation failed:\n " + "\n ".join(errors))
1160
+
1161
+ errs = _validate_hmm_features_structure(self.hmm_feature_sets)
1162
+ errors.extend(errs)
1163
+
1164
+ return errors
1165
+
1166
+ def to_dict(self) -> Dict[str, Any]:
1167
+ return asdict(self)
1168
+
1169
+ def to_yaml(self, path: Optional[Union[str, Path]] = None) -> str:
1170
+ """
1171
+ Dump config to YAML (string if path None) or save to file at path.
1172
+ If pyyaml is not installed, fallback to JSON for file write.
1173
+ """
1174
+ data = self.to_dict()
1175
+ if path is None:
1176
+ if yaml is None:
1177
+ return json.dumps(data, indent=2)
1178
+ return yaml.safe_dump(data, sort_keys=False)
1179
+ else:
1180
+ p = Path(path)
1181
+ if yaml is None:
1182
+ p.write_text(json.dumps(data, indent=2), encoding="utf8")
1183
+ else:
1184
+ p.write_text(yaml.safe_dump(data, sort_keys=False), encoding="utf8")
1185
+ return str(p)
1186
+
1187
+ def save(self, path: Union[str, Path]) -> str:
1188
+ return self.to_yaml(path)
1189
+
1190
+ def __repr__(self) -> str:
1191
+ return f"<ExperimentConfig modality={self.smf_modality} experiment_name={self.experiment_name} source={self.config_source}>"