ennbo 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- ennbo-0.0.4/.cursorrules +54 -0
- ennbo-0.0.4/.cursorrules~ +51 -0
- ennbo-0.0.4/.pre-commit-config.yaml +34 -0
- ennbo-0.0.4/LICENSE +21 -0
- ennbo-0.0.4/PKG-INFO +52 -0
- ennbo-0.0.4/README.md +7 -0
- ennbo-0.0.4/admin/find_forgotten_py.sh +43 -0
- ennbo-0.0.4/examples/demo_enn.ipynb +133 -0
- ennbo-0.0.4/examples/demo_turbo_enn.ipynb +223 -0
- ennbo-0.0.4/pyproject.toml +39 -0
- ennbo-0.0.4/requirements.md +15 -0
- ennbo-0.0.4/requirements.txt~ +6 -0
- ennbo-0.0.4/src/enn/__init__.py +10 -0
- ennbo-0.0.4/src/enn/core.py +158 -0
- ennbo-0.0.4/src/enn/enn_normal.py +27 -0
- ennbo-0.0.4/src/enn/enn_params.py +9 -0
- ennbo-0.0.4/src/enn/fit.py +122 -0
- ennbo-0.0.4/src/enn/proposal.py +147 -0
- ennbo-0.0.4/src/enn/trust_region_state.py +99 -0
- ennbo-0.0.4/src/enn/turbo.py +11 -0
- ennbo-0.0.4/src/enn/turbo_gp.py +48 -0
- ennbo-0.0.4/src/enn/turbo_mode.py +10 -0
- ennbo-0.0.4/src/enn/turbo_optimizer.py +306 -0
- ennbo-0.0.4/src/enn/turbo_utils.py +226 -0
- ennbo-0.0.4/style.md +75 -0
- ennbo-0.0.4/tests/conftest.py +43 -0
- ennbo-0.0.4/tests/test_enn_core.py +149 -0
- ennbo-0.0.4/tests/test_enn_fit.py +39 -0
- ennbo-0.0.4/tests/test_turbo.py +753 -0
ennbo-0.0.4/.cursorrules
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
1
|
+
# FIRST RULE OF CURSOR - MANDATORY: READ THIS FIRST
|
|
2
|
+
MANDATORY FIRST ACTION: When the user sends their first request in a conversation, you MUST immediately read style.md using the read_file tool BEFORE performing any other action (no searches, no code reading, no tool calls of any kind). This is the absolute first step - do not answer the question, do not search, do not read other files. Read style.md first, then proceed with the user's request.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
--
|
|
6
|
+
|
|
7
|
+
### Cursor Rule: Claims vs Hypotheses
|
|
8
|
+
|
|
9
|
+
- Label uncertain reasoning as Hypothesis; only use Claim with explicit evidence.
|
|
10
|
+
- Claims must cite evidence (code refs, logs, metrics). Otherwise, downgrade to Hypothesis.
|
|
11
|
+
- For each Hypothesis, include:
|
|
12
|
+
- Hypothesis: concise, falsifiable statement.
|
|
13
|
+
- Predictions: measurable outcomes if true.
|
|
14
|
+
- Test: minimal experiment (setup, variables, metrics, pass/fail).
|
|
15
|
+
- Confounders: likely alternatives and controls.
|
|
16
|
+
- Language:
|
|
17
|
+
- Hypothesis: “suggests”, “may”, “indicates”.
|
|
18
|
+
- Claim (with evidence): “shows”, “demonstrates”, “causes”.
|
|
19
|
+
- Trigger this rule when explaining performance differences, algorithm behavior, or proposing changes.
|
|
20
|
+
- Whenever you see an error happen more two or more times in a row, stop, state the problem clearly, then hypothesize the cause. Then test to determine the cause. Once the cause is determined, fix the problem.
|
|
21
|
+
|
|
22
|
+
--
|
|
23
|
+
|
|
24
|
+
When running python code, you'll need to set PYTHONPATH
|
|
25
|
+
|
|
26
|
+
--
|
|
27
|
+
|
|
28
|
+
When your write code, always make sure `pytest -sv tests` and `ruff check` pass.
|
|
29
|
+
Iterate until they do.
|
|
30
|
+
|
|
31
|
+
--
|
|
32
|
+
|
|
33
|
+
# Debugging process rule
|
|
34
|
+
|
|
35
|
+
When debugging: run the broken code first to observe the actual failure, write a test that reproduces that exact failure, then fix it. Don't guess—observe empirically. Replicate the exact code path that fails in the test, not a simplified version. The broken code is the source of truth; observe it directly before fixing.
|
|
36
|
+
--
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
Acronyms we can use to make communication more efficient:
|
|
40
|
+
DRY: Don't Repeat Yourself. The user would like you to look for opportunities to factor out common code.
|
|
41
|
+
DCC: Don't Change Code. Repond to the user's query, but don't change *any* code.
|
|
42
|
+
DCT: Don't Change Tests. Repond to the user's query, but don't change any tests. You may change non-test code, though.
|
|
43
|
+
OCT: *Only* Change Tests. Repond to the user's query, but don't change any non-test code. You may change test code, though.
|
|
44
|
+
RR: Restate my request clearly. Lmk if you have any questions. DCC
|
|
45
|
+
|
|
46
|
+
--
|
|
47
|
+
|
|
48
|
+
# Cursor Roles
|
|
49
|
+
The Tester: The Tester writes unit tests but does not implement the code being tested. Tests for unimplemented code should be failing tests.
|
|
50
|
+
The Implementor: The Implementor writes code but does not write or change tests.
|
|
51
|
+
|
|
52
|
+
--
|
|
53
|
+
|
|
54
|
+
After every 5th request, `cat .cursorrules`
|
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
# FIRST RULE OF CURSOR - MANDATORY: READ THIS FIRST
|
|
2
|
+
MANDATORY FIRST ACTION: When the user sends their first request in a conversation, you MUST immediately read style.md using the read_file tool BEFORE performing any other action (no searches, no code reading, no tool calls of any kind). This is the absolute first step - do not answer the question, do not search, do not read other files. Read style.md first, then proceed with the user's request.
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
--
|
|
6
|
+
|
|
7
|
+
### Cursor Rule: Claims vs Hypotheses
|
|
8
|
+
|
|
9
|
+
- Label uncertain reasoning as Hypothesis; only use Claim with explicit evidence.
|
|
10
|
+
- Claims must cite evidence (code refs, logs, metrics). Otherwise, downgrade to Hypothesis.
|
|
11
|
+
- For each Hypothesis, include:
|
|
12
|
+
- Hypothesis: concise, falsifiable statement.
|
|
13
|
+
- Predictions: measurable outcomes if true.
|
|
14
|
+
- Test: minimal experiment (setup, variables, metrics, pass/fail).
|
|
15
|
+
- Confounders: likely alternatives and controls.
|
|
16
|
+
- Language:
|
|
17
|
+
- Hypothesis: “suggests”, “may”, “indicates”.
|
|
18
|
+
- Claim (with evidence): “shows”, “demonstrates”, “causes”.
|
|
19
|
+
- Trigger this rule when explaining performance differences, algorithm behavior, or proposing changes.
|
|
20
|
+
- Whenever you see an error happen more two or more times in a row, stop, state the problem clearly, then hypothesize the cause. Then test to determine the cause. Once the cause is determined, fix the problem.
|
|
21
|
+
|
|
22
|
+
--
|
|
23
|
+
|
|
24
|
+
When running python code, you'll need to set PYTHONPATH
|
|
25
|
+
|
|
26
|
+
--
|
|
27
|
+
|
|
28
|
+
When your write code, always make sure `pytest -sv tests` and `ruff check` pass.
|
|
29
|
+
Iterate until they do.
|
|
30
|
+
|
|
31
|
+
--
|
|
32
|
+
|
|
33
|
+
# Debugging process rule
|
|
34
|
+
|
|
35
|
+
When debugging: run the broken code first to observe the actual failure, write a test that reproduces that exact failure, then fix it. Don't guess—observe empirically. Replicate the exact code path that fails in the test, not a simplified version. The broken code is the source of truth; observe it directly before fixing.
|
|
36
|
+
--
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
Acronyms we can use to make communication more efficient:
|
|
40
|
+
DRY: Don't Repeat Yourself. The user would like you to look for opportunities to factor out common code.
|
|
41
|
+
DCC: Don't Change Code. Repond to the user's query, but don't change *any* code.
|
|
42
|
+
DCT: Don't Change Tests. Repond to the user's query, but don't change any tests. You may change non-test code, though.
|
|
43
|
+
OCT: *Only* Change Tests. Repond to the user's query, but don't change any non-test code. You may change test code, though.
|
|
44
|
+
RR: Restate my request clearly. Lmk if you have any questions. DCC
|
|
45
|
+
|
|
46
|
+
--
|
|
47
|
+
|
|
48
|
+
# Cursor Roles
|
|
49
|
+
The Tester: The Tester writes unit tests but does not implement the code being tested. Tests for unimplemented code should be failing tests.
|
|
50
|
+
The Implementor: The Implementor writes code but does not write or change tests.
|
|
51
|
+
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
repos:
|
|
2
|
+
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
3
|
+
rev: v4.6.0
|
|
4
|
+
hooks:
|
|
5
|
+
- id: trailing-whitespace
|
|
6
|
+
- id: end-of-file-fixer
|
|
7
|
+
- id: check-yaml
|
|
8
|
+
- id: check-json
|
|
9
|
+
- id: check-added-large-files
|
|
10
|
+
args: ["--maxkb=500"]
|
|
11
|
+
|
|
12
|
+
# Forbid committing binary files and run custom checks
|
|
13
|
+
- repo: local
|
|
14
|
+
hooks:
|
|
15
|
+
- id: forbid-binary
|
|
16
|
+
name: Forbid committing binary files
|
|
17
|
+
entry: bash -c 'echo \"Binary files are not allowed in this repository.\"; exit 1'
|
|
18
|
+
language: system
|
|
19
|
+
types: [binary]
|
|
20
|
+
|
|
21
|
+
- id: find-forgotten-py
|
|
22
|
+
name: Find forgotten Python files via git status
|
|
23
|
+
entry: admin/find_forgotten_py.sh
|
|
24
|
+
language: system
|
|
25
|
+
pass_filenames: false
|
|
26
|
+
always_run: true
|
|
27
|
+
|
|
28
|
+
# Linting, formatting, and import sorting via Ruff
|
|
29
|
+
- repo: https://github.com/astral-sh/ruff-pre-commit
|
|
30
|
+
rev: v0.7.0
|
|
31
|
+
hooks:
|
|
32
|
+
- id: ruff
|
|
33
|
+
args: ["--fix"]
|
|
34
|
+
- id: ruff-format
|
ennbo-0.0.4/LICENSE
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 yubo research
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
ennbo-0.0.4/PKG-INFO
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: ennbo
|
|
3
|
+
Version: 0.0.4
|
|
4
|
+
Summary: Epistemic Nearest Neighbors
|
|
5
|
+
Project-URL: Homepage, https://github.com/yubo-research/enn
|
|
6
|
+
Project-URL: Source, https://github.com/yubo-research/enn
|
|
7
|
+
Author-email: YUBO Lab <david.sweet@yu.edu>
|
|
8
|
+
License: MIT License
|
|
9
|
+
|
|
10
|
+
Copyright (c) 2025 yubo research
|
|
11
|
+
|
|
12
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
13
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
14
|
+
in the Software without restriction, including without limitation the rights
|
|
15
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
16
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
17
|
+
furnished to do so, subject to the following conditions:
|
|
18
|
+
|
|
19
|
+
The above copyright notice and this permission notice shall be included in all
|
|
20
|
+
copies or substantial portions of the Software.
|
|
21
|
+
|
|
22
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
23
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
24
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
25
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
26
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
27
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
28
|
+
SOFTWARE.
|
|
29
|
+
License-File: LICENSE
|
|
30
|
+
Classifier: Intended Audience :: Science/Research
|
|
31
|
+
Classifier: License :: OSI Approved :: MIT License
|
|
32
|
+
Classifier: Programming Language :: Python :: 3
|
|
33
|
+
Classifier: Programming Language :: Python :: 3.11
|
|
34
|
+
Classifier: Programming Language :: Python :: 3.12
|
|
35
|
+
Classifier: Topic :: Scientific/Engineering
|
|
36
|
+
Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
|
|
37
|
+
Classifier: Topic :: Scientific/Engineering :: Mathematics
|
|
38
|
+
Requires-Python: >=3.11
|
|
39
|
+
Requires-Dist: faiss==1.9.0
|
|
40
|
+
Requires-Dist: gpytorch==1.13
|
|
41
|
+
Requires-Dist: numpy==1.26.4
|
|
42
|
+
Requires-Dist: scipy==1.15.3
|
|
43
|
+
Requires-Dist: torch==2.5.1
|
|
44
|
+
Description-Content-Type: text/markdown
|
|
45
|
+
|
|
46
|
+
# enn
|
|
47
|
+
Epistemic Nearest Neighbors
|
|
48
|
+
|
|
49
|
+
- ENN model
|
|
50
|
+
- TuRBO-ENN optimizer
|
|
51
|
+
|
|
52
|
+
|
ennbo-0.0.4/README.md
ADDED
|
@@ -0,0 +1,43 @@
|
|
|
1
|
+
#!/usr/bin/env bash
|
|
2
|
+
set -euo pipefail
|
|
3
|
+
|
|
4
|
+
# Run from repo root so paths from git status are correct
|
|
5
|
+
REPO_ROOT="$(cd "$(dirname "${BASH_SOURCE[0]}")/.." && pwd)"
|
|
6
|
+
cd "$REPO_ROOT"
|
|
7
|
+
|
|
8
|
+
echo "Scanning for untracked or modified Python files (git status)..."
|
|
9
|
+
|
|
10
|
+
status_output="$(git status --porcelain --untracked-files=all)"
|
|
11
|
+
|
|
12
|
+
if [[ -z "$status_output" ]]; then
|
|
13
|
+
echo "Working tree clean. No forgotten .py files."
|
|
14
|
+
exit 0
|
|
15
|
+
fi
|
|
16
|
+
|
|
17
|
+
# Select only:
|
|
18
|
+
# - untracked files ("??")
|
|
19
|
+
# - or files with unstaged changes (second status column non-space)
|
|
20
|
+
# We parse the raw porcelain line as:
|
|
21
|
+
# XY <space> PATH
|
|
22
|
+
# where X is index status, Y is work-tree status.
|
|
23
|
+
py_files="$(
|
|
24
|
+
printf '%s\n' "$status_output" |
|
|
25
|
+
awk '{
|
|
26
|
+
line = $0
|
|
27
|
+
x = substr(line, 1, 1)
|
|
28
|
+
y = substr(line, 2, 1)
|
|
29
|
+
path = substr(line, 4)
|
|
30
|
+
if ((x == "?" && y == "?") || y != " ")
|
|
31
|
+
print path
|
|
32
|
+
}' |
|
|
33
|
+
grep -E '\.py$' || true
|
|
34
|
+
)"
|
|
35
|
+
|
|
36
|
+
if [[ -z "$py_files" ]]; then
|
|
37
|
+
echo "No forgotten .py files detected."
|
|
38
|
+
exit 0
|
|
39
|
+
fi
|
|
40
|
+
|
|
41
|
+
echo "Potential forgotten Python files:"
|
|
42
|
+
printf "%s\n" "$py_files"
|
|
43
|
+
exit 1
|
|
@@ -0,0 +1,133 @@
|
|
|
1
|
+
{
|
|
2
|
+
"cells": [
|
|
3
|
+
{
|
|
4
|
+
"cell_type": "markdown",
|
|
5
|
+
"id": "534cb992",
|
|
6
|
+
"metadata": {},
|
|
7
|
+
"source": [
|
|
8
|
+
"# Epistemic Nearest Neighbors (ENN)\n",
|
|
9
|
+
"\n",
|
|
10
|
+
"ENN is a non-parametric surrogate with $O(N)$ computation-time scaling, where $N$ is the number of observations in the data set. ENN can be used in Bayesian optimization as a scalable alternative to a GP (which scales as $O(N^2)$.)\n",
|
|
11
|
+
"\n",
|
|
12
|
+
"**Sweet, D., & Jadhav, S. A. (2025).** Taking the GP Out of the Loop. *arXiv preprint arXiv:2506.12818*. \n",
|
|
13
|
+
" https://arxiv.org/abs/2506.12818\n",
|
|
14
|
+
"\n",
|
|
15
|
+
" ---"
|
|
16
|
+
]
|
|
17
|
+
},
|
|
18
|
+
{
|
|
19
|
+
"cell_type": "code",
|
|
20
|
+
"execution_count": null,
|
|
21
|
+
"id": "8792c830",
|
|
22
|
+
"metadata": {},
|
|
23
|
+
"outputs": [],
|
|
24
|
+
"source": [
|
|
25
|
+
"import numpy as np\n",
|
|
26
|
+
"\n",
|
|
27
|
+
"from enn import EpistemicNearestNeighbors, enn_fit\n",
|
|
28
|
+
"\n",
|
|
29
|
+
"\n",
|
|
30
|
+
"def plot_enn_demo(ax, num_samples: int, k: int, noise: float, m: int = 1) -> None:\n",
|
|
31
|
+
" x = np.random.rand(num_samples)\n",
|
|
32
|
+
" eps = np.random.randn(num_samples)\n",
|
|
33
|
+
" y = np.sin(2 * m * np.pi * x) + noise * eps\n",
|
|
34
|
+
" yvar = (noise**2) * np.ones_like(y)\n",
|
|
35
|
+
" train_x = x[:, None]\n",
|
|
36
|
+
" train_y = y[:, None]\n",
|
|
37
|
+
" train_yvar = yvar[:, None]\n",
|
|
38
|
+
" model = EpistemicNearestNeighbors(train_x, train_y, train_yvar, hnsw_threshold=None)\n",
|
|
39
|
+
" rng = np.random.default_rng(0)\n",
|
|
40
|
+
" result = enn_fit(\n",
|
|
41
|
+
" model,\n",
|
|
42
|
+
" num_fit_candidates=30,\n",
|
|
43
|
+
" num_fit_samples=min(10, num_samples),\n",
|
|
44
|
+
" rng=rng,\n",
|
|
45
|
+
" )\n",
|
|
46
|
+
" print(result)\n",
|
|
47
|
+
" params = result\n",
|
|
48
|
+
" x_hat = np.linspace(0.0, 1.0, 30)\n",
|
|
49
|
+
" x_hat_2d = x_hat[:, None]\n",
|
|
50
|
+
" posterior = model.posterior(x_hat_2d, params=params, exclude_nearest=False)\n",
|
|
51
|
+
" mu = posterior.mu[:, 0]\n",
|
|
52
|
+
" se = posterior.se[:, 0]\n",
|
|
53
|
+
" marker_size = 3 if num_samples >= 100 else 15\n",
|
|
54
|
+
" ax.scatter(x, y, s=marker_size, color=\"black\", alpha=0.5)\n",
|
|
55
|
+
" ax.plot(x_hat, mu, linestyle=\"--\", color=\"tab:blue\", alpha=0.7)\n",
|
|
56
|
+
" ax.fill_between(x_hat, mu - 2 * se, mu + 2 * se, color=\"tab:blue\", alpha=0.2)\n",
|
|
57
|
+
" ax.set_ylim(-5, 5)\n",
|
|
58
|
+
" ax.set_title(f\"n={num_samples}, noise={noise}\")"
|
|
59
|
+
]
|
|
60
|
+
},
|
|
61
|
+
{
|
|
62
|
+
"cell_type": "code",
|
|
63
|
+
"execution_count": null,
|
|
64
|
+
"id": "992d16f9",
|
|
65
|
+
"metadata": {},
|
|
66
|
+
"outputs": [],
|
|
67
|
+
"source": [
|
|
68
|
+
"import matplotlib.pyplot as plt\n",
|
|
69
|
+
"\n",
|
|
70
|
+
"\n",
|
|
71
|
+
"k = 5\n",
|
|
72
|
+
"fig, axes = plt.subplots(2, 3, figsize=(9, 6), sharex=True, sharey=True)\n",
|
|
73
|
+
"num_samples_list = [5, 10]\n",
|
|
74
|
+
"noise_list = [0.0, 0.1, 0.3]\n",
|
|
75
|
+
"for row_idx, num_samples in enumerate(num_samples_list):\n",
|
|
76
|
+
" for col_idx, noise in enumerate(noise_list):\n",
|
|
77
|
+
" ax = axes[row_idx, col_idx]\n",
|
|
78
|
+
" np.random.seed(1)\n",
|
|
79
|
+
" plot_enn_demo(ax, num_samples=num_samples, k=k, noise=noise)\n",
|
|
80
|
+
"for ax in axes[-1, :]:\n",
|
|
81
|
+
" ax.set_xlabel(\"x\")\n",
|
|
82
|
+
"for ax in axes[:, 0]:\n",
|
|
83
|
+
" ax.set_ylabel(\"y\")\n",
|
|
84
|
+
"fig.tight_layout()"
|
|
85
|
+
]
|
|
86
|
+
},
|
|
87
|
+
{
|
|
88
|
+
"cell_type": "code",
|
|
89
|
+
"execution_count": null,
|
|
90
|
+
"id": "84968071",
|
|
91
|
+
"metadata": {},
|
|
92
|
+
"outputs": [],
|
|
93
|
+
"source": [
|
|
94
|
+
"import time\n",
|
|
95
|
+
"import matplotlib.pyplot as plt\n",
|
|
96
|
+
"\n",
|
|
97
|
+
"np.random.seed(1)\n",
|
|
98
|
+
"fig, ax = plt.subplots(figsize=(5, 3))\n",
|
|
99
|
+
"t_0 = time.time()\n",
|
|
100
|
+
"plot_enn_demo(ax, num_samples=1_000_000, k=5, noise=0.3, m=3)\n",
|
|
101
|
+
"t_1 = time.time()\n",
|
|
102
|
+
"print(f\"Time taken: {t_1 - t_0:.2f} seconds\")\n",
|
|
103
|
+
"ax.set_xlabel(\"x\")\n",
|
|
104
|
+
"ax.set_ylabel(\"y\")\n",
|
|
105
|
+
"fig.tight_layout()"
|
|
106
|
+
]
|
|
107
|
+
},
|
|
108
|
+
{
|
|
109
|
+
"cell_type": "code",
|
|
110
|
+
"execution_count": null,
|
|
111
|
+
"id": "7f4a8ae6",
|
|
112
|
+
"metadata": {},
|
|
113
|
+
"outputs": [],
|
|
114
|
+
"source": []
|
|
115
|
+
},
|
|
116
|
+
{
|
|
117
|
+
"cell_type": "code",
|
|
118
|
+
"execution_count": null,
|
|
119
|
+
"id": "b8f018dc",
|
|
120
|
+
"metadata": {},
|
|
121
|
+
"outputs": [],
|
|
122
|
+
"source": []
|
|
123
|
+
}
|
|
124
|
+
],
|
|
125
|
+
"metadata": {
|
|
126
|
+
"language_info": {
|
|
127
|
+
"name": "python",
|
|
128
|
+
"pygments_lexer": "ipython3"
|
|
129
|
+
}
|
|
130
|
+
},
|
|
131
|
+
"nbformat": 4,
|
|
132
|
+
"nbformat_minor": 5
|
|
133
|
+
}
|
|
@@ -0,0 +1,223 @@
|
|
|
1
|
+
{
|
|
2
|
+
"cells": [
|
|
3
|
+
{
|
|
4
|
+
"cell_type": "markdown",
|
|
5
|
+
"metadata": {},
|
|
6
|
+
"source": [
|
|
7
|
+
"# TuRBO-ENN\n",
|
|
8
|
+
"\n",
|
|
9
|
+
"This code implements TuRBO [1], a SOTA Bayesian optimization algorithm.\n",
|
|
10
|
+
"\n",
|
|
11
|
+
"The optimization class, `Turbo`, supports four modes of operation.\n",
|
|
12
|
+
"\n",
|
|
13
|
+
"**LHD_ONLY** \n",
|
|
14
|
+
"Generate a Latin Hypercube Design (LHD) for every batch of arms. This is included as a simple baseline.\n",
|
|
15
|
+
"\n",
|
|
16
|
+
"**TURBO_ZERO** \n",
|
|
17
|
+
"Initialze with LHD. Afterward, sample near the best-so-far x value, x_best. Samples are \"near\" x_best in two senses: (i) They are in a trust region, an adaptively-sized box around x_best, and (ii) They perturb only a small number of dimensions using RAASP sampling [2]. Other dimensions take the same value as in x_best. The num_arms proposals are chosen randomly from RAASP candidates inside the trust region.\n",
|
|
18
|
+
"\n",
|
|
19
|
+
"This is included to help differentiate the impact of the trust region from the impact of the surrogate. Notice (below) that the trust region has high impact.\n",
|
|
20
|
+
"\n",
|
|
21
|
+
"**TURBO_ONE** \n",
|
|
22
|
+
"This adds a GP surrogate to TURBO_ZERO. The num_arms proposals are chosen via Thompson sampling from RAASP candidates inside the trust region. Occasionally, the trust region adapter resets and (i) discards all observations, and (ii) begins anew with and LHD design.\n",
|
|
23
|
+
"\n",
|
|
24
|
+
"This is the standard SOTA method. It should match the TuRBO reference [implementation](https://github.com/uber-research/TuRBO). \n",
|
|
25
|
+
"\n",
|
|
26
|
+
"**TURBO_ENN** \n",
|
|
27
|
+
"This replaces the GP surrogate with a simpler, more scalable surrogate called Epistemic Nearest Neighbors (ENN). ENN's proposal time scales as $O(N)$ rather than the $O(N^2)$ of a GP surrogate. [3]\n",
|
|
28
|
+
"\n",
|
|
29
|
+
"\n",
|
|
30
|
+
"## References\n",
|
|
31
|
+
"\n",
|
|
32
|
+
"1. **Eriksson, D., Pearce, M., Gardner, J. R., Turner, R., & Poloczek, M. (2020).** Scalable Global Optimization via Local Bayesian Optimization. *Advances in Neural Information Processing Systems, 32*. \n",
|
|
33
|
+
" https://arxiv.org/abs/1910.01739\n",
|
|
34
|
+
"\n",
|
|
35
|
+
"2. **Rashidi, B., Johnstonbaugh, K., & Gao, C. (2024).** Cylindrical Thompson Sampling for High-Dimensional Bayesian Optimization. *Proceedings of The 27th International Conference on Artificial Intelligence and Statistics* (pp. 3502–3510). PMLR. \n",
|
|
36
|
+
" https://proceedings.mlr.press/v238/rashidi24a.html\n",
|
|
37
|
+
"\n",
|
|
38
|
+
"3. **Sweet, D., & Jadhav, S. A. (2025).** Taking the GP Out of the Loop. *arXiv preprint arXiv:2506.12818*. \n",
|
|
39
|
+
" https://arxiv.org/abs/2506.12818\n"
|
|
40
|
+
]
|
|
41
|
+
},
|
|
42
|
+
{
|
|
43
|
+
"cell_type": "markdown",
|
|
44
|
+
"metadata": {},
|
|
45
|
+
"source": [
|
|
46
|
+
"---"
|
|
47
|
+
]
|
|
48
|
+
},
|
|
49
|
+
{
|
|
50
|
+
"cell_type": "code",
|
|
51
|
+
"execution_count": null,
|
|
52
|
+
"metadata": {},
|
|
53
|
+
"outputs": [],
|
|
54
|
+
"source": [
|
|
55
|
+
"import numpy as np\n",
|
|
56
|
+
"\n",
|
|
57
|
+
"\n",
|
|
58
|
+
"class Ackley:\n",
|
|
59
|
+
" def __init__(self):\n",
|
|
60
|
+
" self.a = 20.0\n",
|
|
61
|
+
" self.b = 0.2\n",
|
|
62
|
+
" self.c = 2 * np.pi\n",
|
|
63
|
+
" self.bounds = [-32.768, 32.768]\n",
|
|
64
|
+
"\n",
|
|
65
|
+
" def __call__(self, x):\n",
|
|
66
|
+
" x = np.asarray(x, dtype=float)\n",
|
|
67
|
+
" if x.ndim == 1:\n",
|
|
68
|
+
" x = x[None, :]\n",
|
|
69
|
+
" x = x - 1\n",
|
|
70
|
+
" y = (\n",
|
|
71
|
+
" -self.a * np.exp(-self.b * np.sqrt((x**2).mean(axis=1)))\n",
|
|
72
|
+
" - np.exp(np.cos(self.c * x).mean(axis=1))\n",
|
|
73
|
+
" + self.a\n",
|
|
74
|
+
" + np.e\n",
|
|
75
|
+
" )\n",
|
|
76
|
+
" result = -y\n",
|
|
77
|
+
" return result if result.ndim > 0 else float(result)"
|
|
78
|
+
]
|
|
79
|
+
},
|
|
80
|
+
{
|
|
81
|
+
"cell_type": "code",
|
|
82
|
+
"execution_count": null,
|
|
83
|
+
"metadata": {},
|
|
84
|
+
"outputs": [],
|
|
85
|
+
"source": [
|
|
86
|
+
"import time\n",
|
|
87
|
+
"\n",
|
|
88
|
+
"from enn import TurboMode, Turbo\n",
|
|
89
|
+
"\n",
|
|
90
|
+
"\n",
|
|
91
|
+
"def run_optimization(turbo_mode: TurboMode):\n",
|
|
92
|
+
" num_dim = 100\n",
|
|
93
|
+
" num_iterations = 100\n",
|
|
94
|
+
" num_arms = 100\n",
|
|
95
|
+
"\n",
|
|
96
|
+
" objective = Ackley()\n",
|
|
97
|
+
" bounds = np.array([objective.bounds] * num_dim, dtype=float)\n",
|
|
98
|
+
"\n",
|
|
99
|
+
" rng = np.random.default_rng(42)\n",
|
|
100
|
+
" optimizer = Turbo(\n",
|
|
101
|
+
" bounds=bounds,\n",
|
|
102
|
+
" mode=turbo_mode,\n",
|
|
103
|
+
" num_arms=num_arms,\n",
|
|
104
|
+
" rng=rng,\n",
|
|
105
|
+
" k=10,\n",
|
|
106
|
+
" )\n",
|
|
107
|
+
"\n",
|
|
108
|
+
" best_values = []\n",
|
|
109
|
+
" proposal_times = []\n",
|
|
110
|
+
" best_y = -np.inf\n",
|
|
111
|
+
"\n",
|
|
112
|
+
" for iteration in range(num_iterations):\n",
|
|
113
|
+
" t_0 = time.time()\n",
|
|
114
|
+
" x_candidates = optimizer.ask(num_arms=num_arms)\n",
|
|
115
|
+
" t_1 = time.time()\n",
|
|
116
|
+
" proposal_times.append(t_1 - t_0)\n",
|
|
117
|
+
"\n",
|
|
118
|
+
" y_values = objective(x_candidates)\n",
|
|
119
|
+
"\n",
|
|
120
|
+
" optimizer.tell(x_candidates, y_values)\n",
|
|
121
|
+
"\n",
|
|
122
|
+
" current_best = float(np.max(y_values))\n",
|
|
123
|
+
" if current_best > best_y:\n",
|
|
124
|
+
" best_y = current_best\n",
|
|
125
|
+
" best_values.append(best_y)\n",
|
|
126
|
+
" if iteration % 10 == 0:\n",
|
|
127
|
+
" print(f\"Iteration {iteration} best value: {best_y}\")\n",
|
|
128
|
+
"\n",
|
|
129
|
+
" evals = num_arms * np.arange(len(best_values))\n",
|
|
130
|
+
" return best_values, proposal_times, evals"
|
|
131
|
+
]
|
|
132
|
+
},
|
|
133
|
+
{
|
|
134
|
+
"cell_type": "code",
|
|
135
|
+
"execution_count": null,
|
|
136
|
+
"metadata": {},
|
|
137
|
+
"outputs": [],
|
|
138
|
+
"source": [
|
|
139
|
+
"import matplotlib.pyplot as plt\n",
|
|
140
|
+
"\n",
|
|
141
|
+
"# TURBO_ONE is too slow for my patience right now.\n",
|
|
142
|
+
"# It's now an exercise for the reader.\n",
|
|
143
|
+
"RUN_TURBO_ONE = False\n",
|
|
144
|
+
"\n",
|
|
145
|
+
"best_values_zero, proposal_times_zero, evals_zero = run_optimization(\n",
|
|
146
|
+
" TurboMode.TURBO_ZERO\n",
|
|
147
|
+
")\n",
|
|
148
|
+
"if RUN_TURBO_ONE:\n",
|
|
149
|
+
" best_values_one, proposal_times_one, evals_one = run_optimization(\n",
|
|
150
|
+
" TurboMode.TURBO_ONE\n",
|
|
151
|
+
" )\n",
|
|
152
|
+
"best_values_enn, proposal_times_enn, evals_enn = run_optimization(TurboMode.TURBO_ENN)\n",
|
|
153
|
+
"best_values_lhd, proposal_times_lhd, evals_lhd = run_optimization(TurboMode.LHD_ONLY)"
|
|
154
|
+
]
|
|
155
|
+
},
|
|
156
|
+
{
|
|
157
|
+
"cell_type": "code",
|
|
158
|
+
"execution_count": null,
|
|
159
|
+
"metadata": {},
|
|
160
|
+
"outputs": [],
|
|
161
|
+
"source": [
|
|
162
|
+
"plt.figure(figsize=(10, 6))\n",
|
|
163
|
+
"plt.plot(evals_zero, best_values_zero, linewidth=2, label=\"TURBO_ZERO\")\n",
|
|
164
|
+
"plt.plot(evals_enn, best_values_enn, linewidth=2, label=\"TURBO_ENN\")\n",
|
|
165
|
+
"plt.plot(evals_lhd, best_values_lhd, linewidth=2, label=\"LHD_ONLY\")\n",
|
|
166
|
+
"if RUN_TURBO_ONE:\n",
|
|
167
|
+
" plt.plot(evals_one, best_values_one, linewidth=2, label=\"TURBO_ONE\")\n",
|
|
168
|
+
"plt.xlabel(\"Function Evaluations\")\n",
|
|
169
|
+
"plt.ylabel(\"Best Function Value\")\n",
|
|
170
|
+
"plt.title(\"Convergence Comparison: All Turbo Modes\")\n",
|
|
171
|
+
"plt.legend()\n",
|
|
172
|
+
"plt.grid(True, alpha=0.3)\n",
|
|
173
|
+
"plt.tight_layout()\n",
|
|
174
|
+
"plt.show()"
|
|
175
|
+
]
|
|
176
|
+
},
|
|
177
|
+
{
|
|
178
|
+
"cell_type": "code",
|
|
179
|
+
"execution_count": null,
|
|
180
|
+
"metadata": {},
|
|
181
|
+
"outputs": [],
|
|
182
|
+
"source": [
|
|
183
|
+
"plt.figure(figsize=(10, 6))\n",
|
|
184
|
+
"plt.plot(evals_zero, proposal_times_zero, linewidth=2, label=\"TURBO_ZERO\")\n",
|
|
185
|
+
"plt.plot(evals_enn, proposal_times_enn, linewidth=2, label=\"TURBO_ENN\")\n",
|
|
186
|
+
"plt.plot(evals_lhd, proposal_times_lhd, linewidth=2, label=\"LHD_ONLY\")\n",
|
|
187
|
+
"if RUN_TURBO_ONE:\n",
|
|
188
|
+
" plt.plot(evals_one, proposal_times_one, linewidth=2, label=\"TURBO_ONE\")\n",
|
|
189
|
+
"plt.xlabel(\"Function Evaluations\")\n",
|
|
190
|
+
"plt.ylabel(\"Proposal Time (seconds)\")\n",
|
|
191
|
+
"plt.title(\"Proposal Time vs Function Evaluations: All Turbo Modes\")\n",
|
|
192
|
+
"# c = plt.axis()\n",
|
|
193
|
+
"# plt.axis([c[0], c[1], 0, 5])\n",
|
|
194
|
+
"plt.legend()\n",
|
|
195
|
+
"plt.grid(True, alpha=0.3)\n",
|
|
196
|
+
"plt.tight_layout()\n",
|
|
197
|
+
"plt.show()"
|
|
198
|
+
]
|
|
199
|
+
},
|
|
200
|
+
{
|
|
201
|
+
"cell_type": "code",
|
|
202
|
+
"execution_count": null,
|
|
203
|
+
"metadata": {},
|
|
204
|
+
"outputs": [],
|
|
205
|
+
"source": []
|
|
206
|
+
},
|
|
207
|
+
{
|
|
208
|
+
"cell_type": "code",
|
|
209
|
+
"execution_count": null,
|
|
210
|
+
"metadata": {},
|
|
211
|
+
"outputs": [],
|
|
212
|
+
"source": []
|
|
213
|
+
}
|
|
214
|
+
],
|
|
215
|
+
"metadata": {
|
|
216
|
+
"language_info": {
|
|
217
|
+
"name": "python",
|
|
218
|
+
"pygments_lexer": "ipython3"
|
|
219
|
+
}
|
|
220
|
+
},
|
|
221
|
+
"nbformat": 4,
|
|
222
|
+
"nbformat_minor": 2
|
|
223
|
+
}
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
[build-system]
|
|
2
|
+
requires = ["hatchling"]
|
|
3
|
+
build-backend = "hatchling.build"
|
|
4
|
+
|
|
5
|
+
[project]
|
|
6
|
+
name = "ennbo"
|
|
7
|
+
version = "0.0.4"
|
|
8
|
+
description = "Epistemic Nearest Neighbors"
|
|
9
|
+
readme = "README.md"
|
|
10
|
+
requires-python = ">=3.11"
|
|
11
|
+
license = { file = "LICENSE" }
|
|
12
|
+
authors = [
|
|
13
|
+
{ name = "YUBO Lab", email = "david.sweet@yu.edu" },
|
|
14
|
+
]
|
|
15
|
+
dependencies = [
|
|
16
|
+
"numpy==1.26.4",
|
|
17
|
+
"torch==2.5.1",
|
|
18
|
+
"gpytorch==1.13",
|
|
19
|
+
"faiss==1.9.0",
|
|
20
|
+
"scipy==1.15.3",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
classifiers = [
|
|
24
|
+
"Programming Language :: Python :: 3",
|
|
25
|
+
"Programming Language :: Python :: 3.11",
|
|
26
|
+
"Programming Language :: Python :: 3.12",
|
|
27
|
+
"Topic :: Scientific/Engineering",
|
|
28
|
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
|
29
|
+
"Topic :: Scientific/Engineering :: Mathematics",
|
|
30
|
+
"Intended Audience :: Science/Research",
|
|
31
|
+
"License :: OSI Approved :: MIT License",
|
|
32
|
+
]
|
|
33
|
+
|
|
34
|
+
[project.urls]
|
|
35
|
+
Homepage = "https://github.com/yubo-research/enn"
|
|
36
|
+
Source = "https://github.com/yubo-research/enn"
|
|
37
|
+
|
|
38
|
+
[tool.hatch.build.targets.wheel]
|
|
39
|
+
packages = ["src/enn"]
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
|
|
2
|
+
YOU MAY NOT ALTER THIS DOCUMENT.
|
|
3
|
+
|
|
4
|
+
# REQUIREMENTS
|
|
5
|
+
- Follow ../bbo/turbo_m_ref/gp.py for the GP. Use GPyTorch and Adam.
|
|
6
|
+
- Use ExactGP.posterior().sample() for Thompson sampling. You need to take a joint sample for TS to work.
|
|
7
|
+
- No comments or docstrings.
|
|
8
|
+
- All code should take a user-defined rng. Take torch.Generator and/or np.random.generator objects.
|
|
9
|
+
- Only one class per file. MyClass -> my_class.py.
|
|
10
|
+
- The surrogate (GP or ENN) should be refit for *every* proposal (ask()).
|
|
11
|
+
- Use lazy importing: If at all possible, do not import at the module level. This helps keep imports fast. `if TYPE_CHECKING:` is your friend.
|
|
12
|
+
- Optimizer parameter defaults should follow ../bbo/turbo_ref/*.py
|
|
13
|
+
- There should be at least one unit test for each new function. Make it a good one.
|
|
14
|
+
|
|
15
|
+
YOU MAY NOT ALTER THIS DOCUMENT.
|