badr 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (42) hide show
  1. badr/__base__.py +126 -0
  2. badr/__init__.py +13 -0
  3. badr/_version.py +1 -0
  4. badr/algorithms/__base__.py +90 -0
  5. badr/algorithms/__init__.py +13 -0
  6. badr/algorithms/badr.py +247 -0
  7. badr/algorithms/frank_wolfe.py +137 -0
  8. badr/algorithms/slsqp.py +130 -0
  9. badr/algorithms/trust_constr.py +111 -0
  10. badr/datasets/__init__.py +207 -0
  11. badr/datasets/adult.py +200 -0
  12. badr/datasets/arrhythmia.py +404 -0
  13. badr/datasets/communities_and_crime.py +143 -0
  14. badr/datasets/compas.py +178 -0
  15. badr/datasets/dataframe.py +168 -0
  16. badr/datasets/folktables.py +881 -0
  17. badr/datasets/german_credit.py +255 -0
  18. badr/datasets/law_school.py +150 -0
  19. badr/datasets/parkinsons_telemonitoring.py +117 -0
  20. badr/datasets/student_performance.py +159 -0
  21. badr/metrics/__base__.py +90 -0
  22. badr/metrics/__init__.py +19 -0
  23. badr/metrics/demographic_parity.py +53 -0
  24. badr/metrics/disparate_mistreatment.py +35 -0
  25. badr/metrics/equal_opportunity.py +43 -0
  26. badr/metrics/equalized_odds.py +55 -0
  27. badr/metrics/group_variance.py +61 -0
  28. badr/metrics/hsic.py +38 -0
  29. badr/metrics/individual_fairness.py +58 -0
  30. badr/models/__base__.py +106 -0
  31. badr/models/__init__.py +14 -0
  32. badr/models/_logistic_regression.py +267 -0
  33. badr/models/_ridge_regression.py +363 -0
  34. badr/models/_smooth_svm.py +467 -0
  35. badr/oracles/__base__.py +101 -0
  36. badr/oracles/__init__.py +5 -0
  37. badr/oracles/implicit_oracle.py +125 -0
  38. badr/oracles/stochastic_oracle.py +385 -0
  39. badr-0.1.0.dist-info/METADATA +56 -0
  40. badr-0.1.0.dist-info/RECORD +42 -0
  41. badr-0.1.0.dist-info/WHEEL +4 -0
  42. badr-0.1.0.dist-info/licenses/LICENSE +674 -0
