adaptivepy-sampling 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,151 @@
1
+ """Coordinate trajectory loading and frame extraction via mdtraj."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Dict, List, Optional
8
+
9
+ import mdtraj as md
10
+
11
+ from adaptivepy.io.loader import list_trajectory_files
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def build_trajectory_map(
17
+ trajectories_dir: Path,
18
+ traj_names: List[str],
19
+ ) -> Dict[int, Path]:
20
+ """Map trajectory IDs to coordinate file paths by matching stems.
21
+
22
+ Parameters
23
+ ----------
24
+ trajectories_dir : Path
25
+ Directory containing coordinate trajectories.
26
+ traj_names : list of str
27
+ Basenames corresponding to feature files (e.g. ``traj_0``).
28
+
29
+ Returns
30
+ -------
31
+ dict
32
+ Mapping from ``traj_id`` to trajectory file path.
33
+
34
+ Raises
35
+ ------
36
+ ValueError
37
+ If a trajectory file cannot be found for any ``traj_name``.
38
+ """
39
+ traj_files = list_trajectory_files(trajectories_dir)
40
+ stem_to_path = {path.stem: path for path in traj_files}
41
+
42
+ mapping: Dict[int, Path] = {}
43
+ for traj_id, name in enumerate(traj_names):
44
+ if name not in stem_to_path:
45
+ raise ValueError(
46
+ f"No trajectory file found matching feature stem '{name}' "
47
+ f"in {trajectories_dir}"
48
+ )
49
+ mapping[traj_id] = stem_to_path[name]
50
+ return mapping
51
+
52
+
53
+ def load_trajectory(topology: Path, trajectory_path: Path) -> md.Trajectory:
54
+ """Load a single trajectory using mdtraj.
55
+
56
+ Parameters
57
+ ----------
58
+ topology : Path
59
+ Topology file (PDB, parm7, etc.).
60
+ trajectory_path : Path
61
+ Coordinate trajectory file.
62
+
63
+ Returns
64
+ -------
65
+ mdtraj.Trajectory
66
+ Loaded trajectory object.
67
+ """
68
+ logger.info("Loading trajectory %s with topology %s", trajectory_path, topology)
69
+ return md.load(str(trajectory_path), top=str(topology))
70
+
71
+
72
+ def extract_frame(
73
+ topology: Path,
74
+ trajectory_path: Path,
75
+ frame_id: int,
76
+ ) -> md.Trajectory:
77
+ """Load a trajectory and return a single-frame subset.
78
+
79
+ Parameters
80
+ ----------
81
+ topology : Path
82
+ Topology file path.
83
+ trajectory_path : Path
84
+ Coordinate trajectory file path.
85
+ frame_id : int
86
+ Zero-based frame index to extract.
87
+
88
+ Returns
89
+ -------
90
+ mdtraj.Trajectory
91
+ Single-frame trajectory suitable for PDB export.
92
+ """
93
+ traj = load_trajectory(topology, trajectory_path)
94
+ if frame_id < 0 or frame_id >= traj.n_frames:
95
+ raise IndexError(
96
+ f"frame_id {frame_id} out of range for trajectory with "
97
+ f"{traj.n_frames} frames ({trajectory_path})"
98
+ )
99
+ return traj[frame_id]
100
+
101
+
102
+ def get_trajectory_frame_count(topology: Path, trajectory_path: Path) -> int:
103
+ """Return the number of frames in a trajectory without loading all coordinates.
104
+
105
+ Parameters
106
+ ----------
107
+ topology : Path
108
+ Topology file path.
109
+ trajectory_path : Path
110
+ Coordinate trajectory file path.
111
+
112
+ Returns
113
+ -------
114
+ int
115
+ Number of frames in the trajectory.
116
+ """
117
+ traj = md.load(str(trajectory_path), top=str(topology))
118
+ return traj.n_frames
119
+
120
+
121
+ def validate_trajectory_frame_counts(
122
+ topology: Path,
123
+ trajectory_map: Dict[int, Path],
124
+ expected_counts: Dict[int, int],
125
+ ) -> None:
126
+ """Verify trajectory frame counts match feature frame counts.
127
+
128
+ Parameters
129
+ ----------
130
+ topology : Path
131
+ Topology file path.
132
+ trajectory_map : dict
133
+ Mapping from ``traj_id`` to trajectory file.
134
+ expected_counts : dict
135
+ Expected frame count per ``traj_id`` from features.
136
+
137
+ Raises
138
+ ------
139
+ ValueError
140
+ If any trajectory has a different number of frames than its features.
141
+ """
142
+ for traj_id, traj_path in trajectory_map.items():
143
+ n_traj_frames = get_trajectory_frame_count(topology, traj_path)
144
+ n_feature_frames = expected_counts.get(traj_id)
145
+ if n_feature_frames is None:
146
+ continue
147
+ if n_traj_frames != n_feature_frames:
148
+ raise ValueError(
149
+ f"Frame count mismatch for traj_id {traj_id} ({traj_path.name}): "
150
+ f"trajectory has {n_traj_frames}, features have {n_feature_frames}"
151
+ )
adaptivepy/models.py ADDED
@@ -0,0 +1,83 @@
1
+ """Core data models for AdaptivePy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass, field
6
+ from typing import Dict, List, Optional
7
+
8
+ import numpy as np
9
+
10
+
11
+ @dataclass
12
+ class FrameRecord:
13
+ """A single frame tracked through the adaptive sampling pipeline.
14
+
15
+ Attributes
16
+ ----------
17
+ traj_id : int
18
+ Index of the source trajectory.
19
+ frame_id : int
20
+ Frame index within the source trajectory.
21
+ features : np.ndarray
22
+ Feature vector for this frame, shape ``(n_features,)``.
23
+ cluster_id : int or None
24
+ Assigned cluster label after clustering.
25
+ global_index : int or None
26
+ Row index in the concatenated feature matrix.
27
+ """
28
+
29
+ traj_id: int
30
+ frame_id: int
31
+ features: np.ndarray
32
+ cluster_id: Optional[int] = None
33
+ global_index: Optional[int] = None
34
+
35
+
36
+ @dataclass
37
+ class Dataset:
38
+ """Internal representation of loaded trajectory features.
39
+
40
+ Attributes
41
+ ----------
42
+ frames : list of FrameRecord
43
+ One record per frame across all trajectories.
44
+ feature_matrix : np.ndarray
45
+ Concatenated features, shape ``(n_total_frames, n_features)``.
46
+ traj_index_map : dict
47
+ Maps ``traj_id`` to ``(start_index, end_index)`` in ``feature_matrix``.
48
+ traj_names : list of str
49
+ Basenames of feature files (without extension), e.g. ``traj_0``.
50
+ """
51
+
52
+ frames: List[FrameRecord] = field(default_factory=list)
53
+ feature_matrix: Optional[np.ndarray] = None
54
+ traj_index_map: Dict[int, tuple[int, int]] = field(default_factory=dict)
55
+ traj_names: List[str] = field(default_factory=list)
56
+
57
+
58
+ @dataclass
59
+ class SeedResult:
60
+ """A selected seed frame produced by a policy.
61
+
62
+ Attributes
63
+ ----------
64
+ seed_id : int
65
+ Sequential identifier within a policy run.
66
+ policy : str
67
+ Name of the policy that selected this seed.
68
+ traj_id : int
69
+ Source trajectory index.
70
+ frame_id : int
71
+ Frame index within the source trajectory.
72
+ cluster_id : int
73
+ Cluster from which the seed was drawn.
74
+ global_index : int
75
+ Row index in the concatenated feature matrix.
76
+ """
77
+
78
+ seed_id: int
79
+ policy: str
80
+ traj_id: int
81
+ frame_id: int
82
+ cluster_id: int
83
+ global_index: int
@@ -0,0 +1,23 @@
1
+ """Output writers for AdaptivePy."""
2
+
3
+ from adaptivepy.output.pdb_writer import write_seed_pdbs
4
+ from adaptivepy.output.writer import (
5
+ write_assignments,
6
+ write_cluster_model,
7
+ write_cluster_statistics,
8
+ write_combined_metadata,
9
+ write_policy_outputs,
10
+ write_run_config,
11
+ write_seeds_csv,
12
+ )
13
+
14
+ __all__ = [
15
+ "write_assignments",
16
+ "write_cluster_model",
17
+ "write_cluster_statistics",
18
+ "write_combined_metadata",
19
+ "write_policy_outputs",
20
+ "write_run_config",
21
+ "write_seed_pdbs",
22
+ "write_seeds_csv",
23
+ ]
@@ -0,0 +1,59 @@
1
+ """PDB export for selected seed frames using mdtraj."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ from pathlib import Path
7
+ from typing import Dict, List
8
+
9
+ from adaptivepy.io.trajectory import extract_frame
10
+ from adaptivepy.models import SeedResult
11
+ from adaptivepy.utils.io_utils import ensure_dir
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ def write_seed_pdbs(
17
+ seeds: List[SeedResult],
18
+ topology: Path,
19
+ trajectory_map: Dict[int, Path],
20
+ output_dir: Path,
21
+ ) -> List[Path]:
22
+ """Extract coordinate frames and save each seed as a PDB file.
23
+
24
+ Parameters
25
+ ----------
26
+ seeds : list of SeedResult
27
+ Selected seed frames.
28
+ topology : Path
29
+ Topology file for mdtraj loading.
30
+ trajectory_map : dict
31
+ Mapping from ``traj_id`` to trajectory file path.
32
+ output_dir : Path
33
+ Directory where PDB files are written (typically ``pdbs/``).
34
+
35
+ Returns
36
+ -------
37
+ list of Path
38
+ Paths to written PDB files.
39
+ """
40
+ pdb_dir = ensure_dir(output_dir / "pdbs")
41
+ written: List[Path] = []
42
+
43
+ for seed in seeds:
44
+ traj_path = trajectory_map[seed.traj_id]
45
+ frame = extract_frame(topology, traj_path, seed.frame_id)
46
+ pdb_path = pdb_dir / (
47
+ f"seed_{seed.seed_id}_traj{seed.traj_id}_frame{seed.frame_id}.pdb"
48
+ )
49
+ frame.save_pdb(str(pdb_path))
50
+ written.append(pdb_path)
51
+ logger.info(
52
+ "Wrote PDB for seed %d: traj=%d frame=%d -> %s",
53
+ seed.seed_id,
54
+ seed.traj_id,
55
+ seed.frame_id,
56
+ pdb_path.name,
57
+ )
58
+
59
+ return written
@@ -0,0 +1,229 @@
1
+ """Write run outputs: CSV metadata, numpy arrays, and serialized models."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import csv
6
+ import logging
7
+ from pathlib import Path
8
+ from typing import Any, Dict, Iterable, List, Optional
9
+
10
+ import joblib
11
+ import numpy as np
12
+ import yaml
13
+
14
+ from adaptivepy.config.schema import RunConfig, config_to_dict
15
+ from adaptivepy.models import SeedResult
16
+ from adaptivepy.stats.cluster_stats import ClusterStats, cluster_stats_to_rows
17
+ from adaptivepy.utils.io_utils import copy_file, ensure_dir
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ def write_run_config(config: RunConfig, output_dir: Path, source_path: Path) -> Path:
23
+ """Save a copy of the run configuration to the output directory.
24
+
25
+ Parameters
26
+ ----------
27
+ config : RunConfig
28
+ Parsed configuration object.
29
+ output_dir : Path
30
+ Run output directory.
31
+ source_path : Path
32
+ Original YAML file path (copied verbatim when available).
33
+
34
+ Returns
35
+ -------
36
+ Path
37
+ Path to the saved configuration file.
38
+ """
39
+ dst = ensure_dir(output_dir) / "run_config.yaml"
40
+ if source_path.is_file():
41
+ copy_file(source_path, dst)
42
+ else:
43
+ with dst.open("w", encoding="utf-8") as handle:
44
+ yaml.safe_dump(config_to_dict(config), handle, sort_keys=False)
45
+ return dst
46
+
47
+
48
+ def write_assignments(assignments: np.ndarray, output_dir: Path) -> Path:
49
+ """Save per-frame cluster assignments as a numpy array.
50
+
51
+ Parameters
52
+ ----------
53
+ assignments : np.ndarray
54
+ Cluster label per frame.
55
+ output_dir : Path
56
+ Run output directory.
57
+
58
+ Returns
59
+ -------
60
+ Path
61
+ Path to ``assignments.npy``.
62
+ """
63
+ path = ensure_dir(output_dir) / "assignments.npy"
64
+ np.save(path, assignments)
65
+ return path
66
+
67
+
68
+ def write_cluster_model(model: Any, output_dir: Path) -> Path:
69
+ """Serialize the fitted clustering model with joblib.
70
+
71
+ Parameters
72
+ ----------
73
+ model : object
74
+ Fitted clustering model.
75
+ output_dir : Path
76
+ Run output directory.
77
+
78
+ Returns
79
+ -------
80
+ Path
81
+ Path to ``cluster_model.pkl``.
82
+ """
83
+ path = ensure_dir(output_dir) / "cluster_model.pkl"
84
+ joblib.dump(model, path)
85
+ return path
86
+
87
+
88
+ def write_cluster_statistics(
89
+ cluster_stats: ClusterStats,
90
+ output_dir: Path,
91
+ ) -> Path:
92
+ """Write cluster population statistics to ``metadata.csv``.
93
+
94
+ Parameters
95
+ ----------
96
+ cluster_stats : dict
97
+ Per-cluster statistics.
98
+ output_dir : Path
99
+ Run output directory.
100
+
101
+ Returns
102
+ -------
103
+ Path
104
+ Path to ``metadata.csv``.
105
+ """
106
+ path = ensure_dir(output_dir) / "metadata.csv"
107
+ rows = cluster_stats_to_rows(cluster_stats)
108
+ with path.open("w", newline="", encoding="utf-8") as handle:
109
+ writer = csv.DictWriter(handle, fieldnames=["cluster_id", "population"])
110
+ writer.writeheader()
111
+ writer.writerows(rows)
112
+ return path
113
+
114
+
115
+ def write_seeds_csv(seeds: Iterable[SeedResult], output_dir: Path) -> Path:
116
+ """Write selected seeds to ``seeds.csv``.
117
+
118
+ Parameters
119
+ ----------
120
+ seeds : iterable of SeedResult
121
+ Seed records for one policy.
122
+ output_dir : Path
123
+ Policy-specific output directory.
124
+
125
+ Returns
126
+ -------
127
+ Path
128
+ Path to ``seeds.csv``.
129
+ """
130
+ path = ensure_dir(output_dir) / "seeds.csv"
131
+ fieldnames = [
132
+ "seed_id",
133
+ "policy",
134
+ "traj_id",
135
+ "frame_id",
136
+ "cluster_id",
137
+ "global_index",
138
+ ]
139
+ with path.open("w", newline="", encoding="utf-8") as handle:
140
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
141
+ writer.writeheader()
142
+ for seed in seeds:
143
+ writer.writerow(
144
+ {
145
+ "seed_id": seed.seed_id,
146
+ "policy": seed.policy,
147
+ "traj_id": seed.traj_id,
148
+ "frame_id": seed.frame_id,
149
+ "cluster_id": seed.cluster_id,
150
+ "global_index": seed.global_index,
151
+ }
152
+ )
153
+ return path
154
+
155
+
156
+ def write_combined_metadata(
157
+ policy_seeds: Dict[str, List[SeedResult]],
158
+ output_dir: Path,
159
+ ) -> Path:
160
+ """Write a combined seed table across all policies.
161
+
162
+ Parameters
163
+ ----------
164
+ policy_seeds : dict
165
+ Mapping from policy name to seed lists.
166
+ output_dir : Path
167
+ Top-level results directory.
168
+
169
+ Returns
170
+ -------
171
+ Path
172
+ Path to ``combined_metadata.csv``.
173
+ """
174
+ path = ensure_dir(output_dir) / "combined_metadata.csv"
175
+ fieldnames = [
176
+ "seed_id",
177
+ "policy",
178
+ "traj_id",
179
+ "frame_id",
180
+ "cluster_id",
181
+ "global_index",
182
+ ]
183
+ with path.open("w", newline="", encoding="utf-8") as handle:
184
+ writer = csv.DictWriter(handle, fieldnames=fieldnames)
185
+ writer.writeheader()
186
+ for _policy, seeds in policy_seeds.items():
187
+ for seed in seeds:
188
+ writer.writerow(
189
+ {
190
+ "seed_id": seed.seed_id,
191
+ "policy": seed.policy,
192
+ "traj_id": seed.traj_id,
193
+ "frame_id": seed.frame_id,
194
+ "cluster_id": seed.cluster_id,
195
+ "global_index": seed.global_index,
196
+ }
197
+ )
198
+ logger.info("Wrote combined metadata to %s", path)
199
+ return path
200
+
201
+
202
+ def write_policy_outputs(
203
+ policy_name: str,
204
+ seeds: List[SeedResult],
205
+ cluster_stats: ClusterStats,
206
+ results_dir: Path,
207
+ ) -> Path:
208
+ """Write all outputs for a single policy into its subdirectory.
209
+
210
+ Parameters
211
+ ----------
212
+ policy_name : str
213
+ Policy identifier used as subdirectory name.
214
+ seeds : list of SeedResult
215
+ Seeds selected by the policy.
216
+ cluster_stats : dict
217
+ Global cluster statistics (same for all policies).
218
+ results_dir : Path
219
+ Top-level results directory.
220
+
221
+ Returns
222
+ -------
223
+ Path
224
+ Policy output directory path.
225
+ """
226
+ policy_dir = ensure_dir(results_dir / policy_name)
227
+ write_seeds_csv(seeds, policy_dir)
228
+ write_cluster_statistics(cluster_stats, policy_dir)
229
+ return policy_dir
@@ -0,0 +1,21 @@
1
+ """Adaptive sampling policies for AdaptivePy."""
2
+
3
+ from adaptivepy.policies.base import (
4
+ POLICY_REGISTRY,
5
+ Policy,
6
+ get_policy,
7
+ list_policies,
8
+ register_policy,
9
+ )
10
+
11
+ # Import concrete policies so they self-register.
12
+ from adaptivepy.policies import least_counts # noqa: F401
13
+ from adaptivepy.policies import random # noqa: F401
14
+
15
+ __all__ = [
16
+ "POLICY_REGISTRY",
17
+ "Policy",
18
+ "get_policy",
19
+ "list_policies",
20
+ "register_policy",
21
+ ]
@@ -0,0 +1,105 @@
1
+ """Adaptive sampling policy base class and registry."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from abc import ABC, abstractmethod
6
+ from typing import Dict, List, Type
7
+
8
+ from adaptivepy.stats.cluster_stats import ClusterStats
9
+
10
+ POLICY_REGISTRY: Dict[str, Type["Policy"]] = {}
11
+
12
+
13
+ def register_policy(cls: Type["Policy"]) -> Type["Policy"]:
14
+ """Register a policy class in :data:`POLICY_REGISTRY`.
15
+
16
+ Parameters
17
+ ----------
18
+ cls : type
19
+ Policy subclass with a ``name`` class attribute.
20
+
21
+ Returns
22
+ -------
23
+ type
24
+ The registered policy class (unchanged).
25
+
26
+ Raises
27
+ ------
28
+ ValueError
29
+ If the policy name is missing or already registered.
30
+ """
31
+ if not getattr(cls, "name", None):
32
+ raise ValueError(f"Policy {cls.__name__} must define a 'name' attribute.")
33
+ if cls.name in POLICY_REGISTRY:
34
+ raise ValueError(f"Policy '{cls.name}' is already registered.")
35
+ POLICY_REGISTRY[cls.name] = cls
36
+ return cls
37
+
38
+
39
+ class Policy(ABC):
40
+ """Base class for cluster selection policies.
41
+
42
+ Subclasses implement :meth:`select_clusters` to choose which clusters
43
+ should contribute seed frames.
44
+ """
45
+
46
+ name: str = ""
47
+
48
+ @abstractmethod
49
+ def select_clusters(
50
+ self,
51
+ cluster_stats: ClusterStats,
52
+ n_seeds: int,
53
+ ) -> List[int]:
54
+ """Select cluster IDs from which to draw seed frames.
55
+
56
+ Parameters
57
+ ----------
58
+ cluster_stats : dict
59
+ Per-cluster population and frame lists.
60
+ n_seeds : int
61
+ Maximum number of clusters (seeds) to select.
62
+
63
+ Returns
64
+ -------
65
+ list of int
66
+ Selected cluster IDs.
67
+ """
68
+ ...
69
+
70
+
71
+ def get_policy(name: str, **kwargs) -> Policy:
72
+ """Instantiate a registered policy by name.
73
+
74
+ Parameters
75
+ ----------
76
+ name : str
77
+ Registered policy name.
78
+ **kwargs
79
+ Constructor arguments forwarded to the policy class.
80
+
81
+ Returns
82
+ -------
83
+ Policy
84
+ Policy instance.
85
+
86
+ Raises
87
+ ------
88
+ ValueError
89
+ If the policy name is unknown.
90
+ """
91
+ if name not in POLICY_REGISTRY:
92
+ available = ", ".join(sorted(POLICY_REGISTRY))
93
+ raise ValueError(f"Unknown policy '{name}'. Available: {available}")
94
+ return POLICY_REGISTRY[name](**kwargs)
95
+
96
+
97
+ def list_policies() -> List[str]:
98
+ """Return names of all registered policies.
99
+
100
+ Returns
101
+ -------
102
+ list of str
103
+ Sorted policy names.
104
+ """
105
+ return sorted(POLICY_REGISTRY.keys())