SearchLibrium 0.0.87__tar.gz → 0.0.89__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/PKG-INFO +1 -1
  2. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/pyproject.toml +1 -1
  3. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/Halton.py +21 -6
  4. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/MixedLogit.py +25 -9
  5. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/__init__.py +9 -2
  6. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/_choice_model.py +0 -1
  7. searchlibrium-0.0.89/src/SearchLibrium/banditsa.py +300 -0
  8. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/call_meta.py +105 -2
  9. searchlibrium-0.0.89/src/SearchLibrium/sapbil.py +698 -0
  10. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/search.py +35 -5
  11. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/siman.py +13 -1
  12. searchlibrium-0.0.89/src/SearchLibrium/test_sapbil_vs_banditsa.py +246 -0
  13. searchlibrium-0.0.89/src/SearchLibrium/version.txt +1 -0
  14. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/PKG-INFO +1 -1
  15. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/SOURCES.txt +3 -0
  16. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/top_level.txt +1 -0
  17. searchlibrium-0.0.87/src/SearchLibrium/version.txt +0 -1
  18. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/README.md +0 -0
  19. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/setup.cfg +0 -0
  20. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/Mode_Activity_Nested.py +0 -0
  21. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/RandomP.py +0 -0
  22. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/SEARCH_SM_MARIO.py +0 -0
  23. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/Two_Level_Nest.py +0 -0
  24. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/__main__.py +0 -0
  25. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/_device.py +0 -0
  26. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/bhhh/minimize.py +0 -0
  27. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/boxcox_functions.py +0 -0
  28. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/constraints_builder.py +0 -0
  29. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/harmony.py +0 -0
  30. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/latent_class.py +0 -0
  31. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/main.py +0 -0
  32. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/main_debug.py +0 -0
  33. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/mdcev.py +0 -0
  34. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/misc.py +0 -0
  35. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/mixed_logit.py +0 -0
  36. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/mixed_nested.py +0 -0
  37. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/mixedrrm.py +0 -0
  38. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/multinomial_logit.py +0 -0
  39. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/multinomial_nested.py +0 -0
  40. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/multinomial_probit.py +0 -0
  41. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/ordered_logit.py +0 -0
  42. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/ordered_logit_mixed.py +0 -0
  43. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/rrm.py +0 -0
  44. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/selection_models.py +0 -0
  45. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/setup.py +0 -0
  46. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium/threshold.py +0 -0
  47. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/dependency_links.txt +0 -0
  48. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/entry_points.txt +0 -0
  49. {searchlibrium-0.0.87 → searchlibrium-0.0.89}/src/SearchLibrium.egg-info/requires.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: SearchLibrium
3
- Version: 0.0.87
3
+ Version: 0.0.89
4
4
  Summary: A Python package for econometric models driven by search
5
5
  Author: Alexander Paz Prithvi Beeramole, Robert Burdett
6
6
  Author-email: Zeke Ahern <z.ahern@qut.edu.au>
@@ -59,7 +59,7 @@ Homepage = "https://github.com/zahern/HypothesisX"
59
59
  realpython = "SearchLibrium.__main__:main"
60
60
 
61
61
  [tool.bumpver]
62
- current_version = "0.0.87"
62
+ current_version = "0.0.89"
63
63
  version_pattern = "MAJOR.MINOR.PATCH"
64
64
  commit_message = "[skip ci] Bump version {old_version} -> {new_version}"
65
65
  commit = true
@@ -67,7 +67,7 @@ import scipy.stats as ss
67
67
  class Halton:
68
68
  """Class for generating Halton sequences and Halton-based draws."""
69
69
 
70
- def __init__(self, primes=None, drop=100, shuffled=False):
70
+ def __init__(self, primes=None, drop=100, shuffled=False, antithetic=False):
71
71
  self.primes = primes or [2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37, 41, 43, 47,
72
72
  53, 59, 61, 71, 73, 79, 83, 89, 97, 101, 103, 107, 109,
73
73
  113, 127, 131, 137, 139, 149, 151, 157, 163, 167, 173,
@@ -76,15 +76,24 @@ class Halton:
76
76
  307, 311]