badr/__base__.py ADDED
@@ -0,0 +1,126 @@
1
+ from badr.algorithms import SLSQP
2
+ from badr.datasets import Dataset
3
+ from badr.metrics import FairnessMetric
4
+ from badr.models import Model
5
+ from badr.oracles import ImplicitOracle, StochasticOracle
6
+
7
+
8
+ class Badr:
9
+ def __init__(
10
+ self,
11
+ dset: Dataset,
12
+ model: Model,
13
+ metric: FairnessMetric,
14
+ train_test: str = "train",
15
+ oracle: str = "implicit",
16
+ solver_cls=None,
17
+ solver_kwargs=None,
18
+ ) -> None:
19
+ """
20
+ Parameters
21
+ ----------
22
+ dset : Dataset
23
+ Dataset with (X_train, y_train), (X_test, y_test), and groups.
24
+ model : Model
25
+ Model with set_group_weights(...), fit(...), and coef_/intercept_.
26
+ metric : FairnessMetric
27
+ Metric; bound to `model` if `metric.model is None`.
28
+ train_test : {"train", "test"}, default="train"
29
+ Which split to use.
30
+ oracle : {"implicit", "stochastic"}, default="implicit"
31
+ Oracle implementation to use.
32
+ solver_cls : type, optional
33
+ Solver class (default: SLSQP).
34
+ solver_kwargs : dict, optional
35
+ Keyword args passed to `solver_cls(...)`.
36
+
37
+ Raises
38
+ ------
39
+ ValueError
40
+ If `oracle` is not one of {"implicit", "stochastic"}.
41
+ """
42
+ if metric.model is None:
43
+ metric.set_model(model)
44
+ self.dset = dset
45
+ self.model = model
46
+ self.metric = metric
47
+ self.train_test = train_test
48
+ self.X = dset.X_train if train_test == "train" else dset.X_test
49
+ self.y = dset.y_train if train_test == "train" else dset.y_test
50
+ if oracle == "implicit":
51
+ self.oracle = ImplicitOracle(
52
+ dset=dset,
53
+ model=model,
54
+ metric=metric,
55
+ train_test=train_test,
56
+ )
57
+ elif oracle == "stochastic":
58
+ self.oracle = StochasticOracle(
59
+ dset=dset,
60
+ model=model,
61
+ metric=metric,
62
+ train_test=train_test,
63
+ )
64
+ else:
65
+ raise ValueError(
66
+ f"Unknown oracle type: {oracle}. Supported types are 'implicit' and 'stochastic'."
67
+ )
68
+ self.solver_cls = solver_cls or SLSQP
69
+ self.solver_kwargs = solver_kwargs or {}
70
+ self._solver = None
71
+
72
+ def set_solver(self, solver):
73
+ """
74
+ Set a solver instance to use in `run`.
75
+
76
+ Parameters
77
+ ----------
78
+ solver
79
+ Solver instance. If `solver.oracle is None`, `run` will set it.
80
+
81
+ Returns
82
+ -------
83
+ Badr
84
+ Self.
85
+ """
86
+ self._solver = solver
87
+ return self
88
+
89
+ def run(self, **run_kwargs) -> None:
90
+ """
91
+ Run the solver, set group weights, refit the model, and compute outputs.
92
+
93
+ Parameters
94
+ ----------
95
+ **run_kwargs
96
+ Passed to `solver.run(**run_kwargs)`.
97
+
98
+ Sets Attributes
99
+ ---------------
100
+ group_weights
101
+ Learned group weights.
102
+ coef_
103
+ Fitted coefficients.
104
+ intercept_
105
+ Fitted intercept.
106
+ fairness
107
+ Metric value on the selected split.
108
+ group_losses
109
+ Per-group losses from the model.
110
+ """
111
+ if self._solver is None: # 1) lazy‐instantiate solver if not given
112
+ self._solver = self.solver_cls(**self.solver_kwargs)
113
+
114
+ # 2) bind oracle if user didn’t already
115
+ if self._solver.oracle is None:
116
+ self._solver.set_oracle(self.oracle)
117
+ solver = self._solver
118
+ solver.run(**run_kwargs)
119
+ self.group_weights = solver.group_weights
120
+ print(f"Group weights: {self.group_weights}")
121
+ self.model.set_group_weights(self.group_weights)
122
+ self.model.fit(self.X, self.y, self.dset.groups)
123
+ self.coef_ = self.model.coef_
124
+ self.intercept_ = self.model.intercept_
125
+ self.fairness = self.metric.fun(self.model.coef_, self.dset, self.train_test)
126
+ self.group_losses = self.model._group_loss(self.dset)
badr/__init__.py ADDED
@@ -0,0 +1,13 @@
1
+ from . import algorithms, datasets, metrics, models, oracles
2
+ from .__base__ import Badr
3
+ from ._version import __version__
4
+
5
+ __all__ = [
6
+ "__version__",
7
+ "datasets",
8
+ "metrics",
9
+ "models",
10
+ "oracles",
11
+ "algorithms",
12
+ "Badr",
13
+ ]
badr/_version.py ADDED
@@ -0,0 +1 @@
1
+ __version__ = "0.1.0"
@@ -0,0 +1,90 @@
1
+ from datetime import datetime
2
+
3
+ import jax.numpy as jnp
4
+ import numpy as np
5
+
6
+ from badr.oracles import Oracle
7
+
8
+
9
+ class Trace:
10
+ def __init__(self, f=None, freq=1) -> None:
11
+ self.trace_x = []
12
+ self.trace_time = []
13
+ self.trace_fx = []
14
+ self.start = datetime.now()
15
+ self._counter = 0
16
+ self.freq = int(freq)
17
+ self.f = f
18
+
19
+ def __call__(self, dl):
20
+ if self._counter % self.freq == 0:
21
+ if self.f is not None:
22
+ self.trace_fx.append(float(self.f(jnp.array(dl["x"]))))
23
+ else:
24
+ self.trace_x.append(np.copy(dl["x"]))
25
+ delta = (datetime.now() - self.start).total_seconds()
26
+ self.trace_time.append(delta)
27
+ self._counter += 1
28
+
29
+
30
+ class Algorithm:
31
+ """
32
+ Base class for algorithms optimizing group weights.
33
+
34
+ Stores common bookkeeping (objective trace, timings, iterates) and a reference
35
+ to an :class:`~badr.oracles.Oracle`.
36
+
37
+ Parameters
38
+ ----------
39
+ name : str
40
+ Display name for the algorithm.
41
+
42
+ Attributes
43
+ ----------
44
+ oracle : Oracle or None
45
+ Oracle providing ``fun``/``grad`` (and possibly stochastic primitives).
46
+ n_groups : int
47
+ Number of groups (set from the oracle).
48
+ group_weights : jax.numpy.ndarray
49
+ Latest group-weight iterate (typically on the simplex).
50
+ history_f : list[float]
51
+ Traced objective values (when enabled by the algorithm / ``trace`` flag).
52
+ history_time : list[float]
53
+ Elapsed time trace (seconds).
54
+ history_lambda : list[jax.numpy.ndarray]
55
+ Traced group-weight iterates (when enabled).
56
+ """
57
+
58
+ def __init__(self, name: str) -> None:
59
+ self.name = name
60
+ self.history_f = []
61
+ self.history_time = []
62
+ self.history_lambda = []
63
+ self._last_callback_time = None
64
+ self.group_weights: jnp.ndarray = jnp.array([])
65
+ self.oracle = None
66
+
67
+ def set_oracle(self, oracle: Oracle) -> None:
68
+ # Set the oracle for the algorithm.
69
+ self.oracle = oracle
70
+ self.n_groups = oracle.n_groups
71
+
72
+ def run(self, max_iter: int = 1, verbose: int = 1, trace: bool = False):
73
+ """
74
+ Run the algorithm.
75
+
76
+ Parameters
77
+ ----------
78
+ max_iter : int, default=1
79
+ Maximum number of iterations.
80
+ verbose : int, default=1
81
+ Verbosity level (interpretation is algorithm-specific).
82
+ trace : bool, default=False
83
+ If True, record per-iteration history (objective/iterates/times).
84
+
85
+ Raises
86
+ ------
87
+ NotImplementedError
88
+ If not implemented by a subclass.
89
+ """
90
+ raise NotImplementedError("Algorithm.run() must be implemented in subclasses.")
@@ -0,0 +1,13 @@
1
+ from .__base__ import Algorithm
2
+ from .badr import BADRSGD
3
+ from .frank_wolfe import FrankWolfe
4
+ from .slsqp import SLSQP
5
+ from .trust_constr import TrustConstr
6
+
7
+ __all__ = [
8
+ "Algorithm",
9
+ "BADRSGD",
10
+ "FrankWolfe",
11
+ "SLSQP",
12
+ "TrustConstr",
13
+ ]
@@ -0,0 +1,247 @@
1
+ import time
2
+ import jax.numpy as jnp
3
+ from .__base__ import Algorithm
4
+
5
+
6
+ class BADRSGD(Algorithm):
7
+ """
8
+ Stochastic BADR updates over (w, v, lambda) with simplex projection.
9
+
10
+ Uses a :class:`~badr.oracles.StochasticOracle`-style interface:
11
+ draws a minibatch, forms a group-weighted inner gradient for ``w``, updates an
12
+ auxiliary vector ``v`` using Hessian-vector products, and updates ``lambda``
13
+ via a clipped step followed by projection onto the simplex.
14
+
15
+ Parameters
16
+ ----------
17
+ w0 : jax.numpy.ndarray
18
+ Initial parameter vector for the lower-level variable ``w``.
19
+ batch_size : int, default=1
20
+ Total minibatch size across groups.
21
+ step_w : float, default=1e-1
22
+ Step size for ``w`` and ``v`` updates (used as ``step_1``).
23
+ step_v : float, default=1e-1
24
+ Stored step size for ``v`` (currently not used in the shown code).
25
+ step_lambda : float, default=1.0
26
+ Step size for ``lambda`` update (used as ``step_2``).
27
+ clip_value : float, default=1.0
28
+ L2-norm clipping threshold applied to the ``lambda`` gradient estimate.
29
+
30
+ Attributes
31
+ ----------
32
+ primal_solution : jax.numpy.ndarray or None
33
+ Final ``w`` iterate.
34
+ aux_solution : jax.numpy.ndarray or None
35
+ Final ``v`` iterate.
36
+ message : str or None
37
+ Status message after :meth:`run`.
38
+ history_w, history_v, history_lambda : list[jax.numpy.ndarray]
39
+ Iterates recorded during :meth:`run`.
40
+ history_inner_loss_batch : list[float]
41
+ ``f_hat(w, lambda; batch)`` over iterations.
42
+ history_outer_metric : list[float]
43
+ Metric value on the full split at current ``w``.
44
+ history_norm_grad_w, history_norm_v, history_norm_jt_v, history_norm_grad_H_w : list[float]
45
+ Diagnostic norms recorded per iteration.
46
+ history_lambda_entropy : list[float]
47
+ Entropy ``-sum(lambda log lambda)`` per iteration.
48
+ history_clip_fraction : list[float]
49
+ Running fraction of iterations where lambda-gradient clipping activated.
50
+
51
+ Notes
52
+ -----
53
+ - Simplex projection uses sorting-based Euclidean projection.
54
+ - Clipping is detected by comparing the pre/post L2 norms of the gradient.
55
+ """
56
+
57
+ def __init__(
58
+ self,
59
+ w0: jnp.ndarray,
60
+ batch_size: int = 1,
61
+ step_w: float = 1e-1,
62
+ step_v: float = 1e-1,
63
+ step_lambda: float = 1.0,
64
+ clip_value: float = 1.0,
65
+ ) -> None:
66
+ super().__init__("BADRSGD")
67
+ self.w0 = jnp.array(w0)
68
+ self.batch_size = batch_size
69
+ self.step_w = step_w
70
+ self.step_v = step_v
71
+ self.step_lambda = step_lambda
72
+ self.clip_value = clip_value
73
+
74
+ # existing histories
75
+ self.history_w = []
76
+ self.history_v = []
77
+ self.history_lambda = []
78
+ self.history_time = []
79
+ self.message = None
80
+ self.primal_solution = None
81
+ self.aux_solution = None
82
+
83
+ # NEW: metric histories
84
+ self.history_inner_loss_batch = []
85
+ self.history_outer_metric = []
86
+ self.history_norm_grad_w = []
87
+ self.history_norm_v = []
88
+ self.history_norm_jt_v = []
89
+ self.history_norm_grad_H_w = []
90
+ self.history_lambda_entropy = []
91
+ self.history_clip_fraction = [] # running fraction in [0,1]
92
+
93
+ @staticmethod
94
+ def _project_simplex(v: jnp.ndarray, radius: float = 1.0) -> jnp.ndarray:
95
+ v = jnp.asarray(v)
96
+ v_sorted = jnp.sort(v)[::-1]
97
+ cssv = jnp.cumsum(v_sorted) - radius
98
+ ind = jnp.arange(1, v.shape[0] + 1)
99
+ cond = v_sorted - cssv / ind > 0
100
+ rho = max(int(jnp.sum(cond)) - 1, 0)
101
+ theta = cssv[rho] / (rho + 1)
102
+ return jnp.maximum(v - theta, 0.0)
103
+
104
+ def _clip(self, g: jnp.ndarray) -> jnp.ndarray:
105
+ norm = jnp.linalg.norm(g)
106
+ factor = jnp.minimum(1.0, self.clip_value / (norm + 1e-12))
107
+ return factor * g
108
+
109
+ def run(self, max_iter: int = 1000, verbose: int = 1, trace: bool = False):
110
+ """
111
+ Run stochastic BADR iterations.
112
+
113
+ Parameters
114
+ ----------
115
+ max_iter : int, default=1000
116
+ Number of iterations.
117
+ verbose : int, default=1
118
+ If > 0, prints a completion message.
119
+ trace : bool, default=False
120
+ If True, also appends the outer metric to ``history_f`` each iteration.
121
+
122
+ Returns
123
+ -------
124
+ jax.numpy.ndarray
125
+ Final group weights ``lambda``.
126
+
127
+ Raises
128
+ ------
129
+ ValueError
130
+ If no oracle has been set.
131
+ """
132
+ if self.oracle is None:
133
+ raise ValueError("Oracle not set. Please set the oracle before running.")
134
+
135
+ w = jnp.array(self.w0)
136
+ v = jnp.zeros_like(w)
137
+ lmbda = jnp.ones(self.oracle.n_groups) / self.oracle.n_groups
138
+
139
+ # reset histories
140
+ self.history_w = [w]
141
+ self.history_v = [v]
142
+ self.history_lambda = [lmbda]
143
+ self.history_time = [0.0]
144
+ if trace:
145
+ self.history_f = [
146
+ float(
147
+ self.oracle.metric.fun(w, self.oracle.dset, self.oracle.train_test)
148
+ )
149
+ ]
150
+ else:
151
+ self.history_f = []
152
+
153
+ # also reset metric histories
154
+ self.history_inner_loss_batch = []
155
+ self.history_outer_metric = []
156
+ self.history_norm_grad_w = []
157
+ self.history_norm_v = []
158
+ self.history_norm_jt_v = []
159
+ self.history_norm_grad_H_w = []
160
+ self.history_lambda_entropy = []
161
+ self.history_clip_fraction = []
162
+
163
+ start_t = time.perf_counter()
164
+ step_1 = self.step_w
165
+ step_2 = self.step_lambda
166
+ clipped_count = 0 # running count of clipped λ-updates
167
+
168
+ for t in range(max_iter):
169
+ # --- draw ONE inner batch and reuse it for all inner terms (SOBA)
170
+ batch, self.oracle.key = self.oracle.sample_batch(
171
+ self.oracle.key, self.batch_size
172
+ )
173
+
174
+ # 1) inner gradient g_w = ∇_w f_inner(w, λ; batch)
175
+ grad_F = self.oracle.grad_lower_groups(w, batch) # shape (S, d)
176
+ grad_w = grad_F.T @ lmbda # shape (d,)
177
+ w_next = w - step_1 * grad_w
178
+
179
+ # 2) outer gradient pieces (metric)
180
+ grad_H_w, grad_H_lambda = self.oracle.grad_upper(w, lmbda)
181
+
182
+ # 3) v-update: v ← v - ρ( H_{ww} v + ∇_w H )
183
+ hvp_weighted = self.oracle.hvp_w_f_hat(w, lmbda, v, batch) # shape (d,)
184
+ v_next = v - step_1 * (grad_H_w + hvp_weighted)
185
+
186
+ # 4) λ-update: λ ← Proj_Δ( λ - γ * Clip( J^T v + ∇_λ H ) )
187
+ jt_v = self.oracle.jt_v_lambda_of_grad_w_f_hat(w, v, batch) # shape (S,)
188
+ lambda_grad_est = jt_v + grad_H_lambda
189
+ # detect clipping (norm reduced)
190
+ pre_norm = float(jnp.linalg.norm(lambda_grad_est))
191
+ clipped = self._clip(lambda_grad_est)
192
+ post_norm = float(jnp.linalg.norm(clipped))
193
+ did_clip = 1.0 if post_norm + 1e-12 < pre_norm else 0.0
194
+ clipped_count += did_clip
195
+
196
+ lambda_candidate = lmbda - step_2 * clipped
197
+ lambda_next = self._project_simplex(lambda_candidate)
198
+
199
+ # step
200
+ w, v, lmbda = w_next, v_next, lambda_next
201
+
202
+ # --- logging ---
203
+ # inner loss on the sampled batch
204
+ inner_loss_batch = float(self.oracle.f_hat(w, lmbda, batch))
205
+ # outer metric (current w on full data; same as your trace metric)
206
+ outer_metric_val = float(
207
+ self.oracle.metric.fun(w, self.oracle.dset, self.oracle.train_test)
208
+ )
209
+
210
+ # norms
211
+ norm_grad_w = float(jnp.linalg.norm(grad_w))
212
+ norm_v = float(jnp.linalg.norm(v))
213
+ norm_jt_v = float(jnp.linalg.norm(jt_v))
214
+ norm_grad_H_w = float(jnp.linalg.norm(grad_H_w))
215
+
216
+ # λ entropy
217
+ lambda_entropy = float(-jnp.sum(lmbda * jnp.log(lmbda + 1e-12)))
218
+
219
+ # fraction of clipped λ-updates up to now
220
+ clip_fraction = clipped_count / float(t + 1)
221
+
222
+ # bookkeeping
223
+ elapsed = time.perf_counter() - start_t
224
+ self.history_w.append(w)
225
+ self.history_v.append(v)
226
+ self.history_lambda.append(lmbda)
227
+ self.history_time.append(elapsed)
228
+ if trace:
229
+ self.history_f.append(outer_metric_val)
230
+
231
+ # store metrics
232
+ self.history_inner_loss_batch.append(inner_loss_batch)
233
+ self.history_outer_metric.append(outer_metric_val)
234
+ self.history_norm_grad_w.append(norm_grad_w)
235
+ self.history_norm_v.append(norm_v)
236
+ self.history_norm_jt_v.append(norm_jt_v)
237
+ self.history_norm_grad_H_w.append(norm_grad_H_w)
238
+ self.history_lambda_entropy.append(lambda_entropy)
239
+ self.history_clip_fraction.append(clip_fraction)
240
+
241
+ self.primal_solution = w
242
+ self.aux_solution = v
243
+ self.group_weights = lmbda
244
+ self.message = f"Completed {max_iter} (stochastic) iterations."
245
+ if verbose > 0:
246
+ print("[BADR] Message:", self.message)
247
+ return self.group_weights
@@ -0,0 +1,137 @@
1
+ from datetime import datetime
2
+ from functools import partial
3
+
4
+ import jax.numpy as jnp
5
+ from jax import jit
6
+
7
+ from .__base__ import Algorithm
8
+
9
+
10
+ class FrankWolfe(Algorithm):
11
+ """
12
+ Frank--Wolfe on the simplex using a linear minimization oracle.
13
+
14
+ Uses the oracle gradient ``g = oracle.grad(x)`` and the simplex LMO that
15
+ returns the vertex at ``argmin_i g_i``. The update uses a diminishing step
16
+ size given in :meth:`_step`. Convergence is checked with the FW gap
17
+ ``g_t = -g^T (x - x_prev)``.
18
+
19
+ Parameters
20
+ ----------
21
+ starting_point : jax.numpy.ndarray or None, optional
22
+ Initial point on the simplex. If None, uses the uniform distribution.
23
+ eps : float, default=1e-6
24
+ Stopping threshold on the FW gap.
25
+
26
+ Attributes
27
+ ----------
28
+ iterates : list[jax.numpy.ndarray]
29
+ Stored iterates (always filled).
30
+ success : bool
31
+ True if the stopping condition was reached.
32
+ message : str or None
33
+ Status message after :meth:`run`.
34
+ """
35
+
36
+ def __init__(self, starting_point=None, eps=1e-6):
37
+ super().__init__("Frank-Wolfe")
38
+ self.oracle = None
39
+ self.eps = eps
40
+ self.starting_point = starting_point
41
+ self.success = False
42
+ self.iterates = []
43
+ self.history_x = []
44
+ self.history_time = []
45
+ self.history_fx = []
46
+ self.message = None
47
+
48
+ @staticmethod
49
+ def _lmo(grad: jnp.ndarray) -> jnp.ndarray:
50
+ v = jnp.zeros_like(grad)
51
+ idx = jnp.argmin(grad)
52
+ return v.at[idx].set(1.0)
53
+
54
+ @staticmethod
55
+ @partial(jit, static_argnums=(2,))
56
+ def _step(x: jnp.ndarray, g: jnp.ndarray, iteration: int) -> jnp.ndarray:
57
+ v = FrankWolfe._lmo(g)
58
+ d = v - x
59
+ alpha = jnp.minimum(
60
+ 1.0,
61
+ (2.0 + jnp.log(iteration + 1)) / (iteration + 2 + jnp.log(iteration + 1)),
62
+ )
63
+ return x + alpha * d
64
+
65
+ def run(self, max_iter: int = 300, verbose: int = 1, trace: bool = False):
66
+ """
67
+ Run Frank--Wolfe iterations.
68
+
69
+ Parameters
70
+ ----------
71
+ max_iter : int, default=300
72
+ Maximum number of iterations.
73
+ verbose : int, default=1
74
+ If > 0, prints success and message.
75
+ trace : bool, default=False
76
+ If True, records iterates and elapsed time.
77
+
78
+ Returns
79
+ -------
80
+ jax.numpy.ndarray
81
+ Final group weights.
82
+
83
+ Raises
84
+ ------
85
+ ValueError
86
+ If no oracle has been set.
87
+ """
88
+ self.group_weights = jnp.zeros(self.n_groups)
89
+ if self.starting_point is None:
90
+ self.starting_point = jnp.full(self.n_groups, 1.0 / self.n_groups)
91
+ else:
92
+ self.starting_point = self.starting_point
93
+ if self.oracle is None:
94
+ raise ValueError("Oracle not set. Please set the oracle before running.")
95
+ if trace:
96
+ self.start = datetime.now()
97
+
98
+ x = self.starting_point
99
+ self.iterates = [x]
100
+
101
+ for it in range(max_iter):
102
+ g = self.oracle.grad(x)
103
+
104
+ x = FrankWolfe._step(x, g, it)
105
+
106
+ self.iterates.append(x)
107
+
108
+ if trace:
109
+ now = datetime.now()
110
+ self.history_time.append((now - self.start).total_seconds())
111
+ self.history_x.append(x)
112
+
113
+ g_t = jnp.dot(-g, x - self.iterates[-2])
114
+ if g_t < self.eps:
115
+ self.success = True
116
+ break
117
+
118
+ self.group_weights = x
119
+ self.message = f"Converged in {len(self.iterates) - 1} iterations (over {max_iter} iterations)."
120
+ if verbose > 0:
121
+ print(f"[FW] Success: {self.success}")
122
+ print(f"[FW] Message: {self.message}")
123
+ return self.group_weights
124
+
125
+ def postprocess(self):
126
+ """
127
+ Populate ``history_f`` from the stored ``history_x``.
128
+
129
+ Raises
130
+ ------
131
+ ValueError
132
+ If no oracle has been set.
133
+ """
134
+ if self.oracle is None:
135
+ raise ValueError("Oracle not set. Please set the oracle before running.")
136
+ for x in self.history_x:
137
+ self.history_f.append(self.oracle.fun(x))