adaptivepy-sampling 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptivepy/__init__.py +7 -0
- adaptivepy/api.py +229 -0
- adaptivepy/cli/__init__.py +5 -0
- adaptivepy/cli/run.py +68 -0
- adaptivepy/clustering/__init__.py +103 -0
- adaptivepy/clustering/base.py +73 -0
- adaptivepy/clustering/regular_space.py +135 -0
- adaptivepy/clustering/sklearn_kmeans.py +93 -0
- adaptivepy/clustering/sklearn_minibatch.py +94 -0
- adaptivepy/config/__init__.py +17 -0
- adaptivepy/config/schema.py +196 -0
- adaptivepy/io/__init__.py +27 -0
- adaptivepy/io/loader.py +267 -0
- adaptivepy/io/trajectory.py +151 -0
- adaptivepy/models.py +83 -0
- adaptivepy/output/__init__.py +23 -0
- adaptivepy/output/pdb_writer.py +59 -0
- adaptivepy/output/writer.py +229 -0
- adaptivepy/policies/__init__.py +21 -0
- adaptivepy/policies/base.py +105 -0
- adaptivepy/policies/least_counts.py +43 -0
- adaptivepy/policies/random.py +53 -0
- adaptivepy/selection/__init__.py +5 -0
- adaptivepy/selection/frame_selector.py +132 -0
- adaptivepy/stats/__init__.py +15 -0
- adaptivepy/stats/cluster_stats.py +118 -0
- adaptivepy/utils/__init__.py +6 -0
- adaptivepy/utils/io_utils.py +49 -0
- adaptivepy/utils/logging.py +55 -0
- adaptivepy_sampling-0.1.0.dist-info/METADATA +52 -0
- adaptivepy_sampling-0.1.0.dist-info/RECORD +34 -0
- adaptivepy_sampling-0.1.0.dist-info/WHEEL +5 -0
- adaptivepy_sampling-0.1.0.dist-info/entry_points.txt +2 -0
- adaptivepy_sampling-0.1.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,151 @@
|
|
|
1
|
+
"""Coordinate trajectory loading and frame extraction via mdtraj."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
import mdtraj as md
|
|
10
|
+
|
|
11
|
+
from adaptivepy.io.loader import list_trajectory_files
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def build_trajectory_map(
|
|
17
|
+
trajectories_dir: Path,
|
|
18
|
+
traj_names: List[str],
|
|
19
|
+
) -> Dict[int, Path]:
|
|
20
|
+
"""Map trajectory IDs to coordinate file paths by matching stems.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
trajectories_dir : Path
|
|
25
|
+
Directory containing coordinate trajectories.
|
|
26
|
+
traj_names : list of str
|
|
27
|
+
Basenames corresponding to feature files (e.g. ``traj_0``).
|
|
28
|
+
|
|
29
|
+
Returns
|
|
30
|
+
-------
|
|
31
|
+
dict
|
|
32
|
+
Mapping from ``traj_id`` to trajectory file path.
|
|
33
|
+
|
|
34
|
+
Raises
|
|
35
|
+
------
|
|
36
|
+
ValueError
|
|
37
|
+
If a trajectory file cannot be found for any ``traj_name``.
|
|
38
|
+
"""
|
|
39
|
+
traj_files = list_trajectory_files(trajectories_dir)
|
|
40
|
+
stem_to_path = {path.stem: path for path in traj_files}
|
|
41
|
+
|
|
42
|
+
mapping: Dict[int, Path] = {}
|
|
43
|
+
for traj_id, name in enumerate(traj_names):
|
|
44
|
+
if name not in stem_to_path:
|
|
45
|
+
raise ValueError(
|
|
46
|
+
f"No trajectory file found matching feature stem '{name}' "
|
|
47
|
+
f"in {trajectories_dir}"
|
|
48
|
+
)
|
|
49
|
+
mapping[traj_id] = stem_to_path[name]
|
|
50
|
+
return mapping
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_trajectory(topology: Path, trajectory_path: Path) -> md.Trajectory:
|
|
54
|
+
"""Load a single trajectory using mdtraj.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
topology : Path
|
|
59
|
+
Topology file (PDB, parm7, etc.).
|
|
60
|
+
trajectory_path : Path
|
|
61
|
+
Coordinate trajectory file.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
mdtraj.Trajectory
|
|
66
|
+
Loaded trajectory object.
|
|
67
|
+
"""
|
|
68
|
+
logger.info("Loading trajectory %s with topology %s", trajectory_path, topology)
|
|
69
|
+
return md.load(str(trajectory_path), top=str(topology))
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def extract_frame(
|
|
73
|
+
topology: Path,
|
|
74
|
+
trajectory_path: Path,
|
|
75
|
+
frame_id: int,
|
|
76
|
+
) -> md.Trajectory:
|
|
77
|
+
"""Load a trajectory and return a single-frame subset.
|
|
78
|
+
|
|
79
|
+
Parameters
|
|
80
|
+
----------
|
|
81
|
+
topology : Path
|
|
82
|
+
Topology file path.
|
|
83
|
+
trajectory_path : Path
|
|
84
|
+
Coordinate trajectory file path.
|
|
85
|
+
frame_id : int
|
|
86
|
+
Zero-based frame index to extract.
|
|
87
|
+
|
|
88
|
+
Returns
|
|
89
|
+
-------
|
|
90
|
+
mdtraj.Trajectory
|
|
91
|
+
Single-frame trajectory suitable for PDB export.
|
|
92
|
+
"""
|
|
93
|
+
traj = load_trajectory(topology, trajectory_path)
|
|
94
|
+
if frame_id < 0 or frame_id >= traj.n_frames:
|
|
95
|
+
raise IndexError(
|
|
96
|
+
f"frame_id {frame_id} out of range for trajectory with "
|
|
97
|
+
f"{traj.n_frames} frames ({trajectory_path})"
|
|
98
|
+
)
|
|
99
|
+
return traj[frame_id]
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def get_trajectory_frame_count(topology: Path, trajectory_path: Path) -> int:
|
|
103
|
+
"""Return the number of frames in a trajectory without loading all coordinates.
|
|
104
|
+
|
|
105
|
+
Parameters
|
|
106
|
+
----------
|
|
107
|
+
topology : Path
|
|
108
|
+
Topology file path.
|
|
109
|
+
trajectory_path : Path
|
|
110
|
+
Coordinate trajectory file path.
|
|
111
|
+
|
|
112
|
+
Returns
|
|
113
|
+
-------
|
|
114
|
+
int
|
|
115
|
+
Number of frames in the trajectory.
|
|
116
|
+
"""
|
|
117
|
+
traj = md.load(str(trajectory_path), top=str(topology))
|
|
118
|
+
return traj.n_frames
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def validate_trajectory_frame_counts(
|
|
122
|
+
topology: Path,
|
|
123
|
+
trajectory_map: Dict[int, Path],
|
|
124
|
+
expected_counts: Dict[int, int],
|
|
125
|
+
) -> None:
|
|
126
|
+
"""Verify trajectory frame counts match feature frame counts.
|
|
127
|
+
|
|
128
|
+
Parameters
|
|
129
|
+
----------
|
|
130
|
+
topology : Path
|
|
131
|
+
Topology file path.
|
|
132
|
+
trajectory_map : dict
|
|
133
|
+
Mapping from ``traj_id`` to trajectory file.
|
|
134
|
+
expected_counts : dict
|
|
135
|
+
Expected frame count per ``traj_id`` from features.
|
|
136
|
+
|
|
137
|
+
Raises
|
|
138
|
+
------
|
|
139
|
+
ValueError
|
|
140
|
+
If any trajectory has a different number of frames than its features.
|
|
141
|
+
"""
|
|
142
|
+
for traj_id, traj_path in trajectory_map.items():
|
|
143
|
+
n_traj_frames = get_trajectory_frame_count(topology, traj_path)
|
|
144
|
+
n_feature_frames = expected_counts.get(traj_id)
|
|
145
|
+
if n_feature_frames is None:
|
|
146
|
+
continue
|
|
147
|
+
if n_traj_frames != n_feature_frames:
|
|
148
|
+
raise ValueError(
|
|
149
|
+
f"Frame count mismatch for traj_id {traj_id} ({traj_path.name}): "
|
|
150
|
+
f"trajectory has {n_traj_frames}, features have {n_feature_frames}"
|
|
151
|
+
)
|
adaptivepy/models.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
1
|
+
"""Core data models for AdaptivePy."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from dataclasses import dataclass, field
|
|
6
|
+
from typing import Dict, List, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
@dataclass
|
|
12
|
+
class FrameRecord:
|
|
13
|
+
"""A single frame tracked through the adaptive sampling pipeline.
|
|
14
|
+
|
|
15
|
+
Attributes
|
|
16
|
+
----------
|
|
17
|
+
traj_id : int
|
|
18
|
+
Index of the source trajectory.
|
|
19
|
+
frame_id : int
|
|
20
|
+
Frame index within the source trajectory.
|
|
21
|
+
features : np.ndarray
|
|
22
|
+
Feature vector for this frame, shape ``(n_features,)``.
|
|
23
|
+
cluster_id : int or None
|
|
24
|
+
Assigned cluster label after clustering.
|
|
25
|
+
global_index : int or None
|
|
26
|
+
Row index in the concatenated feature matrix.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
traj_id: int
|
|
30
|
+
frame_id: int
|
|
31
|
+
features: np.ndarray
|
|
32
|
+
cluster_id: Optional[int] = None
|
|
33
|
+
global_index: Optional[int] = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclass
|
|
37
|
+
class Dataset:
|
|
38
|
+
"""Internal representation of loaded trajectory features.
|
|
39
|
+
|
|
40
|
+
Attributes
|
|
41
|
+
----------
|
|
42
|
+
frames : list of FrameRecord
|
|
43
|
+
One record per frame across all trajectories.
|
|
44
|
+
feature_matrix : np.ndarray
|
|
45
|
+
Concatenated features, shape ``(n_total_frames, n_features)``.
|
|
46
|
+
traj_index_map : dict
|
|
47
|
+
Maps ``traj_id`` to ``(start_index, end_index)`` in ``feature_matrix``.
|
|
48
|
+
traj_names : list of str
|
|
49
|
+
Basenames of feature files (without extension), e.g. ``traj_0``.
|
|
50
|
+
"""
|
|
51
|
+
|
|
52
|
+
frames: List[FrameRecord] = field(default_factory=list)
|
|
53
|
+
feature_matrix: Optional[np.ndarray] = None
|
|
54
|
+
traj_index_map: Dict[int, tuple[int, int]] = field(default_factory=dict)
|
|
55
|
+
traj_names: List[str] = field(default_factory=list)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class SeedResult:
|
|
60
|
+
"""A selected seed frame produced by a policy.
|
|
61
|
+
|
|
62
|
+
Attributes
|
|
63
|
+
----------
|
|
64
|
+
seed_id : int
|
|
65
|
+
Sequential identifier within a policy run.
|
|
66
|
+
policy : str
|
|
67
|
+
Name of the policy that selected this seed.
|
|
68
|
+
traj_id : int
|
|
69
|
+
Source trajectory index.
|
|
70
|
+
frame_id : int
|
|
71
|
+
Frame index within the source trajectory.
|
|
72
|
+
cluster_id : int
|
|
73
|
+
Cluster from which the seed was drawn.
|
|
74
|
+
global_index : int
|
|
75
|
+
Row index in the concatenated feature matrix.
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
seed_id: int
|
|
79
|
+
policy: str
|
|
80
|
+
traj_id: int
|
|
81
|
+
frame_id: int
|
|
82
|
+
cluster_id: int
|
|
83
|
+
global_index: int
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
"""Output writers for AdaptivePy."""
|
|
2
|
+
|
|
3
|
+
from adaptivepy.output.pdb_writer import write_seed_pdbs
|
|
4
|
+
from adaptivepy.output.writer import (
|
|
5
|
+
write_assignments,
|
|
6
|
+
write_cluster_model,
|
|
7
|
+
write_cluster_statistics,
|
|
8
|
+
write_combined_metadata,
|
|
9
|
+
write_policy_outputs,
|
|
10
|
+
write_run_config,
|
|
11
|
+
write_seeds_csv,
|
|
12
|
+
)
|
|
13
|
+
|
|
14
|
+
__all__ = [
|
|
15
|
+
"write_assignments",
|
|
16
|
+
"write_cluster_model",
|
|
17
|
+
"write_cluster_statistics",
|
|
18
|
+
"write_combined_metadata",
|
|
19
|
+
"write_policy_outputs",
|
|
20
|
+
"write_run_config",
|
|
21
|
+
"write_seed_pdbs",
|
|
22
|
+
"write_seeds_csv",
|
|
23
|
+
]
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""PDB export for selected seed frames using mdtraj."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, List
|
|
8
|
+
|
|
9
|
+
from adaptivepy.io.trajectory import extract_frame
|
|
10
|
+
from adaptivepy.models import SeedResult
|
|
11
|
+
from adaptivepy.utils.io_utils import ensure_dir
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def write_seed_pdbs(
|
|
17
|
+
seeds: List[SeedResult],
|
|
18
|
+
topology: Path,
|
|
19
|
+
trajectory_map: Dict[int, Path],
|
|
20
|
+
output_dir: Path,
|
|
21
|
+
) -> List[Path]:
|
|
22
|
+
"""Extract coordinate frames and save each seed as a PDB file.
|
|
23
|
+
|
|
24
|
+
Parameters
|
|
25
|
+
----------
|
|
26
|
+
seeds : list of SeedResult
|
|
27
|
+
Selected seed frames.
|
|
28
|
+
topology : Path
|
|
29
|
+
Topology file for mdtraj loading.
|
|
30
|
+
trajectory_map : dict
|
|
31
|
+
Mapping from ``traj_id`` to trajectory file path.
|
|
32
|
+
output_dir : Path
|
|
33
|
+
Directory where PDB files are written (typically ``pdbs/``).
|
|
34
|
+
|
|
35
|
+
Returns
|
|
36
|
+
-------
|
|
37
|
+
list of Path
|
|
38
|
+
Paths to written PDB files.
|
|
39
|
+
"""
|
|
40
|
+
pdb_dir = ensure_dir(output_dir / "pdbs")
|
|
41
|
+
written: List[Path] = []
|
|
42
|
+
|
|
43
|
+
for seed in seeds:
|
|
44
|
+
traj_path = trajectory_map[seed.traj_id]
|
|
45
|
+
frame = extract_frame(topology, traj_path, seed.frame_id)
|
|
46
|
+
pdb_path = pdb_dir / (
|
|
47
|
+
f"seed_{seed.seed_id}_traj{seed.traj_id}_frame{seed.frame_id}.pdb"
|
|
48
|
+
)
|
|
49
|
+
frame.save_pdb(str(pdb_path))
|
|
50
|
+
written.append(pdb_path)
|
|
51
|
+
logger.info(
|
|
52
|
+
"Wrote PDB for seed %d: traj=%d frame=%d -> %s",
|
|
53
|
+
seed.seed_id,
|
|
54
|
+
seed.traj_id,
|
|
55
|
+
seed.frame_id,
|
|
56
|
+
pdb_path.name,
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
return written
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""Write run outputs: CSV metadata, numpy arrays, and serialized models."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import csv
|
|
6
|
+
import logging
|
|
7
|
+
from pathlib import Path
|
|
8
|
+
from typing import Any, Dict, Iterable, List, Optional
|
|
9
|
+
|
|
10
|
+
import joblib
|
|
11
|
+
import numpy as np
|
|
12
|
+
import yaml
|
|
13
|
+
|
|
14
|
+
from adaptivepy.config.schema import RunConfig, config_to_dict
|
|
15
|
+
from adaptivepy.models import SeedResult
|
|
16
|
+
from adaptivepy.stats.cluster_stats import ClusterStats, cluster_stats_to_rows
|
|
17
|
+
from adaptivepy.utils.io_utils import copy_file, ensure_dir
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def write_run_config(config: RunConfig, output_dir: Path, source_path: Path) -> Path:
|
|
23
|
+
"""Save a copy of the run configuration to the output directory.
|
|
24
|
+
|
|
25
|
+
Parameters
|
|
26
|
+
----------
|
|
27
|
+
config : RunConfig
|
|
28
|
+
Parsed configuration object.
|
|
29
|
+
output_dir : Path
|
|
30
|
+
Run output directory.
|
|
31
|
+
source_path : Path
|
|
32
|
+
Original YAML file path (copied verbatim when available).
|
|
33
|
+
|
|
34
|
+
Returns
|
|
35
|
+
-------
|
|
36
|
+
Path
|
|
37
|
+
Path to the saved configuration file.
|
|
38
|
+
"""
|
|
39
|
+
dst = ensure_dir(output_dir) / "run_config.yaml"
|
|
40
|
+
if source_path.is_file():
|
|
41
|
+
copy_file(source_path, dst)
|
|
42
|
+
else:
|
|
43
|
+
with dst.open("w", encoding="utf-8") as handle:
|
|
44
|
+
yaml.safe_dump(config_to_dict(config), handle, sort_keys=False)
|
|
45
|
+
return dst
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def write_assignments(assignments: np.ndarray, output_dir: Path) -> Path:
|
|
49
|
+
"""Save per-frame cluster assignments as a numpy array.
|
|
50
|
+
|
|
51
|
+
Parameters
|
|
52
|
+
----------
|
|
53
|
+
assignments : np.ndarray
|
|
54
|
+
Cluster label per frame.
|
|
55
|
+
output_dir : Path
|
|
56
|
+
Run output directory.
|
|
57
|
+
|
|
58
|
+
Returns
|
|
59
|
+
-------
|
|
60
|
+
Path
|
|
61
|
+
Path to ``assignments.npy``.
|
|
62
|
+
"""
|
|
63
|
+
path = ensure_dir(output_dir) / "assignments.npy"
|
|
64
|
+
np.save(path, assignments)
|
|
65
|
+
return path
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def write_cluster_model(model: Any, output_dir: Path) -> Path:
|
|
69
|
+
"""Serialize the fitted clustering model with joblib.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
model : object
|
|
74
|
+
Fitted clustering model.
|
|
75
|
+
output_dir : Path
|
|
76
|
+
Run output directory.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
Path
|
|
81
|
+
Path to ``cluster_model.pkl``.
|
|
82
|
+
"""
|
|
83
|
+
path = ensure_dir(output_dir) / "cluster_model.pkl"
|
|
84
|
+
joblib.dump(model, path)
|
|
85
|
+
return path
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def write_cluster_statistics(
|
|
89
|
+
cluster_stats: ClusterStats,
|
|
90
|
+
output_dir: Path,
|
|
91
|
+
) -> Path:
|
|
92
|
+
"""Write cluster population statistics to ``metadata.csv``.
|
|
93
|
+
|
|
94
|
+
Parameters
|
|
95
|
+
----------
|
|
96
|
+
cluster_stats : dict
|
|
97
|
+
Per-cluster statistics.
|
|
98
|
+
output_dir : Path
|
|
99
|
+
Run output directory.
|
|
100
|
+
|
|
101
|
+
Returns
|
|
102
|
+
-------
|
|
103
|
+
Path
|
|
104
|
+
Path to ``metadata.csv``.
|
|
105
|
+
"""
|
|
106
|
+
path = ensure_dir(output_dir) / "metadata.csv"
|
|
107
|
+
rows = cluster_stats_to_rows(cluster_stats)
|
|
108
|
+
with path.open("w", newline="", encoding="utf-8") as handle:
|
|
109
|
+
writer = csv.DictWriter(handle, fieldnames=["cluster_id", "population"])
|
|
110
|
+
writer.writeheader()
|
|
111
|
+
writer.writerows(rows)
|
|
112
|
+
return path
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def write_seeds_csv(seeds: Iterable[SeedResult], output_dir: Path) -> Path:
|
|
116
|
+
"""Write selected seeds to ``seeds.csv``.
|
|
117
|
+
|
|
118
|
+
Parameters
|
|
119
|
+
----------
|
|
120
|
+
seeds : iterable of SeedResult
|
|
121
|
+
Seed records for one policy.
|
|
122
|
+
output_dir : Path
|
|
123
|
+
Policy-specific output directory.
|
|
124
|
+
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
Path
|
|
128
|
+
Path to ``seeds.csv``.
|
|
129
|
+
"""
|
|
130
|
+
path = ensure_dir(output_dir) / "seeds.csv"
|
|
131
|
+
fieldnames = [
|
|
132
|
+
"seed_id",
|
|
133
|
+
"policy",
|
|
134
|
+
"traj_id",
|
|
135
|
+
"frame_id",
|
|
136
|
+
"cluster_id",
|
|
137
|
+
"global_index",
|
|
138
|
+
]
|
|
139
|
+
with path.open("w", newline="", encoding="utf-8") as handle:
|
|
140
|
+
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
|
141
|
+
writer.writeheader()
|
|
142
|
+
for seed in seeds:
|
|
143
|
+
writer.writerow(
|
|
144
|
+
{
|
|
145
|
+
"seed_id": seed.seed_id,
|
|
146
|
+
"policy": seed.policy,
|
|
147
|
+
"traj_id": seed.traj_id,
|
|
148
|
+
"frame_id": seed.frame_id,
|
|
149
|
+
"cluster_id": seed.cluster_id,
|
|
150
|
+
"global_index": seed.global_index,
|
|
151
|
+
}
|
|
152
|
+
)
|
|
153
|
+
return path
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
def write_combined_metadata(
|
|
157
|
+
policy_seeds: Dict[str, List[SeedResult]],
|
|
158
|
+
output_dir: Path,
|
|
159
|
+
) -> Path:
|
|
160
|
+
"""Write a combined seed table across all policies.
|
|
161
|
+
|
|
162
|
+
Parameters
|
|
163
|
+
----------
|
|
164
|
+
policy_seeds : dict
|
|
165
|
+
Mapping from policy name to seed lists.
|
|
166
|
+
output_dir : Path
|
|
167
|
+
Top-level results directory.
|
|
168
|
+
|
|
169
|
+
Returns
|
|
170
|
+
-------
|
|
171
|
+
Path
|
|
172
|
+
Path to ``combined_metadata.csv``.
|
|
173
|
+
"""
|
|
174
|
+
path = ensure_dir(output_dir) / "combined_metadata.csv"
|
|
175
|
+
fieldnames = [
|
|
176
|
+
"seed_id",
|
|
177
|
+
"policy",
|
|
178
|
+
"traj_id",
|
|
179
|
+
"frame_id",
|
|
180
|
+
"cluster_id",
|
|
181
|
+
"global_index",
|
|
182
|
+
]
|
|
183
|
+
with path.open("w", newline="", encoding="utf-8") as handle:
|
|
184
|
+
writer = csv.DictWriter(handle, fieldnames=fieldnames)
|
|
185
|
+
writer.writeheader()
|
|
186
|
+
for _policy, seeds in policy_seeds.items():
|
|
187
|
+
for seed in seeds:
|
|
188
|
+
writer.writerow(
|
|
189
|
+
{
|
|
190
|
+
"seed_id": seed.seed_id,
|
|
191
|
+
"policy": seed.policy,
|
|
192
|
+
"traj_id": seed.traj_id,
|
|
193
|
+
"frame_id": seed.frame_id,
|
|
194
|
+
"cluster_id": seed.cluster_id,
|
|
195
|
+
"global_index": seed.global_index,
|
|
196
|
+
}
|
|
197
|
+
)
|
|
198
|
+
logger.info("Wrote combined metadata to %s", path)
|
|
199
|
+
return path
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
def write_policy_outputs(
|
|
203
|
+
policy_name: str,
|
|
204
|
+
seeds: List[SeedResult],
|
|
205
|
+
cluster_stats: ClusterStats,
|
|
206
|
+
results_dir: Path,
|
|
207
|
+
) -> Path:
|
|
208
|
+
"""Write all outputs for a single policy into its subdirectory.
|
|
209
|
+
|
|
210
|
+
Parameters
|
|
211
|
+
----------
|
|
212
|
+
policy_name : str
|
|
213
|
+
Policy identifier used as subdirectory name.
|
|
214
|
+
seeds : list of SeedResult
|
|
215
|
+
Seeds selected by the policy.
|
|
216
|
+
cluster_stats : dict
|
|
217
|
+
Global cluster statistics (same for all policies).
|
|
218
|
+
results_dir : Path
|
|
219
|
+
Top-level results directory.
|
|
220
|
+
|
|
221
|
+
Returns
|
|
222
|
+
-------
|
|
223
|
+
Path
|
|
224
|
+
Policy output directory path.
|
|
225
|
+
"""
|
|
226
|
+
policy_dir = ensure_dir(results_dir / policy_name)
|
|
227
|
+
write_seeds_csv(seeds, policy_dir)
|
|
228
|
+
write_cluster_statistics(cluster_stats, policy_dir)
|
|
229
|
+
return policy_dir
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
"""Adaptive sampling policies for AdaptivePy."""
|
|
2
|
+
|
|
3
|
+
from adaptivepy.policies.base import (
|
|
4
|
+
POLICY_REGISTRY,
|
|
5
|
+
Policy,
|
|
6
|
+
get_policy,
|
|
7
|
+
list_policies,
|
|
8
|
+
register_policy,
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
# Import concrete policies so they self-register.
|
|
12
|
+
from adaptivepy.policies import least_counts # noqa: F401
|
|
13
|
+
from adaptivepy.policies import random # noqa: F401
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
"POLICY_REGISTRY",
|
|
17
|
+
"Policy",
|
|
18
|
+
"get_policy",
|
|
19
|
+
"list_policies",
|
|
20
|
+
"register_policy",
|
|
21
|
+
]
|
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
"""Adaptive sampling policy base class and registry."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Dict, List, Type
|
|
7
|
+
|
|
8
|
+
from adaptivepy.stats.cluster_stats import ClusterStats
|
|
9
|
+
|
|
10
|
+
POLICY_REGISTRY: Dict[str, Type["Policy"]] = {}
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def register_policy(cls: Type["Policy"]) -> Type["Policy"]:
|
|
14
|
+
"""Register a policy class in :data:`POLICY_REGISTRY`.
|
|
15
|
+
|
|
16
|
+
Parameters
|
|
17
|
+
----------
|
|
18
|
+
cls : type
|
|
19
|
+
Policy subclass with a ``name`` class attribute.
|
|
20
|
+
|
|
21
|
+
Returns
|
|
22
|
+
-------
|
|
23
|
+
type
|
|
24
|
+
The registered policy class (unchanged).
|
|
25
|
+
|
|
26
|
+
Raises
|
|
27
|
+
------
|
|
28
|
+
ValueError
|
|
29
|
+
If the policy name is missing or already registered.
|
|
30
|
+
"""
|
|
31
|
+
if not getattr(cls, "name", None):
|
|
32
|
+
raise ValueError(f"Policy {cls.__name__} must define a 'name' attribute.")
|
|
33
|
+
if cls.name in POLICY_REGISTRY:
|
|
34
|
+
raise ValueError(f"Policy '{cls.name}' is already registered.")
|
|
35
|
+
POLICY_REGISTRY[cls.name] = cls
|
|
36
|
+
return cls
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class Policy(ABC):
|
|
40
|
+
"""Base class for cluster selection policies.
|
|
41
|
+
|
|
42
|
+
Subclasses implement :meth:`select_clusters` to choose which clusters
|
|
43
|
+
should contribute seed frames.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
name: str = ""
|
|
47
|
+
|
|
48
|
+
@abstractmethod
|
|
49
|
+
def select_clusters(
|
|
50
|
+
self,
|
|
51
|
+
cluster_stats: ClusterStats,
|
|
52
|
+
n_seeds: int,
|
|
53
|
+
) -> List[int]:
|
|
54
|
+
"""Select cluster IDs from which to draw seed frames.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
cluster_stats : dict
|
|
59
|
+
Per-cluster population and frame lists.
|
|
60
|
+
n_seeds : int
|
|
61
|
+
Maximum number of clusters (seeds) to select.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
list of int
|
|
66
|
+
Selected cluster IDs.
|
|
67
|
+
"""
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
|
|
71
|
+
def get_policy(name: str, **kwargs) -> Policy:
|
|
72
|
+
"""Instantiate a registered policy by name.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
name : str
|
|
77
|
+
Registered policy name.
|
|
78
|
+
**kwargs
|
|
79
|
+
Constructor arguments forwarded to the policy class.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
Policy
|
|
84
|
+
Policy instance.
|
|
85
|
+
|
|
86
|
+
Raises
|
|
87
|
+
------
|
|
88
|
+
ValueError
|
|
89
|
+
If the policy name is unknown.
|
|
90
|
+
"""
|
|
91
|
+
if name not in POLICY_REGISTRY:
|
|
92
|
+
available = ", ".join(sorted(POLICY_REGISTRY))
|
|
93
|
+
raise ValueError(f"Unknown policy '{name}'. Available: {available}")
|
|
94
|
+
return POLICY_REGISTRY[name](**kwargs)
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
def list_policies() -> List[str]:
|
|
98
|
+
"""Return names of all registered policies.
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
list of str
|
|
103
|
+
Sorted policy names.
|
|
104
|
+
"""
|
|
105
|
+
return sorted(POLICY_REGISTRY.keys())
|