edmkit-search 0.0.1a1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. edmkit_search-0.0.1a1/LICENSE +21 -0
  2. edmkit_search-0.0.1a1/PKG-INFO +70 -0
  3. edmkit_search-0.0.1a1/README.md +46 -0
  4. edmkit_search-0.0.1a1/pyproject.toml +48 -0
  5. edmkit_search-0.0.1a1/src/edmkit/search/__init__.py +3 -0
  6. edmkit_search-0.0.1a1/src/edmkit/search/dataset/__init__.py +5 -0
  7. edmkit_search-0.0.1a1/src/edmkit/search/dataset/containers.py +90 -0
  8. edmkit_search-0.0.1a1/src/edmkit/search/dataset/transforms.py +110 -0
  9. edmkit_search-0.0.1a1/src/edmkit/search/energy/__init__.py +16 -0
  10. edmkit_search-0.0.1a1/src/edmkit/search/energy/energy.py +28 -0
  11. edmkit_search-0.0.1a1/src/edmkit/search/energy/folds.py +64 -0
  12. edmkit_search-0.0.1a1/src/edmkit/search/energy/holdout.py +47 -0
  13. edmkit_search-0.0.1a1/src/edmkit/search/energy/loo.py +46 -0
  14. edmkit_search-0.0.1a1/src/edmkit/search/energy/weight.py +30 -0
  15. edmkit_search-0.0.1a1/src/edmkit/search/neighborhood/__init__.py +4 -0
  16. edmkit_search-0.0.1a1/src/edmkit/search/neighborhood/forward.py +45 -0
  17. edmkit_search-0.0.1a1/src/edmkit/search/neighborhood/neighborhood.py +23 -0
  18. edmkit_search-0.0.1a1/src/edmkit/search/state/__init__.py +3 -0
  19. edmkit_search-0.0.1a1/src/edmkit/search/state/states.py +13 -0
  20. edmkit_search-0.0.1a1/src/edmkit/search/strategy/__init__.py +6 -0
  21. edmkit_search-0.0.1a1/src/edmkit/search/strategy/beam.py +37 -0
  22. edmkit_search-0.0.1a1/src/edmkit/search/strategy/frontier.py +23 -0
  23. edmkit_search-0.0.1a1/src/edmkit/search/strategy/greedy.py +13 -0
  24. edmkit_search-0.0.1a1/src/edmkit/search/strategy/run.py +28 -0
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2026 temma
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,70 @@
1
+ Metadata-Version: 2.4
2
+ Name: edmkit-search
3
+ Version: 0.0.1a1
4
+ Summary: Trajectory construction on an energy landscape for edmkit
5
+ Keywords: edm,empirical-dynamic-modeling,feature-selection,search,time-series
6
+ Author: FUJISHIGE TEMMA
7
+ Author-email: FUJISHIGE TEMMA <tenma.x0@gmail.com>
8
+ License-Expression: MIT
9
+ License-File: LICENSE
10
+ Classifier: Development Status :: 3 - Alpha
11
+ Classifier: Intended Audience :: Science/Research
12
+ Classifier: Operating System :: OS Independent
13
+ Classifier: Programming Language :: Python :: 3
14
+ Classifier: Programming Language :: Python :: 3.13
15
+ Classifier: Topic :: Scientific/Engineering
16
+ Classifier: Typing :: Typed
17
+ Requires-Dist: edmkit>=0.0.7
18
+ Requires-Dist: numpy>=2.4.4
19
+ Requires-Python: >=3.13
20
+ Project-URL: Homepage, https://github.com/FujishigeTemma/edmkit-search
21
+ Project-URL: Repository, https://github.com/FujishigeTemma/edmkit-search
22
+ Project-URL: Issues, https://github.com/FujishigeTemma/edmkit-search/issues
23
+ Description-Content-Type: text/markdown
24
+
25
+ # edmkit-search
26
+
27
+ Trajectory construction on an energy landscape, packaged as a library on top
28
+ of [`edmkit`](https://github.com/FujishigeTemma/edmkit).
29
+
30
+ ## Install
31
+
32
+ ```bash
33
+ pip install edmkit-search
34
+ # or
35
+ uv add edmkit-search
36
+ ```
37
+
38
+ ## Overview
39
+
40
+ The search loop builds a trajectory step by step:
41
+
42
+ 1. From the current frontier (batch of `(state, context)`),
43
+ 2. expand neighbors via a `Neighborhood`,
44
+ 3. score children with an `Energy` to obtain `(energies, contexts)`,
45
+ 4. and pick the next frontier with a `Strategy`.
46
+
47
+ See `CODING.md` for the design rules. The subpackages mirror the four
48
+ abstractions: `energy/`, `neighborhood/`, `state/`, `strategy/`, plus
49
+ `dataset/` for the input containers.
50
+
51
+ ## Quick start
52
+
53
+ ```bash
54
+ uv sync
55
+ PYTHON_GIL=0 uv run python e2e/synthetic.py
56
+ ```
57
+
58
+ `e2e/synthetic.py` generates a Lorenz-96 trajectory mixed with noise columns
59
+ and uses greedy forward selection to recover the informative subset.
60
+
61
+ ## Tests
62
+
63
+ ```bash
64
+ uv run pytest
65
+ ```
66
+
67
+ ## Experiments
68
+
69
+ Real-data experiments (fly, fish) and the analysis pipeline live in a
70
+ separate repository: [`edmkit-search-experiments`](../edmkit-search-experiments).
@@ -0,0 +1,46 @@
1
+ # edmkit-search
2
+
3
+ Trajectory construction on an energy landscape, packaged as a library on top
4
+ of [`edmkit`](https://github.com/FujishigeTemma/edmkit).
5
+
6
+ ## Install
7
+
8
+ ```bash
9
+ pip install edmkit-search
10
+ # or
11
+ uv add edmkit-search
12
+ ```
13
+
14
+ ## Overview
15
+
16
+ The search loop builds a trajectory step by step:
17
+
18
+ 1. From the current frontier (batch of `(state, context)`),
19
+ 2. expand neighbors via a `Neighborhood`,
20
+ 3. score children with an `Energy` to obtain `(energies, contexts)`,
21
+ 4. and pick the next frontier with a `Strategy`.
22
+
23
+ See `CODING.md` for the design rules. The subpackages mirror the four
24
+ abstractions: `energy/`, `neighborhood/`, `state/`, `strategy/`, plus
25
+ `dataset/` for the input containers.
26
+
27
+ ## Quick start
28
+
29
+ ```bash
30
+ uv sync
31
+ PYTHON_GIL=0 uv run python e2e/synthetic.py
32
+ ```
33
+
34
+ `e2e/synthetic.py` generates a Lorenz-96 trajectory mixed with noise columns
35
+ and uses greedy forward selection to recover the informative subset.
36
+
37
+ ## Tests
38
+
39
+ ```bash
40
+ uv run pytest
41
+ ```
42
+
43
+ ## Experiments
44
+
45
+ Real-data experiments (fly, fish) and the analysis pipeline live in a
46
+ separate repository: [`edmkit-search-experiments`](../edmkit-search-experiments).
@@ -0,0 +1,48 @@
1
+ [project]
2
+ name = "edmkit-search"
3
+ version = "0.0.1a1"
4
+ description = "Trajectory construction on an energy landscape for edmkit"
5
+ authors = [{ name = "FUJISHIGE TEMMA", email = "tenma.x0@gmail.com" }]
6
+ license = "MIT"
7
+ license-files = ["LICENSE"]
8
+ readme = "README.md"
9
+ requires-python = ">=3.13"
10
+ keywords = ["edm", "empirical-dynamic-modeling", "feature-selection", "search", "time-series"]
11
+ classifiers = [
12
+ "Development Status :: 3 - Alpha",
13
+ "Intended Audience :: Science/Research",
14
+ "Operating System :: OS Independent",
15
+ "Programming Language :: Python :: 3",
16
+ "Programming Language :: Python :: 3.13",
17
+ "Topic :: Scientific/Engineering",
18
+ "Typing :: Typed",
19
+ ]
20
+ dependencies = [
21
+ "edmkit>=0.0.7",
22
+ "numpy>=2.4.4",
23
+ ]
24
+
25
+ [project.urls]
26
+ Homepage = "https://github.com/FujishigeTemma/edmkit-search"
27
+ Repository = "https://github.com/FujishigeTemma/edmkit-search"
28
+ Issues = "https://github.com/FujishigeTemma/edmkit-search/issues"
29
+
30
+ [dependency-groups]
31
+ dev = ["hypothesis>=6.152.5", "pytest>=9.0.3", "ruff>=0.15.12", "ty>=0.0.35"]
32
+
33
+ [tool.pytest.ini_options]
34
+ testpaths = ["tests"]
35
+ addopts = "-x --tb=short"
36
+
37
+ [build-system]
38
+ requires = ["uv_build>=0.10.9,<0.11.0"]
39
+ build-backend = "uv_build"
40
+
41
+ [tool.uv.build-backend]
42
+ module-name = "edmkit.search"
43
+
44
+ [tool.ruff]
45
+ exclude = [".claude"]
46
+
47
+ [tool.ty.src]
48
+ exclude = [".claude"]
@@ -0,0 +1,3 @@
1
+ from . import dataset, energy, neighborhood, state, strategy
2
+
3
+ __all__ = ["dataset", "energy", "neighborhood", "state", "strategy"]
@@ -0,0 +1,5 @@
1
+ # ruff: noqa: F401
2
+ """Dataset subpackage: containers and transforms."""
3
+
4
+ from .containers import Dataset, Subset
5
+ from .transforms import Transform, compose, gaussian_noise, zscore_normalize
@@ -0,0 +1,90 @@
1
+ from functools import cached_property
2
+
3
+ import numpy as np
4
+
5
+ from .transforms import Transform
6
+
7
+
8
+ class Dataset:
9
+ """Time series dataset for X -> Y mapping.
10
+
11
+ Parameters
12
+ ----------
13
+ `X` : `np.ndarray` of shape `(T, D_x)`
14
+ Input time series.
15
+ `Y` : `np.ndarray` of shape `(T, D_y)` or `(T,)`
16
+ Output time series. 1D is auto-promoted to `(T, 1)`.
17
+ `transform` : :type: `Transform` or `None`, default `None`
18
+ `(x, y)` -> `(x', y')` closure for preprocessing / augmentation.
19
+
20
+ Raises
21
+ ------
22
+ `ValueError`
23
+ If `X` is not 2-dimensional, `Y` is not 1D or 2D,
24
+ or if `T` dimensions mismatch.
25
+ """
26
+
27
+ def __init__(
28
+ self,
29
+ X: np.ndarray,
30
+ Y: np.ndarray,
31
+ *,
32
+ transform: Transform | None = None,
33
+ ):
34
+ if X.ndim != 2:
35
+ raise ValueError(f"X must be 2-dimensional (T, D_x), got shape {X.shape}")
36
+ if Y.ndim == 1:
37
+ Y = Y[:, np.newaxis]
38
+ if Y.ndim != 2:
39
+ raise ValueError(f"Y must be 1D or 2D, got shape {Y.shape}")
40
+ if X.shape[0] != Y.shape[0]:
41
+ raise ValueError(
42
+ f"T mismatch: X has {X.shape[0]} timesteps, Y has {Y.shape[0]}"
43
+ )
44
+
45
+ self.X = X.astype(np.float32)
46
+ self.Y = Y.astype(np.float32)
47
+ self.transform = transform
48
+
49
+ def __len__(self) -> int:
50
+ return self.X.shape[0]
51
+
52
+ def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
53
+ x = self.X[idx].copy()
54
+ y = self.Y[idx].copy()
55
+ if self.transform is not None:
56
+ x, y = self.transform(x, y)
57
+ return x, y
58
+
59
+
60
+ class Subset(Dataset):
61
+ """A view into a `Dataset` selected by index, without copying data.
62
+
63
+ This is a `Dataset` subtype, so row views can be passed anywhere a
64
+ dataset is expected.
65
+
66
+ Parameters
67
+ ----------
68
+ `dataset` : :type: `Dataset`
69
+ The underlying dataset.
70
+ `indices` : `np.ndarray`
71
+ Indices into `dataset` to expose.
72
+ """
73
+
74
+ def __init__(self, dataset: Dataset, indices: np.ndarray):
75
+ self.dataset = dataset
76
+ self.indices = indices
77
+
78
+ @cached_property
79
+ def X(self) -> np.ndarray:
80
+ return self.dataset.X[self.indices]
81
+
82
+ @cached_property
83
+ def Y(self) -> np.ndarray:
84
+ return self.dataset.Y[self.indices]
85
+
86
+ def __len__(self) -> int:
87
+ return len(self.indices)
88
+
89
+ def __getitem__(self, idx: int) -> tuple[np.ndarray, np.ndarray]:
90
+ return self.dataset[self.indices[idx]]
@@ -0,0 +1,110 @@
1
+ from collections.abc import Callable
2
+ from functools import reduce
3
+ from typing import TypeAlias
4
+
5
+ import numpy as np
6
+
7
+ Transform: TypeAlias = Callable[[np.ndarray, np.ndarray], tuple[np.ndarray, np.ndarray]]
8
+ """Type alias for a function that takes `(x, y)` and returns `(x', y')`.
9
+ Typically a closure returned by a higher-order preprocessing or data-augmentation function.
10
+ """
11
+
12
+
13
+ def zscore_normalize(
14
+ data: np.ndarray,
15
+ *,
16
+ target: str,
17
+ ) -> Transform:
18
+ """Return a `Transform` that applies z-score normalization using statistics computed from `data`.
19
+
20
+ Parameters
21
+ ----------
22
+ `data` : `np.ndarray` of shape `(T, D)` or `(N, T, D)`
23
+ Data from which mean and std are computed.
24
+ `target` : `str`, default `"x"`
25
+ Which element of the `(x, y)` pair to normalize.
26
+ `"x"` normalizes input, `"y"` normalizes output, `"both"` normalizes both
27
+ (using the same statistics — only valid when `D_x == D_y`).
28
+
29
+ Returns
30
+ -------
31
+ :type: `Transform`
32
+ `(x, y)` -> `(x', y')` where the selected target(s) are normalized.
33
+
34
+ Examples
35
+ --------
36
+ ```python
37
+ zscore_x = zscore_normalize(X_train, target="x") # normalize input
38
+ zscore_y = zscore_normalize(Y_train, target="y") # normalize output
39
+ tf = compose(zscore_x, zscore_y) # both, independent stats
40
+ ```
41
+ """
42
+ if target not in ("x", "y", "both"):
43
+ raise ValueError(f"target must be 'x', 'y', or 'both', got {target!r}")
44
+
45
+ flat = data.reshape(-1, data.shape[-1]) # (N*T, D) or (T, D)
46
+ mean = flat.mean(axis=0).astype(np.float32) # (D,)
47
+ std = np.clip(flat.std(axis=0).astype(np.float32), 1e-8, None) # (D,)
48
+
49
+ def normalize(a: np.ndarray) -> np.ndarray:
50
+ return (a - mean) / std
51
+
52
+ if target == "x":
53
+
54
+ def transform(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
55
+ return (normalize(x), y)
56
+ elif target == "y":
57
+
58
+ def transform(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
59
+ return (x, normalize(y))
60
+ else:
61
+
62
+ def transform(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
63
+ return (normalize(x), normalize(y))
64
+
65
+ return transform
66
+
67
+
68
+ def gaussian_noise(
69
+ sigma: float = 0.1,
70
+ rng: np.random.Generator | None = None,
71
+ ) -> Transform:
72
+ """Return a `Transform` that adds Gaussian noise to the input (for data augmentation).
73
+
74
+ Parameters
75
+ ----------
76
+ `sigma` : `float`, default `0.1`
77
+ Standard deviation of the noise.
78
+ `rng` : `np.random.Generator` or `None`, default `None`
79
+ Random number generator for reproducibility.
80
+ If `None`, a new unseeded generator is created.
81
+
82
+ Returns
83
+ -------
84
+ :type: `Transform`
85
+ `(x, y)` -> `(x + noise, y)`
86
+ """
87
+ rng = np.random.default_rng(rng)
88
+
89
+ def transform(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
90
+ noise = rng.standard_normal(x.shape).astype(x.dtype) * sigma
91
+ return (x + noise, y)
92
+
93
+ return transform
94
+
95
+
96
+ def compose(*transforms: Transform) -> Transform:
97
+ """Return a `Transform` that applies multiple transforms in left-to-right order.
98
+
99
+ Examples
100
+ --------
101
+ ```python
102
+ transform = compose(zscore_normalize(X_train), gaussian_noise(0.05))
103
+ # zscore is applied first, then noise
104
+ ```
105
+ """
106
+
107
+ def transform(x: np.ndarray, y: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
108
+ return reduce(lambda xy, fn: fn(*xy), transforms, (x, y))
109
+
110
+ return transform
@@ -0,0 +1,16 @@
1
+ from edmkit.search.energy.energy import Contexts, Energies, Energy
2
+ from edmkit.search.energy.folds import folds
3
+ from edmkit.search.energy.holdout import holdout
4
+ from edmkit.search.energy.loo import loo
5
+ from edmkit.search.energy.weight import WeightFunc, softmax
6
+
7
+ __all__ = [
8
+ "Contexts",
9
+ "Energies",
10
+ "Energy",
11
+ "WeightFunc",
12
+ "folds",
13
+ "holdout",
14
+ "loo",
15
+ "softmax",
16
+ ]
@@ -0,0 +1,28 @@
1
+ from collections.abc import Callable
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+
7
+ from edmkit.search.state import States
8
+
9
+ type Energies = npt.NDArray[np.float64]
10
+ """
11
+ Energies is a 1D array of energy values, one for each state.
12
+ """
13
+ type Contexts = npt.NDArray[np.float64]
14
+ """
15
+ Contexts is an opaque ndarray that can be used to store any additional information needed for the next energy computation.
16
+ """
17
+
18
+
19
+ @dataclass(frozen=True)
20
+ class Energy:
21
+ initial: Callable[[], npt.NDArray[np.float64]]
22
+ """
23
+ initial returns the initial Contexts entry to feed into the first step call.
24
+ """
25
+ step: Callable[
26
+ [States, Contexts],
27
+ tuple[Energies, Contexts],
28
+ ]
@@ -0,0 +1,64 @@
1
+ from collections.abc import Sequence
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+ from edmkit.metrics import MetricFunc
6
+ from edmkit.splits import Fold
7
+ from edmkit.types import PredictFunc
8
+
9
+ from edmkit.search import dataset
10
+ from edmkit.search.state import States
11
+
12
+ from .energy import Contexts, Energies, Energy
13
+ from .weight import WeightFunc
14
+
15
+ BATCH_SIZE = 10000
16
+
17
+
18
+ def folds(
19
+ *,
20
+ data: dataset.Dataset,
21
+ folds: Sequence[Fold],
22
+ predict: PredictFunc,
23
+ metric: MetricFunc,
24
+ weight: WeightFunc,
25
+ ) -> Energy:
26
+ if len(folds) == 0:
27
+ raise ValueError("folds must be non-empty")
28
+
29
+ n_folds = len(folds)
30
+ subsets = [
31
+ (dataset.Subset(data, f.train), dataset.Subset(data, f.validation))
32
+ for f in folds
33
+ ]
34
+
35
+ def initial() -> npt.NDArray[np.float64]:
36
+ return np.zeros((1, n_folds), dtype=np.float64)
37
+
38
+ def step(
39
+ states: States,
40
+ contexts: Contexts,
41
+ ) -> tuple[Energies, Contexts]:
42
+ n = states.shape[0]
43
+ metrics = np.empty((n, n_folds), dtype=np.float64)
44
+ for start in range(0, n, BATCH_SIZE):
45
+ end = min(start + BATCH_SIZE, n)
46
+ size = end - start
47
+ idx = states[start:end].astype(np.intp)
48
+ for i, (train, validation) in enumerate(subsets):
49
+ X = np.ascontiguousarray(train.X[:, idx].transpose(1, 0, 2))
50
+ Y = np.broadcast_to(train.Y, (size, *train.Y.shape))
51
+ Q = np.ascontiguousarray(validation.X[:, idx].transpose(1, 0, 2))
52
+ metrics[start:end, i] = metric(
53
+ predict(X, Y, Q),
54
+ np.broadcast_to(validation.Y, (size, *validation.Y.shape)),
55
+ )
56
+
57
+ energies = np.array(
58
+ [weight(contexts[i]) @ (metrics[i] - contexts[i]) for i in range(n)],
59
+ dtype=np.float64,
60
+ )
61
+
62
+ return energies, metrics
63
+
64
+ return Energy(initial=initial, step=step)
@@ -0,0 +1,47 @@
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+ from edmkit.metrics import MetricFunc
4
+ from edmkit.splits import Fold
5
+ from edmkit.types import PredictFunc
6
+
7
+ from edmkit.search import dataset
8
+ from edmkit.search.state import States
9
+
10
+ from .energy import Contexts, Energies, Energy
11
+
12
+ BATCH_SIZE = 10000
13
+
14
+
15
+ def holdout(
16
+ *,
17
+ data: dataset.Dataset,
18
+ fold: Fold,
19
+ predict: PredictFunc,
20
+ metric: MetricFunc,
21
+ ) -> Energy:
22
+ train = dataset.Subset(data, fold.train)
23
+ validation = dataset.Subset(data, fold.validation)
24
+
25
+ def initial() -> npt.NDArray[np.float64]:
26
+ return np.empty((1, 0), dtype=np.float64)
27
+
28
+ def step(
29
+ states: States,
30
+ contexts: Contexts, # not used
31
+ ) -> tuple[Energies, Contexts]:
32
+ n = states.shape[0]
33
+ energies = np.empty(n, dtype=np.float64)
34
+ for start in range(0, n, BATCH_SIZE):
35
+ end = min(start + BATCH_SIZE, n)
36
+ size = end - start
37
+ idx = states[start:end].astype(np.intp)
38
+ X = np.ascontiguousarray(train.X[:, idx].transpose(1, 0, 2))
39
+ Y = np.broadcast_to(train.Y, (size, *train.Y.shape))
40
+ Q = np.ascontiguousarray(validation.X[:, idx].transpose(1, 0, 2))
41
+ energies[start:end] = metric(
42
+ predict(X, Y, Q),
43
+ np.broadcast_to(validation.Y, (size, *validation.Y.shape)),
44
+ )
45
+ return energies, np.empty((n, 0), dtype=np.float64)
46
+
47
+ return Energy(initial=initial, step=step)
@@ -0,0 +1,46 @@
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+ from edmkit.metrics import MetricFunc
4
+
5
+ from edmkit import simplex_projection
6
+ from edmkit.search import dataset
7
+ from edmkit.search.state import States
8
+
9
+ from .energy import Contexts, Energies, Energy
10
+
11
+ BATCH_SIZE = 10000
12
+
13
+
14
+ def loo(
15
+ *,
16
+ data: dataset.Dataset,
17
+ metric: MetricFunc,
18
+ theiler_window: int = 0,
19
+ max_batch_size: int | None = None,
20
+ ) -> Energy:
21
+ if theiler_window < 0:
22
+ raise ValueError("theiler_window must be non-negative")
23
+
24
+ def initial() -> npt.NDArray[np.float64]:
25
+ return np.empty((1, 0), dtype=np.float64)
26
+
27
+ def step(
28
+ states: States,
29
+ contexts: Contexts,
30
+ ) -> tuple[Energies, Contexts]:
31
+ del contexts
32
+ n = states.shape[0]
33
+ energies = np.empty(n, dtype=np.float64)
34
+ for start in range(0, n, BATCH_SIZE):
35
+ end = min(start + BATCH_SIZE, n)
36
+ size = end - start
37
+ idx = states[start:end].astype(np.intp)
38
+ X = np.ascontiguousarray(data.X[:, idx].transpose(1, 0, 2))
39
+ Y = np.broadcast_to(data.Y, (size, *data.Y.shape))
40
+ energies[start:end] = metric(
41
+ simplex_projection.loo(X, Y, theiler_window=theiler_window),
42
+ Y,
43
+ )
44
+ return energies, np.empty((n, 0), dtype=np.float64)
45
+
46
+ return Energy(initial=initial, step=step)
@@ -0,0 +1,30 @@
1
+ from collections.abc import Callable
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+
6
+ type WeightFunc = Callable[[npt.NDArray[np.float64]], npt.NDArray[np.float64]]
7
+
8
+
9
+ def normalized(values: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
10
+ total = values.sum()
11
+ if not np.isfinite(total) or total <= 0:
12
+ raise ValueError("weights must have a positive finite sum")
13
+ return values / total
14
+
15
+
16
+ def softmax(temperature: float = 1.0) -> WeightFunc:
17
+ if temperature <= 0:
18
+ raise ValueError(f"temperature must be positive, got {temperature}")
19
+
20
+ def weight(values: npt.NDArray[np.float64]) -> npt.NDArray[np.float64]:
21
+ state = np.asarray(values, dtype=np.float64)
22
+ if state.ndim != 1:
23
+ raise ValueError(f"weight state must be 1D, got {state.ndim}D")
24
+ if state.size == 0:
25
+ return np.empty(0, dtype=np.float64)
26
+ logits = state / temperature
27
+ logits = logits - np.max(logits)
28
+ return normalized(np.exp(logits))
29
+
30
+ return weight
@@ -0,0 +1,4 @@
1
+ from edmkit.search.neighborhood.forward import forward
2
+ from edmkit.search.neighborhood.neighborhood import Neighborhood
3
+
4
+ __all__ = ["Neighborhood", "forward"]
@@ -0,0 +1,45 @@
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+
4
+ from edmkit.search.state import States
5
+
6
+ from .neighborhood import Neighborhood
7
+
8
+
9
+ def forward(n: int) -> Neighborhood:
10
+ """Forward search over `n` candidate indices.
11
+
12
+ Each parent state of length `d` (assumed to hold unique indices in `[0, n)`)
13
+ expands into exactly `n - d` children — one per index not yet selected. Within
14
+ a parent, children are emitted in a per-row random order; across parents, the
15
+ parent order is preserved.
16
+ """
17
+ if n < 0:
18
+ raise ValueError(f"n must be non-negative, got {n}")
19
+
20
+ def expand(
21
+ states: States, rng: np.random.Generator
22
+ ) -> tuple[States, npt.NDArray[np.int64]]:
23
+ N, d = states.shape
24
+ per_parent = n - d
25
+ if N == 0 or per_parent == 0:
26
+ return (
27
+ np.empty((0, d + 1), dtype=np.int64),
28
+ np.empty(0, dtype=np.int64),
29
+ )
30
+
31
+ # Per-row random permutation of [0, n).
32
+ perms = rng.permuted(np.tile(np.arange(n, dtype=np.int64), (N, 1)), axis=1)
33
+ # mask[i, j] = True iff j is not already in states[i].
34
+ mask = np.ones((N, n), dtype=bool)
35
+ mask[np.arange(N)[:, None], states] = False
36
+ # Each row of `perms` has exactly `per_parent` survivors (states are unique).
37
+ survivors = np.take_along_axis(mask, perms, axis=1)
38
+ chosen = perms[survivors].reshape(N, per_parent)
39
+
40
+ parents = np.repeat(np.arange(N, dtype=np.int64), per_parent)
41
+ prefix = np.repeat(states, per_parent, axis=0)
42
+ children = np.concatenate([prefix, chosen.reshape(-1, 1)], axis=1)
43
+ return children, parents
44
+
45
+ return expand
@@ -0,0 +1,23 @@
1
+ from collections.abc import Callable
2
+
3
+ import numpy as np
4
+ import numpy.typing as npt
5
+
6
+ from edmkit.search.state import States
7
+
8
+ type Neighborhood = Callable[
9
+ [States, np.random.Generator],
10
+ tuple[States, npt.NDArray[np.int64]],
11
+ ]
12
+ """
13
+ A Neighborhood expands a batch of N parent states into M children.
14
+
15
+ Given parents of shape (N, d), it returns:
16
+ * `children`: shape (M, d') — the next-step states
17
+ * `parents`: shape (M,) — `parents[i] in [0, N)` points to the parent row that
18
+ produced `children[i]`
19
+
20
+ `parents` lets callers (e.g. beam search) replicate parent-side data such as the
21
+ energy context in lockstep with the children, without the neighborhood needing to
22
+ know about that data.
23
+ """
@@ -0,0 +1,3 @@
1
+ from edmkit.search.state.states import States, initial
2
+
3
+ __all__ = ["States", "initial"]
@@ -0,0 +1,13 @@
1
+ import numpy as np
2
+ import numpy.typing as npt
3
+
4
+ type States = npt.NDArray[np.int64]
5
+ """
6
+ States is a 2D array of shape (N, d): N states, each holding d indices into the original dataset.
7
+ A single step's batch always contains states of equal length d.
8
+ """
9
+
10
+
11
+ def initial() -> States:
12
+ """Initial states batch: a single empty state, shape (1, 0)."""
13
+ return np.empty((1, 0), dtype=np.int64)
@@ -0,0 +1,6 @@
1
+ from edmkit.search.strategy.beam import beam
2
+ from edmkit.search.strategy.frontier import Frontier, Step
3
+ from edmkit.search.strategy.greedy import greedy
4
+ from edmkit.search.strategy.run import run
5
+
6
+ __all__ = ["Frontier", "Step", "beam", "greedy", "run"]
@@ -0,0 +1,37 @@
1
+ import numpy as np
2
+
3
+ from edmkit.search.energy import Energy
4
+ from edmkit.search.neighborhood import Neighborhood
5
+
6
+ from .frontier import Frontier, Step
7
+
8
+
9
+ def beam(
10
+ E: Energy,
11
+ N: Neighborhood,
12
+ *,
13
+ width: int,
14
+ cutoff: float = float("inf"),
15
+ ) -> Step:
16
+ if width < 1:
17
+ raise ValueError(f"width must be >= 1, got {width}")
18
+
19
+ def step(frontier: Frontier, rng: np.random.Generator) -> Frontier:
20
+ children, parents = N(frontier.states, rng)
21
+ if children.shape[0] == 0:
22
+ return Frontier(
23
+ states=children,
24
+ contexts=frontier.contexts[:0],
25
+ energies=np.empty(0, dtype=np.float64),
26
+ )
27
+
28
+ energies, contexts = E.step(children, frontier.contexts[parents])
29
+ kept = np.flatnonzero(energies <= cutoff)
30
+ order = kept[np.argsort(energies[kept], kind="stable")[:width]]
31
+ return Frontier(
32
+ states=children[order],
33
+ contexts=contexts[order],
34
+ energies=energies[order],
35
+ )
36
+
37
+ return step
@@ -0,0 +1,23 @@
1
+ from collections.abc import Callable
2
+ from dataclasses import dataclass
3
+
4
+ import numpy as np
5
+
6
+ from edmkit.search.energy import Contexts, Energies
7
+ from edmkit.search.state import States
8
+
9
+
10
+ @dataclass(frozen=True)
11
+ class Frontier:
12
+ states: States
13
+ contexts: Contexts
14
+ energies: Energies
15
+
16
+ def __len__(self) -> int:
17
+ return self.states.shape[0]
18
+
19
+
20
+ type Step = Callable[
21
+ [Frontier, np.random.Generator],
22
+ Frontier,
23
+ ]
@@ -0,0 +1,13 @@
1
+ from edmkit.search.energy import Energy
2
+ from edmkit.search.neighborhood import Neighborhood
3
+ from edmkit.search.strategy.beam import beam
4
+ from edmkit.search.strategy.frontier import Step
5
+
6
+
7
+ def greedy(
8
+ E: Energy,
9
+ N: Neighborhood,
10
+ *,
11
+ cutoff: float = float("inf"),
12
+ ) -> Step:
13
+ return beam(E, N, width=1, cutoff=cutoff)
@@ -0,0 +1,28 @@
1
+ from collections.abc import Iterator
2
+
3
+ import numpy as np
4
+
5
+ from .frontier import Frontier, Step
6
+
7
+
8
+ def run(
9
+ initial: Frontier,
10
+ step: Step,
11
+ *,
12
+ max_steps: int,
13
+ rng: np.random.Generator,
14
+ ) -> Iterator[Frontier]:
15
+ if max_steps < 0:
16
+ raise ValueError(f"max_steps must be non-negative, got {max_steps}")
17
+
18
+ frontier = initial
19
+ for _ in range(max_steps):
20
+ frontier = step(frontier, rng)
21
+ if not frontier:
22
+ return
23
+ i = int(np.argmin(frontier.energies))
24
+ yield Frontier(
25
+ states=frontier.states[i : i + 1],
26
+ contexts=frontier.contexts[i : i + 1],
27
+ energies=frontier.energies[i : i + 1],
28
+ )