ennbo 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
enn/turbo/turbo_utils.py CHANGED
@@ -24,15 +24,23 @@ def _next_power_of_2(n: int) -> int:
24
24
 
25
25
 
26
26
  @contextlib.contextmanager
27
- def torch_rng_context(generator: torch.Generator | Any) -> Iterator[None]:
27
+ def torch_seed_context(
28
+ seed: int, device: torch.device | Any | None = None
29
+ ) -> Iterator[None]:
28
30
  import torch
29
31
 
30
- old_state = torch.get_rng_state()
31
- try:
32
- torch.set_rng_state(generator.get_state())
32
+ devices: list[int] | None = None
33
+ if device is not None and getattr(device, "type", None) == "cuda":
34
+ idx = 0 if getattr(device, "index", None) is None else int(device.index)
35
+ devices = [idx]
36
+ with torch.random.fork_rng(devices=devices, enabled=True):
37
+ torch.manual_seed(int(seed))
38
+ if device is not None and getattr(device, "type", None) == "cuda":
39
+ torch.cuda.manual_seed_all(int(seed))
40
+ if device is not None and getattr(device, "type", None) == "mps":
41
+ if hasattr(torch, "mps") and hasattr(torch.mps, "manual_seed"):
42
+ torch.mps.manual_seed(int(seed))
33
43
  yield
34
- finally:
35
- torch.set_rng_state(old_state)
36
44
 
37
45
 
