adaptivepy-sampling 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,43 @@
1
+ """Least-counts adaptive sampling policy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import List
6
+
7
+ from adaptivepy.policies.base import Policy, register_policy
8
+ from adaptivepy.stats.cluster_stats import ClusterStats, sort_clusters_by_population
9
+
10
+
11
+ @register_policy
12
+ class LeastCountsPolicy(Policy):
13
+ """Select clusters with the smallest populations.
14
+
15
+ Clusters are sorted by ascending population and the first ``n_seeds``
16
+ cluster IDs are returned (one seed per cluster).
17
+ """
18
+
19
+ name = "least_counts"
20
+
21
+ def select_clusters(
22
+ self,
23
+ cluster_stats: ClusterStats,
24
+ n_seeds: int,
25
+ ) -> List[int]:
26
+ """Select the least-populated clusters.
27
+
28
+ Parameters
29
+ ----------
30
+ cluster_stats : dict
31
+ Per-cluster statistics.
32
+ n_seeds : int
33
+ Number of clusters to select.
34
+
35
+ Returns
36
+ -------
37
+ list of int
38
+ Cluster IDs with smallest populations.
39
+ """
40
+ sorted_clusters = sort_clusters_by_population(
41
+ cluster_stats, ascending=True
42
+ )
43
+ return sorted_clusters[:n_seeds]
@@ -0,0 +1,53 @@
1
+ """Random cluster selection policy."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import List, Optional
6
+
7
+ import numpy as np
8
+
9
+ from adaptivepy.policies.base import Policy, register_policy
10
+ from adaptivepy.stats.cluster_stats import ClusterStats
11
+
12
+
13
+ @register_policy
14
+ class RandomPolicy(Policy):
15
+ """Uniformly sample cluster IDs at random.
16
+
17
+ Parameters
18
+ ----------
19
+ random_state : int or None
20
+ Seed for the random number generator.
21
+ """
22
+
23
+ name = "random"
24
+
25
+ def __init__(self, random_state: Optional[int] = None) -> None:
26
+ self.random_state = random_state
27
+ self._rng = np.random.default_rng(random_state)
28
+
29
+ def select_clusters(
30
+ self,
31
+ cluster_stats: ClusterStats,
32
+ n_seeds: int,
33
+ ) -> List[int]:
34
+ """Randomly sample ``n_seeds`` distinct cluster IDs.
35
+
36
+ Parameters
37
+ ----------
38
+ cluster_stats : dict
39
+ Per-cluster statistics.
40
+ n_seeds : int
41
+ Number of clusters to sample.
42
+
43
+ Returns
44
+ -------
45
+ list of int
46
+ Randomly selected cluster IDs.
47
+ """
48
+ cluster_ids = list(cluster_stats.keys())
49
+ n_select = min(n_seeds, len(cluster_ids))
50
+ if n_select == 0:
51
+ return []
52
+ chosen = self._rng.choice(cluster_ids, size=n_select, replace=False)
53
+ return [int(c) for c in chosen]
@@ -0,0 +1,5 @@
1
+ """Seed selection utilities."""
2
+
3
+ from adaptivepy.selection.frame_selector import select_seeds
4
+
5
+ __all__ = ["select_seeds"]
@@ -0,0 +1,132 @@
1
+ """Frame-level seed selection within chosen clusters."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import List, Optional
6
+
7
+ import numpy as np
8
+
9
+ from adaptivepy.models import FrameRecord, SeedResult
10
+ from adaptivepy.stats.cluster_stats import ClusterStats
11
+
12
+
13
+ def _nearest_center_frame(
14
+ frames: List[FrameRecord],
15
+ center: np.ndarray,
16
+ ) -> FrameRecord:
17
+ """Return the frame closest to a cluster centroid in feature space.
18
+
19
+ Parameters
20
+ ----------
21
+ frames : list of FrameRecord
22
+ Frames belonging to one cluster.
23
+ center : np.ndarray
24
+ Cluster center, shape ``(n_features,)``.
25
+
26
+ Returns
27
+ -------
28
+ FrameRecord
29
+ Frame with minimum Euclidean distance to ``center``.
30
+ """
31
+ features = np.stack([f.features for f in frames], axis=0)
32
+ dists = np.linalg.norm(features - center, axis=1)
33
+ return frames[int(np.argmin(dists))]
34
+
35
+
36
+ def _random_frame(
37
+ frames: List[FrameRecord],
38
+ rng: np.random.Generator,
39
+ ) -> FrameRecord:
40
+ """Return a uniformly random frame from a cluster.
41
+
42
+ Parameters
43
+ ----------
44
+ frames : list of FrameRecord
45
+ Frames belonging to one cluster.
46
+ rng : np.random.Generator
47
+ Random number generator.
48
+
49
+ Returns
50
+ -------
51
+ FrameRecord
52
+ Randomly selected frame.
53
+ """
54
+ index = int(rng.integers(0, len(frames)))
55
+ return frames[index]
56
+
57
+
58
+ def select_seeds(
59
+ policy_name: str,
60
+ selected_clusters: List[int],
61
+ cluster_stats: ClusterStats,
62
+ cluster_centers: Optional[np.ndarray],
63
+ method: str = "nearest_center",
64
+ random_state: Optional[int] = None,
65
+ ) -> List[SeedResult]:
66
+ """Select one seed frame from each chosen cluster.
67
+
68
+ Parameters
69
+ ----------
70
+ policy_name : str
71
+ Name of the policy that selected the clusters.
72
+ selected_clusters : list of int
73
+ Cluster IDs chosen by the policy.
74
+ cluster_stats : dict
75
+ Per-cluster frame lists and populations.
76
+ cluster_centers : np.ndarray or None
77
+ Cluster centroids, shape ``(n_clusters, n_features)``. Required for
78
+ ``nearest_center`` selection when centers are defined per label index.
79
+ method : str
80
+ Selection method: ``nearest_center`` or ``random_frame``.
81
+ random_state : int or None
82
+ Random seed for ``random_frame`` selection.
83
+
84
+ Returns
85
+ -------
86
+ list of SeedResult
87
+ Selected seed frames with metadata.
88
+
89
+ Raises
90
+ ------
91
+ ValueError
92
+ If ``method`` is unknown or centers are missing when required.
93
+ """
94
+ if method not in {"nearest_center", "random_frame"}:
95
+ raise ValueError(
96
+ f"Unknown seed selection method '{method}'. "
97
+ "Use 'nearest_center' or 'random_frame'."
98
+ )
99
+
100
+ rng = np.random.default_rng(random_state)
101
+ seeds: List[SeedResult] = []
102
+
103
+ for seed_id, cluster_id in enumerate(selected_clusters):
104
+ entry = cluster_stats.get(cluster_id)
105
+ if entry is None or not entry["frames"]:
106
+ continue
107
+
108
+ frames = entry["frames"]
109
+
110
+ if method == "random_frame":
111
+ chosen = _random_frame(frames, rng)
112
+ else:
113
+ if cluster_centers is None:
114
+ center = np.mean(np.stack([f.features for f in frames]), axis=0)
115
+ elif cluster_id < len(cluster_centers):
116
+ center = cluster_centers[cluster_id]
117
+ else:
118
+ center = np.mean(np.stack([f.features for f in frames]), axis=0)
119
+ chosen = _nearest_center_frame(frames, center)
120
+
121
+ seeds.append(
122
+ SeedResult(
123
+ seed_id=seed_id,
124
+ policy=policy_name,
125
+ traj_id=chosen.traj_id,
126
+ frame_id=chosen.frame_id,
127
+ cluster_id=cluster_id,
128
+ global_index=chosen.global_index or 0,
129
+ )
130
+ )
131
+
132
+ return seeds
@@ -0,0 +1,15 @@
1
+ """Cluster statistics for AdaptivePy."""
2
+
3
+ from adaptivepy.stats.cluster_stats import (
4
+ assign_clusters,
5
+ cluster_stats_to_rows,
6
+ compute_cluster_stats,
7
+ sort_clusters_by_population,
8
+ )
9
+
10
+ __all__ = [
11
+ "assign_clusters",
12
+ "cluster_stats_to_rows",
13
+ "compute_cluster_stats",
14
+ "sort_clusters_by_population",
15
+ ]
@@ -0,0 +1,118 @@
1
+ """Cluster population statistics and frame assignments."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from typing import Dict, List, TypedDict
6
+
7
+ import numpy as np
8
+
9
+ from adaptivepy.models import Dataset, FrameRecord
10
+
11
+
12
+ class ClusterStatEntry(TypedDict):
13
+ """Statistics for a single cluster."""
14
+
15
+ population: int
16
+ frames: List[FrameRecord]
17
+
18
+
19
+ ClusterStats = Dict[int, ClusterStatEntry]
20
+
21
+
22
+ def assign_clusters(dataset: Dataset, labels: np.ndarray) -> None:
23
+ """Attach cluster labels to frame records in a dataset.
24
+
25
+ Parameters
26
+ ----------
27
+ dataset : Dataset
28
+ Dataset whose frames will be updated in place.
29
+ labels : np.ndarray
30
+ Cluster label per frame, shape ``(n_frames,)``.
31
+
32
+ Raises
33
+ ------
34
+ ValueError
35
+ If label count does not match the number of frames.
36
+ """
37
+ if len(labels) != len(dataset.frames):
38
+ raise ValueError(
39
+ f"Expected {len(dataset.frames)} labels, got {len(labels)}."
40
+ )
41
+ for record, cluster_id in zip(dataset.frames, labels):
42
+ record.cluster_id = int(cluster_id)
43
+
44
+
45
+ def compute_cluster_stats(dataset: Dataset) -> ClusterStats:
46
+ """Compute per-cluster populations and frame lists.
47
+
48
+ Parameters
49
+ ----------
50
+ dataset : Dataset
51
+ Dataset with cluster assignments on each frame record.
52
+
53
+ Returns
54
+ -------
55
+ dict
56
+ Mapping from ``cluster_id`` to population and frame list.
57
+
58
+ Raises
59
+ ------
60
+ ValueError
61
+ If any frame lacks a cluster assignment.
62
+ """
63
+ stats: ClusterStats = {}
64
+
65
+ for record in dataset.frames:
66
+ if record.cluster_id is None:
67
+ raise ValueError("All frames must have cluster assignments.")
68
+ cluster_id = record.cluster_id
69
+ if cluster_id not in stats:
70
+ stats[cluster_id] = {"population": 0, "frames": []}
71
+ stats[cluster_id]["population"] += 1
72
+ stats[cluster_id]["frames"].append(record)
73
+
74
+ return stats
75
+
76
+
77
+ def sort_clusters_by_population(
78
+ cluster_stats: ClusterStats,
79
+ ascending: bool = True,
80
+ ) -> List[int]:
81
+ """Return cluster IDs sorted by population.
82
+
83
+ Parameters
84
+ ----------
85
+ cluster_stats : dict
86
+ Per-cluster statistics from :func:`compute_cluster_stats`.
87
+ ascending : bool
88
+ If ``True``, smallest populations first.
89
+
90
+ Returns
91
+ -------
92
+ list of int
93
+ Sorted cluster IDs.
94
+ """
95
+ return sorted(
96
+ cluster_stats.keys(),
97
+ key=lambda cid: cluster_stats[cid]["population"],
98
+ reverse=not ascending,
99
+ )
100
+
101
+
102
+ def cluster_stats_to_rows(cluster_stats: ClusterStats) -> List[Dict[str, int]]:
103
+ """Convert cluster statistics to flat rows for CSV export.
104
+
105
+ Parameters
106
+ ----------
107
+ cluster_stats : dict
108
+ Per-cluster statistics.
109
+
110
+ Returns
111
+ -------
112
+ list of dict
113
+ Rows with keys ``cluster_id`` and ``population``.
114
+ """
115
+ return [
116
+ {"cluster_id": cluster_id, "population": entry["population"]}
117
+ for cluster_id, entry in sorted(cluster_stats.items())
118
+ ]
@@ -0,0 +1,6 @@
1
+ """Utility helpers for AdaptivePy."""
2
+
3
+ from adaptivepy.utils.io_utils import copy_file, ensure_dir
4
+ from adaptivepy.utils.logging import setup_logger
5
+
6
+ __all__ = ["copy_file", "ensure_dir", "setup_logger"]
@@ -0,0 +1,49 @@
1
+ """Shared I/O utilities."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import shutil
6
+ from pathlib import Path
7
+ from typing import Union
8
+
9
+ PathLike = Union[str, Path]
10
+
11
+
12
+ def ensure_dir(path: PathLike) -> Path:
13
+ """Create a directory and all parent directories if they do not exist.
14
+
15
+ Parameters
16
+ ----------
17
+ path : str or Path
18
+ Directory path to create.
19
+
20
+ Returns
21
+ -------
22
+ Path
23
+ Resolved path to the created directory.
24
+ """
25
+ resolved = Path(path).resolve()
26
+ resolved.mkdir(parents=True, exist_ok=True)
27
+ return resolved
28
+
29
+
30
+ def copy_file(src: PathLike, dst: PathLike) -> Path:
31
+ """Copy a file to a destination path, creating parent directories as needed.
32
+
33
+ Parameters
34
+ ----------
35
+ src : str or Path
36
+ Source file path.
37
+ dst : str or Path
38
+ Destination file path.
39
+
40
+ Returns
41
+ -------
42
+ Path
43
+ Resolved destination path.
44
+ """
45
+ src_path = Path(src)
46
+ dst_path = Path(dst)
47
+ ensure_dir(dst_path.parent)
48
+ shutil.copy2(src_path, dst_path)
49
+ return dst_path.resolve()
@@ -0,0 +1,55 @@
1
+ """Logging utilities for AdaptivePy runs."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import logging
6
+ import sys
7
+ from pathlib import Path
8
+ from typing import Optional
9
+
10
+
11
+ def setup_logger(
12
+ name: str = "adaptivepy",
13
+ log_file: Optional[Path] = None,
14
+ level: int = logging.INFO,
15
+ ) -> logging.Logger:
16
+ """Configure and return a logger with console and optional file handlers.
17
+
18
+ Parameters
19
+ ----------
20
+ name : str
21
+ Logger name.
22
+ log_file : Path or None
23
+ If provided, log messages are also written to this file.
24
+ level : int
25
+ Logging level for both handlers.
26
+
27
+ Returns
28
+ -------
29
+ logging.Logger
30
+ Configured logger instance.
31
+ """
32
+ logger = logging.getLogger(name)
33
+ logger.setLevel(level)
34
+ logger.handlers.clear()
35
+ logger.propagate = False
36
+
37
+ formatter = logging.Formatter(
38
+ "%(asctime)s [%(levelname)s] %(name)s: %(message)s",
39
+ datefmt="%Y-%m-%d %H:%M:%S",
40
+ )
41
+
42
+ console_handler = logging.StreamHandler(sys.stdout)
43
+ console_handler.setLevel(level)
44
+ console_handler.setFormatter(formatter)
45
+ logger.addHandler(console_handler)
46
+
47
+ if log_file is not None:
48
+ log_file = Path(log_file)
49
+ log_file.parent.mkdir(parents=True, exist_ok=True)
50
+ file_handler = logging.FileHandler(log_file, encoding="utf-8")
51
+ file_handler.setLevel(level)
52
+ file_handler.setFormatter(formatter)
53
+ logger.addHandler(file_handler)
54
+
55
+ return logger
@@ -0,0 +1,52 @@
1
+ Metadata-Version: 2.4
2
+ Name: adaptivepy-sampling
3
+ Version: 0.1.0
4
+ Summary: Adaptive sampling on MD trajectories via clustering and policy-driven seed selection
5
+ Author: AdaptivePy Contributors
6
+ License: MIT
7
+ Requires-Python: >=3.9
8
+ Description-Content-Type: text/markdown
9
+ Requires-Dist: numpy>=1.20
10
+ Requires-Dist: scikit-learn>=1.0
11
+ Requires-Dist: pyyaml>=6.0
12
+ Requires-Dist: click>=8.0
13
+ Requires-Dist: joblib>=1.0
14
+ Requires-Dist: mdtraj>=1.9
15
+ Provides-Extra: dev
16
+ Requires-Dist: pytest>=7.0; extra == "dev"
17
+
18
+ # AdaptivePy
19
+
20
+ Adaptive sampling on molecular dynamics trajectories using clustering-based state space partitioning and policy-driven seed selection.
21
+
22
+ ## Installation
23
+
24
+ ```bash
25
+ pip install -e .
26
+ ```
27
+
28
+ ## Quick start
29
+
30
+ 1. Prepare feature files (`features/traj_0.npy`, ...) with shape `(n_frames, n_features)`.
31
+ 2. Optionally add matching coordinate trajectories (`trajectories/traj_0.xtc`, ...) and a topology file.
32
+ 3. Edit `examples/config.yaml` and run:
33
+
34
+ ```bash
35
+ adaptivepy run examples/config.yaml
36
+ ```
37
+
38
+ ## CLI
39
+
40
+ ```bash
41
+ adaptivepy run config.yaml
42
+ adaptivepy validate config.yaml
43
+ adaptivepy list-policies
44
+ ```
45
+
46
+ ## Python API
47
+
48
+ ```python
49
+ from adaptivepy import run_adaptive_sampling
50
+
51
+ results = run_adaptive_sampling("config.yaml")
52
+ ```
@@ -0,0 +1,34 @@
1
+ adaptivepy/__init__.py,sha256=Z9oSeaY0nBjZGv5aA1-8XJQ4oB5HDPwc9FnpSj6Xddw,198
2
+ adaptivepy/api.py,sha256=OrnkbOmWMUusbxSMY9C-oilk1-k0Yu_7wZBsBTwzWgg,7258
3
+ adaptivepy/models.py,sha256=2CUiOj9nvD1Nt0keFDSUFcR1b0DZaB0gywT5NWLB4Bs,2239
4
+ adaptivepy/cli/__init__.py,sha256=IzJw4MHHsGRl5adW3RptsZi5uGXLSzkJhWzIjlc8VCM,96
5
+ adaptivepy/cli/run.py,sha256=ItFjNxU7RRzdiqNGVBELXsE9K8GaIUOdbv7EpNfRgpE,1741
6
+ adaptivepy/clustering/__init__.py,sha256=k1ee5Q1t0-42N2jOplA9Y0r3XdtLuIfAUwhHcUnqtRg,2841
7
+ adaptivepy/clustering/base.py,sha256=UAxMJ6FHZZa5OY0sgrVZDSRDNKKDdsgJ89GxFOgtI88,1754
8
+ adaptivepy/clustering/regular_space.py,sha256=FnbOycEesf3G9UGjp9UiJAtNlNcyUhaIIu_97cuuGzo,4511
9
+ adaptivepy/clustering/sklearn_kmeans.py,sha256=7INdhM6wSMD2yzK95NiQJI2LyT8iYty71ludnVReV8E,2405
10
+ adaptivepy/clustering/sklearn_minibatch.py,sha256=jQL4i5I9iMIW8jOUH6oPZd2DpGrr0BrjtrGCI9LG3sA,2563
11
+ adaptivepy/config/__init__.py,sha256=JLrboCd7_fsG-GqLpCmTcv1m715lLLFQ5u6tKh1bQIk,309
12
+ adaptivepy/config/schema.py,sha256=5SZT6XWg2X1rqEiSZt4MFBET9pifeJcbY8AhGQxb-SI,5929
13
+ adaptivepy/io/__init__.py,sha256=nbrwleheNEBidhH8VGgdAzErIutK8mMeTLGSJw2hReg,631
14
+ adaptivepy/io/loader.py,sha256=HZSTsVwjpbXy8YLowuKKVvZAgMU9mqmLvqsjO-I_ldE,7878
15
+ adaptivepy/io/trajectory.py,sha256=kkyuf5KFcPHBAGxjYbDPxJ8ZFc0_jJx2mPHKYTszT9A,4208
16
+ adaptivepy/output/__init__.py,sha256=O2F1DG8dDo9T9aI-lfQm8dlJ4izndRLOTbShR-DUJpI,540
17
+ adaptivepy/output/pdb_writer.py,sha256=kLsExDbvCMiLKXPAQsJu8289FDcEEDgqZZhn2insEW4,1596
18
+ adaptivepy/output/writer.py,sha256=dl3LmgfOLHIlaK1jXks-KBBNg8Dj2nXzIKEBl8q9NQY,6146
19
+ adaptivepy/policies/__init__.py,sha256=rc8cTwHa5WG-6K9xaRBv_kBNFmsCaJ-SKlYlmyfa82I,457
20
+ adaptivepy/policies/base.py,sha256=D8TjMlWDM-0uv0Klk-zxXTSkCWf0X8t7kDQvpibh2yA,2497
21
+ adaptivepy/policies/least_counts.py,sha256=PKYu5j_lza_c3NC8xx8Qz6e9rMXLT4JM0hN-erMroXo,1121
22
+ adaptivepy/policies/random.py,sha256=P0ZyTg801IBkB0qJbLX5o6aUkTkIQOUngIwnZJ6j6is,1365
23
+ adaptivepy/selection/__init__.py,sha256=4UdLHfz5X4DNksHK7_XXr7W_pFpXNSDJOwezpgq72w0,122
24
+ adaptivepy/selection/frame_selector.py,sha256=oKEC095iXUiTy4lieLU_xCQDdLsYnhy0AJAU6Nl9qlE,3806
25
+ adaptivepy/stats/__init__.py,sha256=5e6zIKUUbYfxcDCfKCha39Lw4yuB8zTQXtVFEvEdtuE,328
26
+ adaptivepy/stats/cluster_stats.py,sha256=B46tPK0IhweEQ4OgX2HMWMYXmK1KDiGudarhXoiFGFA,2989
27
+ adaptivepy/utils/__init__.py,sha256=RzCDgCQgs4GpVCN-CzxEWgPwtB7ArzvWdE9l_sF4cCM,204
28
+ adaptivepy/utils/io_utils.py,sha256=SXel2DmgCV_5C-xCo7nu_hdCKHllMwVgfa8-8DWycwc,1051
29
+ adaptivepy/utils/logging.py,sha256=qtIMxtRPY79ArtTuJ5oOA0ZjvrccQHHIQ5v_71TDzLk,1455
30
+ adaptivepy_sampling-0.1.0.dist-info/METADATA,sha256=mrBmBlzdWFonaqtZ0aYc2seEQP7LQzh3aiyomp8I_3Y,1222
31
+ adaptivepy_sampling-0.1.0.dist-info/WHEEL,sha256=aeYiig01lYGDzBgS8HxWXOg3uV61G9ijOsup-k9o1sk,91
32
+ adaptivepy_sampling-0.1.0.dist-info/entry_points.txt,sha256=DDQkzgiBjliB_Fyy0y6qs956Y8-RIzgrRVdVG_ac3b8,55
33
+ adaptivepy_sampling-0.1.0.dist-info/top_level.txt,sha256=mObCepJKVRgCqGBCOOO9d135EvogjEfTwCa5U7X0VKU,11
34
+ adaptivepy_sampling-0.1.0.dist-info/RECORD,,
@@ -0,0 +1,5 @@
1
+ Wheel-Version: 1.0
2
+ Generator: setuptools (82.0.1)
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any
5
+
@@ -0,0 +1,2 @@
1
+ [console_scripts]
2
+ adaptivepy = adaptivepy.cli.run:main
@@ -0,0 +1 @@
1
+ adaptivepy