77
77
  self.drop = drop
78
78
  self.shuffled = shuffled
79
+ self.antithetic = antithetic
79
80
 
80
81
  def generate_draws(self, sample_size, n_draws, n_vars):
81
- """Generate Halton draws for multiple variables using different primes."""
82
+ """Generate Halton draws for multiple variables using different primes.
83
+
84
+ When ``antithetic=True``, draws of size ``n_draws // 2`` are generated
85
+ then mirrored (1 - u) to produce negatively-correlated antithetic pairs.
86
+ For normal-based distributions this halves variance at zero extra cost.
87
+ """
88
+ base = n_draws // 2 if self.antithetic else n_draws
82
89
  draws = [
83
- self.halton_seq(sample_size * n_draws, prime=self.primes[i % len(self.primes)]).reshape(sample_size,
84
- n_draws)
90
+ self.halton_seq(sample_size * base, prime=self.primes[i % len(self.primes)]).reshape(sample_size, base)
85
91
  for i in range(n_vars)
86
92
  ]
87
- return np.stack(draws, axis=1) # (N, Kr, R)
93
+ draws = np.stack(draws, axis=1) # (N, Kr, R_base)
94
+ if self.antithetic:
95
+ draws = np.concatenate([draws, 1.0 - draws], axis=2) # (N, Kr, n_draws)
96
+ return draws
88
97
 
89
98
  def halton_seq(self, length, prime):
90
99
  """Generates a scrambled Halton sequence for a given prime number."""
@@ -159,11 +168,17 @@ class Draws:
159
168
  draws = np.atleast_3d(draws)
160
169
  return draws
161
170
 
171
+ # Clip uniform draws away from 0/1 to prevent norm.ppf -> ±inf.
172
+ # This matters especially for antithetic draws where the complement of a
173
+ # small Halton value is very close to 1.
174
+ _PPF_CLIP = 1e-10
175
+
162
176
  def evaluate_distribution(self, distr, values):
163
177
  """Transform uniform values to the specified distribution."""
164
178
  for k, distr_k in enumerate(distr):
165
179
  if distr_k in ['n', 'ln', 'tn']: # Normal-based
166
- values[:, k, :] = ss.norm.ppf(values[:, k, :])
180
+ u = np.clip(values[:, k, :], self._PPF_CLIP, 1.0 - self._PPF_CLIP)
181
+ values[:, k, :] = ss.norm.ppf(u)
167
182
  elif distr_k == 't': # Triangular
168
183
  values_k = values[:, k, :]
169
184
  values[:, k, :] = (np.sqrt(2 * values_k) - 1) * (values_k <= .5) + \
@@ -51,9 +51,11 @@ class MixedLogit(DiscreteChoiceModel):
51
51
  n_vars (int): Number of variables.
