warp-md-cuda 0.1.0__cp311-cp311-manylinux_2_28_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
warp_md/cli.py ADDED
@@ -0,0 +1,1464 @@
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+ from typing import Any, Dict, Iterable, Optional
7
+
8
+ import numpy as np
9
+
10
+ _API_IMPORT_ERROR: Optional[Exception]
11
+ try:
12
+ from . import (
13
+ BondAngleDistributionPlan,
14
+ BondLengthDistributionPlan,
15
+ ChainRgPlan,
16
+ ConductivityPlan,
17
+ ContourLengthPlan,
18
+ DielectricPlan,
19
+ DipoleAlignmentPlan,
20
+ EndToEndPlan,
21
+ EquipartitionPlan,
22
+ HbondPlan,
23
+ IonPairCorrelationPlan,
24
+ MsdPlan,
25
+ PersistenceLengthPlan,
26
+ RdfPlan,
27
+ RgPlan,
28
+ RmsdPlan,
29
+ RotAcfPlan,
30
+ StructureFactorPlan,
31
+ System,
32
+ Trajectory,
33
+ WaterCountPlan,
34
+ )
35
+ _API_IMPORT_ERROR = None
36
+ except Exception as exc: # pragma: no cover - import guard for help/metadata usage
37
+ BondAngleDistributionPlan = None # type: ignore[assignment]
38
+ BondLengthDistributionPlan = None # type: ignore[assignment]
39
+ ChainRgPlan = None # type: ignore[assignment]
40
+ ConductivityPlan = None # type: ignore[assignment]
41
+ ContourLengthPlan = None # type: ignore[assignment]
42
+ DielectricPlan = None # type: ignore[assignment]
43
+ DipoleAlignmentPlan = None # type: ignore[assignment]
44
+ EndToEndPlan = None # type: ignore[assignment]
45
+ EquipartitionPlan = None # type: ignore[assignment]
46
+ HbondPlan = None # type: ignore[assignment]
47
+ IonPairCorrelationPlan = None # type: ignore[assignment]
48
+ MsdPlan = None # type: ignore[assignment]
49
+ PersistenceLengthPlan = None # type: ignore[assignment]
50
+ RdfPlan = None # type: ignore[assignment]
51
+ RgPlan = None # type: ignore[assignment]
52
+ RmsdPlan = None # type: ignore[assignment]
53
+ RotAcfPlan = None # type: ignore[assignment]
54
+ StructureFactorPlan = None # type: ignore[assignment]
55
+ System = None # type: ignore[assignment]
56
+ Trajectory = None # type: ignore[assignment]
57
+ WaterCountPlan = None # type: ignore[assignment]
58
+ _API_IMPORT_ERROR = exc
59
+ from .builder import charges_from_selections, charges_from_table, group_types_from_selections
60
+
61
+
62
+ def _require_api() -> None:
63
+ if _API_IMPORT_ERROR is not None:
64
+ raise RuntimeError(
65
+ "warp-md Python bindings are unavailable. Run `maturin develop` or install warp-md."
66
+ ) from _API_IMPORT_ERROR
67
+
68
+
69
+ def _load_config(path: str) -> Dict[str, Any]:
70
+ cfg_path = Path(path)
71
+ if not cfg_path.exists():
72
+ raise FileNotFoundError(f"config not found: {path}")
73
+ if cfg_path.suffix in {".yaml", ".yml"}:
74
+ try:
75
+ import yaml # type: ignore
76
+ except Exception as exc: # pragma: no cover - optional dependency
77
+ raise RuntimeError("YAML config requires PyYAML installed") from exc
78
+ return yaml.safe_load(cfg_path.read_text())
79
+ return json.loads(cfg_path.read_text())
80
+
81
+
82
+ def _as_tuple(value: Any, size: int, label: str) -> Optional[tuple[Any, ...]]:
83
+ if value is None:
84
+ return None
85
+ if isinstance(value, tuple):
86
+ if len(value) != size:
87
+ raise ValueError(f"{label} must have length {size}")
88
+ return value
89
+ if isinstance(value, list):
90
+ if len(value) != size:
91
+ raise ValueError(f"{label} must have length {size}")
92
+ return tuple(value)
93
+ raise ValueError(f"{label} must be a list/tuple of length {size}")
94
+
95
+
96
+ def _pick(spec: Dict[str, Any], keys: Iterable[str]) -> Dict[str, Any]:
97
+ return {k: spec[k] for k in keys if k in spec and spec[k] is not None}
98
+
99
+
100
+ def _normalize_system_spec(spec: Any) -> Dict[str, Any]:
101
+ if isinstance(spec, str):
102
+ return {"path": spec}
103
+ if isinstance(spec, dict):
104
+ return spec
105
+ raise ValueError("system spec must be a path or object")
106
+
107
+
108
+ def _normalize_traj_spec(spec: Any) -> Dict[str, Any]:
109
+ if isinstance(spec, str):
110
+ return {"path": spec}
111
+ if isinstance(spec, dict):
112
+ return spec
113
+ raise ValueError("trajectory spec must be a path or object")
114
+
115
+
116
+ def _load_system(spec: Dict[str, Any]) -> System:
117
+ _require_api()
118
+ path = spec.get("path")
119
+ if not path:
120
+ raise ValueError("system.path is required")
121
+ fmt = spec.get("format")
122
+ if fmt is None:
123
+ fmt = Path(path).suffix.lower().lstrip(".")
124
+ if fmt == "pdb":
125
+ return System.from_pdb(path)
126
+ if fmt == "gro":
127
+ return System.from_gro(path)
128
+ raise ValueError("system.format must be pdb or gro")
129
+
130
+
131
+ def _load_trajectory(spec: Dict[str, Any], system: System) -> Trajectory:
132
+ _require_api()
133
+ path = spec.get("path")
134
+ if not path:
135
+ raise ValueError("trajectory.path is required")
136
+ fmt = spec.get("format")
137
+ if fmt is None:
138
+ fmt = Path(path).suffix.lower().lstrip(".")
139
+ if fmt == "dcd":
140
+ return Trajectory.open_dcd(path, system, length_scale=spec.get("length_scale"))
141
+ if fmt == "xtc":
142
+ return Trajectory.open_xtc(path, system)
143
+ raise ValueError("trajectory.format must be dcd or xtc")
144
+
145
+
146
+ def _select(system: System, expr: str, label: str):
147
+ if not expr:
148
+ raise ValueError(f"{label} selection is required")
149
+ return system.select(expr)
150
+
151
+
152
+ def _resolve_charges(system: System, spec: Any) -> list[float]:
153
+ if isinstance(spec, list):
154
+ return [float(x) for x in spec]
155
+ if isinstance(spec, dict):
156
+ mode = spec.get("from")
157
+ default = spec.get("default", 0.0)
158
+ if mode == "table":
159
+ path = spec.get("path")
160
+ if not path:
161
+ raise ValueError("charges.from=table requires path")
162
+ return charges_from_table(system, path, delimiter=spec.get("delimiter"), default=default)
163
+ if mode == "selections":
164
+ entries = spec.get("entries")
165
+ if not entries:
166
+ raise ValueError("charges.from=selections requires entries")
167
+ return charges_from_selections(system, entries, default=default)
168
+ raise ValueError("charges must be a list or {from: table|selections}")
169
+
170
+
171
+ def _resolve_group_types(
172
+ system: System,
173
+ selection,
174
+ group_by: str,
175
+ spec: Any,
176
+ ) -> Optional[list[int]]:
177
+ if spec is None:
178
+ return None
179
+ if isinstance(spec, list):
180
+ return [int(x) for x in spec]
181
+ if isinstance(spec, dict):
182
+ if spec.get("from") != "selections":
183
+ raise ValueError("group_types.from must be selections")
184
+ type_selections = spec.get("type_selections")
185
+ if not type_selections:
186
+ raise ValueError("group_types.type_selections required")
187
+ sel_expr = spec.get("selection")
188
+ sel = selection if sel_expr is None else system.select(sel_expr)
189
+ group_by = spec.get("group_by", group_by)
190
+ return group_types_from_selections(system, sel, group_by, type_selections)
191
+ raise ValueError("group_types must be a list or {from: selections}")
192
+
193
+
194
+ def _save_output(path: str, output: Any) -> None:
195
+ out_path = Path(path)
196
+ suffix = out_path.suffix.lower()
197
+ if suffix == "":
198
+ out_path = out_path.with_suffix(".npz")
199
+ suffix = ".npz"
200
+
201
+ if suffix == ".npy":
202
+ if not isinstance(output, np.ndarray):
203
+ raise ValueError(".npy output requires a single array")
204
+ np.save(out_path, output)
205
+ return
206
+
207
+ if suffix == ".csv":
208
+ if not isinstance(output, np.ndarray):
209
+ raise ValueError(".csv output requires a single array")
210
+ np.savetxt(out_path, output, delimiter=",")
211
+ return
212
+
213
+ if suffix == ".json":
214
+ out_path.write_text(json.dumps(_to_jsonable(output), indent=2))
215
+ return
216
+
217
+ if suffix == ".npz":
218
+ arrays = _to_npz_dict(output)
219
+ np.savez(out_path, **arrays)
220
+ return
221
+
222
+ raise ValueError("output extension must be .npz, .npy, .csv, or .json")
223
+
224
+
225
+ def _to_npz_dict(output: Any) -> Dict[str, np.ndarray]:
226
+ if isinstance(output, np.ndarray):
227
+ return {"data": output}
228
+ if isinstance(output, dict):
229
+ return {str(k): np.asarray(v) for k, v in output.items()}
230
+ if isinstance(output, (list, tuple)):
231
+ return {f"arr_{i}": np.asarray(v) for i, v in enumerate(output)}
232
+ return {"data": np.asarray(output)}
233
+
234
+
235
+ def _to_jsonable(output: Any) -> Any:
236
+ if isinstance(output, np.ndarray):
237
+ return output.tolist()
238
+ if isinstance(output, dict):
239
+ return {str(k): _to_jsonable(v) for k, v in output.items()}
240
+ if isinstance(output, (list, tuple)):
241
+ return [_to_jsonable(v) for v in output]
242
+ if isinstance(output, (np.floating, np.integer)):
243
+ return output.item()
244
+ return output
245
+
246
+
247
+ def _build_rg(system: System, spec: Dict[str, Any]):
248
+ sel = _select(system, spec.get("selection"), "rg.selection")
249
+ kwargs = _pick(spec, ["mass_weighted"])
250
+ return RgPlan(sel, **kwargs)
251
+
252
+
253
+ def _build_rmsd(system: System, spec: Dict[str, Any]):
254
+ sel = _select(system, spec.get("selection"), "rmsd.selection")
255
+ kwargs = _pick(spec, ["reference", "align"])
256
+ return RmsdPlan(sel, **kwargs)
257
+
258
+
259
+ def _build_msd(system: System, spec: Dict[str, Any]):
260
+ sel = _select(system, spec.get("selection"), "msd.selection")
261
+ group_by = spec.get("group_by", "resid")
262
+ group_types = _resolve_group_types(system, sel, group_by, spec.get("group_types"))
263
+ kwargs = _pick(
264
+ spec,
265
+ [
266
+ "axis",
267
+ "length_scale",
268
+ "frame_decimation",
269
+ "dt_decimation",
270
+ "time_binning",
271
+ "lag_mode",
272
+ "max_lag",
273
+ "memory_budget_bytes",
274
+ "multi_tau_m",
275
+ "multi_tau_levels",
276
+ ],
277
+ )
278
+ if "axis" in kwargs:
279
+ kwargs["axis"] = _as_tuple(kwargs["axis"], 3, "axis")
280
+ if "frame_decimation" in kwargs:
281
+ kwargs["frame_decimation"] = _as_tuple(kwargs["frame_decimation"], 2, "frame_decimation")
282
+ if "dt_decimation" in kwargs:
283
+ kwargs["dt_decimation"] = _as_tuple(kwargs["dt_decimation"], 4, "dt_decimation")
284
+ if "time_binning" in kwargs:
285
+ kwargs["time_binning"] = _as_tuple(kwargs["time_binning"], 2, "time_binning")
286
+ if group_types is not None:
287
+ kwargs["group_types"] = group_types
288
+ return MsdPlan(sel, group_by=group_by, **kwargs)
289
+
290
+
291
+ def _build_rotacf(system: System, spec: Dict[str, Any]):
292
+ sel = _select(system, spec.get("selection"), "rotacf.selection")
293
+ group_by = spec.get("group_by", "resid")
294
+ group_types = _resolve_group_types(system, sel, group_by, spec.get("group_types"))
295
+ kwargs = _pick(
296
+ spec,
297
+ [
298
+ "orientation",
299
+ "p2_legendre",
300
+ "length_scale",
301
+ "frame_decimation",
302
+ "dt_decimation",
303
+ "time_binning",
304
+ "lag_mode",
305
+ "max_lag",
306
+ "memory_budget_bytes",
307
+ "multi_tau_m",
308
+ "multi_tau_levels",
309
+ ],
310
+ )
311
+ if "orientation" in kwargs:
312
+ orient = kwargs["orientation"]
313
+ if not isinstance(orient, (list, tuple)) or len(orient) not in (2, 3):
314
+ raise ValueError("rotacf.orientation must be length 2 or 3")
315
+ if "frame_decimation" in kwargs:
316
+ kwargs["frame_decimation"] = _as_tuple(kwargs["frame_decimation"], 2, "frame_decimation")
317
+ if "dt_decimation" in kwargs:
318
+ kwargs["dt_decimation"] = _as_tuple(kwargs["dt_decimation"], 4, "dt_decimation")
319
+ if "time_binning" in kwargs:
320
+ kwargs["time_binning"] = _as_tuple(kwargs["time_binning"], 2, "time_binning")
321
+ if group_types is not None:
322
+ kwargs["group_types"] = group_types
323
+ return RotAcfPlan(sel, group_by=group_by, **kwargs)
324
+
325
+
326
+ def _build_conductivity(system: System, spec: Dict[str, Any]):
327
+ sel = _select(system, spec.get("selection"), "conductivity.selection")
328
+ group_by = spec.get("group_by", "resid")
329
+ charges_spec = spec.get("charges")
330
+ if charges_spec is None:
331
+ raise ValueError("conductivity.charges is required")
332
+ charges = _resolve_charges(system, charges_spec)
333
+ temperature = spec.get("temperature")
334
+ if temperature is None:
335
+ raise ValueError("conductivity.temperature is required")
336
+ group_types = _resolve_group_types(system, sel, group_by, spec.get("group_types"))
337
+ kwargs = _pick(
338
+ spec,
339
+ [
340
+ "transference",
341
+ "length_scale",
342
+ "frame_decimation",
343
+ "dt_decimation",
344
+ "time_binning",
345
+ "lag_mode",
346
+ "max_lag",
347
+ "memory_budget_bytes",
348
+ "multi_tau_m",
349
+ "multi_tau_levels",
350
+ ],
351
+ )
352
+ if "frame_decimation" in kwargs:
353
+ kwargs["frame_decimation"] = _as_tuple(kwargs["frame_decimation"], 2, "frame_decimation")
354
+ if "dt_decimation" in kwargs:
355
+ kwargs["dt_decimation"] = _as_tuple(kwargs["dt_decimation"], 4, "dt_decimation")
356
+ if "time_binning" in kwargs:
357
+ kwargs["time_binning"] = _as_tuple(kwargs["time_binning"], 2, "time_binning")
358
+ if group_types is not None:
359
+ kwargs["group_types"] = group_types
360
+ return ConductivityPlan(sel, charges, temperature, group_by=group_by, **kwargs)
361
+
362
+
363
+ def _build_dielectric(system: System, spec: Dict[str, Any]):
364
+ sel = _select(system, spec.get("selection"), "dielectric.selection")
365
+ group_by = spec.get("group_by", "resid")
366
+ charges_spec = spec.get("charges")
367
+ if charges_spec is None:
368
+ raise ValueError("dielectric.charges is required")
369
+ charges = _resolve_charges(system, charges_spec)
370
+ group_types = _resolve_group_types(system, sel, group_by, spec.get("group_types"))
371
+ kwargs = _pick(spec, ["length_scale"])
372
+ if group_types is not None:
373
+ kwargs["group_types"] = group_types
374
+ return DielectricPlan(sel, charges, group_by=group_by, **kwargs)
375
+
376
+
377
+ def _build_dipole_alignment(system: System, spec: Dict[str, Any]):
378
+ sel = _select(system, spec.get("selection"), "dipole_alignment.selection")
379
+ group_by = spec.get("group_by", "resid")
380
+ charges_spec = spec.get("charges")
381
+ if charges_spec is None:
382
+ raise ValueError("dipole_alignment.charges is required")
383
+ charges = _resolve_charges(system, charges_spec)
384
+ group_types = _resolve_group_types(system, sel, group_by, spec.get("group_types"))
385
+ kwargs = _pick(spec, ["length_scale"])
386
+ if group_types is not None:
387
+ kwargs["group_types"] = group_types
388
+ return DipoleAlignmentPlan(sel, charges, group_by=group_by, **kwargs)
389
+
390
+
391
+ def _build_ion_pair(system: System, spec: Dict[str, Any]):
392
+ sel = _select(system, spec.get("selection"), "ion_pair_correlation.selection")
393
+ group_by = spec.get("group_by", "resid")
394
+ rclust_cat = spec.get("rclust_cat")
395
+ rclust_ani = spec.get("rclust_ani")
396
+ if rclust_cat is None or rclust_ani is None:
397
+ raise ValueError("ion_pair_correlation.rclust_cat and rclust_ani are required")
398
+ group_types = _resolve_group_types(system, sel, group_by, spec.get("group_types"))
399
+ kwargs = _pick(
400
+ spec,
401
+ [
402
+ "cation_type",
403
+ "anion_type",
404
+ "max_cluster",
405
+ "length_scale",
406
+ "lag_mode",
407
+ "max_lag",
408
+ "memory_budget_bytes",
409
+ "multi_tau_m",
410
+ "multi_tau_levels",
411
+ ],
412
+ )
413
+ if group_types is not None:
414
+ kwargs["group_types"] = group_types
415
+ return IonPairCorrelationPlan(sel, rclust_cat, rclust_ani, group_by=group_by, **kwargs)
416
+
417
+
418
+ def _build_structure_factor(system: System, spec: Dict[str, Any]):
419
+ sel = _select(system, spec.get("selection"), "structure_factor.selection")
420
+ bins = spec.get("bins")
421
+ r_max = spec.get("r_max")
422
+ q_bins = spec.get("q_bins")
423
+ q_max = spec.get("q_max")
424
+ if None in (bins, r_max, q_bins, q_max):
425
+ raise ValueError("structure_factor requires bins, r_max, q_bins, q_max")
426
+ kwargs = _pick(spec, ["pbc", "length_scale"])
427
+ return StructureFactorPlan(sel, bins, r_max, q_bins, q_max, **kwargs)
428
+
429
+
430
+ def _build_water_count(system: System, spec: Dict[str, Any]):
431
+ water_sel = _select(system, spec.get("water_selection"), "water_count.water_selection")
432
+ center_sel = _select(system, spec.get("center_selection"), "water_count.center_selection")
433
+ box_unit = spec.get("box_unit")
434
+ region_size = spec.get("region_size")
435
+ if box_unit is None or region_size is None:
436
+ raise ValueError("water_count requires box_unit and region_size")
437
+ kwargs = _pick(spec, ["shift", "length_scale"])
438
+ kwargs["box_unit"] = _as_tuple(box_unit, 3, "box_unit")
439
+ kwargs["region_size"] = _as_tuple(region_size, 3, "region_size")
440
+ if "shift" in kwargs:
441
+ kwargs["shift"] = _as_tuple(kwargs["shift"], 3, "shift")
442
+ return WaterCountPlan(water_sel, center_sel, **kwargs)
443
+
444
+
445
+ def _build_equipartition(system: System, spec: Dict[str, Any]):
446
+ sel = _select(system, spec.get("selection"), "equipartition.selection")
447
+ group_by = spec.get("group_by", "resid")
448
+ group_types = _resolve_group_types(system, sel, group_by, spec.get("group_types"))
449
+ kwargs = _pick(spec, ["velocity_scale", "length_scale"])
450
+ if group_types is not None:
451
+ kwargs["group_types"] = group_types
452
+ return EquipartitionPlan(sel, group_by=group_by, **kwargs)
453
+
454
+
455
+ def _build_hbond(system: System, spec: Dict[str, Any]):
456
+ donors = _select(system, spec.get("donors"), "hbond.donors")
457
+ acceptors = _select(system, spec.get("acceptors"), "hbond.acceptors")
458
+ dist_cutoff = spec.get("dist_cutoff")
459
+ if dist_cutoff is None:
460
+ raise ValueError("hbond.dist_cutoff is required")
461
+ hydrogens_expr = spec.get("hydrogens")
462
+ angle_cutoff = spec.get("angle_cutoff")
463
+ if hydrogens_expr:
464
+ hydrogens = _select(system, hydrogens_expr, "hbond.hydrogens")
465
+ return HbondPlan(donors, acceptors, dist_cutoff, hydrogens=hydrogens, angle_cutoff=angle_cutoff)
466
+ return HbondPlan(donors, acceptors, dist_cutoff)
467
+
468
+
469
+ def _build_rdf(system: System, spec: Dict[str, Any]):
470
+ sel_a = _select(system, spec.get("sel_a"), "rdf.sel_a")
471
+ sel_b = _select(system, spec.get("sel_b"), "rdf.sel_b")
472
+ bins = spec.get("bins")
473
+ r_max = spec.get("r_max")
474
+ if bins is None or r_max is None:
475
+ raise ValueError("rdf requires bins and r_max")
476
+ kwargs = _pick(spec, ["pbc"])
477
+ return RdfPlan(sel_a, sel_b, bins, r_max, **kwargs)
478
+
479
+
480
+ def _build_end_to_end(system: System, spec: Dict[str, Any]):
481
+ sel = _select(system, spec.get("selection"), "end_to_end.selection")
482
+ return EndToEndPlan(sel)
483
+
484
+
485
+ def _build_contour_length(system: System, spec: Dict[str, Any]):
486
+ sel = _select(system, spec.get("selection"), "contour_length.selection")
487
+ return ContourLengthPlan(sel)
488
+
489
+
490
+ def _build_chain_rg(system: System, spec: Dict[str, Any]):
491
+ sel = _select(system, spec.get("selection"), "chain_rg.selection")
492
+ return ChainRgPlan(sel)
493
+
494
+
495
+ def _build_bond_length(system: System, spec: Dict[str, Any]):
496
+ sel = _select(system, spec.get("selection"), "bond_length_distribution.selection")
497
+ bins = spec.get("bins")
498
+ r_max = spec.get("r_max")
499
+ if bins is None or r_max is None:
500
+ raise ValueError("bond_length_distribution requires bins and r_max")
501
+ return BondLengthDistributionPlan(sel, bins, r_max)
502
+
503
+
504
+ def _build_bond_angle(system: System, spec: Dict[str, Any]):
505
+ sel = _select(system, spec.get("selection"), "bond_angle_distribution.selection")
506
+ bins = spec.get("bins")
507
+ if bins is None:
508
+ raise ValueError("bond_angle_distribution requires bins")
509
+ kwargs = _pick(spec, ["degrees"])
510
+ return BondAngleDistributionPlan(sel, bins, **kwargs)
511
+
512
+
513
+ def _build_persistence(system: System, spec: Dict[str, Any]):
514
+ sel = _select(system, spec.get("selection"), "persistence_length.selection")
515
+ return PersistenceLengthPlan(sel)
516
+
517
+
518
+ PLAN_BUILDERS = {
519
+ "rg": _build_rg,
520
+ "rmsd": _build_rmsd,
521
+ "msd": _build_msd,
522
+ "rotacf": _build_rotacf,
523
+ "conductivity": _build_conductivity,
524
+ "dielectric": _build_dielectric,
525
+ "dipole_alignment": _build_dipole_alignment,
526
+ "ion_pair_correlation": _build_ion_pair,
527
+ "structure_factor": _build_structure_factor,
528
+ "water_count": _build_water_count,
529
+ "equipartition": _build_equipartition,
530
+ "hbond": _build_hbond,
531
+ "rdf": _build_rdf,
532
+ "end_to_end": _build_end_to_end,
533
+ "contour_length": _build_contour_length,
534
+ "chain_rg": _build_chain_rg,
535
+ "bond_length_distribution": _build_bond_length,
536
+ "bond_angle_distribution": _build_bond_angle,
537
+ "persistence_length": _build_persistence,
538
+ }
539
+
540
+
541
+ CLI_TO_PLAN = {
542
+ "rg": "rg",
543
+ "rmsd": "rmsd",
544
+ "msd": "msd",
545
+ "rotacf": "rotacf",
546
+ "conductivity": "conductivity",
547
+ "dielectric": "dielectric",
548
+ "dipole-alignment": "dipole_alignment",
549
+ "ion-pair-correlation": "ion_pair_correlation",
550
+ "structure-factor": "structure_factor",
551
+ "water-count": "water_count",
552
+ "equipartition": "equipartition",
553
+ "hbond": "hbond",
554
+ "rdf": "rdf",
555
+ "end-to-end": "end_to_end",
556
+ "contour-length": "contour_length",
557
+ "chain-rg": "chain_rg",
558
+ "bond-length-distribution": "bond_length_distribution",
559
+ "bond-angle-distribution": "bond_angle_distribution",
560
+ "persistence-length": "persistence_length",
561
+ }
562
+
563
+
564
+ def _default_out(name: str, output_dir: str, used: Dict[str, int]) -> str:
565
+ count = used.get(name, 0)
566
+ used[name] = count + 1
567
+ suffix = "" if count == 0 else f"_{count}"
568
+ return str(Path(output_dir) / f"{name}{suffix}.npz")
569
+
570
+
571
+ def run_config(config_path: str, dry_run: bool = False) -> None:
572
+ cfg = _load_config(config_path)
573
+ system_spec = _normalize_system_spec(cfg.get("system") or cfg.get("topology"))
574
+ traj_spec = _normalize_traj_spec(cfg.get("trajectory") or cfg.get("traj"))
575
+ system = _load_system(system_spec)
576
+ default_device = cfg.get("device", "auto")
577
+ default_chunk = cfg.get("chunk_frames")
578
+ output_dir = cfg.get("output_dir", ".")
579
+ if not dry_run:
580
+ Path(output_dir).mkdir(parents=True, exist_ok=True)
581
+
582
+ analyses = cfg.get("analyses")
583
+ if not analyses:
584
+ raise ValueError("config.analyses is required")
585
+
586
+ used_names: Dict[str, int] = {}
587
+ for item in analyses:
588
+ name = item.get("name")
589
+ if name not in PLAN_BUILDERS:
590
+ alt_name = name.replace("-", "_") if isinstance(name, str) else name
591
+ if alt_name not in PLAN_BUILDERS:
592
+ raise ValueError(f"unknown analysis name: {name}")
593
+ name = alt_name
594
+ out_path = item.get("out") or _default_out(name, output_dir, used_names)
595
+ if dry_run:
596
+ print(f"{name} -> {out_path}")
597
+ continue
598
+ plan = PLAN_BUILDERS[name](system, item)
599
+ traj = _load_trajectory(traj_spec, system)
600
+ device = item.get("device", default_device)
601
+ chunk = item.get("chunk_frames", default_chunk)
602
+ output = plan.run(traj, system, chunk_frames=chunk, device=device)
603
+ _save_output(out_path, output)
604
+ print(f"{name}: wrote {out_path}")
605
+
606
+
607
+ def list_plans() -> None:
608
+ for name in sorted(CLI_TO_PLAN.keys()):
609
+ print(name)
610
+
611
+
612
+ def example_config() -> None:
613
+ example = {
614
+ "system": {"path": "topology.pdb"},
615
+ "trajectory": {"path": "traj.xtc"},
616
+ "device": "auto",
617
+ "chunk_frames": 500,
618
+ "output_dir": "outputs",
619
+ "analyses": [
620
+ {
621
+ "name": "rg",
622
+ "selection": "protein",
623
+ "mass_weighted": False,
624
+ },
625
+ {
626
+ "name": "rdf",
627
+ "sel_a": "resname SOL and name OW",
628
+ "sel_b": "resname SOL and name OW",
629
+ "bins": 200,
630
+ "r_max": 10.0,
631
+ },
632
+ ],
633
+ }
634
+ print(json.dumps(example, indent=2))
635
+
636
+
637
+ def _split_values(raw: str) -> list[str]:
638
+ if "," in raw:
639
+ parts = [part.strip() for part in raw.split(",")]
640
+ else:
641
+ parts = raw.split()
642
+ return [part for part in parts if part]
643
+
644
+
645
+ def _parse_float_tuple(raw: str, size: int, label: str) -> tuple[float, ...]:
646
+ values = _split_values(raw)
647
+ if len(values) != size:
648
+ raise ValueError(f"{label} must have {size} values")
649
+ return tuple(float(v) for v in values)
650
+
651
+
652
+ def _parse_int_tuple(raw: str, size: int, label: str) -> tuple[int, ...]:
653
+ values = _split_values(raw)
654
+ if len(values) != size:
655
+ raise ValueError(f"{label} must have {size} values")
656
+ return tuple(int(v) for v in values)
657
+
658
+
659
+ def _parse_int_list(raw: str, label: str) -> list[int]:
660
+ values = _split_values(raw)
661
+ if not values:
662
+ raise ValueError(f"{label} must have at least one value")
663
+ return [int(v) for v in values]
664
+
665
+
666
+ def _parse_json_list(raw: str, label: str) -> list[Any]:
667
+ try:
668
+ data = json.loads(raw)
669
+ except json.JSONDecodeError as exc:
670
+ raise ValueError(f"{label} must be valid JSON") from exc
671
+ if not isinstance(data, list):
672
+ raise ValueError(f"{label} must be a JSON list")
673
+ return data
674
+
675
+
676
+ def _parse_charges_arg(raw: str, system: System) -> list[float]:
677
+ if raw.startswith("table:"):
678
+ path = raw[len("table:") :].strip()
679
+ if not path:
680
+ raise ValueError("charges table path is required")
681
+ return charges_from_table(system, path)
682
+ if raw.startswith("selections:"):
683
+ payload = raw[len("selections:") :].strip()
684
+ entries = _parse_json_list(payload, "charges selections")
685
+ return charges_from_selections(system, entries)
686
+ data = _parse_json_list(raw, "charges")
687
+ return [float(x) for x in data]
688
+
689
+
690
+ def _parse_group_types_arg(
691
+ raw: Optional[str],
692
+ system: System,
693
+ selection,
694
+ group_by: str,
695
+ ) -> Optional[list[int]]:
696
+ if raw is None:
697
+ return None
698
+ if raw.startswith("selections:"):
699
+ payload = raw[len("selections:") :].strip()
700
+ if payload.startswith("["):
701
+ selections = _parse_json_list(payload, "group_types selections")
702
+ else:
703
+ selections = [s.strip() for s in payload.split(",") if s.strip()]
704
+ if not selections:
705
+ raise ValueError("group_types selections cannot be empty")
706
+ return group_types_from_selections(system, selection, group_by, selections)
707
+ data = _parse_json_list(raw, "group_types")
708
+ return [int(x) for x in data]
709
+
710
+
711
+ def _summary_from_output(output: Any, analysis: str, out_path: Path) -> Dict[str, Any]:
712
+ if isinstance(output, np.generic):
713
+ output = output.item()
714
+ summary: Dict[str, Any] = {
715
+ "analysis": analysis,
716
+ "out": str(out_path),
717
+ }
718
+ if isinstance(output, np.ndarray):
719
+ summary.update(
720
+ {
721
+ "kind": "array",
722
+ "shape": list(output.shape),
723
+ "dtype": str(output.dtype),
724
+ "keys": ["data"],
725
+ }
726
+ )
727
+ return summary
728
+ if isinstance(output, dict):
729
+ summary["kind"] = "dict"
730
+ summary["keys"] = [str(k) for k in output.keys()]
731
+ summary["shapes"] = {str(k): list(np.asarray(v).shape) for k, v in output.items()}
732
+ return summary
733
+ if isinstance(output, (list, tuple)):
734
+ summary["kind"] = "tuple"
735
+ summary["keys"] = [f"arr_{i}" for i in range(len(output))]
736
+ summary["shapes"] = {
737
+ f"arr_{i}": list(np.asarray(v).shape) for i, v in enumerate(output)
738
+ }
739
+ return summary
740
+ summary["kind"] = "scalar"
741
+ summary["value"] = output
742
+ return summary
743
+
744
+
745
+ def _print_summary(summary: Dict[str, Any], fmt: str) -> None:
746
+ if fmt == "json":
747
+ print(json.dumps(summary, indent=2))
748
+ return
749
+ print(f"analysis: {summary.get('analysis')}")
750
+ print(f"out: {summary.get('out')}")
751
+ print(f"kind: {summary.get('kind')}")
752
+ if "keys" in summary:
753
+ print("keys: " + ", ".join(summary["keys"]))
754
+ if "shape" in summary:
755
+ print(f"shape: {summary['shape']}")
756
+ if "dtype" in summary:
757
+ print(f"dtype: {summary['dtype']}")
758
+ if "value" in summary:
759
+ print(f"value: {summary['value']}")
760
+
761
+
762
+ def _infer_format(path: str) -> str:
763
+ return Path(path).suffix.lower().lstrip(".")
764
+
765
+
766
+ def _load_system_from_args(args: argparse.Namespace) -> System:
767
+ fmt = args.topology_format or _infer_format(args.topology)
768
+ spec = {"path": args.topology, "format": fmt}
769
+ return _load_system(spec)
770
+
771
+
772
+ def _load_traj_from_args(args: argparse.Namespace, system: System) -> Trajectory:
773
+ fmt = args.traj_format or _infer_format(args.traj)
774
+ spec = {
775
+ "path": args.traj,
776
+ "format": fmt,
777
+ "length_scale": args.traj_length_scale,
778
+ }
779
+ return _load_trajectory(spec, system)
780
+
781
+
782
+ def add_shared_args(parser: argparse.ArgumentParser) -> None:
783
+ parser.add_argument("--topology", required=True, help="Topology file (.pdb or .gro)")
784
+ parser.add_argument("--traj", required=True, help="Trajectory file (.dcd or .xtc)")
785
+ parser.add_argument(
786
+ "--topology-format",
787
+ choices=["pdb", "gro"],
788
+ help="Override topology format",
789
+ )
790
+ parser.add_argument(
791
+ "--traj-format",
792
+ choices=["dcd", "xtc"],
793
+ help="Override trajectory format",
794
+ )
795
+ parser.add_argument(
796
+ "--traj-length-scale",
797
+ type=float,
798
+ help="DCD length scale (e.g., 10.0 for nm->A)",
799
+ )
800
+ parser.add_argument("--device", default="auto", help="auto|cpu|cuda|cuda:0")
801
+ parser.add_argument("--chunk-frames", type=int, help="Frames per chunk")
802
+ parser.add_argument("--out", help="Output path (.npz/.npy/.csv/.json)")
803
+ summary_group = parser.add_mutually_exclusive_group()
804
+ summary_group.add_argument(
805
+ "--print-summary",
806
+ dest="print_summary",
807
+ action="store_true",
808
+ help="Print a JSON/text summary to stdout",
809
+ )
810
+ summary_group.add_argument(
811
+ "--no-summary",
812
+ dest="print_summary",
813
+ action="store_false",
814
+ help="Disable summary output",
815
+ )
816
+ parser.set_defaults(print_summary=True)
817
+ parser.add_argument(
818
+ "--summary-format",
819
+ choices=["json", "text"],
820
+ default="json",
821
+ help="Summary format",
822
+ )
823
+
824
+
825
+ def add_dynamics_args(parser: argparse.ArgumentParser) -> None:
826
+ parser.add_argument(
827
+ "--frame-decimation",
828
+ help="start,stride (e.g., 0,10)",
829
+ )
830
+ parser.add_argument(
831
+ "--dt-decimation",
832
+ help="cut1,stride1,cut2,stride2",
833
+ )
834
+ parser.add_argument(
835
+ "--time-binning",
836
+ help="eps_num,eps_add",
837
+ )
838
+ parser.add_argument(
839
+ "--lag-mode",
840
+ choices=["auto", "multi_tau", "ring", "fft"],
841
+ help="Lag mode (auto/multi_tau/ring/fft)",
842
+ )
843
+ parser.add_argument("--max-lag", type=int, help="Max lag (ring mode)")
844
+ parser.add_argument("--memory-budget-bytes", type=int, help="Memory budget")
845
+ parser.add_argument("--multi-tau-m", type=int, help="Multi-tau m")
846
+ parser.add_argument("--multi-tau-levels", type=int, help="Multi-tau levels")
847
+
848
+
849
+ def add_group_types_args(parser: argparse.ArgumentParser) -> None:
850
+ parser.add_argument(
851
+ "--group-types",
852
+ help=(
853
+ "JSON list or selections:<sel1,sel2>. "
854
+ "Example: --group-types '[0,1,1]' or --group-types 'selections:resname NA,resname CL'"
855
+ ),
856
+ )
857
+
858
+
859
+ def setup_rg_args(parser: argparse.ArgumentParser) -> None:
860
+ parser.add_argument("--selection", required=True, help="Selection string")
861
+ parser.add_argument("--mass-weighted", action="store_true", help="Mass-weighted Rg")
862
+
863
+
864
+ def setup_rmsd_args(parser: argparse.ArgumentParser) -> None:
865
+ parser.add_argument("--selection", required=True, help="Selection string")
866
+ parser.add_argument(
867
+ "--reference",
868
+ choices=["topology", "frame0"],
869
+ default="topology",
870
+ help="Reference frame",
871
+ )
872
+ parser.add_argument(
873
+ "--align",
874
+ action=argparse.BooleanOptionalAction,
875
+ default=True,
876
+ help="Align before RMSD",
877
+ )
878
+
879
+
880
+ def setup_msd_args(parser: argparse.ArgumentParser) -> None:
881
+ parser.add_argument("--selection", required=True, help="Selection string")
882
+ parser.add_argument(
883
+ "--group-by",
884
+ choices=["resid", "chain", "resid_chain"],
885
+ default="resid",
886
+ help="Group-by mode",
887
+ )
888
+ parser.add_argument("--axis", help="x,y,z axis components")
889
+ parser.add_argument("--length-scale", type=float, help="Length scale")
890
+ add_group_types_args(parser)
891
+ add_dynamics_args(parser)
892
+
893
+
894
+ def setup_rotacf_args(parser: argparse.ArgumentParser) -> None:
895
+ parser.add_argument("--selection", required=True, help="Selection string")
896
+ parser.add_argument(
897
+ "--group-by",
898
+ choices=["resid", "chain", "resid_chain"],
899
+ default="resid",
900
+ help="Group-by mode",
901
+ )
902
+ parser.add_argument("--orientation", required=True, help="Indices (2 or 3) within group")
903
+ parser.add_argument(
904
+ "--p2-legendre",
905
+ action=argparse.BooleanOptionalAction,
906
+ default=True,
907
+ help="Use P2 Legendre",
908
+ )
909
+ parser.add_argument("--length-scale", type=float, help="Length scale")
910
+ add_group_types_args(parser)
911
+ add_dynamics_args(parser)
912
+
913
+
914
+ def setup_conductivity_args(parser: argparse.ArgumentParser) -> None:
915
+ parser.add_argument("--selection", required=True, help="Selection string")
916
+ parser.add_argument(
917
+ "--charges",
918
+ required=True,
919
+ help=(
920
+ "Charges: JSON list, table:path, or selections:[{selection,charge},...]"
921
+ ),
922
+ )
923
+ parser.add_argument("--temperature", type=float, required=True, help="Temperature (K)")
924
+ parser.add_argument(
925
+ "--group-by",
926
+ choices=["resid", "chain", "resid_chain"],
927
+ default="resid",
928
+ help="Group-by mode",
929
+ )
930
+ parser.add_argument(
931
+ "--transference",
932
+ action=argparse.BooleanOptionalAction,
933
+ default=False,
934
+ help="Compute transference matrix",
935
+ )
936
+ parser.add_argument("--length-scale", type=float, help="Length scale")
937
+ add_group_types_args(parser)
938
+ add_dynamics_args(parser)
939
+
940
+
941
+ def setup_dielectric_args(parser: argparse.ArgumentParser) -> None:
942
+ parser.add_argument("--selection", required=True, help="Selection string")
943
+ parser.add_argument(
944
+ "--charges",
945
+ required=True,
946
+ help="Charges: JSON list, table:path, or selections:[{selection,charge},...]",
947
+ )
948
+ parser.add_argument(
949
+ "--group-by",
950
+ choices=["resid", "chain", "resid_chain"],
951
+ default="resid",
952
+ help="Group-by mode",
953
+ )
954
+ parser.add_argument("--length-scale", type=float, help="Length scale")
955
+ add_group_types_args(parser)
956
+
957
+
958
+ def setup_dipole_alignment_args(parser: argparse.ArgumentParser) -> None:
959
+ parser.add_argument("--selection", required=True, help="Selection string")
960
+ parser.add_argument(
961
+ "--charges",
962
+ required=True,
963
+ help="Charges: JSON list, table:path, or selections:[{selection,charge},...]",
964
+ )
965
+ parser.add_argument(
966
+ "--group-by",
967
+ choices=["resid", "chain", "resid_chain"],
968
+ default="resid",
969
+ help="Group-by mode",
970
+ )
971
+ parser.add_argument("--length-scale", type=float, help="Length scale")
972
+ add_group_types_args(parser)
973
+
974
+
975
+ def setup_ion_pair_args(parser: argparse.ArgumentParser) -> None:
976
+ parser.add_argument("--selection", required=True, help="Selection string")
977
+ parser.add_argument("--rclust-cat", type=float, required=True, help="Cation cutoff")
978
+ parser.add_argument("--rclust-ani", type=float, required=True, help="Anion cutoff")
979
+ parser.add_argument(
980
+ "--group-by",
981
+ choices=["resid", "chain", "resid_chain"],
982
+ default="resid",
983
+ help="Group-by mode",
984
+ )
985
+ parser.add_argument("--cation-type", type=int, default=0, help="Cation type index")
986
+ parser.add_argument("--anion-type", type=int, default=1, help="Anion type index")
987
+ parser.add_argument("--max-cluster", type=int, default=10, help="Max cluster size")
988
+ parser.add_argument("--length-scale", type=float, help="Length scale")
989
+ add_group_types_args(parser)
990
+ add_dynamics_args(parser)
991
+
992
+
993
+ def setup_structure_factor_args(parser: argparse.ArgumentParser) -> None:
994
+ parser.add_argument("--selection", required=True, help="Selection string")
995
+ parser.add_argument("--bins", type=int, required=True, help="r-space bins")
996
+ parser.add_argument("--r-max", type=float, required=True, help="r-space max (A)")
997
+ parser.add_argument("--q-bins", type=int, required=True, help="q-space bins")
998
+ parser.add_argument("--q-max", type=float, required=True, help="q-space max (1/A)")
999
+ parser.add_argument(
1000
+ "--pbc",
1001
+ choices=["orthorhombic", "none"],
1002
+ default="orthorhombic",
1003
+ help="PBC mode",
1004
+ )
1005
+ parser.add_argument("--length-scale", type=float, help="Length scale")
1006
+
1007
+
1008
+ def setup_water_count_args(parser: argparse.ArgumentParser) -> None:
1009
+ parser.add_argument("--water-selection", required=True, help="Water selection")
1010
+ parser.add_argument("--center-selection", required=True, help="Center selection")
1011
+ parser.add_argument("--box-unit", required=True, help="Box unit (x,y,z)")
1012
+ parser.add_argument("--region-size", required=True, help="Region size (x,y,z)")
1013
+ parser.add_argument("--shift", help="Shift (x,y,z)")
1014
+ parser.add_argument("--length-scale", type=float, help="Length scale")
1015
+
1016
+
1017
+ def setup_equipartition_args(parser: argparse.ArgumentParser) -> None:
1018
+ parser.add_argument("--selection", required=True, help="Selection string")
1019
+ parser.add_argument(
1020
+ "--group-by",
1021
+ choices=["resid", "chain", "resid_chain"],
1022
+ default="resid",
1023
+ help="Group-by mode",
1024
+ )
1025
+ parser.add_argument("--velocity-scale", type=float, help="Velocity scale")
1026
+ parser.add_argument("--length-scale", type=float, help="Length scale")
1027
+ add_group_types_args(parser)
1028
+
1029
+
1030
+ def setup_hbond_args(parser: argparse.ArgumentParser) -> None:
1031
+ parser.add_argument("--donors", required=True, help="Donor selection")
1032
+ parser.add_argument("--acceptors", required=True, help="Acceptor selection")
1033
+ parser.add_argument("--dist-cutoff", type=float, required=True, help="Distance cutoff (A)")
1034
+ parser.add_argument("--hydrogens", help="Hydrogen selection")
1035
+ parser.add_argument("--angle-cutoff", type=float, help="Angle cutoff (deg)")
1036
+
1037
+
1038
+ def setup_rdf_args(parser: argparse.ArgumentParser) -> None:
1039
+ parser.add_argument("--sel-a", required=True, help="Selection A")
1040
+ parser.add_argument("--sel-b", required=True, help="Selection B")
1041
+ parser.add_argument("--bins", type=int, required=True, help="Number of bins")
1042
+ parser.add_argument("--r-max", type=float, required=True, help="Max distance (A)")
1043
+ parser.add_argument(
1044
+ "--pbc",
1045
+ choices=["orthorhombic", "none"],
1046
+ default="orthorhombic",
1047
+ help="PBC mode",
1048
+ )
1049
+
1050
+
1051
+ def setup_end_to_end_args(parser: argparse.ArgumentParser) -> None:
1052
+ parser.add_argument("--selection", required=True, help="Selection string")
1053
+
1054
+
1055
+ def setup_contour_length_args(parser: argparse.ArgumentParser) -> None:
1056
+ parser.add_argument("--selection", required=True, help="Selection string")
1057
+
1058
+
1059
+ def setup_chain_rg_args(parser: argparse.ArgumentParser) -> None:
1060
+ parser.add_argument("--selection", required=True, help="Selection string")
1061
+
1062
+
1063
+ def setup_bond_length_args(parser: argparse.ArgumentParser) -> None:
1064
+ parser.add_argument("--selection", required=True, help="Selection string")
1065
+ parser.add_argument("--bins", type=int, required=True, help="Number of bins")
1066
+ parser.add_argument("--r-max", type=float, required=True, help="Max distance (A)")
1067
+
1068
+
1069
+ def setup_bond_angle_args(parser: argparse.ArgumentParser) -> None:
1070
+ parser.add_argument("--selection", required=True, help="Selection string")
1071
+ parser.add_argument("--bins", type=int, required=True, help="Number of bins")
1072
+ parser.add_argument(
1073
+ "--degrees",
1074
+ action=argparse.BooleanOptionalAction,
1075
+ default=True,
1076
+ help="Return degrees (default true)",
1077
+ )
1078
+
1079
+
1080
+ def setup_persistence_args(parser: argparse.ArgumentParser) -> None:
1081
+ parser.add_argument("--selection", required=True, help="Selection string")
1082
+
1083
+
1084
+ REGISTRY = {
1085
+ "rg": setup_rg_args,
1086
+ "rmsd": setup_rmsd_args,
1087
+ "msd": setup_msd_args,
1088
+ "rotacf": setup_rotacf_args,
1089
+ "conductivity": setup_conductivity_args,
1090
+ "dielectric": setup_dielectric_args,
1091
+ "dipole-alignment": setup_dipole_alignment_args,
1092
+ "ion-pair-correlation": setup_ion_pair_args,
1093
+ "structure-factor": setup_structure_factor_args,
1094
+ "water-count": setup_water_count_args,
1095
+ "equipartition": setup_equipartition_args,
1096
+ "hbond": setup_hbond_args,
1097
+ "rdf": setup_rdf_args,
1098
+ "end-to-end": setup_end_to_end_args,
1099
+ "contour-length": setup_contour_length_args,
1100
+ "chain-rg": setup_chain_rg_args,
1101
+ "bond-length-distribution": setup_bond_length_args,
1102
+ "bond-angle-distribution": setup_bond_angle_args,
1103
+ "persistence-length": setup_persistence_args,
1104
+ }
1105
+
1106
+
1107
+ def _spec_rg(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1108
+ return {
1109
+ "selection": args.selection,
1110
+ "mass_weighted": args.mass_weighted,
1111
+ }
1112
+
1113
+
1114
+ def _spec_rmsd(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1115
+ return {
1116
+ "selection": args.selection,
1117
+ "reference": args.reference,
1118
+ "align": args.align,
1119
+ }
1120
+
1121
+
1122
+ def _spec_msd(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1123
+ spec: Dict[str, Any] = {
1124
+ "selection": args.selection,
1125
+ "group_by": args.group_by,
1126
+ }
1127
+ if args.axis:
1128
+ spec["axis"] = _parse_float_tuple(args.axis, 3, "axis")
1129
+ if args.length_scale is not None:
1130
+ spec["length_scale"] = args.length_scale
1131
+ if args.frame_decimation:
1132
+ spec["frame_decimation"] = _parse_int_tuple(args.frame_decimation, 2, "frame_decimation")
1133
+ if args.dt_decimation:
1134
+ spec["dt_decimation"] = _parse_int_tuple(args.dt_decimation, 4, "dt_decimation")
1135
+ if args.time_binning:
1136
+ spec["time_binning"] = _parse_float_tuple(args.time_binning, 2, "time_binning")
1137
+ if args.lag_mode:
1138
+ spec["lag_mode"] = args.lag_mode
1139
+ if args.max_lag is not None:
1140
+ spec["max_lag"] = args.max_lag
1141
+ if args.memory_budget_bytes is not None:
1142
+ spec["memory_budget_bytes"] = args.memory_budget_bytes
1143
+ if args.multi_tau_m is not None:
1144
+ spec["multi_tau_m"] = args.multi_tau_m
1145
+ if args.multi_tau_levels is not None:
1146
+ spec["multi_tau_levels"] = args.multi_tau_levels
1147
+ selection = _select(system, args.selection, "msd.selection")
1148
+ group_types = _parse_group_types_arg(args.group_types, system, selection, args.group_by)
1149
+ if group_types is not None:
1150
+ spec["group_types"] = group_types
1151
+ return spec
1152
+
1153
+
1154
+ def _spec_rotacf(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1155
+ spec: Dict[str, Any] = {
1156
+ "selection": args.selection,
1157
+ "group_by": args.group_by,
1158
+ "p2_legendre": args.p2_legendre,
1159
+ }
1160
+ orient = _parse_int_list(args.orientation, "orientation")
1161
+ if len(orient) not in (2, 3):
1162
+ raise ValueError("orientation must have 2 or 3 indices")
1163
+ spec["orientation"] = orient
1164
+ if args.length_scale is not None:
1165
+ spec["length_scale"] = args.length_scale
1166
+ if args.frame_decimation:
1167
+ spec["frame_decimation"] = _parse_int_tuple(args.frame_decimation, 2, "frame_decimation")
1168
+ if args.dt_decimation:
1169
+ spec["dt_decimation"] = _parse_int_tuple(args.dt_decimation, 4, "dt_decimation")
1170
+ if args.time_binning:
1171
+ spec["time_binning"] = _parse_float_tuple(args.time_binning, 2, "time_binning")
1172
+ if args.lag_mode:
1173
+ spec["lag_mode"] = args.lag_mode
1174
+ if args.max_lag is not None:
1175
+ spec["max_lag"] = args.max_lag
1176
+ if args.memory_budget_bytes is not None:
1177
+ spec["memory_budget_bytes"] = args.memory_budget_bytes
1178
+ if args.multi_tau_m is not None:
1179
+ spec["multi_tau_m"] = args.multi_tau_m
1180
+ if args.multi_tau_levels is not None:
1181
+ spec["multi_tau_levels"] = args.multi_tau_levels
1182
+ selection = _select(system, args.selection, "rotacf.selection")
1183
+ group_types = _parse_group_types_arg(args.group_types, system, selection, args.group_by)
1184
+ if group_types is not None:
1185
+ spec["group_types"] = group_types
1186
+ return spec
1187
+
1188
+
1189
+ def _spec_conductivity(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1190
+ spec: Dict[str, Any] = {
1191
+ "selection": args.selection,
1192
+ "group_by": args.group_by,
1193
+ "temperature": args.temperature,
1194
+ "transference": args.transference,
1195
+ "charges": _parse_charges_arg(args.charges, system),
1196
+ }
1197
+ if args.length_scale is not None:
1198
+ spec["length_scale"] = args.length_scale
1199
+ if args.frame_decimation:
1200
+ spec["frame_decimation"] = _parse_int_tuple(args.frame_decimation, 2, "frame_decimation")
1201
+ if args.dt_decimation:
1202
+ spec["dt_decimation"] = _parse_int_tuple(args.dt_decimation, 4, "dt_decimation")
1203
+ if args.time_binning:
1204
+ spec["time_binning"] = _parse_float_tuple(args.time_binning, 2, "time_binning")
1205
+ if args.lag_mode:
1206
+ spec["lag_mode"] = args.lag_mode
1207
+ if args.max_lag is not None:
1208
+ spec["max_lag"] = args.max_lag
1209
+ if args.memory_budget_bytes is not None:
1210
+ spec["memory_budget_bytes"] = args.memory_budget_bytes
1211
+ if args.multi_tau_m is not None:
1212
+ spec["multi_tau_m"] = args.multi_tau_m
1213
+ if args.multi_tau_levels is not None:
1214
+ spec["multi_tau_levels"] = args.multi_tau_levels
1215
+ selection = _select(system, args.selection, "conductivity.selection")
1216
+ group_types = _parse_group_types_arg(args.group_types, system, selection, args.group_by)
1217
+ if group_types is not None:
1218
+ spec["group_types"] = group_types
1219
+ return spec
1220
+
1221
+
1222
+ def _spec_dielectric(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1223
+ spec: Dict[str, Any] = {
1224
+ "selection": args.selection,
1225
+ "group_by": args.group_by,
1226
+ "charges": _parse_charges_arg(args.charges, system),
1227
+ }
1228
+ if args.length_scale is not None:
1229
+ spec["length_scale"] = args.length_scale
1230
+ selection = _select(system, args.selection, "dielectric.selection")
1231
+ group_types = _parse_group_types_arg(args.group_types, system, selection, args.group_by)
1232
+ if group_types is not None:
1233
+ spec["group_types"] = group_types
1234
+ return spec
1235
+
1236
+
1237
+ def _spec_dipole_alignment(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1238
+ spec: Dict[str, Any] = {
1239
+ "selection": args.selection,
1240
+ "group_by": args.group_by,
1241
+ "charges": _parse_charges_arg(args.charges, system),
1242
+ }
1243
+ if args.length_scale is not None:
1244
+ spec["length_scale"] = args.length_scale
1245
+ selection = _select(system, args.selection, "dipole_alignment.selection")
1246
+ group_types = _parse_group_types_arg(args.group_types, system, selection, args.group_by)
1247
+ if group_types is not None:
1248
+ spec["group_types"] = group_types
1249
+ return spec
1250
+
1251
+
1252
+ def _spec_ion_pair(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1253
+ spec: Dict[str, Any] = {
1254
+ "selection": args.selection,
1255
+ "group_by": args.group_by,
1256
+ "rclust_cat": args.rclust_cat,
1257
+ "rclust_ani": args.rclust_ani,
1258
+ "cation_type": args.cation_type,
1259
+ "anion_type": args.anion_type,
1260
+ "max_cluster": args.max_cluster,
1261
+ }
1262
+ if args.length_scale is not None:
1263
+ spec["length_scale"] = args.length_scale
1264
+ if args.lag_mode:
1265
+ spec["lag_mode"] = args.lag_mode
1266
+ if args.max_lag is not None:
1267
+ spec["max_lag"] = args.max_lag
1268
+ if args.memory_budget_bytes is not None:
1269
+ spec["memory_budget_bytes"] = args.memory_budget_bytes
1270
+ if args.multi_tau_m is not None:
1271
+ spec["multi_tau_m"] = args.multi_tau_m
1272
+ if args.multi_tau_levels is not None:
1273
+ spec["multi_tau_levels"] = args.multi_tau_levels
1274
+ selection = _select(system, args.selection, "ion_pair_correlation.selection")
1275
+ group_types = _parse_group_types_arg(args.group_types, system, selection, args.group_by)
1276
+ if group_types is not None:
1277
+ spec["group_types"] = group_types
1278
+ return spec
1279
+
1280
+
1281
+ def _spec_structure_factor(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1282
+ spec: Dict[str, Any] = {
1283
+ "selection": args.selection,
1284
+ "bins": args.bins,
1285
+ "r_max": args.r_max,
1286
+ "q_bins": args.q_bins,
1287
+ "q_max": args.q_max,
1288
+ "pbc": args.pbc,
1289
+ }
1290
+ if args.length_scale is not None:
1291
+ spec["length_scale"] = args.length_scale
1292
+ return spec
1293
+
1294
+
1295
+ def _spec_water_count(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1296
+ spec: Dict[str, Any] = {
1297
+ "water_selection": args.water_selection,
1298
+ "center_selection": args.center_selection,
1299
+ "box_unit": _parse_float_tuple(args.box_unit, 3, "box_unit"),
1300
+ "region_size": _parse_float_tuple(args.region_size, 3, "region_size"),
1301
+ }
1302
+ if args.shift:
1303
+ spec["shift"] = _parse_float_tuple(args.shift, 3, "shift")
1304
+ if args.length_scale is not None:
1305
+ spec["length_scale"] = args.length_scale
1306
+ return spec
1307
+
1308
+
1309
+ def _spec_equipartition(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1310
+ spec: Dict[str, Any] = {
1311
+ "selection": args.selection,
1312
+ "group_by": args.group_by,
1313
+ }
1314
+ if args.velocity_scale is not None:
1315
+ spec["velocity_scale"] = args.velocity_scale
1316
+ if args.length_scale is not None:
1317
+ spec["length_scale"] = args.length_scale
1318
+ selection = _select(system, args.selection, "equipartition.selection")
1319
+ group_types = _parse_group_types_arg(args.group_types, system, selection, args.group_by)
1320
+ if group_types is not None:
1321
+ spec["group_types"] = group_types
1322
+ return spec
1323
+
1324
+
1325
+ def _spec_hbond(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1326
+ spec: Dict[str, Any] = {
1327
+ "donors": args.donors,
1328
+ "acceptors": args.acceptors,
1329
+ "dist_cutoff": args.dist_cutoff,
1330
+ }
1331
+ if args.hydrogens:
1332
+ if args.angle_cutoff is None:
1333
+ raise ValueError("angle_cutoff is required when hydrogens are provided")
1334
+ spec["hydrogens"] = args.hydrogens
1335
+ spec["angle_cutoff"] = args.angle_cutoff
1336
+ return spec
1337
+
1338
+
1339
+ def _spec_rdf(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1340
+ return {
1341
+ "sel_a": args.sel_a,
1342
+ "sel_b": args.sel_b,
1343
+ "bins": args.bins,
1344
+ "r_max": args.r_max,
1345
+ "pbc": args.pbc,
1346
+ }
1347
+
1348
+
1349
+ def _spec_end_to_end(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1350
+ return {"selection": args.selection}
1351
+
1352
+
1353
+ def _spec_contour_length(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1354
+ return {"selection": args.selection}
1355
+
1356
+
1357
+ def _spec_chain_rg(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1358
+ return {"selection": args.selection}
1359
+
1360
+
1361
+ def _spec_bond_length(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1362
+ return {
1363
+ "selection": args.selection,
1364
+ "bins": args.bins,
1365
+ "r_max": args.r_max,
1366
+ }
1367
+
1368
+
1369
+ def _spec_bond_angle(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1370
+ return {
1371
+ "selection": args.selection,
1372
+ "bins": args.bins,
1373
+ "degrees": args.degrees,
1374
+ }
1375
+
1376
+
1377
+ def _spec_persistence(args: argparse.Namespace, system: System) -> Dict[str, Any]:
1378
+ return {"selection": args.selection}
1379
+
1380
+
1381
+ SPEC_BUILDERS = {
1382
+ "rg": _spec_rg,
1383
+ "rmsd": _spec_rmsd,
1384
+ "msd": _spec_msd,
1385
+ "rotacf": _spec_rotacf,
1386
+ "conductivity": _spec_conductivity,
1387
+ "dielectric": _spec_dielectric,
1388
+ "dipole-alignment": _spec_dipole_alignment,
1389
+ "ion-pair-correlation": _spec_ion_pair,
1390
+ "structure-factor": _spec_structure_factor,
1391
+ "water-count": _spec_water_count,
1392
+ "equipartition": _spec_equipartition,
1393
+ "hbond": _spec_hbond,
1394
+ "rdf": _spec_rdf,
1395
+ "end-to-end": _spec_end_to_end,
1396
+ "contour-length": _spec_contour_length,
1397
+ "chain-rg": _spec_chain_rg,
1398
+ "bond-length-distribution": _spec_bond_length,
1399
+ "bond-angle-distribution": _spec_bond_angle,
1400
+ "persistence-length": _spec_persistence,
1401
+ }
1402
+
1403
+
1404
+ def build_plan_from_args(args: argparse.Namespace, system: System):
1405
+ plan_name = CLI_TO_PLAN[args.analysis]
1406
+ spec = SPEC_BUILDERS[args.analysis](args, system)
1407
+ return PLAN_BUILDERS[plan_name](system, spec)
1408
+
1409
+
1410
+ def run_single_analysis(args: argparse.Namespace) -> None:
1411
+ system = _load_system_from_args(args)
1412
+ traj = _load_traj_from_args(args, system)
1413
+ plan = build_plan_from_args(args, system)
1414
+ output = plan.run(traj, system, chunk_frames=args.chunk_frames, device=args.device)
1415
+ out_path = Path(args.out or f"{args.analysis}.npz")
1416
+ _save_output(str(out_path), output)
1417
+ if args.print_summary:
1418
+ summary = _summary_from_output(output, args.analysis, out_path)
1419
+ _print_summary(summary, args.summary_format)
1420
+
1421
+
1422
+ def build_parser() -> argparse.ArgumentParser:
1423
+ parser = argparse.ArgumentParser(prog="warp-md")
1424
+ sub = parser.add_subparsers(dest="cmd", required=True)
1425
+
1426
+ run = sub.add_parser("run", help="run analyses from a JSON/YAML config")
1427
+ run.add_argument("config", help="path to config.json|yaml")
1428
+ run.add_argument("--dry-run", action="store_true", help="validate and show outputs")
1429
+
1430
+ sub.add_parser("list-plans", help="list available analysis names")
1431
+ sub.add_parser("example", help="print example config")
1432
+
1433
+ for name, setup in REGISTRY.items():
1434
+ help_text = f"Run {name} analysis"
1435
+ analysis = sub.add_parser(name, help=help_text, description=help_text)
1436
+ add_shared_args(analysis)
1437
+ setup(analysis)
1438
+
1439
+ return parser
1440
+
1441
+
1442
+ def main(argv: Optional[list[str]] = None) -> int:
1443
+ parser = build_parser()
1444
+ args = parser.parse_args(argv)
1445
+
1446
+ if args.cmd == "run":
1447
+ run_config(args.config, dry_run=args.dry_run)
1448
+ return 0
1449
+ if args.cmd == "list-plans":
1450
+ list_plans()
1451
+ return 0
1452
+ if args.cmd == "example":
1453
+ example_config()
1454
+ return 0
1455
+ if args.cmd in REGISTRY:
1456
+ args.analysis = args.cmd
1457
+ run_single_analysis(args)
1458
+ return 0
1459
+
1460
+ return 1
1461
+
1462
+
1463
+ if __name__ == "__main__":
1464
+ raise SystemExit(main())