badr 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- badr/__base__.py +126 -0
- badr/__init__.py +13 -0
- badr/_version.py +1 -0
- badr/algorithms/__base__.py +90 -0
- badr/algorithms/__init__.py +13 -0
- badr/algorithms/badr.py +247 -0
- badr/algorithms/frank_wolfe.py +137 -0
- badr/algorithms/slsqp.py +130 -0
- badr/algorithms/trust_constr.py +111 -0
- badr/datasets/__init__.py +207 -0
- badr/datasets/adult.py +200 -0
- badr/datasets/arrhythmia.py +404 -0
- badr/datasets/communities_and_crime.py +143 -0
- badr/datasets/compas.py +178 -0
- badr/datasets/dataframe.py +168 -0
- badr/datasets/folktables.py +881 -0
- badr/datasets/german_credit.py +255 -0
- badr/datasets/law_school.py +150 -0
- badr/datasets/parkinsons_telemonitoring.py +117 -0
- badr/datasets/student_performance.py +159 -0
- badr/metrics/__base__.py +90 -0
- badr/metrics/__init__.py +19 -0
- badr/metrics/demographic_parity.py +53 -0
- badr/metrics/disparate_mistreatment.py +35 -0
- badr/metrics/equal_opportunity.py +43 -0
- badr/metrics/equalized_odds.py +55 -0
- badr/metrics/group_variance.py +61 -0
- badr/metrics/hsic.py +38 -0
- badr/metrics/individual_fairness.py +58 -0
- badr/models/__base__.py +106 -0
- badr/models/__init__.py +14 -0
- badr/models/_logistic_regression.py +267 -0
- badr/models/_ridge_regression.py +363 -0
- badr/models/_smooth_svm.py +467 -0
- badr/oracles/__base__.py +101 -0
- badr/oracles/__init__.py +5 -0
- badr/oracles/implicit_oracle.py +125 -0
- badr/oracles/stochastic_oracle.py +385 -0
- badr-0.1.0.dist-info/METADATA +56 -0
- badr-0.1.0.dist-info/RECORD +42 -0
- badr-0.1.0.dist-info/WHEEL +4 -0
- badr-0.1.0.dist-info/licenses/LICENSE +674 -0
badr/__base__.py
ADDED
|
@@ -0,0 +1,126 @@
|
|
|
1
|
+
from badr.algorithms import SLSQP
|
|
2
|
+
from badr.datasets import Dataset
|
|
3
|
+
from badr.metrics import FairnessMetric
|
|
4
|
+
from badr.models import Model
|
|
5
|
+
from badr.oracles import ImplicitOracle, StochasticOracle
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
class Badr:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
dset: Dataset,
|
|
12
|
+
model: Model,
|
|
13
|
+
metric: FairnessMetric,
|
|
14
|
+
train_test: str = "train",
|
|
15
|
+
oracle: str = "implicit",
|
|
16
|
+
solver_cls=None,
|
|
17
|
+
solver_kwargs=None,
|
|
18
|
+
) -> None:
|
|
19
|
+
"""
|
|
20
|
+
Parameters
|
|
21
|
+
----------
|
|
22
|
+
dset : Dataset
|
|
23
|
+
Dataset with (X_train, y_train), (X_test, y_test), and groups.
|
|
24
|
+
model : Model
|
|
25
|
+
Model with set_group_weights(...), fit(...), and coef_/intercept_.
|
|
26
|
+
metric : FairnessMetric
|
|
27
|
+
Metric; bound to `model` if `metric.model is None`.
|
|
28
|
+
train_test : {"train", "test"}, default="train"
|
|
29
|
+
Which split to use.
|
|
30
|
+
oracle : {"implicit", "stochastic"}, default="implicit"
|
|
31
|
+
Oracle implementation to use.
|
|
32
|
+
solver_cls : type, optional
|
|
33
|
+
Solver class (default: SLSQP).
|
|
34
|
+
solver_kwargs : dict, optional
|
|
35
|
+
Keyword args passed to `solver_cls(...)`.
|
|
36
|
+
|
|
37
|
+
Raises
|
|
38
|
+
------
|
|
39
|
+
ValueError
|
|
40
|
+
If `oracle` is not one of {"implicit", "stochastic"}.
|
|
41
|
+
"""
|
|
42
|
+
if metric.model is None:
|
|
43
|
+
metric.set_model(model)
|
|
44
|
+
self.dset = dset
|
|
45
|
+
self.model = model
|
|
46
|
+
self.metric = metric
|
|
47
|
+
self.train_test = train_test
|
|
48
|
+
self.X = dset.X_train if train_test == "train" else dset.X_test
|
|
49
|
+
self.y = dset.y_train if train_test == "train" else dset.y_test
|
|
50
|
+
if oracle == "implicit":
|
|
51
|
+
self.oracle = ImplicitOracle(
|
|
52
|
+
dset=dset,
|
|
53
|
+
model=model,
|
|
54
|
+
metric=metric,
|
|
55
|
+
train_test=train_test,
|
|
56
|
+
)
|
|
57
|
+
elif oracle == "stochastic":
|
|
58
|
+
self.oracle = StochasticOracle(
|
|
59
|
+
dset=dset,
|
|
60
|
+
model=model,
|
|
61
|
+
metric=metric,
|
|
62
|
+
train_test=train_test,
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
raise ValueError(
|
|
66
|
+
f"Unknown oracle type: {oracle}. Supported types are 'implicit' and 'stochastic'."
|
|
67
|
+
)
|
|
68
|
+
self.solver_cls = solver_cls or SLSQP
|
|
69
|
+
self.solver_kwargs = solver_kwargs or {}
|
|
70
|
+
self._solver = None
|
|
71
|
+
|
|
72
|
+
def set_solver(self, solver):
|
|
73
|
+
"""
|
|
74
|
+
Set a solver instance to use in `run`.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
solver
|
|
79
|
+
Solver instance. If `solver.oracle is None`, `run` will set it.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
Badr
|
|
84
|
+
Self.
|
|
85
|
+
"""
|
|
86
|
+
self._solver = solver
|
|
87
|
+
return self
|
|
88
|
+
|
|
89
|
+
def run(self, **run_kwargs) -> None:
|
|
90
|
+
"""
|
|
91
|
+
Run the solver, set group weights, refit the model, and compute outputs.
|
|
92
|
+
|
|
93
|
+
Parameters
|
|
94
|
+
----------
|
|
95
|
+
**run_kwargs
|
|
96
|
+
Passed to `solver.run(**run_kwargs)`.
|
|
97
|
+
|
|
98
|
+
Sets Attributes
|
|
99
|
+
---------------
|
|
100
|
+
group_weights
|
|
101
|
+
Learned group weights.
|
|
102
|
+
coef_
|
|
103
|
+
Fitted coefficients.
|
|
104
|
+
intercept_
|
|
105
|
+
Fitted intercept.
|
|
106
|
+
fairness
|
|
107
|
+
Metric value on the selected split.
|
|
108
|
+
group_losses
|
|
109
|
+
Per-group losses from the model.
|
|
110
|
+
"""
|
|
111
|
+
if self._solver is None: # 1) lazy‐instantiate solver if not given
|
|
112
|
+
self._solver = self.solver_cls(**self.solver_kwargs)
|
|
113
|
+
|
|
114
|
+
# 2) bind oracle if user didn’t already
|
|
115
|
+
if self._solver.oracle is None:
|
|
116
|
+
self._solver.set_oracle(self.oracle)
|
|
117
|
+
solver = self._solver
|
|
118
|
+
solver.run(**run_kwargs)
|
|
119
|
+
self.group_weights = solver.group_weights
|
|
120
|
+
print(f"Group weights: {self.group_weights}")
|
|
121
|
+
self.model.set_group_weights(self.group_weights)
|
|
122
|
+
self.model.fit(self.X, self.y, self.dset.groups)
|
|
123
|
+
self.coef_ = self.model.coef_
|
|
124
|
+
self.intercept_ = self.model.intercept_
|
|
125
|
+
self.fairness = self.metric.fun(self.model.coef_, self.dset, self.train_test)
|
|
126
|
+
self.group_losses = self.model._group_loss(self.dset)
|
badr/__init__.py
ADDED
badr/_version.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
1
|
+
__version__ = "0.1.0"
|
|
@@ -0,0 +1,90 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
|
|
3
|
+
import jax.numpy as jnp
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from badr.oracles import Oracle
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class Trace:
|
|
10
|
+
def __init__(self, f=None, freq=1) -> None:
|
|
11
|
+
self.trace_x = []
|
|
12
|
+
self.trace_time = []
|
|
13
|
+
self.trace_fx = []
|
|
14
|
+
self.start = datetime.now()
|
|
15
|
+
self._counter = 0
|
|
16
|
+
self.freq = int(freq)
|
|
17
|
+
self.f = f
|
|
18
|
+
|
|
19
|
+
def __call__(self, dl):
|
|
20
|
+
if self._counter % self.freq == 0:
|
|
21
|
+
if self.f is not None:
|
|
22
|
+
self.trace_fx.append(float(self.f(jnp.array(dl["x"]))))
|
|
23
|
+
else:
|
|
24
|
+
self.trace_x.append(np.copy(dl["x"]))
|
|
25
|
+
delta = (datetime.now() - self.start).total_seconds()
|
|
26
|
+
self.trace_time.append(delta)
|
|
27
|
+
self._counter += 1
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class Algorithm:
|
|
31
|
+
"""
|
|
32
|
+
Base class for algorithms optimizing group weights.
|
|
33
|
+
|
|
34
|
+
Stores common bookkeeping (objective trace, timings, iterates) and a reference
|
|
35
|
+
to an :class:`~badr.oracles.Oracle`.
|
|
36
|
+
|
|
37
|
+
Parameters
|
|
38
|
+
----------
|
|
39
|
+
name : str
|
|
40
|
+
Display name for the algorithm.
|
|
41
|
+
|
|
42
|
+
Attributes
|
|
43
|
+
----------
|
|
44
|
+
oracle : Oracle or None
|
|
45
|
+
Oracle providing ``fun``/``grad`` (and possibly stochastic primitives).
|
|
46
|
+
n_groups : int
|
|
47
|
+
Number of groups (set from the oracle).
|
|
48
|
+
group_weights : jax.numpy.ndarray
|
|
49
|
+
Latest group-weight iterate (typically on the simplex).
|
|
50
|
+
history_f : list[float]
|
|
51
|
+
Traced objective values (when enabled by the algorithm / ``trace`` flag).
|
|
52
|
+
history_time : list[float]
|
|
53
|
+
Elapsed time trace (seconds).
|
|
54
|
+
history_lambda : list[jax.numpy.ndarray]
|
|
55
|
+
Traced group-weight iterates (when enabled).
|
|
56
|
+
"""
|
|
57
|
+
|
|
58
|
+
def __init__(self, name: str) -> None:
|
|
59
|
+
self.name = name
|
|
60
|
+
self.history_f = []
|
|
61
|
+
self.history_time = []
|
|
62
|
+
self.history_lambda = []
|
|
63
|
+
self._last_callback_time = None
|
|
64
|
+
self.group_weights: jnp.ndarray = jnp.array([])
|
|
65
|
+
self.oracle = None
|
|
66
|
+
|
|
67
|
+
def set_oracle(self, oracle: Oracle) -> None:
|
|
68
|
+
# Set the oracle for the algorithm.
|
|
69
|
+
self.oracle = oracle
|
|
70
|
+
self.n_groups = oracle.n_groups
|
|
71
|
+
|
|
72
|
+
def run(self, max_iter: int = 1, verbose: int = 1, trace: bool = False):
|
|
73
|
+
"""
|
|
74
|
+
Run the algorithm.
|
|
75
|
+
|
|
76
|
+
Parameters
|
|
77
|
+
----------
|
|
78
|
+
max_iter : int, default=1
|
|
79
|
+
Maximum number of iterations.
|
|
80
|
+
verbose : int, default=1
|
|
81
|
+
Verbosity level (interpretation is algorithm-specific).
|
|
82
|
+
trace : bool, default=False
|
|
83
|
+
If True, record per-iteration history (objective/iterates/times).
|
|
84
|
+
|
|
85
|
+
Raises
|
|
86
|
+
------
|
|
87
|
+
NotImplementedError
|
|
88
|
+
If not implemented by a subclass.
|
|
89
|
+
"""
|
|
90
|
+
raise NotImplementedError("Algorithm.run() must be implemented in subclasses.")
|
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
from .__base__ import Algorithm
|
|
2
|
+
from .badr import BADRSGD
|
|
3
|
+
from .frank_wolfe import FrankWolfe
|
|
4
|
+
from .slsqp import SLSQP
|
|
5
|
+
from .trust_constr import TrustConstr
|
|
6
|
+
|
|
7
|
+
__all__ = [
|
|
8
|
+
"Algorithm",
|
|
9
|
+
"BADRSGD",
|
|
10
|
+
"FrankWolfe",
|
|
11
|
+
"SLSQP",
|
|
12
|
+
"TrustConstr",
|
|
13
|
+
]
|
badr/algorithms/badr.py
ADDED
|
@@ -0,0 +1,247 @@
|
|
|
1
|
+
import time
|
|
2
|
+
import jax.numpy as jnp
|
|
3
|
+
from .__base__ import Algorithm
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class BADRSGD(Algorithm):
|
|
7
|
+
"""
|
|
8
|
+
Stochastic BADR updates over (w, v, lambda) with simplex projection.
|
|
9
|
+
|
|
10
|
+
Uses a :class:`~badr.oracles.StochasticOracle`-style interface:
|
|
11
|
+
draws a minibatch, forms a group-weighted inner gradient for ``w``, updates an
|
|
12
|
+
auxiliary vector ``v`` using Hessian-vector products, and updates ``lambda``
|
|
13
|
+
via a clipped step followed by projection onto the simplex.
|
|
14
|
+
|
|
15
|
+
Parameters
|
|
16
|
+
----------
|
|
17
|
+
w0 : jax.numpy.ndarray
|
|
18
|
+
Initial parameter vector for the lower-level variable ``w``.
|
|
19
|
+
batch_size : int, default=1
|
|
20
|
+
Total minibatch size across groups.
|
|
21
|
+
step_w : float, default=1e-1
|
|
22
|
+
Step size for ``w`` and ``v`` updates (used as ``step_1``).
|
|
23
|
+
step_v : float, default=1e-1
|
|
24
|
+
Stored step size for ``v`` (currently not used in the shown code).
|
|
25
|
+
step_lambda : float, default=1.0
|
|
26
|
+
Step size for ``lambda`` update (used as ``step_2``).
|
|
27
|
+
clip_value : float, default=1.0
|
|
28
|
+
L2-norm clipping threshold applied to the ``lambda`` gradient estimate.
|
|
29
|
+
|
|
30
|
+
Attributes
|
|
31
|
+
----------
|
|
32
|
+
primal_solution : jax.numpy.ndarray or None
|
|
33
|
+
Final ``w`` iterate.
|
|
34
|
+
aux_solution : jax.numpy.ndarray or None
|
|
35
|
+
Final ``v`` iterate.
|
|
36
|
+
message : str or None
|
|
37
|
+
Status message after :meth:`run`.
|
|
38
|
+
history_w, history_v, history_lambda : list[jax.numpy.ndarray]
|
|
39
|
+
Iterates recorded during :meth:`run`.
|
|
40
|
+
history_inner_loss_batch : list[float]
|
|
41
|
+
``f_hat(w, lambda; batch)`` over iterations.
|
|
42
|
+
history_outer_metric : list[float]
|
|
43
|
+
Metric value on the full split at current ``w``.
|
|
44
|
+
history_norm_grad_w, history_norm_v, history_norm_jt_v, history_norm_grad_H_w : list[float]
|
|
45
|
+
Diagnostic norms recorded per iteration.
|
|
46
|
+
history_lambda_entropy : list[float]
|
|
47
|
+
Entropy ``-sum(lambda log lambda)`` per iteration.
|
|
48
|
+
history_clip_fraction : list[float]
|
|
49
|
+
Running fraction of iterations where lambda-gradient clipping activated.
|
|
50
|
+
|
|
51
|
+
Notes
|
|
52
|
+
-----
|
|
53
|
+
- Simplex projection uses sorting-based Euclidean projection.
|
|
54
|
+
- Clipping is detected by comparing the pre/post L2 norms of the gradient.
|
|
55
|
+
"""
|
|
56
|
+
|
|
57
|
+
def __init__(
|
|
58
|
+
self,
|
|
59
|
+
w0: jnp.ndarray,
|
|
60
|
+
batch_size: int = 1,
|
|
61
|
+
step_w: float = 1e-1,
|
|
62
|
+
step_v: float = 1e-1,
|
|
63
|
+
step_lambda: float = 1.0,
|
|
64
|
+
clip_value: float = 1.0,
|
|
65
|
+
) -> None:
|
|
66
|
+
super().__init__("BADRSGD")
|
|
67
|
+
self.w0 = jnp.array(w0)
|
|
68
|
+
self.batch_size = batch_size
|
|
69
|
+
self.step_w = step_w
|
|
70
|
+
self.step_v = step_v
|
|
71
|
+
self.step_lambda = step_lambda
|
|
72
|
+
self.clip_value = clip_value
|
|
73
|
+
|
|
74
|
+
# existing histories
|
|
75
|
+
self.history_w = []
|
|
76
|
+
self.history_v = []
|
|
77
|
+
self.history_lambda = []
|
|
78
|
+
self.history_time = []
|
|
79
|
+
self.message = None
|
|
80
|
+
self.primal_solution = None
|
|
81
|
+
self.aux_solution = None
|
|
82
|
+
|
|
83
|
+
# NEW: metric histories
|
|
84
|
+
self.history_inner_loss_batch = []
|
|
85
|
+
self.history_outer_metric = []
|
|
86
|
+
self.history_norm_grad_w = []
|
|
87
|
+
self.history_norm_v = []
|
|
88
|
+
self.history_norm_jt_v = []
|
|
89
|
+
self.history_norm_grad_H_w = []
|
|
90
|
+
self.history_lambda_entropy = []
|
|
91
|
+
self.history_clip_fraction = [] # running fraction in [0,1]
|
|
92
|
+
|
|
93
|
+
@staticmethod
|
|
94
|
+
def _project_simplex(v: jnp.ndarray, radius: float = 1.0) -> jnp.ndarray:
|
|
95
|
+
v = jnp.asarray(v)
|
|
96
|
+
v_sorted = jnp.sort(v)[::-1]
|
|
97
|
+
cssv = jnp.cumsum(v_sorted) - radius
|
|
98
|
+
ind = jnp.arange(1, v.shape[0] + 1)
|
|
99
|
+
cond = v_sorted - cssv / ind > 0
|
|
100
|
+
rho = max(int(jnp.sum(cond)) - 1, 0)
|
|
101
|
+
theta = cssv[rho] / (rho + 1)
|
|
102
|
+
return jnp.maximum(v - theta, 0.0)
|
|
103
|
+
|
|
104
|
+
def _clip(self, g: jnp.ndarray) -> jnp.ndarray:
|
|
105
|
+
norm = jnp.linalg.norm(g)
|
|
106
|
+
factor = jnp.minimum(1.0, self.clip_value / (norm + 1e-12))
|
|
107
|
+
return factor * g
|
|
108
|
+
|
|
109
|
+
def run(self, max_iter: int = 1000, verbose: int = 1, trace: bool = False):
|
|
110
|
+
"""
|
|
111
|
+
Run stochastic BADR iterations.
|
|
112
|
+
|
|
113
|
+
Parameters
|
|
114
|
+
----------
|
|
115
|
+
max_iter : int, default=1000
|
|
116
|
+
Number of iterations.
|
|
117
|
+
verbose : int, default=1
|
|
118
|
+
If > 0, prints a completion message.
|
|
119
|
+
trace : bool, default=False
|
|
120
|
+
If True, also appends the outer metric to ``history_f`` each iteration.
|
|
121
|
+
|
|
122
|
+
Returns
|
|
123
|
+
-------
|
|
124
|
+
jax.numpy.ndarray
|
|
125
|
+
Final group weights ``lambda``.
|
|
126
|
+
|
|
127
|
+
Raises
|
|
128
|
+
------
|
|
129
|
+
ValueError
|
|
130
|
+
If no oracle has been set.
|
|
131
|
+
"""
|
|
132
|
+
if self.oracle is None:
|
|
133
|
+
raise ValueError("Oracle not set. Please set the oracle before running.")
|
|
134
|
+
|
|
135
|
+
w = jnp.array(self.w0)
|
|
136
|
+
v = jnp.zeros_like(w)
|
|
137
|
+
lmbda = jnp.ones(self.oracle.n_groups) / self.oracle.n_groups
|
|
138
|
+
|
|
139
|
+
# reset histories
|
|
140
|
+
self.history_w = [w]
|
|
141
|
+
self.history_v = [v]
|
|
142
|
+
self.history_lambda = [lmbda]
|
|
143
|
+
self.history_time = [0.0]
|
|
144
|
+
if trace:
|
|
145
|
+
self.history_f = [
|
|
146
|
+
float(
|
|
147
|
+
self.oracle.metric.fun(w, self.oracle.dset, self.oracle.train_test)
|
|
148
|
+
)
|
|
149
|
+
]
|
|
150
|
+
else:
|
|
151
|
+
self.history_f = []
|
|
152
|
+
|
|
153
|
+
# also reset metric histories
|
|
154
|
+
self.history_inner_loss_batch = []
|
|
155
|
+
self.history_outer_metric = []
|
|
156
|
+
self.history_norm_grad_w = []
|
|
157
|
+
self.history_norm_v = []
|
|
158
|
+
self.history_norm_jt_v = []
|
|
159
|
+
self.history_norm_grad_H_w = []
|
|
160
|
+
self.history_lambda_entropy = []
|
|
161
|
+
self.history_clip_fraction = []
|
|
162
|
+
|
|
163
|
+
start_t = time.perf_counter()
|
|
164
|
+
step_1 = self.step_w
|
|
165
|
+
step_2 = self.step_lambda
|
|
166
|
+
clipped_count = 0 # running count of clipped λ-updates
|
|
167
|
+
|
|
168
|
+
for t in range(max_iter):
|
|
169
|
+
# --- draw ONE inner batch and reuse it for all inner terms (SOBA)
|
|
170
|
+
batch, self.oracle.key = self.oracle.sample_batch(
|
|
171
|
+
self.oracle.key, self.batch_size
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
# 1) inner gradient g_w = ∇_w f_inner(w, λ; batch)
|
|
175
|
+
grad_F = self.oracle.grad_lower_groups(w, batch) # shape (S, d)
|
|
176
|
+
grad_w = grad_F.T @ lmbda # shape (d,)
|
|
177
|
+
w_next = w - step_1 * grad_w
|
|
178
|
+
|
|
179
|
+
# 2) outer gradient pieces (metric)
|
|
180
|
+
grad_H_w, grad_H_lambda = self.oracle.grad_upper(w, lmbda)
|
|
181
|
+
|
|
182
|
+
# 3) v-update: v ← v - ρ( H_{ww} v + ∇_w H )
|
|
183
|
+
hvp_weighted = self.oracle.hvp_w_f_hat(w, lmbda, v, batch) # shape (d,)
|
|
184
|
+
v_next = v - step_1 * (grad_H_w + hvp_weighted)
|
|
185
|
+
|
|
186
|
+
# 4) λ-update: λ ← Proj_Δ( λ - γ * Clip( J^T v + ∇_λ H ) )
|
|
187
|
+
jt_v = self.oracle.jt_v_lambda_of_grad_w_f_hat(w, v, batch) # shape (S,)
|
|
188
|
+
lambda_grad_est = jt_v + grad_H_lambda
|
|
189
|
+
# detect clipping (norm reduced)
|
|
190
|
+
pre_norm = float(jnp.linalg.norm(lambda_grad_est))
|
|
191
|
+
clipped = self._clip(lambda_grad_est)
|
|
192
|
+
post_norm = float(jnp.linalg.norm(clipped))
|
|
193
|
+
did_clip = 1.0 if post_norm + 1e-12 < pre_norm else 0.0
|
|
194
|
+
clipped_count += did_clip
|
|
195
|
+
|
|
196
|
+
lambda_candidate = lmbda - step_2 * clipped
|
|
197
|
+
lambda_next = self._project_simplex(lambda_candidate)
|
|
198
|
+
|
|
199
|
+
# step
|
|
200
|
+
w, v, lmbda = w_next, v_next, lambda_next
|
|
201
|
+
|
|
202
|
+
# --- logging ---
|
|
203
|
+
# inner loss on the sampled batch
|
|
204
|
+
inner_loss_batch = float(self.oracle.f_hat(w, lmbda, batch))
|
|
205
|
+
# outer metric (current w on full data; same as your trace metric)
|
|
206
|
+
outer_metric_val = float(
|
|
207
|
+
self.oracle.metric.fun(w, self.oracle.dset, self.oracle.train_test)
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# norms
|
|
211
|
+
norm_grad_w = float(jnp.linalg.norm(grad_w))
|
|
212
|
+
norm_v = float(jnp.linalg.norm(v))
|
|
213
|
+
norm_jt_v = float(jnp.linalg.norm(jt_v))
|
|
214
|
+
norm_grad_H_w = float(jnp.linalg.norm(grad_H_w))
|
|
215
|
+
|
|
216
|
+
# λ entropy
|
|
217
|
+
lambda_entropy = float(-jnp.sum(lmbda * jnp.log(lmbda + 1e-12)))
|
|
218
|
+
|
|
219
|
+
# fraction of clipped λ-updates up to now
|
|
220
|
+
clip_fraction = clipped_count / float(t + 1)
|
|
221
|
+
|
|
222
|
+
# bookkeeping
|
|
223
|
+
elapsed = time.perf_counter() - start_t
|
|
224
|
+
self.history_w.append(w)
|
|
225
|
+
self.history_v.append(v)
|
|
226
|
+
self.history_lambda.append(lmbda)
|
|
227
|
+
self.history_time.append(elapsed)
|
|
228
|
+
if trace:
|
|
229
|
+
self.history_f.append(outer_metric_val)
|
|
230
|
+
|
|
231
|
+
# store metrics
|
|
232
|
+
self.history_inner_loss_batch.append(inner_loss_batch)
|
|
233
|
+
self.history_outer_metric.append(outer_metric_val)
|
|
234
|
+
self.history_norm_grad_w.append(norm_grad_w)
|
|
235
|
+
self.history_norm_v.append(norm_v)
|
|
236
|
+
self.history_norm_jt_v.append(norm_jt_v)
|
|
237
|
+
self.history_norm_grad_H_w.append(norm_grad_H_w)
|
|
238
|
+
self.history_lambda_entropy.append(lambda_entropy)
|
|
239
|
+
self.history_clip_fraction.append(clip_fraction)
|
|
240
|
+
|
|
241
|
+
self.primal_solution = w
|
|
242
|
+
self.aux_solution = v
|
|
243
|
+
self.group_weights = lmbda
|
|
244
|
+
self.message = f"Completed {max_iter} (stochastic) iterations."
|
|
245
|
+
if verbose > 0:
|
|
246
|
+
print("[BADR] Message:", self.message)
|
|
247
|
+
return self.group_weights
|
|
@@ -0,0 +1,137 @@
|
|
|
1
|
+
from datetime import datetime
|
|
2
|
+
from functools import partial
|
|
3
|
+
|
|
4
|
+
import jax.numpy as jnp
|
|
5
|
+
from jax import jit
|
|
6
|
+
|
|
7
|
+
from .__base__ import Algorithm
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class FrankWolfe(Algorithm):
|
|
11
|
+
"""
|
|
12
|
+
Frank--Wolfe on the simplex using a linear minimization oracle.
|
|
13
|
+
|
|
14
|
+
Uses the oracle gradient ``g = oracle.grad(x)`` and the simplex LMO that
|
|
15
|
+
returns the vertex at ``argmin_i g_i``. The update uses a diminishing step
|
|
16
|
+
size given in :meth:`_step`. Convergence is checked with the FW gap
|
|
17
|
+
``g_t = -g^T (x - x_prev)``.
|
|
18
|
+
|
|
19
|
+
Parameters
|
|
20
|
+
----------
|
|
21
|
+
starting_point : jax.numpy.ndarray or None, optional
|
|
22
|
+
Initial point on the simplex. If None, uses the uniform distribution.
|
|
23
|
+
eps : float, default=1e-6
|
|
24
|
+
Stopping threshold on the FW gap.
|
|
25
|
+
|
|
26
|
+
Attributes
|
|
27
|
+
----------
|
|
28
|
+
iterates : list[jax.numpy.ndarray]
|
|
29
|
+
Stored iterates (always filled).
|
|
30
|
+
success : bool
|
|
31
|
+
True if the stopping condition was reached.
|
|
32
|
+
message : str or None
|
|
33
|
+
Status message after :meth:`run`.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
def __init__(self, starting_point=None, eps=1e-6):
|
|
37
|
+
super().__init__("Frank-Wolfe")
|
|
38
|
+
self.oracle = None
|
|
39
|
+
self.eps = eps
|
|
40
|
+
self.starting_point = starting_point
|
|
41
|
+
self.success = False
|
|
42
|
+
self.iterates = []
|
|
43
|
+
self.history_x = []
|
|
44
|
+
self.history_time = []
|
|
45
|
+
self.history_fx = []
|
|
46
|
+
self.message = None
|
|
47
|
+
|
|
48
|
+
@staticmethod
|
|
49
|
+
def _lmo(grad: jnp.ndarray) -> jnp.ndarray:
|
|
50
|
+
v = jnp.zeros_like(grad)
|
|
51
|
+
idx = jnp.argmin(grad)
|
|
52
|
+
return v.at[idx].set(1.0)
|
|
53
|
+
|
|
54
|
+
@staticmethod
|
|
55
|
+
@partial(jit, static_argnums=(2,))
|
|
56
|
+
def _step(x: jnp.ndarray, g: jnp.ndarray, iteration: int) -> jnp.ndarray:
|
|
57
|
+
v = FrankWolfe._lmo(g)
|
|
58
|
+
d = v - x
|
|
59
|
+
alpha = jnp.minimum(
|
|
60
|
+
1.0,
|
|
61
|
+
(2.0 + jnp.log(iteration + 1)) / (iteration + 2 + jnp.log(iteration + 1)),
|
|
62
|
+
)
|
|
63
|
+
return x + alpha * d
|
|
64
|
+
|
|
65
|
+
def run(self, max_iter: int = 300, verbose: int = 1, trace: bool = False):
|
|
66
|
+
"""
|
|
67
|
+
Run Frank--Wolfe iterations.
|
|
68
|
+
|
|
69
|
+
Parameters
|
|
70
|
+
----------
|
|
71
|
+
max_iter : int, default=300
|
|
72
|
+
Maximum number of iterations.
|
|
73
|
+
verbose : int, default=1
|
|
74
|
+
If > 0, prints success and message.
|
|
75
|
+
trace : bool, default=False
|
|
76
|
+
If True, records iterates and elapsed time.
|
|
77
|
+
|
|
78
|
+
Returns
|
|
79
|
+
-------
|
|
80
|
+
jax.numpy.ndarray
|
|
81
|
+
Final group weights.
|
|
82
|
+
|
|
83
|
+
Raises
|
|
84
|
+
------
|
|
85
|
+
ValueError
|
|
86
|
+
If no oracle has been set.
|
|
87
|
+
"""
|
|
88
|
+
self.group_weights = jnp.zeros(self.n_groups)
|
|
89
|
+
if self.starting_point is None:
|
|
90
|
+
self.starting_point = jnp.full(self.n_groups, 1.0 / self.n_groups)
|
|
91
|
+
else:
|
|
92
|
+
self.starting_point = self.starting_point
|
|
93
|
+
if self.oracle is None:
|
|
94
|
+
raise ValueError("Oracle not set. Please set the oracle before running.")
|
|
95
|
+
if trace:
|
|
96
|
+
self.start = datetime.now()
|
|
97
|
+
|
|
98
|
+
x = self.starting_point
|
|
99
|
+
self.iterates = [x]
|
|
100
|
+
|
|
101
|
+
for it in range(max_iter):
|
|
102
|
+
g = self.oracle.grad(x)
|
|
103
|
+
|
|
104
|
+
x = FrankWolfe._step(x, g, it)
|
|
105
|
+
|
|
106
|
+
self.iterates.append(x)
|
|
107
|
+
|
|
108
|
+
if trace:
|
|
109
|
+
now = datetime.now()
|
|
110
|
+
self.history_time.append((now - self.start).total_seconds())
|
|
111
|
+
self.history_x.append(x)
|
|
112
|
+
|
|
113
|
+
g_t = jnp.dot(-g, x - self.iterates[-2])
|
|
114
|
+
if g_t < self.eps:
|
|
115
|
+
self.success = True
|
|
116
|
+
break
|
|
117
|
+
|
|
118
|
+
self.group_weights = x
|
|
119
|
+
self.message = f"Converged in {len(self.iterates) - 1} iterations (over {max_iter} iterations)."
|
|
120
|
+
if verbose > 0:
|
|
121
|
+
print(f"[FW] Success: {self.success}")
|
|
122
|
+
print(f"[FW] Message: {self.message}")
|
|
123
|
+
return self.group_weights
|
|
124
|
+
|
|
125
|
+
def postprocess(self):
|
|
126
|
+
"""
|
|
127
|
+
Populate ``history_f`` from the stored ``history_x``.
|
|
128
|
+
|
|
129
|
+
Raises
|
|
130
|
+
------
|
|
131
|
+
ValueError
|
|
132
|
+
If no oracle has been set.
|
|
133
|
+
"""
|
|
134
|
+
if self.oracle is None:
|
|
135
|
+
raise ValueError("Oracle not set. Please set the oracle before running.")
|
|
136
|
+
for x in self.history_x:
|
|
137
|
+
self.history_f.append(self.oracle.fun(x))
|