52
52
  """
53
53
  if n_vars == 0:
54
- return np.ndarray((1,0,1))
55
- draws_s = Draws(k=n_vars, halton_opts=None)
56
- draws = draws_s.generate_draws(sample_size, n_draws, n_vars)
54
+ return np.ndarray((1, 0, 1))
55
+ # Use self.halton_opts so options like antithetic/shuffled are respected.
56
+ opts = self.halton_opts or {}
57
+ draws_s = Draws(k=n_vars, halton_opts=opts)
58
+ draws = draws_s.generate_draws(sample_size, n_draws)
57
59
  return draws
58
60
 
59
61
 
@@ -210,7 +212,6 @@ class MixedLogit(DiscreteChoiceModel):
210
212
  means_1 = np.mean(self.y, axis=3) # means_1[i,j] = avg(y[i,j,:])
211
213
  means_2 = np.mean(means_1, axis=1) # means_2[i] = avg(means_1[i,:])
212
214
  self.obs_prob = np.mean(means_2, axis=0) # obs_prob = avg(means_2[:])
213
- print(f'observed probs debug{self.obs_prob}')
214
215
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
215
216
  # DEFINE MEMBER FUNCTIONS TO APPLY
216
217
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
@@ -230,8 +231,6 @@ class MixedLogit(DiscreteChoiceModel):
230
231
  #print(f"self.rvtrans: {self.rvtransdist}")
231
232
  self.rvdist = [item for item in self.rvdist if item is not False]
232
233
  self.rvtransdist = [item for item in self.rvtransdist if item is not False]
233
- print(f"self.randvars: {self.rvdist}")
234
- print(f"self.rvtrans: {self.rvtransdist}")
235
234
  draws = self.generate_draws(self.N, self.n_draws, len(self.rvdist))
236
235
  drawstrans = self.generate_draws(self.N, self.n_draws, len(self.rvtransdist))
237
236
  self.draws, self.drawstrans = draws, drawstrans # Record generated values
@@ -282,13 +281,30 @@ class MixedLogit(DiscreteChoiceModel):
282
281
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
283
282
 
284
283
  arr = self.init_coeff[:lower]
285
- rep = np.repeat(0.1, self.Kchol + self.Kbw) # Array of 0.1s
284
+
285
+ # Better scale initialisation: use half the absolute value of the MNL
286
+ # mean estimates for the corresponding random variable, floored at 0.05.
287
+ # This gives BFGS a much better starting point than a flat 0.1 when
288
+ # coefficients are very small (e.g. cost in money units) or very large.
289
+ br_means = arr[self.Kf + 2 * self.Kftrans: self.Kf + 2 * self.Kftrans + self.Kr]
290
+
291
+ # Cholesky elements (correlated random vars) — keep at 0.1; the diagonal
292
+ # scaling is embedded in off-diagonal terms and is harder to infer.
293
+ chol_init = np.repeat(0.1, self.Kchol)
294
+
295
+ # Bandwidth (std dev) for non-correlated random vars.
296
+ bw_means = br_means[self.correlationLength:]
297
+ bw_init = np.maximum(np.abs(bw_means) * 0.5, 0.05)
298
+
299
+ rep = np.concatenate([chol_init, bw_init])
286
300
  self.init_coeff = np.concatenate((arr, rep, self.init_coeff[lower:upper],))
287
301
 
288
302
  if self.Krtrans: # CHECK ">0"
289
303
  # {
290
- rep = np.repeat(0.1, self.Krtrans) # An array with 0.1 repeated Krtrans times
291
- self.init_coeff = np.concatenate((self.init_coeff, rep, self.init_coeff[-self.Krtrans:]))
304
+ # Similarly scale random-transformed std devs from MNL means.
305
+ rtrans_means = self.init_coeff[lower:upper]
306
+ rtrans_scale_init = np.maximum(np.abs(rtrans_means) * 0.5, 0.05)
307
+ self.init_coeff = np.concatenate((self.init_coeff, rtrans_scale_init, self.init_coeff[-self.Krtrans:]))
292
308
  # }
293
309
  # }
294
310
 
@@ -12,6 +12,11 @@ from addicty import Dict
12
12
 
13
13
  import os
14
14
 
15
+ try:
16
+ from banditsa import BanditSA, PerturbationBandit
17
+ except ImportError:
18
+ from .banditsa import BanditSA, PerturbationBandit
19
+
15
20
  def new_features():
16
21
  '''ADDICTY DICT'''
17
22
  '''vars you want to ensure are included'''
@@ -103,7 +108,8 @@ try:
103
108
  from .search import Parameters
104
109
 
105
110
  from . import misc
106
- from .call_meta import call_harmony, call_siman, call_parsa, call_search, estimate_ctrl
111
+ from .sapbil import SAPBIL, ProbabilityMatrix
112
+ from .call_meta import call_harmony, call_siman, call_parsa, call_search, call_sapbil, estimate_ctrl
107
113
 
108
114
  except ImportError as e:
109
115
  from _choice_model import DiscreteChoiceModel
@@ -121,7 +127,8 @@ except ImportError as e:
121
127
  from RandomP import RandomParameters
122
128
  from constraints_builder import ConstraintBuilder, create_constraints
123
129
  from search import Parameters
124
- from call_meta import call_siman, call_harmony, call_search, estimate_ctrl
130
+ from sapbil import SAPBIL, ProbabilityMatrix
131
+ from call_meta import call_siman, call_harmony, call_search, call_sapbil, estimate_ctrl
125
132
  try:
126
133
  from .main import print_ascii_art_logo
127
134
  except Exception:
@@ -391,7 +391,6 @@ class DiscreteChoiceModel(ABC):
391
391
  # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
392
392
  non_sigs = self.num_of_exceeding_pvalues(self.pvalues, 0.0)
393
393
  #print('log like is before', self.loglik)
394
- print('apply p val')
395
394
  self.loglik -= non_sigs*self.pval_penalty # penalise the non-sigs
396
395
  logging.info('applying pval')
397
396
  #print('log like is', self.loglik)
@@ -0,0 +1,300 @@
1
+ """Bandit-guided Simulated Annealing for SearchLibrium.
2
+
3
+ This module provides:
4
+ - PerturbationBandit: Thompson-sampling policy over perturbation action types
5
+ - BanditSA: SA variant that uses the bandit to pick perturbation actions
6
+ """
7
+
8
+ import logging
9
+
10
+ import numpy as np
11
+
12
+ try:
13
+ from siman import SA
14
+ except ImportError:
15
+ from .siman import SA
16
+
17
+
18
+ class PerturbationBandit:
19
+ """Thompson Sampling over perturbation action-type arms."""
20
+
21
+ def __init__(self, arm_names, prior_alpha=1.0, prior_beta=1.0, epsilon=0.05, rng=None):
22
+ self.arm_names = list(arm_names)
23
+ self.n_arms = len(self.arm_names)
24
+ self.alpha = np.full(self.n_arms, float(prior_alpha), dtype=float)
25
+ self.beta = np.full(self.n_arms, float(prior_beta), dtype=float)
26
+ self.counts = np.zeros(self.n_arms, dtype=int)
27
+ self.epsilon = float(epsilon)
28
+ self.rng = rng if rng is not None else np.random.default_rng()
29
+
30
+ def select_arm(self, available_indices):
31
+ if not available_indices:
32
+ raise ValueError("No available bandit arms to select from.")
33
+
34
+ if self.rng.random() < self.epsilon:
35
+ return int(self.rng.choice(available_indices))
36
+
37
+ samples = {
38
+ idx: self.rng.beta(self.alpha[idx], self.beta[idx])
39
+ for idx in available_indices
40
+ }
41
+ return max(samples, key=samples.get)
42
+
43
+ def update(self, arm_index, reward):
44
+ reward = float(reward)
45
+ self.counts[arm_index] += 1
46
+
47
+ if reward >= 0.0:
48
+ self.alpha[arm_index] += 1.0 + reward
49
+ else:
50
+ self.beta[arm_index] += 1.0 + abs(reward)
51
+
52
+ def summary(self):
53
+ rows = []
54
+ total = np.maximum(self.alpha + self.beta, 1e-12)
55
+ means = self.alpha / total
56
+ for idx, name in enumerate(self.arm_names):
57
+ rows.append(
58
+ {
59
+ "arm": name,
60
+ "alpha": float(self.alpha[idx]),
61
+ "beta": float(self.beta[idx]),
62
+ "mean": float(means[idx]),
63
+ "count": int(self.counts[idx]),
64
+ }
65
+ )
66
+ rows.sort(key=lambda row: row["mean"], reverse=True)
67
+ return rows
68
+
69
+
70
+ class BanditSA(SA):
71
+ """SA variant with Thompson-sampling perturbation selection."""
72
+
73
+ def __init__(
74
+ self,
75
+ param,
76
+ init_sol,
77
+ ctrl,
78
+ idnum=0,
79
+ bandit_prior_alpha=1.0,
80
+ bandit_prior_beta=1.0,
81
+ bandit_epsilon=0.05,
82
+ **kwargs,
83
+ ):
84
+ super().__init__(param, init_sol, ctrl, idnum=idnum, **kwargs)
85
+
86
+ self._actions = self._build_action_table()
87
+ self.bandit = PerturbationBandit(
88
+ arm_names=[name for name, _ in self._actions],
89
+ prior_alpha=bandit_prior_alpha,
90
+ prior_beta=bandit_prior_beta,
91
+ epsilon=bandit_epsilon,
92
+ )
93
+ self.bandit_history = []
94
+
95
+ def _build_action_table(self):
96
+ return [
97
+ ("add_asfeature", self.perturb_add_asfeature),
98
+ ("remove_asfeature", self.perturb_remove_asfeature),
99
+ ("add_isfeature", self.perturb_add_isfeature),
100
+ ("remove_isfeature", self.perturb_remove_isfeature),
101
+ ("add_randfeature", self.perturb_add_randfeature),
102
+ ("remove_randfeature", self.perturb_remove_randfeature),
103
+ ("add_bcfeature", self.perturb_add_bcfeature),
104
+ ("remove_bcfeature", self.perturb_remove_bcfeature),
105
+ ("add_corfeature", self.perturb_add_corfeature),
106
+ ("remove_corfeature", self.perturb_remove_corfeature),
107
+ ("change_model", self.perturb_model_t),
108
+ ("change_distribution", self.perturb_distribution),
109
+ ]
110
+
111
+ def _available_arm_indices(self, sol):
112
+ available = []
113
+
114
+ if self.param.asvarnames and len(sol.get("asvars", [])) < len(self.param.asvarnames):
115
+ available.append(0) # add_asfeature
116
+ if len(sol.get("asvars", [])) > 1:
117
+ available.append(1) # remove_asfeature
118
+
119
+ if self.param.isvarnames:
120
+ is_candidates = [v for v in self.param.isvarnames if v not in sol.get("isvars", [])]
121
+ if is_candidates:
122
+ available.append(2) # add_isfeature
123
+ if sol.get("isvars"):
124
+ available.append(3) # remove_isfeature
125
+
126
+ if self.param.allow_random:
127
+ add_r_candidates = [
128
+ v for v in sol.get("asvars", []) if v not in sol.get("randvars", {})
129
+ ]
130
+ if add_r_candidates:
131
+ available.append(4) # add_randfeature
132
+
133
+ rem_r_candidates = [
134
+ v for v in sol.get("randvars", {}) if v not in self.param.ps_randvars
135
+ ]
136
+ if rem_r_candidates:
137
+ available.append(5) # remove_randfeature
138
+
139
+ dist_candidates = [
140
+ v for v in sol.get("randvars", {}) if v not in self.param.ps_randvars
141
+ ]
142
+ if dist_candidates and len(getattr(self.param, "distr", []) or []) > 1:
143
+ available.append(11) # change_distribution
144
+
145
+ if self.param.allow_bcvars:
146
+ add_bc_candidates = [
147
+ v
148
+ for v in sol.get("asvars", [])
149
+ if v not in sol.get("bcvars", []) and v not in self.param.ps_corvars
150
+ ]
151
+ if add_bc_candidates:
152
+ available.append(6) # add_bcfeature
153
+
154
+ rem_bc_candidates = [
155
+ v for v in sol.get("bcvars", []) if v not in self.param.ps_bcvars
156
+ ]
157
+ if rem_bc_candidates:
158
+ available.append(7) # remove_bcfeature
159
+
160
+ if self.param.allow_corvars and self.param.allow_random:
161
+ cor_eligible = [
162
+ v for v in sol.get("randvars", {}) if v not in sol.get("bcvars", [])
163
+ ]
164
+ if len(cor_eligible) >= 2:
165
+ available.append(8) # add_corfeature
166
+
167
+ rem_cor_candidates = [
168
+ v for v in sol.get("corvars", []) if v not in self.param.ps_corvars
169
+ ]
170
+ if rem_cor_candidates:
171
+ available.append(9) # remove_corfeature
172
+
173
+ if self.param.avail_models is not None and len(self.param.avail_models) > 1:
174
+ available.append(10) # change_model
175
+
176
+ return sorted(set(available))
177
+
178
+ def _compute_bandit_reward(self, old_obj, new_obj, accepted, converged):
179
+ if not converged:
180
+ return -1.0
181
+
182
+ sign = self.param.sign_crit(0)
183
+ # Positive value means objective improved under current criterion direction.
184
+ delta = sign * (new_obj - old_obj)
185
+ scale = max(abs(old_obj), 1.0)
186
+ normalized = delta / scale
187
+
188
+ if accepted:
189
+ if normalized > 0:
190
+ return min(1.0, 0.1 + 10.0 * normalized)
191
+ return max(-1.0, 10.0 * normalized)
192
+
193
+ return -0.2
194
+
195
+ def _apply_selected_action(self, sol):
196
+ baseline_sig = self.setup_signature(sol)
197
+ max_attempts = 12
198
+ last_arm_idx = None
199
+
200
+ for _ in range(max_attempts):
201
+ candidate = self.copy_solution(sol)
202
+ available = self._available_arm_indices(candidate)
203
+ if not available:
204
+ return sol, None, False
205
+
206
+ arm_idx = self.bandit.select_arm(available)
207
+ last_arm_idx = arm_idx
208
+ _, action = self._actions[arm_idx]
209
+ result = action(candidate)
210
+ if result is not None:
211
+ candidate = result
212
+
213
+ candidate = self.apply_constraints(candidate)
214
+ candidate = self.repair_solution_for_clarity(candidate)
215
+ if self.setup_signature(candidate) != baseline_sig:
216
+ return candidate, arm_idx, True
217
+
218
+ return sol, last_arm_idx, False
219
+
220
+ def perturb_solution(self, sol):
221
+ curr_score = [sol.obj(i) for i in range(self.nb_crit)]
222
+
223
+ new_sol, arm_idx, changed = self._apply_selected_action(sol)
224
+ if not changed:
225
+ if arm_idx is not None:
226
+ self.bandit.update(arm_idx, -0.5)
227
+ self.bandit_history.append(
228
+ {
229
+ "step": int(self.step),
230
+ "arm": self._actions[arm_idx][0],
231
+ "accepted": False,
232
+ "converged": False,
233
+ "reward": -0.5,
234
+ "reason": "no_change",
235
+ }
236
+ )
237
+ return self.current_sol
238
+
239
+ new_sol, converged = self.evaluate(new_sol)
240
+ if not converged:
241
+ self.not_converged += 1
242
+ reward = self._compute_bandit_reward(curr_score[0], curr_score[0], False, False)
243
+ self.bandit.update(arm_idx, reward)
244
+ self.bandit_history.append(
245
+ {
246
+ "step": int(self.step),
247
+ "arm": self._actions[arm_idx][0],
248
+ "accepted": False,
249
+ "converged": False,
250
+ "reward": float(reward),
251
+ "reason": "non_converged",
252
+ }
253
+ )
254
+ return self.current_sol
255
+
256
+ new_score = [new_sol.obj(i) for i in range(self.nb_crit)]
257
+ accepted = bool(self.accept_change(curr_score, new_score))
258
+
259
+ if accepted:
260
+ self.no_impr = 0
261
+ self.accepted += 1
262
+ self.current_sol = new_sol
263
+ self.update_best(new_sol)
264
+ else:
265
+ self.not_accepted += 1
266
+
267
+ reward = self._compute_bandit_reward(curr_score[0], new_score[0], accepted, converged)
268
+ self.bandit.update(arm_idx, reward)
269
+ self.bandit_history.append(
270
+ {
271
+ "step": int(self.step),
272
+ "arm": self._actions[arm_idx][0],
273
+ "accepted": bool(accepted),
274
+ "converged": bool(converged),
275
+ "reward": float(reward),
276
+ "old_obj": float(curr_score[0]),
277
+ "new_obj": float(new_score[0]),
278
+ }
279
+ )
280
+
281
+ self.log_kpi(new_sol, self.debug_file, accepted)
282
+ return self.current_sol
283
+
284
+ def finalise(self):
285
+ super().finalise()
286
+
287
+ lines = ["Bandit arm summary (sorted by posterior mean):"]
288
+ for row in self.bandit.summary():
289
+ lines.append(
290
+ " {arm:<20s} mean={mean:.4f} alpha={alpha:.2f} beta={beta:.2f} count={count}".format(
291
+ **row
292
+ )
293
+ )
294
+
295
+ for line in lines:
296
+ print(line)
297
+ try:
298
+ print(line, file=self.results_file)
299
+ except Exception:
300
+ logging.debug("Unable to write bandit summary to results_file")
@@ -21,11 +21,15 @@
21
21
  try:
22
22
  from harmony import*
23
23
  from siman import*
24
+ from banditsa import*
24
25
  from threshold import*
26
+ from sapbil import SAPBIL, ProbabilityMatrix
25
27
  except ImportError:
26
28
  from .harmony import*
27
29
  from .siman import*
30
+ from .banditsa import*
28
31
  from .threshold import*
32
+ from .sapbil import SAPBIL, ProbabilityMatrix
29
33
 
30
34
  import numpy as np
31
35
 
@@ -301,6 +305,99 @@ def call_siman(parameters, init_sol=None, ctrl=None, **kwargs):
301
305
  return best
302
306
 
303
307
 
308
+ def call_sapbil(parameters, init_sol=None, ctrl=None, **kwargs):
309
+ """
310
+ Run SA+PBIL (Simulated Annealing coupled with Population-Based Incremental
311
+ Learning) search.
312
+
313
+ Parameters
314
+ ----------
315
+ parameters : Parameters
316
+ Problem definition (variables, data, criteria, models).
317
+ init_sol : Solution, optional
318
+ Warm-start solution. None = generate automatically.
319
+ ctrl : tuple, optional
320
+ ``(tI, tF, max_temp_steps, max_iter)``.
321
+ If omitted the values are estimated from the problem size.
322
+ **kwargs
323
+ ``id_num`` — run identifier (int, used in log file names).
324
+ Any other kwargs are forwarded to the SAPBIL constructor.
325
+
326
+ Returns
327
+ -------
328
+ Solution
329
+ Best converged, all-significant solution found.
330
+ """
331
+ if ctrl is None:
332
+ ctrl = kwargs.pop("ctrl", None)
333
+
334
+ id_num = kwargs.pop("id_num", None)
335
+
336
+ if ctrl is None:
337
+ ctrl = estimate_ctrl(parameters, algorithm="sa")
338
+ print(
339
+ f"[SA+PBIL] Auto-estimated hyperparameters (problem complexity "
340
+ f"= {_problem_size(parameters)['complexity']}):"
341
+ )
342
+ else:
343
+ print("[SA+PBIL] Using provided hyperparameters:")
344
+
345
+ print(_describe_ctrl(ctrl, "sa"))
346
+ print()
347
+
348
+ solver = SAPBIL(parameters, init_sol, ctrl, id_num, **kwargs)
349
+ solver.run()
350
+ solver.close_files()
351
+ best = solver.return_best()
352
+ _print_dashboard(solver, best, algorithm="SA+PBIL")
353
+ return best
354
+
355
+
356
+ def call_banditsa(parameters, init_sol=None, ctrl=None, **kwargs):
357
+ """
358
+ Run Bandit-guided Simulated Annealing (Thompson Sampling on perturbation arms).
359
+
360
+ Parameters
361
+ ----------
362
+ parameters : Parameters
363
+ Problem definition (variables, data, criteria, models).
364
+ init_sol : Solution, optional
365
+ Warm-start solution. None = generate automatically.
366
+ ctrl : tuple, optional
367
+ ``(tI, tF, max_temp_steps, max_iter)``.
368
+ If omitted the values are estimated from the problem size.
369
+ **kwargs
370
+ ``id_num`` - run identifier (int, used in log file names).
371
+ Any other kwargs are forwarded to the BanditSA constructor.
372
+
373
+ Returns
374
+ -------
375
+ Solution
376
+ Best converged, all-significant solution found.
377
+ """
378
+ if ctrl is None:
379
+ ctrl = kwargs.pop('ctrl', None)
380
+
381
+ id_num = kwargs.pop('id_num', None)
382
+
383
+ if ctrl is None:
384
+ ctrl = estimate_ctrl(parameters, algorithm='sa')
385
+ print(f"[BanditSA] Auto-estimated hyperparameters (problem complexity "
386
+ f"= {_problem_size(parameters)['complexity']}):")
387
+ else:
388
+ print("[BanditSA] Using provided hyperparameters:")
389
+
390
+ print(_describe_ctrl(ctrl, 'sa'))
391
+ print()
392
+
393
+ solver = BanditSA(parameters, init_sol, ctrl, id_num, **kwargs)
394
+ solver.run()
395
+ solver.close_files()
396
+ best = solver.return_best()
397
+ _print_dashboard(solver, best, algorithm='BanditSA')
398
+ return best
399
+
400
+
304
401
  # ─────────────────────────────────────────────────────────────────────────────
305
402
  # Harmony Search
306
403
  # ─────────────────────────────────────────────────────────────────────────────
@@ -364,8 +461,9 @@ def call_search(parameters, init_sol=None, algorithm='sa', ctrl=None, **kwargs):
364
461
  Problem definition.
365
462
  init_sol : Solution, optional
366
463
  Warm-start solution.
367
- algorithm : {'sa', 'hs'}
464
+ algorithm : {'sa', 'banditsa', 'hs'}
368
465
  ``'sa'`` — Simulated Annealing (default)
466
+ ``'banditsa'`` — Bandit-guided Simulated Annealing
369
467
  ``'hs'`` — Harmony Search
370
468
  ctrl : tuple, optional
371
469
  Algorithm-specific control tuple. Auto-estimated if omitted.
@@ -383,6 +481,7 @@ def call_search(parameters, init_sol=None, algorithm='sa', ctrl=None, **kwargs):
383
481
  Examples
384
482
  --------
385
483
  >>> best = call_search(params) # SA, auto ctrl
484
+ >>> best = call_search(params, algorithm='banditsa') # BanditSA, auto ctrl
386
485
  >>> best = call_search(params, algorithm='hs') # HS, auto ctrl
387
486
  >>> best = call_search(params, ctrl=(500,0.001,80,15))# SA, manual ctrl
388
487
  >>> best = call_search(params, algorithm='hs',
@@ -391,12 +490,16 @@ def call_search(parameters, init_sol=None, algorithm='sa', ctrl=None, **kwargs):
391
490
  algorithm = algorithm.lower().strip()
392
491
  if algorithm in ('sa', 'siman', 'simulated_annealing'):
393
492
  return call_siman(parameters, init_sol=init_sol, ctrl=ctrl, **kwargs)
493
+ elif algorithm in ('sapbil', 'sa_pbil', 'sa+pbil', 'pbil'):
494
+ return call_sapbil(parameters, init_sol=init_sol, ctrl=ctrl, **kwargs)
495
+ elif algorithm in ('banditsa', 'bandit_sa', 'bandit-simulated-annealing', 'bsa'):
496
+ return call_banditsa(parameters, init_sol=init_sol, ctrl=ctrl, **kwargs)
394
497
  elif algorithm in ('hs', 'harmony', 'harmony_search'):
395
498
  return call_harmony(parameters, init_sol=init_sol, ctrl=ctrl, **kwargs)
396
499
  else:
397
500
  raise ValueError(
398
501
  f"Unknown algorithm '{algorithm}'. "
399
- f"Choose 'sa' (Simulated Annealing) or 'hs' (Harmony Search)."
502
+ f"Choose 'sa', 'sapbil', 'banditsa', or 'hs'."
400
503
  )
401
504
 
402
505