adaptivepy-sampling 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- adaptivepy_sampling-0.1.0/PKG-INFO +52 -0
- adaptivepy_sampling-0.1.0/README.md +35 -0
- adaptivepy_sampling-0.1.0/adaptivepy/__init__.py +7 -0
- adaptivepy_sampling-0.1.0/adaptivepy/api.py +229 -0
- adaptivepy_sampling-0.1.0/adaptivepy/cli/__init__.py +5 -0
- adaptivepy_sampling-0.1.0/adaptivepy/cli/run.py +68 -0
- adaptivepy_sampling-0.1.0/adaptivepy/clustering/__init__.py +103 -0
- adaptivepy_sampling-0.1.0/adaptivepy/clustering/base.py +73 -0
- adaptivepy_sampling-0.1.0/adaptivepy/clustering/regular_space.py +135 -0
- adaptivepy_sampling-0.1.0/adaptivepy/clustering/sklearn_kmeans.py +93 -0
- adaptivepy_sampling-0.1.0/adaptivepy/clustering/sklearn_minibatch.py +94 -0
- adaptivepy_sampling-0.1.0/adaptivepy/config/__init__.py +17 -0
- adaptivepy_sampling-0.1.0/adaptivepy/config/schema.py +196 -0
- adaptivepy_sampling-0.1.0/adaptivepy/io/__init__.py +27 -0
- adaptivepy_sampling-0.1.0/adaptivepy/io/loader.py +267 -0
- adaptivepy_sampling-0.1.0/adaptivepy/io/trajectory.py +151 -0
- adaptivepy_sampling-0.1.0/adaptivepy/models.py +83 -0
- adaptivepy_sampling-0.1.0/adaptivepy/output/__init__.py +23 -0
- adaptivepy_sampling-0.1.0/adaptivepy/output/pdb_writer.py +59 -0
- adaptivepy_sampling-0.1.0/adaptivepy/output/writer.py +229 -0
- adaptivepy_sampling-0.1.0/adaptivepy/policies/__init__.py +21 -0
- adaptivepy_sampling-0.1.0/adaptivepy/policies/base.py +105 -0
- adaptivepy_sampling-0.1.0/adaptivepy/policies/least_counts.py +43 -0
- adaptivepy_sampling-0.1.0/adaptivepy/policies/random.py +53 -0
- adaptivepy_sampling-0.1.0/adaptivepy/selection/__init__.py +5 -0
- adaptivepy_sampling-0.1.0/adaptivepy/selection/frame_selector.py +132 -0
- adaptivepy_sampling-0.1.0/adaptivepy/stats/__init__.py +15 -0
- adaptivepy_sampling-0.1.0/adaptivepy/stats/cluster_stats.py +118 -0
- adaptivepy_sampling-0.1.0/adaptivepy/utils/__init__.py +6 -0
- adaptivepy_sampling-0.1.0/adaptivepy/utils/io_utils.py +49 -0
- adaptivepy_sampling-0.1.0/adaptivepy/utils/logging.py +55 -0
- adaptivepy_sampling-0.1.0/adaptivepy_sampling.egg-info/PKG-INFO +52 -0
- adaptivepy_sampling-0.1.0/adaptivepy_sampling.egg-info/SOURCES.txt +38 -0
- adaptivepy_sampling-0.1.0/adaptivepy_sampling.egg-info/dependency_links.txt +1 -0
- adaptivepy_sampling-0.1.0/adaptivepy_sampling.egg-info/entry_points.txt +2 -0
- adaptivepy_sampling-0.1.0/adaptivepy_sampling.egg-info/requires.txt +9 -0
- adaptivepy_sampling-0.1.0/adaptivepy_sampling.egg-info/top_level.txt +1 -0
- adaptivepy_sampling-0.1.0/pyproject.toml +30 -0
- adaptivepy_sampling-0.1.0/setup.cfg +4 -0
- adaptivepy_sampling-0.1.0/tests/test_adaptivepy.py +77 -0
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: adaptivepy-sampling
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: Adaptive sampling on MD trajectories via clustering and policy-driven seed selection
|
|
5
|
+
Author: AdaptivePy Contributors
|
|
6
|
+
License: MIT
|
|
7
|
+
Requires-Python: >=3.9
|
|
8
|
+
Description-Content-Type: text/markdown
|
|
9
|
+
Requires-Dist: numpy>=1.20
|
|
10
|
+
Requires-Dist: scikit-learn>=1.0
|
|
11
|
+
Requires-Dist: pyyaml>=6.0
|
|
12
|
+
Requires-Dist: click>=8.0
|
|
13
|
+
Requires-Dist: joblib>=1.0
|
|
14
|
+
Requires-Dist: mdtraj>=1.9
|
|
15
|
+
Provides-Extra: dev
|
|
16
|
+
Requires-Dist: pytest>=7.0; extra == "dev"
|
|
17
|
+
|
|
18
|
+
# AdaptivePy
|
|
19
|
+
|
|
20
|
+
Adaptive sampling on molecular dynamics trajectories using clustering-based state space partitioning and policy-driven seed selection.
|
|
21
|
+
|
|
22
|
+
## Installation
|
|
23
|
+
|
|
24
|
+
```bash
|
|
25
|
+
pip install -e .
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Quick start
|
|
29
|
+
|
|
30
|
+
1. Prepare feature files (`features/traj_0.npy`, ...) with shape `(n_frames, n_features)`.
|
|
31
|
+
2. Optionally add matching coordinate trajectories (`trajectories/traj_0.xtc`, ...) and a topology file.
|
|
32
|
+
3. Edit `examples/config.yaml` and run:
|
|
33
|
+
|
|
34
|
+
```bash
|
|
35
|
+
adaptivepy run examples/config.yaml
|
|
36
|
+
```
|
|
37
|
+
|
|
38
|
+
## CLI
|
|
39
|
+
|
|
40
|
+
```bash
|
|
41
|
+
adaptivepy run config.yaml
|
|
42
|
+
adaptivepy validate config.yaml
|
|
43
|
+
adaptivepy list-policies
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
## Python API
|
|
47
|
+
|
|
48
|
+
```python
|
|
49
|
+
from adaptivepy import run_adaptive_sampling
|
|
50
|
+
|
|
51
|
+
results = run_adaptive_sampling("config.yaml")
|
|
52
|
+
```
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# AdaptivePy
|
|
2
|
+
|
|
3
|
+
Adaptive sampling on molecular dynamics trajectories using clustering-based state space partitioning and policy-driven seed selection.
|
|
4
|
+
|
|
5
|
+
## Installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install -e .
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## Quick start
|
|
12
|
+
|
|
13
|
+
1. Prepare feature files (`features/traj_0.npy`, ...) with shape `(n_frames, n_features)`.
|
|
14
|
+
2. Optionally add matching coordinate trajectories (`trajectories/traj_0.xtc`, ...) and a topology file.
|
|
15
|
+
3. Edit `examples/config.yaml` and run:
|
|
16
|
+
|
|
17
|
+
```bash
|
|
18
|
+
adaptivepy run examples/config.yaml
|
|
19
|
+
```
|
|
20
|
+
|
|
21
|
+
## CLI
|
|
22
|
+
|
|
23
|
+
```bash
|
|
24
|
+
adaptivepy run config.yaml
|
|
25
|
+
adaptivepy validate config.yaml
|
|
26
|
+
adaptivepy list-policies
|
|
27
|
+
```
|
|
28
|
+
|
|
29
|
+
## Python API
|
|
30
|
+
|
|
31
|
+
```python
|
|
32
|
+
from adaptivepy import run_adaptive_sampling
|
|
33
|
+
|
|
34
|
+
results = run_adaptive_sampling("config.yaml")
|
|
35
|
+
```
|
|
@@ -0,0 +1,229 @@
|
|
|
1
|
+
"""High-level API orchestrating the adaptive sampling workflow."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
import numpy as np
|
|
10
|
+
|
|
11
|
+
from adaptivepy.clustering import create_clusterer, fit_clusterer
|
|
12
|
+
from adaptivepy.config.schema import RunConfig, load_config
|
|
13
|
+
from adaptivepy.io.loader import (
|
|
14
|
+
list_trajectory_files,
|
|
15
|
+
load_features,
|
|
16
|
+
validate_dataset,
|
|
17
|
+
validate_feature_trajectory_mapping,
|
|
18
|
+
)
|
|
19
|
+
from adaptivepy.io.trajectory import (
|
|
20
|
+
build_trajectory_map,
|
|
21
|
+
validate_trajectory_frame_counts,
|
|
22
|
+
)
|
|
23
|
+
from adaptivepy.models import SeedResult
|
|
24
|
+
from adaptivepy.output.pdb_writer import write_seed_pdbs
|
|
25
|
+
from adaptivepy.output.writer import (
|
|
26
|
+
write_assignments,
|
|
27
|
+
write_cluster_model,
|
|
28
|
+
write_cluster_statistics,
|
|
29
|
+
write_combined_metadata,
|
|
30
|
+
write_policy_outputs,
|
|
31
|
+
write_run_config,
|
|
32
|
+
)
|
|
33
|
+
from adaptivepy.policies import get_policy
|
|
34
|
+
from adaptivepy.selection.frame_selector import select_seeds
|
|
35
|
+
from adaptivepy.stats.cluster_stats import assign_clusters, compute_cluster_stats
|
|
36
|
+
from adaptivepy.utils.io_utils import ensure_dir
|
|
37
|
+
from adaptivepy.utils.logging import setup_logger
|
|
38
|
+
|
|
39
|
+
logger = logging.getLogger(__name__)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def run_adaptive_sampling(
|
|
43
|
+
config_path: str | Path,
|
|
44
|
+
config: Optional[RunConfig] = None,
|
|
45
|
+
) -> Dict[str, List[SeedResult]]:
|
|
46
|
+
"""Execute a full adaptive sampling run from a YAML configuration.
|
|
47
|
+
|
|
48
|
+
Workflow
|
|
49
|
+
--------
|
|
50
|
+
1. Load features (and optionally validate trajectories).
|
|
51
|
+
2. Cluster the concatenated feature matrix.
|
|
52
|
+
3. Compute cluster statistics.
|
|
53
|
+
4. Apply each configured policy and select seed frames.
|
|
54
|
+
5. Write metadata, assignments, model, and optional PDBs.
|
|
55
|
+
|
|
56
|
+
Parameters
|
|
57
|
+
----------
|
|
58
|
+
config_path : str or Path
|
|
59
|
+
Path to the YAML configuration file.
|
|
60
|
+
config : RunConfig or None
|
|
61
|
+
Pre-parsed configuration. If ``None``, loaded from ``config_path``.
|
|
62
|
+
|
|
63
|
+
Returns
|
|
64
|
+
-------
|
|
65
|
+
dict
|
|
66
|
+
Mapping from policy name to lists of :class:`SeedResult`.
|
|
67
|
+
|
|
68
|
+
Raises
|
|
69
|
+
------
|
|
70
|
+
FileNotFoundError
|
|
71
|
+
If required input paths do not exist.
|
|
72
|
+
ValueError
|
|
73
|
+
If validation checks fail.
|
|
74
|
+
"""
|
|
75
|
+
config_path = Path(config_path)
|
|
76
|
+
if config is None:
|
|
77
|
+
config = load_config(config_path)
|
|
78
|
+
|
|
79
|
+
output_dir = ensure_dir(config.output_dir)
|
|
80
|
+
log_path = output_dir / "logs.txt"
|
|
81
|
+
setup_logger("adaptivepy", log_file=log_path)
|
|
82
|
+
logger.info("Starting AdaptivePy run with config %s", config_path)
|
|
83
|
+
|
|
84
|
+
np.random.seed(config.random_seed)
|
|
85
|
+
|
|
86
|
+
# --- Load and validate data ---
|
|
87
|
+
feature_files = sorted(Path(config.features_dir).glob("*.npy"))
|
|
88
|
+
trajectory_files: Optional[List[Path]] = None
|
|
89
|
+
trajectory_map: Optional[Dict[int, Path]] = None
|
|
90
|
+
|
|
91
|
+
if config.trajectories_dir is not None:
|
|
92
|
+
trajectory_files = list_trajectory_files(config.trajectories_dir)
|
|
93
|
+
validate_feature_trajectory_mapping(feature_files, trajectory_files)
|
|
94
|
+
|
|
95
|
+
dataset = load_features(config.features_dir)
|
|
96
|
+
validate_dataset(dataset, trajectory_files)
|
|
97
|
+
|
|
98
|
+
if config.trajectories_dir is not None and config.topology is not None:
|
|
99
|
+
trajectory_map = build_trajectory_map(
|
|
100
|
+
config.trajectories_dir, dataset.traj_names
|
|
101
|
+
)
|
|
102
|
+
expected_counts = {
|
|
103
|
+
traj_id: end - start
|
|
104
|
+
for traj_id, (start, end) in dataset.traj_index_map.items()
|
|
105
|
+
}
|
|
106
|
+
validate_trajectory_frame_counts(
|
|
107
|
+
config.topology, trajectory_map, expected_counts
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
# --- Clustering ---
|
|
111
|
+
clusterer = create_clusterer(
|
|
112
|
+
method=config.clustering.method,
|
|
113
|
+
n_clusters=config.clustering.n_clusters,
|
|
114
|
+
random_state=config.random_seed,
|
|
115
|
+
params=config.clustering.params,
|
|
116
|
+
)
|
|
117
|
+
fit_clusterer(clusterer, dataset.feature_matrix)
|
|
118
|
+
labels = clusterer.predict(dataset.feature_matrix)
|
|
119
|
+
assign_clusters(dataset, labels)
|
|
120
|
+
|
|
121
|
+
cluster_stats = compute_cluster_stats(dataset)
|
|
122
|
+
centers = clusterer.cluster_centers_
|
|
123
|
+
|
|
124
|
+
logger.info(
|
|
125
|
+
"Clustering complete: %d clusters, %d total frames",
|
|
126
|
+
len(cluster_stats),
|
|
127
|
+
len(dataset.frames),
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# --- Global artifacts ---
|
|
131
|
+
write_run_config(config, output_dir, config_path)
|
|
132
|
+
write_assignments(labels, output_dir)
|
|
133
|
+
write_cluster_model(clusterer.model, output_dir)
|
|
134
|
+
write_cluster_statistics(cluster_stats, output_dir)
|
|
135
|
+
|
|
136
|
+
# --- Policies ---
|
|
137
|
+
policy_seeds: Dict[str, List[SeedResult]] = {}
|
|
138
|
+
|
|
139
|
+
for policy_name in config.policies:
|
|
140
|
+
logger.info("Applying policy: %s", policy_name)
|
|
141
|
+
policy_kwargs = {}
|
|
142
|
+
if policy_name == "random":
|
|
143
|
+
policy_kwargs["random_state"] = config.random_seed
|
|
144
|
+
|
|
145
|
+
policy = get_policy(policy_name, **policy_kwargs)
|
|
146
|
+
selected_clusters = policy.select_clusters(
|
|
147
|
+
cluster_stats, config.n_seeds
|
|
148
|
+
)
|
|
149
|
+
seeds = select_seeds(
|
|
150
|
+
policy_name=policy_name,
|
|
151
|
+
selected_clusters=selected_clusters,
|
|
152
|
+
cluster_stats=cluster_stats,
|
|
153
|
+
cluster_centers=centers,
|
|
154
|
+
method=config.seed_selection.method,
|
|
155
|
+
random_state=config.random_seed,
|
|
156
|
+
)
|
|
157
|
+
policy_seeds[policy_name] = seeds
|
|
158
|
+
|
|
159
|
+
policy_dir = write_policy_outputs(
|
|
160
|
+
policy_name, seeds, cluster_stats, output_dir
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
if (
|
|
164
|
+
config.write_pdbs
|
|
165
|
+
and trajectory_map is not None
|
|
166
|
+
and config.topology is not None
|
|
167
|
+
):
|
|
168
|
+
write_seed_pdbs(
|
|
169
|
+
seeds, config.topology, trajectory_map, policy_dir
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
logger.info("Policy %s selected %d seeds", policy_name, len(seeds))
|
|
173
|
+
|
|
174
|
+
if len(policy_seeds) > 1:
|
|
175
|
+
write_combined_metadata(policy_seeds, output_dir)
|
|
176
|
+
|
|
177
|
+
logger.info("AdaptivePy run finished. Results written to %s", output_dir)
|
|
178
|
+
return policy_seeds
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
def validate_config(config_path: str | Path) -> RunConfig:
|
|
182
|
+
"""Validate a configuration file and input data without running clustering.
|
|
183
|
+
|
|
184
|
+
Parameters
|
|
185
|
+
----------
|
|
186
|
+
config_path : str or Path
|
|
187
|
+
Path to the YAML configuration file.
|
|
188
|
+
|
|
189
|
+
Returns
|
|
190
|
+
-------
|
|
191
|
+
RunConfig
|
|
192
|
+
Parsed configuration if validation succeeds.
|
|
193
|
+
|
|
194
|
+
Raises
|
|
195
|
+
------
|
|
196
|
+
ValueError
|
|
197
|
+
If validation fails.
|
|
198
|
+
"""
|
|
199
|
+
config_path = Path(config_path)
|
|
200
|
+
config = load_config(config_path)
|
|
201
|
+
|
|
202
|
+
feature_files = sorted(Path(config.features_dir).glob("*.npy"))
|
|
203
|
+
if not feature_files:
|
|
204
|
+
raise ValueError(f"No feature files in {config.features_dir}")
|
|
205
|
+
|
|
206
|
+
trajectory_files = None
|
|
207
|
+
if config.trajectories_dir is not None:
|
|
208
|
+
trajectory_files = list_trajectory_files(config.trajectories_dir)
|
|
209
|
+
validate_feature_trajectory_mapping(feature_files, trajectory_files)
|
|
210
|
+
|
|
211
|
+
dataset = load_features(config.features_dir)
|
|
212
|
+
validate_dataset(dataset, trajectory_files)
|
|
213
|
+
|
|
214
|
+
if config.trajectories_dir is not None and config.topology is not None:
|
|
215
|
+
trajectory_map = build_trajectory_map(
|
|
216
|
+
config.trajectories_dir, dataset.traj_names
|
|
217
|
+
)
|
|
218
|
+
expected_counts = {
|
|
219
|
+
traj_id: end - start
|
|
220
|
+
for traj_id, (start, end) in dataset.traj_index_map.items()
|
|
221
|
+
}
|
|
222
|
+
validate_trajectory_frame_counts(
|
|
223
|
+
config.topology, trajectory_map, expected_counts
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
for policy_name in config.policies:
|
|
227
|
+
get_policy(policy_name)
|
|
228
|
+
|
|
229
|
+
return config
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
"""Command-line interface for AdaptivePy."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import sys
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
import click
|
|
9
|
+
|
|
10
|
+
from adaptivepy import __version__
|
|
11
|
+
from adaptivepy.api import run_adaptive_sampling, validate_config
|
|
12
|
+
from adaptivepy.policies import list_policies
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
@click.group()
|
|
16
|
+
@click.version_option(version=__version__, prog_name="adaptivepy")
|
|
17
|
+
def main() -> None:
|
|
18
|
+
"""AdaptivePy: clustering-based adaptive sampling for MD trajectories."""
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
@main.command("run")
|
|
22
|
+
@click.argument("config", type=click.Path(exists=True, path_type=Path))
|
|
23
|
+
def run_cmd(config: Path) -> None:
|
|
24
|
+
"""Run adaptive sampling from a YAML configuration file.
|
|
25
|
+
|
|
26
|
+
Parameters
|
|
27
|
+
----------
|
|
28
|
+
config : Path
|
|
29
|
+
Path to the YAML configuration file.
|
|
30
|
+
"""
|
|
31
|
+
try:
|
|
32
|
+
run_adaptive_sampling(config)
|
|
33
|
+
except Exception as exc:
|
|
34
|
+
click.echo(f"Error: {exc}", err=True)
|
|
35
|
+
sys.exit(1)
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
@main.command("validate")
|
|
39
|
+
@click.argument("config", type=click.Path(exists=True, path_type=Path))
|
|
40
|
+
def validate_cmd(config: Path) -> None:
|
|
41
|
+
"""Validate configuration and input data without running clustering.
|
|
42
|
+
|
|
43
|
+
Parameters
|
|
44
|
+
----------
|
|
45
|
+
config : Path
|
|
46
|
+
Path to the YAML configuration file.
|
|
47
|
+
"""
|
|
48
|
+
try:
|
|
49
|
+
validate_config(config)
|
|
50
|
+
click.echo(f"Configuration valid: {config}")
|
|
51
|
+
except Exception as exc:
|
|
52
|
+
click.echo(f"Validation failed: {exc}", err=True)
|
|
53
|
+
sys.exit(1)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@main.command("list-policies")
|
|
57
|
+
def list_policies_cmd() -> None:
|
|
58
|
+
"""List all registered adaptive sampling policies."""
|
|
59
|
+
policies = list_policies()
|
|
60
|
+
if not policies:
|
|
61
|
+
click.echo("No policies registered.")
|
|
62
|
+
return
|
|
63
|
+
for name in policies:
|
|
64
|
+
click.echo(name)
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
if __name__ == "__main__":
|
|
68
|
+
main()
|
|
@@ -0,0 +1,103 @@
|
|
|
1
|
+
"""Clustering backend factory and registry."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from typing import Any, Dict, Type
|
|
6
|
+
|
|
7
|
+
from adaptivepy.clustering.base import Clusterer
|
|
8
|
+
from adaptivepy.clustering.regular_space import SklearnRegularSpaceClusterer
|
|
9
|
+
from adaptivepy.clustering.sklearn_kmeans import SklearnKMeansClusterer
|
|
10
|
+
from adaptivepy.clustering.sklearn_minibatch import SklearnMiniBatchClusterer
|
|
11
|
+
|
|
12
|
+
CLUSTERER_REGISTRY: Dict[str, Type[Clusterer]] = {
|
|
13
|
+
"kmeans": SklearnKMeansClusterer,
|
|
14
|
+
"minibatch_kmeans": SklearnMiniBatchClusterer,
|
|
15
|
+
"regular_space": SklearnRegularSpaceClusterer,
|
|
16
|
+
}
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def create_clusterer(
|
|
20
|
+
method: str,
|
|
21
|
+
n_clusters: int,
|
|
22
|
+
random_state: int | None = None,
|
|
23
|
+
params: Dict[str, Any] | None = None,
|
|
24
|
+
) -> Clusterer:
|
|
25
|
+
"""Instantiate a registered clustering backend.
|
|
26
|
+
|
|
27
|
+
Parameters
|
|
28
|
+
----------
|
|
29
|
+
method : str
|
|
30
|
+
Clustering method name (``kmeans``, ``minibatch_kmeans``,
|
|
31
|
+
``regular_space``).
|
|
32
|
+
n_clusters : int
|
|
33
|
+
Target number of clusters (used by k-means variants; mapped to
|
|
34
|
+
``max_clusters`` for regular-space when applicable).
|
|
35
|
+
random_state : int or None
|
|
36
|
+
Random seed for reproducibility.
|
|
37
|
+
params : dict or None
|
|
38
|
+
Additional backend-specific parameters.
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
Clusterer
|
|
43
|
+
Unfitted clusterer instance.
|
|
44
|
+
|
|
45
|
+
Raises
|
|
46
|
+
------
|
|
47
|
+
ValueError
|
|
48
|
+
If ``method`` is not registered.
|
|
49
|
+
"""
|
|
50
|
+
if method not in CLUSTERER_REGISTRY:
|
|
51
|
+
available = ", ".join(sorted(CLUSTERER_REGISTRY))
|
|
52
|
+
raise ValueError(f"Unknown clustering method '{method}'. Available: {available}")
|
|
53
|
+
|
|
54
|
+
params = dict(params or {})
|
|
55
|
+
clusterer_cls = CLUSTERER_REGISTRY[method]
|
|
56
|
+
|
|
57
|
+
if method == "regular_space":
|
|
58
|
+
if "min_dist" not in params:
|
|
59
|
+
raise ValueError(
|
|
60
|
+
"regular_space clustering requires 'min_dist' in clustering.params."
|
|
61
|
+
)
|
|
62
|
+
max_clusters = params.pop("max_clusters", n_clusters)
|
|
63
|
+
return SklearnRegularSpaceClusterer(
|
|
64
|
+
min_dist=float(params.pop("min_dist")),
|
|
65
|
+
max_clusters=int(max_clusters) if max_clusters else None,
|
|
66
|
+
random_state=random_state,
|
|
67
|
+
**params,
|
|
68
|
+
)
|
|
69
|
+
|
|
70
|
+
return clusterer_cls(
|
|
71
|
+
n_clusters=n_clusters,
|
|
72
|
+
random_state=random_state,
|
|
73
|
+
**params,
|
|
74
|
+
)
|
|
75
|
+
|
|
76
|
+
|
|
77
|
+
def fit_clusterer(clusterer: Clusterer, X) -> Clusterer:
|
|
78
|
+
"""Fit a clusterer and return it for chaining.
|
|
79
|
+
|
|
80
|
+
Parameters
|
|
81
|
+
----------
|
|
82
|
+
clusterer : Clusterer
|
|
83
|
+
Unfitted clusterer instance.
|
|
84
|
+
X : np.ndarray
|
|
85
|
+
Feature matrix.
|
|
86
|
+
|
|
87
|
+
Returns
|
|
88
|
+
-------
|
|
89
|
+
Clusterer
|
|
90
|
+
Fitted clusterer.
|
|
91
|
+
"""
|
|
92
|
+
return clusterer.fit(X)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
__all__ = [
|
|
96
|
+
"CLUSTERER_REGISTRY",
|
|
97
|
+
"Clusterer",
|
|
98
|
+
"SklearnKMeansClusterer",
|
|
99
|
+
"SklearnMiniBatchClusterer",
|
|
100
|
+
"SklearnRegularSpaceClusterer",
|
|
101
|
+
"create_clusterer",
|
|
102
|
+
"fit_clusterer",
|
|
103
|
+
]
|
|
@@ -0,0 +1,73 @@
|
|
|
1
|
+
"""Abstract base class for clustering backends."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from typing import Any, Optional
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class Clusterer(ABC):
|
|
12
|
+
"""Base interface for feature-space clustering algorithms.
|
|
13
|
+
|
|
14
|
+
Implementations wrap scikit-learn or custom clustering methods and expose
|
|
15
|
+
a uniform ``fit`` / ``predict`` API.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
def fit(self, X: np.ndarray) -> "Clusterer":
|
|
20
|
+
"""Fit the clusterer to feature data.
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
X : np.ndarray
|
|
25
|
+
Feature matrix of shape ``(n_samples, n_features)``.
|
|
26
|
+
|
|
27
|
+
Returns
|
|
28
|
+
-------
|
|
29
|
+
Clusterer
|
|
30
|
+
Fitted clusterer instance (``self``).
|
|
31
|
+
"""
|
|
32
|
+
...
|
|
33
|
+
|
|
34
|
+
@abstractmethod
|
|
35
|
+
def predict(self, X: np.ndarray) -> np.ndarray:
|
|
36
|
+
"""Assign cluster labels to feature data.
|
|
37
|
+
|
|
38
|
+
Parameters
|
|
39
|
+
----------
|
|
40
|
+
X : np.ndarray
|
|
41
|
+
Feature matrix of shape ``(n_samples, n_features)``.
|
|
42
|
+
|
|
43
|
+
Returns
|
|
44
|
+
-------
|
|
45
|
+
np.ndarray
|
|
46
|
+
Integer cluster labels of shape ``(n_samples,)``.
|
|
47
|
+
"""
|
|
48
|
+
...
|
|
49
|
+
|
|
50
|
+
@property
|
|
51
|
+
@abstractmethod
|
|
52
|
+
def cluster_centers_(self) -> Optional[np.ndarray]:
|
|
53
|
+
"""Return cluster centroids if available.
|
|
54
|
+
|
|
55
|
+
Returns
|
|
56
|
+
-------
|
|
57
|
+
np.ndarray or None
|
|
58
|
+
Array of shape ``(n_clusters, n_features)``, or ``None`` if the
|
|
59
|
+
method does not define explicit centers.
|
|
60
|
+
"""
|
|
61
|
+
...
|
|
62
|
+
|
|
63
|
+
@property
|
|
64
|
+
@abstractmethod
|
|
65
|
+
def model(self) -> Any:
|
|
66
|
+
"""Return the underlying fitted model object for serialization.
|
|
67
|
+
|
|
68
|
+
Returns
|
|
69
|
+
-------
|
|
70
|
+
object
|
|
71
|
+
Backend-specific model instance.
|
|
72
|
+
"""
|
|
73
|
+
...
|