SearchLibrium 0.0.87__tar.gz → 0.0.89__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/PKG-INFO +1 -1
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/pyproject.toml +1 -1
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/Halton.py +21 -6
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/MixedLogit.py +25 -9
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/__init__.py +9 -2
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/_choice_model.py +0 -1
- searchlibrium-0.0.89/src/SearchLibrium/banditsa.py +300 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/call_meta.py +105 -2
- searchlibrium-0.0.89/src/SearchLibrium/sapbil.py +698 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/search.py +35 -5
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/siman.py +13 -1
- searchlibrium-0.0.89/src/SearchLibrium/test_sapbil_vs_banditsa.py +246 -0
- searchlibrium-0.0.89/src/SearchLibrium/version.txt +1 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/SOURCES.txt +3 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/top_level.txt +1 -0
- searchlibrium-0.0.87/src/SearchLibrium/version.txt +0 -1
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/README.md +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/setup.cfg +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/RandomP.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/Two_Level_Nest.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/__main__.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/_device.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/bhhh/minimize.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/boxcox_functions.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/constraints_builder.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/harmony.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/latent_class.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/main.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/main_debug.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/mdcev.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/misc.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/mixed_logit.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/mixed_nested.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/mixedrrm.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/multinomial_logit.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/multinomial_nested.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/multinomial_probit.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/ordered_logit.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/rrm.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/selection_models.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/setup.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/threshold.py +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
- {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/requires.txt +0 -0
|
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
|
|
|
59
59
|
realpython = "SearchLibrium.__main__:main"
|
|
60
60
|
|
|
61
61
|
[tool.bumpver]
|
|
62
|
-
current_version = "0.0.
|
|
62
|
+
current_version = "0.0.89"
|
|
63
63
|
version_pattern = "MAJOR.MINOR.PATCH"
|
|
64
64
|
commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
|
|
65
65
|
commit = true
|
|
@@ -67,7 +67,7 @@ import scipy.stats as ss
|
|
|
67
67
|
class Halton:
|
|
68
68
|
"""Class for generating Halton sequences and Halton-based draws."""
|
|
69
69
|
|
|
70
|
-
def __init__(self, primes=None, drop=100, shuffled=False):
|
|
70
|
+
def __init__(self, primes=None, drop=100, shuffled=False, antithetic=False):
|
|
71
71
|
self.primes = primes or [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47,
|
|
72
72
|
53, 59, 61, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109,
|
|
73
73
|
113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173,
|
|
@@ -76,15 +76,24 @@ class Halton:
|
|
|
76
76
|
307, 311]
|
|
77
77
|
self.drop = drop
|
|
78
78
|
self.shuffled = shuffled
|
|
79
|
+
self.antithetic = antithetic
|
|
79
80
|
|
|
80
81
|
def generate_draws(self, sample_size, n_draws, n_vars):
|
|
81
|
-
"""Generate Halton draws for multiple variables using different primes.
|
|
82
|
+
"""Generate Halton draws for multiple variables using different primes.
|
|
83
|
+
|
|
84
|
+
When ``antithetic=True``, draws of size ``n_draws // 2`` are generated
|
|
85
|
+
then mirrored (1 - u) to produce negatively-correlated antithetic pairs.
|
|
86
|
+
For normal-based distributions this halves variance at zero extra cost.
|
|
87
|
+
"""
|
|
88
|
+
base = n_draws // 2 if self.antithetic else n_draws
|
|
82
89
|
draws = [
|
|
83
|
-
self.halton_seq(sample_size *
|
|
84
|
-
n_draws)
|
|
90
|
+
self.halton_seq(sample_size * base, prime=self.primes[i % len(self.primes)]).reshape(sample_size, base)
|
|
85
91
|
for i in range(n_vars)
|
|
86
92
|
]
|
|
87
|
-
|
|
93
|
+
draws = np.stack(draws, axis=1) # (N, Kr, R_base)
|
|
94
|
+
if self.antithetic:
|
|
95
|
+
draws = np.concatenate([draws, 1.0 - draws], axis=2) # (N, Kr, n_draws)
|
|
96
|
+
return draws
|
|
88
97
|
|
|
89
98
|
def halton_seq(self, length, prime):
|
|
90
99
|
"""Generates a scrambled Halton sequence for a given prime number."""
|
|
@@ -159,11 +168,17 @@ class Draws:
|
|
|
159
168
|
draws = np.atleast_3d(draws)
|
|
160
169
|
return draws
|
|
161
170
|
|
|
171
|
+
# Clip uniform draws away from 0/1 to prevent norm.ppf -> ±inf.
|
|
172
|
+
# This matters especially for antithetic draws where the complement of a
|
|
173
|
+
# small Halton value is very close to 1.
|
|
174
|
+
_PPF_CLIP = 1e-10
|
|
175
|
+
|
|
162
176
|
def evaluate_distribution(self, distr, values):
|
|
163
177
|
"""Transform uniform values to the specified distribution."""
|
|
164
178
|
for k, distr_k in enumerate(distr):
|
|
165
179
|
if distr_k in ['n', 'ln', 'tn']: # Normal-based
|
|
166
|
-
values[:, k, :]
|
|
180
|
+
u = np.clip(values[:, k, :], self._PPF_CLIP, 1.0 - self._PPF_CLIP)
|
|
181
|
+
values[:, k, :] = ss.norm.ppf(u)
|
|
167
182
|
elif distr_k == 't': # Triangular
|
|
168
183
|
values_k = values[:, k, :]
|
|
169
184
|
values[:, k, :] = (np.sqrt(2 * values_k) - 1) * (values_k <= .5) + \
|
|
@@ -51,9 +51,11 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
51
51
|
n_vars (int): Number of variables.
|
|
52
52
|
"""
|
|
53
53
|
if n_vars == 0:
|
|
54
|
-
return np.ndarray((1,0,1))
|
|
55
|
-
|
|
56
|
-
|
|
54
|
+
return np.ndarray((1, 0, 1))
|
|
55
|
+
# Use self.halton_opts so options like antithetic/shuffled are respected.
|
|
56
|
+
opts = self.halton_opts or {}
|
|
57
|
+
draws_s = Draws(k=n_vars, halton_opts=opts)
|
|
58
|
+
draws = draws_s.generate_draws(sample_size, n_draws)
|
|
57
59
|
return draws
|
|
58
60
|
|
|
59
61
|
|
|
@@ -210,7 +212,6 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
210
212
|
means_1 = np.mean(self.y, axis=3) # means_1[i,j] = avg(y[i,j,:])
|
|
211
213
|
means_2 = np.mean(means_1, axis=1) # means_2[i] = avg(means_1[i,:])
|
|
212
214
|
self.obs_prob = np.mean(means_2, axis=0) # obs_prob = avg(means_2[:])
|
|
213
|
-
print(f'observed probs debug{self.obs_prob}')
|
|
214
215
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
215
216
|
# DEFINE MEMBER FUNCTIONS TO APPLY
|
|
216
217
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
@@ -230,8 +231,6 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
230
231
|
#print(f"self.rvtrans: {self.rvtransdist}")
|
|
231
232
|
self.rvdist = [item for item in self.rvdist if item is not False]
|
|
232
233
|
self.rvtransdist = [item for item in self.rvtransdist if item is not False]
|
|
233
|
-
print(f"self.randvars: {self.rvdist}")
|
|
234
|
-
print(f"self.rvtrans: {self.rvtransdist}")
|
|
235
234
|
draws = self.generate_draws(self.N, self.n_draws, len(self.rvdist))
|
|
236
235
|
drawstrans = self.generate_draws(self.N, self.n_draws, len(self.rvtransdist))
|
|
237
236
|
self.draws, self.drawstrans = draws, drawstrans # Record generated values
|
|
@@ -282,13 +281,30 @@ class MixedLogit(DiscreteChoiceModel):
|
|
|
282
281
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
283
282
|
|
|
284
283
|
arr = self.init_coeff[:lower]
|
|
285
|
-
|
|
284
|
+
|
|
285
|
+
# Better scale initialisation: use half the absolute value of the MNL
|
|
286
|
+
# mean estimates for the corresponding random variable, floored at 0.05.
|
|
287
|
+
# This gives BFGS a much better starting point than a flat 0.1 when
|
|
288
|
+
# coefficients are very small (e.g. cost in money units) or very large.
|
|
289
|
+
br_means = arr[self.Kf + 2 * self.Kftrans: self.Kf + 2 * self.Kftrans + self.Kr]
|
|
290
|
+
|
|
291
|
+
# Cholesky elements (correlated random vars) — keep at 0.1; the diagonal
|
|
292
|
+
# scaling is embedded in off-diagonal terms and is harder to infer.
|
|
293
|
+
chol_init = np.repeat(0.1, self.Kchol)
|
|
294
|
+
|
|
295
|
+
# Bandwidth (std dev) for non-correlated random vars.
|
|
296
|
+
bw_means = br_means[self.correlationLength:]
|
|
297
|
+
bw_init = np.maximum(np.abs(bw_means) * 0.5, 0.05)
|
|
298
|
+
|
|
299
|
+
rep = np.concatenate([chol_init, bw_init])
|
|
286
300
|
self.init_coeff = np.concatenate((arr, rep, self.init_coeff[lower:upper],))
|
|
287
301
|
|
|
288
302
|
if self.Krtrans: # CHECK ">0"
|
|
289
303
|
# {
|
|
290
|
-
|
|
291
|
-
|
|
304
|
+
# Similarly scale random-transformed std devs from MNL means.
|
|
305
|
+
rtrans_means = self.init_coeff[lower:upper]
|
|
306
|
+
rtrans_scale_init = np.maximum(np.abs(rtrans_means) * 0.5, 0.05)
|
|
307
|
+
self.init_coeff = np.concatenate((self.init_coeff, rtrans_scale_init, self.init_coeff[-self.Krtrans:]))
|
|
292
308
|
# }
|
|
293
309
|
# }
|
|
294
310
|
|
|
@@ -12,6 +12,11 @@ from addicty import Dict
|
|
|
12
12
|
|
|
13
13
|
import os
|
|
14
14
|
|
|
15
|
+
try:
|
|
16
|
+
from banditsa import BanditSA, PerturbationBandit
|
|
17
|
+
except ImportError:
|
|
18
|
+
from .banditsa import BanditSA, PerturbationBandit
|
|
19
|
+
|
|
15
20
|
def new_features():
|
|
16
21
|
'''ADDICTY DICT'''
|
|
17
22
|
'''vars you want to ensure are included'''
|
|
@@ -103,7 +108,8 @@ try:
|
|
|
103
108
|
from .search import Parameters
|
|
104
109
|
|
|
105
110
|
from . import misc
|
|
106
|
-
from .
|
|
111
|
+
from .sapbil import SAPBIL, ProbabilityMatrix
|
|
112
|
+
from .call_meta import call_harmony, call_siman, call_parsa, call_search, call_sapbil, estimate_ctrl
|
|
107
113
|
|
|
108
114
|
except ImportError as e:
|
|
109
115
|
from _choice_model import DiscreteChoiceModel
|
|
@@ -121,7 +127,8 @@ except ImportError as e:
|
|
|
121
127
|
from RandomP import RandomParameters
|
|
122
128
|
from constraints_builder import ConstraintBuilder, create_constraints
|
|
123
129
|
from search import Parameters
|
|
124
|
-
from
|
|
130
|
+
from sapbil import SAPBIL, ProbabilityMatrix
|
|
131
|
+
from call_meta import call_siman, call_harmony, call_search, call_sapbil, estimate_ctrl
|
|
125
132
|
try:
|
|
126
133
|
from .main import print_ascii_art_logo
|
|
127
134
|
except Exception:
|
|
@@ -391,7 +391,6 @@ class DiscreteChoiceModel(ABC):
|
|
|
391
391
|
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
|
|
392
392
|
non_sigs = self.num_of_exceeding_pvalues(self.pvalues, 0.0)
|
|
393
393
|
#print('log like is before', self.loglik)
|
|
394
|
-
print('apply p val')
|
|
395
394
|
self.loglik -= non_sigs*self.pval_penalty # penalise the non-sigs
|
|
396
395
|
logging.info('applying pval')
|
|
397
396
|
#print('log like is', self.loglik)
|
|
@@ -0,0 +1,300 @@
|
|
|
1
|
+
"""Bandit-guided Simulated Annealing for SearchLibrium.
|
|
2
|
+
|
|
3
|
+
This module provides:
|
|
4
|
+
- PerturbationBandit: Thompson-sampling policy over perturbation action types
|
|
5
|
+
- BanditSA: SA variant that uses the bandit to pick perturbation actions
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
import logging
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
|
|
12
|
+
try:
|
|
13
|
+
from siman import SA
|
|
14
|
+
except ImportError:
|
|
15
|
+
from .siman import SA
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
class PerturbationBandit:
|
|
19
|
+
"""Thompson Sampling over perturbation action-type arms."""
|
|
20
|
+
|
|
21
|
+
def __init__(self, arm_names, prior_alpha=1.0, prior_beta=1.0, epsilon=0.05, rng=None):
|
|
22
|
+
self.arm_names = list(arm_names)
|
|
23
|
+
self.n_arms = len(self.arm_names)
|
|
24
|
+
self.alpha = np.full(self.n_arms, float(prior_alpha), dtype=float)
|
|
25
|
+
self.beta = np.full(self.n_arms, float(prior_beta), dtype=float)
|
|
26
|
+
self.counts = np.zeros(self.n_arms, dtype=int)
|
|
27
|
+
self.epsilon = float(epsilon)
|
|
28
|
+
self.rng = rng if rng is not None else np.random.default_rng()
|
|
29
|
+
|
|
30
|
+
def select_arm(self, available_indices):
|
|
31
|
+
if not available_indices:
|
|
32
|
+
raise ValueError("No available bandit arms to select from.")
|
|
33
|
+
|
|
34
|
+
if self.rng.random() < self.epsilon:
|
|
35
|
+
return int(self.rng.choice(available_indices))
|
|
36
|
+
|
|
37
|
+
samples = {
|
|
38
|
+
idx: self.rng.beta(self.alpha[idx], self.beta[idx])
|
|
39
|
+
for idx in available_indices
|
|
40
|
+
}
|
|
41
|
+
return max(samples, key=samples.get)
|
|
42
|
+
|
|
43
|
+
def update(self, arm_index, reward):
|
|
44
|
+
reward = float(reward)
|
|
45
|
+
self.counts[arm_index] += 1
|
|
46
|
+
|
|
47
|
+
if reward >= 0.0:
|
|
48
|
+
self.alpha[arm_index] += 1.0 + reward
|
|
49
|
+
else:
|
|
50
|
+
self.beta[arm_index] += 1.0 + abs(reward)
|
|
51
|
+
|
|
52
|
+
def summary(self):
|
|
53
|
+
rows = []
|
|
54
|
+
total = np.maximum(self.alpha + self.beta, 1e-12)
|
|
55
|
+
means = self.alpha / total
|
|
56
|
+
for idx, name in enumerate(self.arm_names):
|
|
57
|
+
rows.append(
|
|
58
|
+
{
|
|
59
|
+
"arm": name,
|
|
60
|
+
"alpha": float(self.alpha[idx]),
|
|
61
|
+
"beta": float(self.beta[idx]),
|
|
62
|
+
"mean": float(means[idx]),
|
|
63
|
+
"count": int(self.counts[idx]),
|
|
64
|
+
}
|
|
65
|
+
)
|
|
66
|
+
rows.sort(key=lambda row: row["mean"], reverse=True)
|
|
67
|
+
return rows
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
class BanditSA(SA):
|
|
71
|
+
"""SA variant with Thompson-sampling perturbation selection."""
|
|
72
|
+
|
|
73
|
+
def __init__(
|
|
74
|
+
self,
|
|
75
|
+
param,
|
|
76
|
+
init_sol,
|
|
77
|
+
ctrl,
|
|
78
|
+
idnum=0,
|
|
79
|
+
bandit_prior_alpha=1.0,
|
|
80
|
+
bandit_prior_beta=1.0,
|
|
81
|
+
bandit_epsilon=0.05,
|
|
82
|
+
**kwargs,
|
|
83
|
+
):
|
|
84
|
+
super().__init__(param, init_sol, ctrl, idnum=idnum, **kwargs)
|
|
85
|
+
|
|
86
|
+
self._actions = self._build_action_table()
|
|
87
|
+
self.bandit = PerturbationBandit(
|
|
88
|
+
arm_names=[name for name, _ in self._actions],
|
|
89
|
+
prior_alpha=bandit_prior_alpha,
|
|
90
|
+
prior_beta=bandit_prior_beta,
|
|
91
|
+
epsilon=bandit_epsilon,
|
|
92
|
+
)
|
|
93
|
+
self.bandit_history = []
|
|
94
|
+
|
|
95
|
+
def _build_action_table(self):
|
|
96
|
+
return [
|
|
97
|
+
("add_asfeature", self.perturb_add_asfeature),
|
|
98
|
+
("remove_asfeature", self.perturb_remove_asfeature),
|
|
99
|
+
("add_isfeature", self.perturb_add_isfeature),
|
|
100
|
+
("remove_isfeature", self.perturb_remove_isfeature),
|
|
101
|
+
("add_randfeature", self.perturb_add_randfeature),
|
|
102
|
+
("remove_randfeature", self.perturb_remove_randfeature),
|
|
103
|
+
("add_bcfeature", self.perturb_add_bcfeature),
|
|
104
|
+
("remove_bcfeature", self.perturb_remove_bcfeature),
|
|
105
|
+
("add_corfeature", self.perturb_add_corfeature),
|
|
106
|
+
("remove_corfeature", self.perturb_remove_corfeature),
|
|
107
|
+
("change_model", self.perturb_model_t),
|
|
108
|
+
("change_distribution", self.perturb_distribution),
|
|
109
|
+
]
|
|
110
|
+
|
|
111
|
+
def _available_arm_indices(self, sol):
|
|
112
|
+
available = []
|
|
113
|
+
|
|
114
|
+
if self.param.asvarnames and len(sol.get("asvars", [])) < len(self.param.asvarnames):
|
|
115
|
+
available.append(0) # add_asfeature
|
|
116
|
+
if len(sol.get("asvars", [])) > 1:
|
|
117
|
+
available.append(1) # remove_asfeature
|
|
118
|
+
|
|
119
|
+
if self.param.isvarnames:
|
|
120
|
+
is_candidates = [v for v in self.param.isvarnames if v not in sol.get("isvars", [])]
|
|
121
|
+
if is_candidates:
|
|
122
|
+
available.append(2) # add_isfeature
|
|
123
|
+
if sol.get("isvars"):
|
|
124
|
+
available.append(3) # remove_isfeature
|
|
125
|
+
|
|
126
|
+
if self.param.allow_random:
|
|
127
|
+
add_r_candidates = [
|
|
128
|
+
v for v in sol.get("asvars", []) if v not in sol.get("randvars", {})
|
|
129
|
+
]
|
|
130
|
+
if add_r_candidates:
|
|
131
|
+
available.append(4) # add_randfeature
|
|
132
|
+
|
|
133
|
+
rem_r_candidates = [
|
|
134
|
+
v for v in sol.get("randvars", {}) if v not in self.param.ps_randvars
|
|
135
|
+
]
|
|
136
|
+
if rem_r_candidates:
|
|
137
|
+
available.append(5) # remove_randfeature
|
|
138
|
+
|
|
139
|
+
dist_candidates = [
|
|
140
|
+
v for v in sol.get("randvars", {}) if v not in self.param.ps_randvars
|
|
141
|
+
]
|
|
142
|
+
if dist_candidates and len(getattr(self.param, "distr", []) or []) > 1:
|
|
143
|
+
available.append(11) # change_distribution
|
|
144
|
+
|
|
145
|
+
if self.param.allow_bcvars:
|
|
146
|
+
add_bc_candidates = [
|
|
147
|
+
v
|
|
148
|
+
for v in sol.get("asvars", [])
|
|
149
|
+
if v not in sol.get("bcvars", []) and v not in self.param.ps_corvars
|
|
150
|
+
]
|
|
151
|
+
if add_bc_candidates:
|
|
152
|
+
available.append(6) # add_bcfeature
|
|
153
|
+
|
|
154
|
+
rem_bc_candidates = [
|
|
155
|
+
v for v in sol.get("bcvars", []) if v not in self.param.ps_bcvars
|
|
156
|
+
]
|
|
157
|
+
if rem_bc_candidates:
|
|
158
|
+
available.append(7) # remove_bcfeature
|
|
159
|
+
|
|
160
|
+
if self.param.allow_corvars and self.param.allow_random:
|
|
161
|
+
cor_eligible = [
|
|
162
|
+
v for v in sol.get("randvars", {}) if v not in sol.get("bcvars", [])
|
|
163
|
+
]
|
|
164
|
+
if len(cor_eligible) >= 2:
|
|
165
|
+
available.append(8) # add_corfeature
|
|
166
|
+
|
|
167
|
+
rem_cor_candidates = [
|
|
168
|
+
v for v in sol.get("corvars", []) if v not in self.param.ps_corvars
|
|
169
|
+
]
|
|
170
|
+
if rem_cor_candidates:
|
|
171
|
+
available.append(9) # remove_corfeature
|
|
172
|
+
|
|
173
|
+
if self.param.avail_models is not None and len(self.param.avail_models) > 1:
|
|
174
|
+
available.append(10) # change_model
|
|
175
|
+
|
|
176
|
+
return sorted(set(available))
|
|
177
|
+
|
|
178
|
+
def _compute_bandit_reward(self, old_obj, new_obj, accepted, converged):
|
|
179
|
+
if not converged:
|
|
180
|
+
return -1.0
|
|
181
|
+
|
|
182
|
+
sign = self.param.sign_crit(0)
|
|
183
|
+
# Positive value means objective improved under current criterion direction.
|
|
184
|
+
delta = sign * (new_obj - old_obj)
|
|
185
|
+
scale = max(abs(old_obj), 1.0)
|
|
186
|
+
normalized = delta / scale
|
|
187
|
+
|
|
188
|
+
if accepted:
|
|
189
|
+
if normalized > 0:
|
|
190
|
+
return min(1.0, 0.1 + 10.0 * normalized)
|
|
191
|
+
return max(-1.0, 10.0 * normalized)
|
|
192
|
+
|
|
193
|
+
return -0.2
|
|
194
|
+
|
|
195
|
+
def _apply_selected_action(self, sol):
|
|
196
|
+
baseline_sig = self.setup_signature(sol)
|
|
197
|
+
max_attempts = 12
|
|
198
|
+
last_arm_idx = None
|
|
199
|
+
|
|
200
|
+
for _ in range(max_attempts):
|
|
201
|
+
candidate = self.copy_solution(sol)
|
|
202
|
+
available = self._available_arm_indices(candidate)
|
|
203
|
+
if not available:
|
|
204
|
+
return sol, None, False
|
|
205
|
+
|
|
206
|
+
arm_idx = self.bandit.select_arm(available)
|
|
207
|
+
last_arm_idx = arm_idx
|
|
208
|
+
_, action = self._actions[arm_idx]
|
|
209
|
+
result = action(candidate)
|
|
210
|
+
if result is not None:
|
|
211
|
+
candidate = result
|
|
212
|
+
|
|
213
|
+
candidate = self.apply_constraints(candidate)
|
|
214
|
+
candidate = self.repair_solution_for_clarity(candidate)
|
|
215
|
+
if self.setup_signature(candidate) != baseline_sig:
|
|
216
|
+
return candidate, arm_idx, True
|
|
217
|
+
|
|
218
|
+
return sol, last_arm_idx, False
|
|
219
|
+
|
|
220
|
+
def perturb_solution(self, sol):
|
|
221
|
+
curr_score = [sol.obj(i) for i in range(self.nb_crit)]
|
|
222
|
+
|
|
223
|
+
new_sol, arm_idx, changed = self._apply_selected_action(sol)
|
|
224
|
+
if not changed:
|
|
225
|
+
if arm_idx is not None:
|
|
226
|
+
self.bandit.update(arm_idx, -0.5)
|
|
227
|
+
self.bandit_history.append(
|
|
228
|
+
{
|
|
229
|
+
"step": int(self.step),
|
|
230
|
+
"arm": self._actions[arm_idx][0],
|
|
231
|
+
"accepted": False,
|
|
232
|
+
"converged": False,
|
|
233
|
+
"reward": -0.5,
|
|
234
|
+
"reason": "no_change",
|
|
235
|
+
}
|
|
236
|
+
)
|
|
237
|
+
return self.current_sol
|
|
238
|
+
|
|
239
|
+
new_sol, converged = self.evaluate(new_sol)
|
|
240
|
+
if not converged:
|
|
241
|
+
self.not_converged += 1
|
|
242
|
+
reward = self._compute_bandit_reward(curr_score[0], curr_score[0], False, False)
|
|
243
|
+
self.bandit.update(arm_idx, reward)
|
|
244
|
+
self.bandit_history.append(
|
|
245
|
+
{
|
|
246
|
+
"step": int(self.step),
|
|
247
|
+
"arm": self._actions[arm_idx][0],
|
|
248
|
+
"accepted": False,
|
|
249
|
+
"converged": False,
|
|
250
|
+
"reward": float(reward),
|
|
251
|
+
"reason": "non_converged",
|
|
252
|
+
}
|
|
253
|
+
)
|
|
254
|
+
return self.current_sol
|
|
255
|
+
|
|
256
|
+
new_score = [new_sol.obj(i) for i in range(self.nb_crit)]
|
|
257
|
+
accepted = bool(self.accept_change(curr_score, new_score))
|
|
258
|
+
|
|
259
|
+
if accepted:
|
|
260
|
+
self.no_impr = 0
|
|
261
|
+
self.accepted += 1
|
|
262
|
+
self.current_sol = new_sol
|
|
263
|
+
self.update_best(new_sol)
|
|
264
|
+
else:
|
|
265
|
+
self.not_accepted += 1
|
|
266
|
+
|
|
267
|
+
reward = self._compute_bandit_reward(curr_score[0], new_score[0], accepted, converged)
|
|
268
|
+
self.bandit.update(arm_idx, reward)
|
|
269
|
+
self.bandit_history.append(
|
|
270
|
+
{
|
|
271
|
+
"step": int(self.step),
|
|
272
|
+
"arm": self._actions[arm_idx][0],
|
|
273
|
+
"accepted": bool(accepted),
|
|
274
|
+
"converged": bool(converged),
|
|
275
|
+
"reward": float(reward),
|
|
276
|
+
"old_obj": float(curr_score[0]),
|
|
277
|
+
"new_obj": float(new_score[0]),
|
|
278
|
+
}
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
self.log_kpi(new_sol, self.debug_file, accepted)
|
|
282
|
+
return self.current_sol
|
|
283
|
+
|
|
284
|
+
def finalise(self):
|
|
285
|
+
super().finalise()
|
|
286
|
+
|
|
287
|
+
lines = ["Bandit arm summary (sorted by posterior mean):"]
|
|
288
|
+
for row in self.bandit.summary():
|
|
289
|
+
lines.append(
|
|
290
|
+
" {arm:<20s} mean={mean:.4f} alpha={alpha:.2f} beta={beta:.2f} count={count}".format(
|
|
291
|
+
**row
|
|
292
|
+
)
|
|
293
|
+
)
|
|
294
|
+
|
|
295
|
+
for line in lines:
|
|
296
|
+
print(line)
|
|
297
|
+
try:
|
|
298
|
+
print(line, file=self.results_file)
|
|
299
|
+
except Exception:
|
|
300
|
+
logging.debug("Unable to write bandit summary to results_file")
|
|
@@ -21,11 +21,15 @@
|
|
|
21
21
|
try:
|
|
22
22
|
from harmony import*
|
|
23
23
|
from siman import*
|
|
24
|
+
from banditsa import*
|
|
24
25
|
from threshold import*
|
|
26
|
+
from sapbil import SAPBIL, ProbabilityMatrix
|
|
25
27
|
except ImportError:
|
|
26
28
|
from .harmony import*
|
|
27
29
|
from .siman import*
|
|
30
|
+
from .banditsa import*
|
|
28
31
|
from .threshold import*
|
|
32
|
+
from .sapbil import SAPBIL, ProbabilityMatrix
|
|
29
33
|
|
|
30
34
|
import numpy as np
|
|
31
35
|
|
|
@@ -301,6 +305,99 @@ def call_siman(parameters, init_sol=None, ctrl=None, **kwargs):
|
|
|
301
305
|
return best
|
|
302
306
|
|
|
303
307
|
|
|
308
|
+
def call_sapbil(parameters, init_sol=None, ctrl=None, **kwargs):
|
|
309
|
+
"""
|
|
310
|
+
Run SA+PBIL (Simulated Annealing coupled with Population-Based Incremental
|
|
311
|
+
Learning) search.
|
|
312
|
+
|
|
313
|
+
Parameters
|
|
314
|
+
----------
|
|
315
|
+
parameters : Parameters
|
|
316
|
+
Problem definition (variables, data, criteria, models).
|
|
317
|
+
init_sol : Solution, optional
|
|
318
|
+
Warm-start solution. None = generate automatically.
|
|
319
|
+
ctrl : tuple, optional
|
|
320
|
+
``(tI, tF, max_temp_steps, max_iter)``.
|
|
321
|
+
If omitted the values are estimated from the problem size.
|
|
322
|
+
**kwargs
|
|
323
|
+
``id_num`` — run identifier (int, used in log file names).
|
|
324
|
+
Any other kwargs are forwarded to the SAPBIL constructor.
|
|
325
|
+
|
|
326
|
+
Returns
|
|
327
|
+
-------
|
|
328
|
+
Solution
|
|
329
|
+
Best converged, all-significant solution found.
|
|
330
|
+
"""
|
|
331
|
+
if ctrl is None:
|
|
332
|
+
ctrl = kwargs.pop("ctrl", None)
|
|
333
|
+
|
|
334
|
+
id_num = kwargs.pop("id_num", None)
|
|
335
|
+
|
|
336
|
+
if ctrl is None:
|
|
337
|
+
ctrl = estimate_ctrl(parameters, algorithm="sa")
|
|
338
|
+
print(
|
|
339
|
+
f"[SA+PBIL] Auto-estimated hyperparameters (problem complexity "
|
|
340
|
+
f"= {_problem_size(parameters)['complexity']}):"
|
|
341
|
+
)
|
|
342
|
+
else:
|
|
343
|
+
print("[SA+PBIL] Using provided hyperparameters:")
|
|
344
|
+
|
|
345
|
+
print(_describe_ctrl(ctrl, "sa"))
|
|
346
|
+
print()
|
|
347
|
+
|
|
348
|
+
solver = SAPBIL(parameters, init_sol, ctrl, id_num, **kwargs)
|
|
349
|
+
solver.run()
|
|
350
|
+
solver.close_files()
|
|
351
|
+
best = solver.return_best()
|
|
352
|
+
_print_dashboard(solver, best, algorithm="SA+PBIL")
|
|
353
|
+
return best
|
|
354
|
+
|
|
355
|
+
|
|
356
|
+
def call_banditsa(parameters, init_sol=None, ctrl=None, **kwargs):
|
|
357
|
+
"""
|
|
358
|
+
Run Bandit-guided Simulated Annealing (Thompson Sampling on perturbation arms).
|
|
359
|
+
|
|
360
|
+
Parameters
|
|
361
|
+
----------
|
|
362
|
+
parameters : Parameters
|
|
363
|
+
Problem definition (variables, data, criteria, models).
|
|
364
|
+
init_sol : Solution, optional
|
|
365
|
+
Warm-start solution. None = generate automatically.
|
|
366
|
+
ctrl : tuple, optional
|
|
367
|
+
``(tI, tF, max_temp_steps, max_iter)``.
|
|
368
|
+
If omitted the values are estimated from the problem size.
|
|
369
|
+
**kwargs
|
|
370
|
+
``id_num`` - run identifier (int, used in log file names).
|
|
371
|
+
Any other kwargs are forwarded to the BanditSA constructor.
|
|
372
|
+
|
|
373
|
+
Returns
|
|
374
|
+
-------
|
|
375
|
+
Solution
|
|
376
|
+
Best converged, all-significant solution found.
|
|
377
|
+
"""
|
|
378
|
+
if ctrl is None:
|
|
379
|
+
ctrl = kwargs.pop('ctrl', None)
|
|
380
|
+
|
|
381
|
+
id_num = kwargs.pop('id_num', None)
|
|
382
|
+
|
|
383
|
+
if ctrl is None:
|
|
384
|
+
ctrl = estimate_ctrl(parameters, algorithm='sa')
|
|
385
|
+
print(f"[BanditSA] Auto-estimated hyperparameters (problem complexity "
|
|
386
|
+
f"= {_problem_size(parameters)['complexity']}):")
|
|
387
|
+
else:
|
|
388
|
+
print("[BanditSA] Using provided hyperparameters:")
|
|
389
|
+
|
|
390
|
+
print(_describe_ctrl(ctrl, 'sa'))
|
|
391
|
+
print()
|
|
392
|
+
|
|
393
|
+
solver = BanditSA(parameters, init_sol, ctrl, id_num, **kwargs)
|
|
394
|
+
solver.run()
|
|
395
|
+
solver.close_files()
|
|
396
|
+
best = solver.return_best()
|
|
397
|
+
_print_dashboard(solver, best, algorithm='BanditSA')
|
|
398
|
+
return best
|
|
399
|
+
|
|
400
|
+
|
|
304
401
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
305
402
|
# Harmony Search
|
|
306
403
|
# ─────────────────────────────────────────────────────────────────────────────
|
|
@@ -364,8 +461,9 @@ def call_search(parameters, init_sol=None, algorithm='sa', ctrl=None, **kwargs):
|
|
|
364
461
|
Problem definition.
|
|
365
462
|
init_sol : Solution, optional
|
|
366
463
|
Warm-start solution.
|
|
367
|
-
algorithm : {'sa', 'hs'}
|
|
464
|
+
algorithm : {'sa', 'banditsa', 'hs'}
|
|
368
465
|
``'sa'`` — Simulated Annealing (default)
|
|
466
|
+
``'banditsa'`` — Bandit-guided Simulated Annealing
|
|
369
467
|
``'hs'`` — Harmony Search
|
|
370
468
|
ctrl : tuple, optional
|
|
371
469
|
Algorithm-specific control tuple. Auto-estimated if omitted.
|
|
@@ -383,6 +481,7 @@ def call_search(parameters, init_sol=None, algorithm='sa', ctrl=None, **kwargs):
|
|
|
383
481
|
Examples
|
|
384
482
|
--------
|
|
385
483
|
>>> best = call_search(params) # SA, auto ctrl
|
|
484
|
+
>>> best = call_search(params, algorithm='banditsa') # BanditSA, auto ctrl
|
|
386
485
|
>>> best = call_search(params, algorithm='hs') # HS, auto ctrl
|
|
387
486
|
>>> best = call_search(params, ctrl=(500,0.001,80,15))# SA, manual ctrl
|
|
388
487
|
>>> best = call_search(params, algorithm='hs',
|
|
@@ -391,12 +490,16 @@ def call_search(parameters, init_sol=None, algorithm='sa', ctrl=None, **kwargs):
|
|
|
391
490
|
algorithm = algorithm.lower().strip()
|
|
392
491
|
if algorithm in ('sa', 'siman', 'simulated_annealing'):
|
|
393
492
|
return call_siman(parameters, init_sol=init_sol, ctrl=ctrl, **kwargs)
|
|
493
|
+
elif algorithm in ('sapbil', 'sa_pbil', 'sa+pbil', 'pbil'):
|
|
494
|
+
return call_sapbil(parameters, init_sol=init_sol, ctrl=ctrl, **kwargs)
|
|
495
|
+
elif algorithm in ('banditsa', 'bandit_sa', 'bandit-simulated-annealing', 'bsa'):
|
|
496
|
+
return call_banditsa(parameters, init_sol=init_sol, ctrl=ctrl, **kwargs)
|
|
394
497
|
elif algorithm in ('hs', 'harmony', 'harmony_search'):
|
|
395
498
|
return call_harmony(parameters, init_sol=init_sol, ctrl=ctrl, **kwargs)
|
|
396
499
|
else:
|
|
397
500
|
raise ValueError(
|
|
398
501
|
f"Unknown algorithm '{algorithm}'. "
|
|
399
|
-
f"Choose 'sa'
|
|
502
|
+
f"Choose 'sa', 'sapbil', 'banditsa', or 'hs'."
|
|
400
503
|
)
|
|
401
504
|
|
|
402
505
|
|