38
46
  def fit_gp(
@@ -45,8 +53,8 @@ def fit_gp(
45
53
  ) -> tuple[
46
54
  "TurboGP | TurboGPNoisy | None",
47
55
  "GaussianLikelihood | None",
48
- float,
49
- float,
56
+ float | np.ndarray,
57
+ float | np.ndarray,
50
58
  ]:
51
59
  import numpy as np
52
60
  import torch
@@ -60,22 +68,40 @@ def fit_gp(
60
68
  x = np.asarray(x_obs_list, dtype=float)
61
69
  y = np.asarray(y_obs_list, dtype=float)
62
70
  n = x.shape[0]
71
+ if y.ndim not in (1, 2):
72
+ raise ValueError(y.shape)
73
+ is_multi_output = y.ndim == 2 and y.shape[1] > 1
63
74
  if yvar_obs_list is not None:
64
75
  if len(yvar_obs_list) != len(y_obs_list):
65
76
  raise ValueError(
66
77
  f"yvar_obs_list length {len(yvar_obs_list)} != y_obs_list length {len(y_obs_list)}"
67
78
  )
79
+ if is_multi_output:
80
+ raise ValueError("yvar_obs_list not supported for multi-output GP")
68
81
  if n == 0:
82
+ if is_multi_output:
83
+ num_outputs = int(y.shape[1])
84
+ return None, None, np.zeros(num_outputs), np.ones(num_outputs)
69
85
  return None, None, 0.0, 1.0
70
- if n == 1:
71
- gp_y_mean = float(y[0])
72
- gp_y_std = 1.0
86
+ if n == 1 and is_multi_output:
87
+ gp_y_mean = y[0].copy()
88
+ gp_y_std = np.ones(int(y.shape[1]), dtype=float)
73
89
  return None, None, gp_y_mean, gp_y_std
74
- gp_y_mean, gp_y_std = standardize_y(y)
75
- y_centered = y - gp_y_mean
76
- z = y_centered / gp_y_std
90
+
91
+ if is_multi_output:
92
+ gp_y_mean = y.mean(axis=0)
93
+ gp_y_std = y.std(axis=0)
94
+ gp_y_std = np.where(gp_y_std < 1e-6, 1.0, gp_y_std)
95
+ z = (y - gp_y_mean) / gp_y_std
96
+ else:
97
+ gp_y_mean, gp_y_std = standardize_y(y)
98
+ y_centered = y - gp_y_mean
99
+ z = y_centered / gp_y_std
77
100
  train_x = torch.as_tensor(x, dtype=torch.float64)
78
- train_y = torch.as_tensor(z, dtype=torch.float64)
101
+ if is_multi_output:
102
+ train_y = torch.as_tensor(z.T, dtype=torch.float64)
103
+ else:
104
+ train_y = torch.as_tensor(z, dtype=torch.float64)
79
105
  lengthscale_constraint = Interval(0.005, 2.0)
80
106
  outputscale_constraint = Interval(0.05, 20.0)
81
107
  if yvar_obs_list is not None:
@@ -92,9 +118,16 @@ def fit_gp(
92
118
  likelihood = model.likelihood
93
119
  else:
94
120
  noise_constraint = Interval(5e-4, 0.2)
95
- likelihood = GaussianLikelihood(noise_constraint=noise_constraint).to(
96
- dtype=train_y.dtype
97
- )
121
+ if is_multi_output:
122
+ num_outputs = int(y.shape[1])
123
+ likelihood = GaussianLikelihood(
124
+ noise_constraint=noise_constraint,
125
+ batch_shape=torch.Size([num_outputs]),
126
+ ).to(dtype=train_y.dtype)
127
+ else:
128
+ likelihood = GaussianLikelihood(noise_constraint=noise_constraint).to(
129
+ dtype=train_y.dtype
130
+ )
98
131
  model = TurboGP(
99
132
  train_x=train_x,
100
133
  train_y=train_y,
@@ -103,11 +136,23 @@ def fit_gp(
103
136
  outputscale_constraint=outputscale_constraint,
104
137
  ard_dims=num_dim,
105
138
  ).to(dtype=train_x.dtype)
106
- likelihood.noise = torch.tensor(0.005, dtype=train_y.dtype)
107
- model.covar_module.outputscale = torch.tensor(1.0, dtype=train_x.dtype)
108
- model.covar_module.base_kernel.lengthscale = torch.full(
109
- (num_dim,), 0.5, dtype=train_x.dtype
110
- )
139
+ if is_multi_output:
140
+ likelihood.noise = torch.full(
141
+ (int(y.shape[1]),), 0.005, dtype=train_y.dtype
142
+ )
143
+ else:
144
+ likelihood.noise = torch.tensor(0.005, dtype=train_y.dtype)
145
+ if is_multi_output:
146
+ num_outputs = int(y.shape[1])
147
+ model.covar_module.outputscale = torch.ones(num_outputs, dtype=train_x.dtype)
148
+ model.covar_module.base_kernel.lengthscale = torch.full(
149
+ (num_outputs, 1, num_dim), 0.5, dtype=train_x.dtype
150
+ )
151
+ else:
152
+ model.covar_module.outputscale = torch.tensor(1.0, dtype=train_x.dtype)
153
+ model.covar_module.base_kernel.lengthscale = torch.full(
154
+ (num_dim,), 0.5, dtype=train_x.dtype
155
+ )
111
156
  model.train()
112
157
  likelihood.train()
113
158
  mll = ExactMarginalLogLikelihood(likelihood, model)
@@ -116,6 +161,8 @@ def fit_gp(
116
161
  optimizer.zero_grad()
117
162
  output = model(train_x)
118
163
  loss = -mll(output, train_y)
164
+ if loss.ndim != 0:
165
+ loss = loss.sum()
119
166
  loss.backward()
120
167
  optimizer.step()
121
168
  model.eval()
@@ -197,6 +244,49 @@ def raasp(
197
244
  )
198
245
 
199
246
 
247
+ def generate_raasp_candidates(
248
+ center: np.ndarray | Any,
249
+ lb: np.ndarray | list[float] | Any,
250
+ ub: np.ndarray | list[float] | Any,
251
+ num_candidates: int,
252
+ *,
253
+ rng: Generator | Any,
254
+ sobol_engine: QMCEngine | Any,
255
+ num_pert: int = 20,
256
+ ) -> np.ndarray:
257
+ if num_candidates <= 0:
258
+ raise ValueError(num_candidates)
259
+ return raasp(
260
+ center,
261
+ lb,
262
+ ub,
263
+ num_candidates,
264
+ num_pert=num_pert,
265
+ rng=rng,
266
+ sobol_engine=sobol_engine,
267
+ )
268
+
269
+
270
+ def generate_trust_region_candidates(
271
+ x_center: np.ndarray | Any,
272
+ lengthscales: np.ndarray | None,
273
+ num_candidates: int,
274
+ *,
275
+ compute_bounds_1d: Any,
276
+ rng: Generator | Any,
277
+ sobol_engine: QMCEngine | Any,
278
+ ) -> np.ndarray:
279
+ """
280
+ Small DRY helper for trust-region candidate generation.
281
+
282
+ `compute_bounds_1d` is typically a TR object's bound computation method.
283
+ """
284
+ lb, ub = compute_bounds_1d(x_center, lengthscales)
285
+ return generate_raasp_candidates(
286
+ x_center, lb, ub, num_candidates, rng=rng, sobol_engine=sobol_engine
287
+ )
288
+
289
+
200
290
  def to_unit(x: np.ndarray | Any, bounds: np.ndarray | Any) -> np.ndarray:
201
291
  import numpy as np
202
292
 
@@ -229,15 +319,15 @@ def gp_thompson_sample(
229
319
 
230
320
  x_torch = torch.as_tensor(x_cand, dtype=torch.float64)
231
321
  seed = int(rng.integers(2**31 - 1))
232
- gen = torch.Generator(device=x_torch.device)
233
- gen.manual_seed(seed)
234
322
  with (
235
323
  torch.no_grad(),
236
324
  gpytorch.settings.fast_pred_var(),
237
- torch_rng_context(gen),
325
+ torch_seed_context(seed, device=x_torch.device),
238
326
  ):
239
327
  posterior = model.posterior(x_torch)
240
328
  samples = posterior.sample(sample_shape=torch.Size([1]))
329
+ if samples.ndim != 2:
330
+ raise ValueError(samples.shape)
241
331
  ts = samples[0].reshape(-1)
242
332
  scores = ts.detach().cpu().numpy().reshape(-1)
243
333
  scores = gp_y_mean + gp_y_std * scores
@@ -7,9 +7,13 @@ if TYPE_CHECKING:
7
7
  from numpy.random import Generator
8
8
 
9
9
  from .base_turbo_impl import BaseTurboImpl
10
+ from .turbo_config import TurboZeroConfig
10
11
 
11
12
 
12
13
  class TurboZeroImpl(BaseTurboImpl):
14
+ def __init__(self, config: TurboZeroConfig) -> None:
15
+ super().__init__(config)
16
+
13
17
  def select_candidates(
14
18
  self,
15
19
  x_cand: np.ndarray,
@@ -18,6 +22,7 @@ class TurboZeroImpl(BaseTurboImpl):
18
22
  rng: Generator,
19
23
  fallback_fn: Callable[[np.ndarray, int], np.ndarray],
20
24
  from_unit_fn: Callable[[np.ndarray], np.ndarray],
25
+ tr_state: object | None = None, # noqa: ARG002
21
26
  ) -> np.ndarray:
22
27
  from .proposal import select_uniform
23
28
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: ennbo
3
- Version: 0.1.0
3
+ Version: 0.1.2
4
4
  Summary: Epistemic Nearest Neighbors
5
5
  Project-URL: Homepage, https://github.com/yubo-research/enn
6
6
  Project-URL: Source, https://github.com/yubo-research/enn
@@ -36,10 +36,10 @@ Classifier: Topic :: Scientific/Engineering
36
36
  Classifier: Topic :: Scientific/Engineering :: Artificial Intelligence
37
37
  Classifier: Topic :: Scientific/Engineering :: Mathematics
38
38
  Requires-Python: >=3.11
39
- Requires-Dist: faiss-cpu==1.9.0
39
+ Requires-Dist: faiss-cpu>=1.9.0
40
40
  Requires-Dist: gpytorch==1.13
41
41
  Requires-Dist: nds==0.4.3
42
- Requires-Dist: numpy==1.26.4
42
+ Requires-Dist: numpy<2.0.0,>=1.26.4
43
43
  Requires-Dist: scipy==1.15.3
44
44
  Requires-Dist: torch==2.5.1
45
45
  Description-Content-Type: text/markdown
@@ -80,9 +80,10 @@ On my MacBook I can run into problems with dependencies and compatibilities.
80
80
 
81
81
  On MacOS try:
82
82
  ```
83
- micromamba env create -n ennbo -f conda-macos.yml
83
+ micromamba env create -n ennbo -f admin/conda-macos.yml
84
84
  micromamba activate ennbo
85
85
  pip install --no-deps ennbo
86
+ pytest -sv tests
86
87
  ```
87
88
 
88
89
  You may replace `micromamba` with `conda` and this will probably still work.
@@ -0,0 +1,29 @@
1
+ enn/__init__.py,sha256=VYIuOTCjhUFIJm78IoJv0WXtvA_IuZhY1sSMJJM3dx8,507
2
+ enn/enn/__init__.py,sha256=K3rntg_ZkITStmXMTBcEhxeS1kel1bb7wB_C7-2WE5Y,135
3
+ enn/enn/enn.py,sha256=HfdrK2gXoI1JvvARsh4NdGOpVpCY2qY_A2RpK4JFVZ4,9310
4
+ enn/enn/enn_fit.py,sha256=RkyFYX4-nUteGivNNS195M2mdRWiGOrLzypVL2b_FsE,4450
5
+ enn/enn/enn_normal.py,sha256=Lm9n-eW5WRn33nb3b9xTGv44Dfn9xAhjys5UJZm2xlc,662
6
+ enn/enn/enn_params.py,sha256=v53qHKwUxnZFNBlcSWI5WqpgpoyFeOzvWs0BxcDqt4o,747
7
+ enn/enn/enn_util.py,sha256=PSeYmxZHz4xLJ6pMr9n22MLxvVmmgBqmQl1ckPrLzDo,4142
8
+ enn/turbo/__init__.py,sha256=utnD3CLZgjCvw-46AAu5Tv2M2Vbg5YXK-_TycGk5BU4,197
9
+ enn/turbo/base_turbo_impl.py,sha256=y4SP9FDT8OaNDOBhKvx1WuTksBd0d-vD8NQxj91QAuA,4396
10
+ enn/turbo/lhd_only_impl.py,sha256=czqLTwhb8d-pKq6jy7g28JcDWMSoopLjNrDeP1dd-3A,1254
11
+ enn/turbo/morbo_trust_region.py,sha256=9z_DgXHEEUfgajRnZ5ieJdeoVAUt7BlLNGu-Thi5Tg4,5973
12
+ enn/turbo/no_trust_region.py,sha256=IxZB1nvmLFgb6GjAXrNpBe1TzgzwJDPYEB8wa_ZPX3k,1813
13
+ enn/turbo/proposal.py,sha256=obFqVyXtZ49veqwnktTJ0_F0nqCERvgFnInstgqhllM,4252
14
+ enn/turbo/turbo_config.py,sha256=tci_GODIED3UHE63XiB9XyxPu_0J5_8R90lxBVxidOQ,2500
15
+ enn/turbo/turbo_enn_impl.py,sha256=qcUssC4xaoMu7usIlp5oJtXrsLY1Rn-mvHUW95zlWZQ,7249
16
+ enn/turbo/turbo_gp.py,sha256=Hi11t0nw5YEG4WM6DeoOW4X-w-M5KGG-P3Zc1sPPx1k,1069
17
+ enn/turbo/turbo_gp_base.py,sha256=tnE5uX_eAt1Db-gemyy83ZvKpdNbMg_tsWkh6sG7zaM,638
18
+ enn/turbo/turbo_gp_noisy.py,sha256=itTL9jUCjE566jwDODT0P36fozsfU_bXACyuKqxYMXs,1080
19
+ enn/turbo/turbo_mode.py,sha256=JMP1jkFCRwPtOzU95MWWd04Sgze7eKF0xNkiPqtQ8SI,181
20
+ enn/turbo/turbo_mode_impl.py,sha256=ubUkV4reOPJH3jbAh6R65cutEHOF23z7Uw1bBE3s9T0,1923
21
+ enn/turbo/turbo_one_impl.py,sha256=PXaBNdLKCgtsLC8Q2z4pHoHYu0edPqKAXb6Tmr4Guvs,11098
22
+ enn/turbo/turbo_optimizer.py,sha256=h6Mu3Pqb2yQRrrqEh6ODxhKIt4Nnt-C241XrNluqxtk,20274
23
+ enn/turbo/turbo_trust_region.py,sha256=0wlN_LhsfMeLqqjhq3xhkJpXtTYIzUon0zbvUjRm2_Q,3797
24
+ enn/turbo/turbo_utils.py,sha256=bEe1F3hBUOUVmodF1WJQv_EKRmp4X74kMxWLNnuV-z0,10733
25
+ enn/turbo/turbo_zero_impl.py,sha256=SvPexeUTQzDPbAwPdZib5lRcIHPOmwD3-ewMxED-nlQ,832
26
+ ennbo-0.1.2.dist-info/METADATA,sha256=-Twds_sAT4LLkcN00vzmS9_a5jUSWRsqi1bO2_RKtHw,5960
27
+ ennbo-0.1.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
28
+ ennbo-0.1.2.dist-info/licenses/LICENSE,sha256=KTA0NjGalsl_JGrjT_x6SSq9ZYVO3gQ-hLVMEaekc5w,1070
29
+ ennbo-0.1.2.dist-info/RECORD,,
@@ -1,27 +0,0 @@
1
- enn/__init__.py,sha256=VYIuOTCjhUFIJm78IoJv0WXtvA_IuZhY1sSMJJM3dx8,507
2
- enn/enn/__init__.py,sha256=K3rntg_ZkITStmXMTBcEhxeS1kel1bb7wB_C7-2WE5Y,135
3
- enn/enn/enn.py,sha256=ZdDPivZj4SL9e87FolU1oscdPdcwUeIByIrvBLsoCfE,8060
4
- enn/enn/enn_fit.py,sha256=uv1BHO-nbxVXkR_tM1Ggoh6YNuR-VrjVECFxLquC7u8,4328
5
- enn/enn/enn_normal.py,sha256=3kOymSx2kzcBMavScXLflPm_gDDLGF9fYLBJ816I3xg,596
6
- enn/enn/enn_params.py,sha256=fwLZTA8ciRp4XUF5L_VAVsC3EvFuOzR85OYLVtv6TSw,184
7
- enn/enn/enn_util.py,sha256=ZELPVeyUl0wiHOxjHYKjxeDz88ExmKMeX3P-bQ6tCoE,3075
8
- enn/turbo/__init__.py,sha256=utnD3CLZgjCvw-46AAu5Tv2M2Vbg5YXK-_TycGk5BU4,197
9
- enn/turbo/base_turbo_impl.py,sha256=wThjwXGboRrVTamsnvzmM0WNIOZ91GNJ-BmGzjgqdhg,2699
10
- enn/turbo/lhd_only_impl.py,sha256=yWsOw7Oq0xfEnyXg5AXJSzZFjM7162pqNY37fHQtJQ4,1023
11
- enn/turbo/proposal.py,sha256=w1izo3ooiiravNRoFWK5ZK7BH-f_HWgqYP8heVtLmYs,3977
12
- enn/turbo/turbo_config.py,sha256=J0ww_qKDDMpbFVXdntuSbJtUTbdnXrFJyGD1svzG3RM,980
13
- enn/turbo/turbo_enn_impl.py,sha256=YMAS4krpPXPNtlh46RRG3VLMuGyYLFw5UkPRBU29mzA,5837
14
- enn/turbo/turbo_gp.py,sha256=i1bxVHima0Nv4MCLlADtlRzt1cENcnVLYk3S9vCoF4c,797
15
- enn/turbo/turbo_gp_base.py,sha256=tnE5uX_eAt1Db-gemyy83ZvKpdNbMg_tsWkh6sG7zaM,638
16
- enn/turbo/turbo_gp_noisy.py,sha256=itTL9jUCjE566jwDODT0P36fozsfU_bXACyuKqxYMXs,1080
17
- enn/turbo/turbo_mode.py,sha256=JMP1jkFCRwPtOzU95MWWd04Sgze7eKF0xNkiPqtQ8SI,181
18
- enn/turbo/turbo_mode_impl.py,sha256=3HKBjOS96Wn-R_znctQm9Ivrm3FhgZFTuBp7McNDQ88,1749
19
- enn/turbo/turbo_one_impl.py,sha256=nS02RdRMcEsi3II07jzcrQbsFsfWYTeahUcqoyhig4Q,5207
20
- enn/turbo/turbo_optimizer.py,sha256=IlofW9_ogCeQMVXa7n8xWEg5fbJBUkvAkeLKe3MoXlA,11902
21
- enn/turbo/turbo_trust_region.py,sha256=VHNYKWtKLt3iKHI0enL9qMMu1Bwi1nupo20L0Sv-vYY,3759
22
- enn/turbo/turbo_utils.py,sha256=XU9-YtW1u5-HKk3bA_M-hVNFPAuNcIYozAmej7ulVsY,7532
23
- enn/turbo/turbo_zero_impl.py,sha256=S4TEHYkVDowtyWSVxWO0ncd1OUIFpeV3IR-eanGr1vg,643
24
- ennbo-0.1.0.dist-info/METADATA,sha256=slkhtsGXaO31u8w35LNKXN2noxUJYTqHQF7bv1DZMmA,5930
25
- ennbo-0.1.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
26
- ennbo-0.1.0.dist-info/licenses/LICENSE,sha256=KTA0NjGalsl_JGrjT_x6SSq9ZYVO3gQ-hLVMEaekc5w,1070
27
- ennbo-0.1.0.dist-info/RECORD,,
File without changes