invrs-opt 0.6.0__py3-none-any.whl → 0.7.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +12 -5
- invrs_opt/experimental/client.py +1 -1
- invrs_opt/{base.py → optimizers/base.py} +9 -3
- invrs_opt/{lbfgsb → optimizers}/lbfgsb.py +293 -152
- invrs_opt/optimizers/wrapped_optax.py +300 -0
- invrs_opt/parameterization/base.py +148 -0
- invrs_opt/parameterization/filter_project.py +92 -0
- invrs_opt/parameterization/gaussian_levelset.py +643 -0
- invrs_opt/parameterization/pixel.py +45 -0
- invrs_opt/{transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.7.1.dist-info/LICENSE +504 -0
- invrs_opt-0.7.1.dist-info/METADATA +559 -0
- invrs_opt-0.7.1.dist-info/RECORD +20 -0
- {invrs_opt-0.6.0.dist-info → invrs_opt-0.7.1.dist-info}/WHEEL +1 -1
- invrs_opt/wrapped_optax/wrapped_optax.py +0 -150
- invrs_opt-0.6.0.dist-info/LICENSE +0 -21
- invrs_opt-0.6.0.dist-info/METADATA +0 -76
- invrs_opt-0.6.0.dist-info/RECORD +0 -16
- /invrs_opt/{lbfgsb → optimizers}/__init__.py +0 -0
- /invrs_opt/{wrapped_optax → parameterization}/__init__.py +0 -0
- {invrs_opt-0.6.0.dist-info → invrs_opt-0.7.1.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -3,12 +3,19 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.
|
6
|
+
__version__ = "v0.7.1"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt
|
10
|
-
|
11
|
-
from invrs_opt.
|
9
|
+
from invrs_opt import parameterization as parameterization
|
10
|
+
|
11
|
+
from invrs_opt.optimizers.lbfgsb import (
|
12
|
+
density_lbfgsb as density_lbfgsb,
|
13
|
+
lbfgsb as lbfgsb,
|
14
|
+
levelset_lbfgsb as levelset_lbfgsb,
|
15
|
+
)
|
16
|
+
|
17
|
+
from invrs_opt.optimizers.wrapped_optax import (
|
12
18
|
density_wrapped_optax as density_wrapped_optax,
|
19
|
+
levelset_wrapped_optax as levelset_wrapped_optax,
|
20
|
+
wrapped_optax as wrapped_optax,
|
13
21
|
)
|
14
|
-
from invrs_opt.wrapped_optax.wrapped_optax import wrapped_optax as wrapped_optax
|
invrs_opt/experimental/client.py
CHANGED
@@ -4,6 +4,7 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import dataclasses
|
7
|
+
import inspect
|
7
8
|
from typing import Any, Protocol
|
8
9
|
|
9
10
|
import optax # type: ignore[import-untyped]
|
@@ -49,6 +50,11 @@ class Optimizer:
|
|
49
50
|
update: UpdateFn
|
50
51
|
|
51
52
|
|
52
|
-
#
|
53
|
-
|
54
|
-
|
53
|
+
# Register all optax state types for serialization.
|
54
|
+
optax_types = {}
|
55
|
+
for name, obj in inspect.getmembers(optax):
|
56
|
+
if name.endswith("State") and isinstance(obj, type):
|
57
|
+
optax_types[obj] = True
|
58
|
+
|
59
|
+
for obj in optax_types.keys():
|
60
|
+
json_utils.register_custom_type(obj)
|
@@ -5,18 +5,25 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
5
5
|
|
6
6
|
import copy
|
7
7
|
import dataclasses
|
8
|
-
from typing import Any,
|
8
|
+
from typing import Any, Dict, Optional, Sequence, Tuple, Union
|
9
9
|
|
10
10
|
import jax
|
11
11
|
import jax.numpy as jnp
|
12
12
|
import numpy as onp
|
13
|
+
import optax # type: ignore[import-untyped]
|
13
14
|
from jax import flatten_util, tree_util
|
14
15
|
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
15
16
|
_lbfgsb as scipy_lbfgsb,
|
16
17
|
)
|
17
18
|
from totypes import types
|
18
19
|
|
19
|
-
from invrs_opt import base
|
20
|
+
from invrs_opt.optimizers import base
|
21
|
+
from invrs_opt.parameterization import (
|
22
|
+
base as parameterization_base,
|
23
|
+
filter_project,
|
24
|
+
gaussian_levelset,
|
25
|
+
pixel,
|
26
|
+
)
|
20
27
|
|
21
28
|
NDArray = onp.ndarray[Any, Any]
|
22
29
|
PyTree = Any
|
@@ -34,10 +41,10 @@ UPDATE_IPRINT = -1
|
|
34
41
|
|
35
42
|
# Maximum value for the `maxcor` parameter in the L-BFGS-B scheme.
|
36
43
|
MAXCOR_MAX_VALUE = 100
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
44
|
+
DEFAULT_MAXCOR = 20
|
45
|
+
DEFAULT_LINE_SEARCH_MAX_STEPS = 100
|
46
|
+
DEFAULT_FTOL = 0.0
|
47
|
+
DEFAULT_GTOL = 0.0
|
41
48
|
|
42
49
|
# Maps bound scenarios to integers.
|
43
50
|
BOUNDS_MAP: Dict[Tuple[bool, bool], int] = {
|
@@ -51,175 +58,225 @@ FORTRAN_INT = scipy_lbfgsb.types.intvar.dtype
|
|
51
58
|
|
52
59
|
|
53
60
|
def lbfgsb(
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
61
|
+
*,
|
62
|
+
maxcor: int = DEFAULT_MAXCOR,
|
63
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
64
|
+
ftol: float = DEFAULT_FTOL,
|
65
|
+
gtol: float = DEFAULT_GTOL,
|
58
66
|
) -> base.Optimizer:
|
59
|
-
"""
|
67
|
+
"""Optimizer implementing the standard L-BFGS-B algorithm.
|
60
68
|
|
61
|
-
|
62
|
-
|
63
|
-
|
69
|
+
The standard L-BFGS-B algorithm uses the direct pixel parameterization for density
|
70
|
+
arrays, which simply enforces that values are between the declared upper and lower
|
71
|
+
bounds of the density.
|
64
72
|
|
65
|
-
|
66
|
-
|
67
|
-
def fn(x):
|
68
|
-
leaves_sum_sq = [jnp.sum(y)**2 for y in tree_util.tree_leaves(x)]
|
69
|
-
return jnp.sum(jnp.asarray(leaves_sum_sq))
|
70
|
-
|
71
|
-
x0 = {
|
72
|
-
"a": jnp.ones((3,)),
|
73
|
-
"b": BoundedArray(
|
74
|
-
value=-jnp.ones((2, 5)),
|
75
|
-
lower_bound=-5,
|
76
|
-
upper_bound=5,
|
77
|
-
),
|
78
|
-
}
|
79
|
-
opt = lbfgsb(maxcor=20, line_search_max_steps=100)
|
80
|
-
state = opt.init(x0)
|
81
|
-
for _ in range(10):
|
82
|
-
x = opt.params(state)
|
83
|
-
value, grad = jax.value_and_grad(fn)(x)
|
84
|
-
state = opt.update(grad, value, state)
|
85
|
-
|
86
|
-
While the algorithm can work with pytrees of jax arrays, numpy arrays can
|
87
|
-
also be used. Thus, e.g. the optimizer can directly be used with autograd.
|
88
|
-
|
89
|
-
When the optimization has converged (according to `ftol` or `gtol` criteria), the
|
90
|
-
optimizer simply returns the parameters which obtained the converged result. The
|
73
|
+
When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
|
74
|
+
the optimizer `params` function will simply return the optimal parameters. The
|
91
75
|
convergence can be queried by `is_converged(state)`.
|
92
76
|
|
93
77
|
Args:
|
94
|
-
maxcor: The maximum number of variable metric corrections used to define
|
95
|
-
|
78
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
79
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
96
80
|
line_search_max_steps: The maximum number of steps in the line search.
|
97
|
-
ftol:
|
98
|
-
|
99
|
-
gtol:
|
81
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
82
|
+
for details.
|
83
|
+
gtol: Convergence criteria based on gradient.
|
100
84
|
|
101
85
|
Returns:
|
102
|
-
The `
|
86
|
+
The `Optimizer` implementing the L-BFGS-B optimizer.
|
103
87
|
"""
|
104
|
-
return
|
88
|
+
return parameterized_lbfgsb(
|
89
|
+
density_parameterization=None,
|
90
|
+
penalty=0.0,
|
105
91
|
maxcor=maxcor,
|
106
92
|
line_search_max_steps=line_search_max_steps,
|
107
93
|
ftol=ftol,
|
108
94
|
gtol=gtol,
|
109
|
-
transform_fn=lambda x: x,
|
110
|
-
initialize_latent_fn=lambda x: x,
|
111
95
|
)
|
112
96
|
|
113
97
|
|
114
98
|
def density_lbfgsb(
|
99
|
+
*,
|
115
100
|
beta: float,
|
116
|
-
maxcor: int =
|
117
|
-
line_search_max_steps: int =
|
118
|
-
ftol: float =
|
119
|
-
gtol: float =
|
101
|
+
maxcor: int = DEFAULT_MAXCOR,
|
102
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
103
|
+
ftol: float = DEFAULT_FTOL,
|
104
|
+
gtol: float = DEFAULT_GTOL,
|
120
105
|
) -> base.Optimizer:
|
121
|
-
"""
|
122
|
-
|
123
|
-
Parameters that are of type `DensityArray2D` are represented as latent parameters
|
124
|
-
that are transformed (in the case where lower and upper bounds are `(-1, 1)`) by,
|
106
|
+
"""Optimizer using L-BFGS-B algorithm with filter-project density parameterization.
|
125
107
|
|
126
|
-
|
108
|
+
In the filter-project density parameterization, the optimization variable
|
109
|
+
associated with a density array is a latent density array; the density is obtained
|
110
|
+
by convolving (i.e. "filtering") the latent density with a Gaussian kernel having
|
111
|
+
full-width at half-maximum equal to the length scale (the mean of declared minimum
|
112
|
+
width and minimum spacing). Then, a tanh nonlinearity is used as a smooth threshold
|
113
|
+
operation ("projection").
|
127
114
|
|
128
|
-
|
129
|
-
|
130
|
-
density is scaled before the transform is applied, and then unscaled afterwards.
|
131
|
-
|
132
|
-
When the optimization has converged (according to `ftol` or `gtol` criteria), the
|
133
|
-
optimizer simply returns the parameters which obtained the converged result. The
|
115
|
+
When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
|
116
|
+
the optimizer `params` function will simply return the optimal parameters. The
|
134
117
|
convergence can be queried by `is_converged(state)`.
|
135
118
|
|
136
119
|
Args:
|
137
|
-
beta: Determines the
|
138
|
-
maxcor: The maximum number of variable metric corrections used to define
|
139
|
-
|
120
|
+
beta: Determines the sharpness of the thresholding operation.
|
121
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
122
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
140
123
|
line_search_max_steps: The maximum number of steps in the line search.
|
141
|
-
ftol:
|
142
|
-
|
143
|
-
gtol:
|
124
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
125
|
+
for details.
|
126
|
+
gtol: Convergence criteria based on gradient.
|
144
127
|
|
145
128
|
Returns:
|
146
|
-
The `
|
129
|
+
The `Optimizer` implementing the L-BFGS-B optimizer.
|
147
130
|
"""
|
131
|
+
return parameterized_lbfgsb(
|
132
|
+
density_parameterization=filter_project.filter_project(beta=beta),
|
133
|
+
penalty=0.0,
|
134
|
+
maxcor=maxcor,
|
135
|
+
line_search_max_steps=line_search_max_steps,
|
136
|
+
ftol=ftol,
|
137
|
+
gtol=gtol,
|
138
|
+
)
|
148
139
|
|
149
|
-
def transform_fn(tree: PyTree) -> PyTree:
|
150
|
-
return tree_util.tree_map(
|
151
|
-
lambda x: transform_density(x) if _is_density(x) else x,
|
152
|
-
tree,
|
153
|
-
is_leaf=_is_density,
|
154
|
-
)
|
155
140
|
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
141
|
+
def levelset_lbfgsb(
|
142
|
+
*,
|
143
|
+
penalty: float,
|
144
|
+
length_scale_spacing_factor: float = (
|
145
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_SPACING_FACTOR
|
146
|
+
),
|
147
|
+
length_scale_fwhm_factor: float = (
|
148
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_FWHM_FACTOR
|
149
|
+
),
|
150
|
+
length_scale_constraint_factor: float = (
|
151
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_FACTOR
|
152
|
+
),
|
153
|
+
smoothing_factor: int = gaussian_levelset.DEFAULT_SMOOTHING_FACTOR,
|
154
|
+
length_scale_constraint_beta: float = (
|
155
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_BETA
|
156
|
+
),
|
157
|
+
length_scale_constraint_weight: float = (
|
158
|
+
gaussian_levelset.DEFAULT_LENGTH_SCALE_CONSTRAINT_WEIGHT
|
159
|
+
),
|
160
|
+
curvature_constraint_weight: float = (
|
161
|
+
gaussian_levelset.DEFAULT_CURVATURE_CONSTRAINT_WEIGHT
|
162
|
+
),
|
163
|
+
fixed_pixel_constraint_weight: float = (
|
164
|
+
gaussian_levelset.DEFAULT_FIXED_PIXEL_CONSTRAINT_WEIGHT
|
165
|
+
),
|
166
|
+
init_optimizer: optax.GradientTransformation = (
|
167
|
+
gaussian_levelset.DEFAULT_INIT_OPTIMIZER
|
168
|
+
),
|
169
|
+
init_steps: int = gaussian_levelset.DEFAULT_INIT_STEPS,
|
170
|
+
maxcor: int = DEFAULT_MAXCOR,
|
171
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
172
|
+
ftol: float = DEFAULT_FTOL,
|
173
|
+
gtol: float = DEFAULT_GTOL,
|
174
|
+
) -> base.Optimizer:
|
175
|
+
"""Optimizer using L-BFGS-B algorithm with levelset density parameterization.
|
176
|
+
|
177
|
+
In the levelset parameterization, the optimization variable associated with a
|
178
|
+
density array is an array giving the amplitudes of Gaussian radial basis functions
|
179
|
+
that represent a levelset function over the domain of the density. In the levelset
|
180
|
+
parameterization, gradients are nonzero only at the edges of features, and in
|
181
|
+
general the topology of a solution does not change during the course of
|
182
|
+
optimization.
|
183
|
+
|
184
|
+
The spacing and full-width at half-maximum of the Gaussian basis functions gives
|
185
|
+
some amount of control over length scales. In addition, constraints associated with
|
186
|
+
length scale, radius of curvature, and deviation from fixed pixels are
|
187
|
+
automatically computed and penalized with a weight given by `penalty`. In general,
|
188
|
+
this helps ensure that features in an optimized density array violate the specified
|
189
|
+
constraints to a lesser degree. The constraints are based on "Analytical level set
|
190
|
+
fabrication constraints for inverse design," by D. Vercruysse et al. (2019).
|
191
|
+
|
192
|
+
When an optimization is determined to have converged (by `ftol` or `gtol` criteria)
|
193
|
+
the optimizer `params` function will simply return the optimal parameters. The
|
194
|
+
convergence can be queried by `is_converged(state)`.
|
162
195
|
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
196
|
+
Args:
|
197
|
+
penalty: The weight of the fabrication penalty, which combines length scale,
|
198
|
+
curvature, and fixed pixel constraints.
|
199
|
+
length_scale_spacing_factor: The number of levelset control points per unit of
|
200
|
+
minimum length scale (mean of density minimum width and minimum spacing).
|
201
|
+
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
202
|
+
the minimum length scale.
|
203
|
+
length_scale_constraint_factor: Multiplies the target length scale in the
|
204
|
+
levelset constraints. A value greater than 1 is pessimistic and drives the
|
205
|
+
solution to have a larger length scale (relative to smaller values).
|
206
|
+
smoothing_factor: For values greater than 1, the density is initially computed
|
207
|
+
at higher resolution and then downsampled, yielding smoother geometries.
|
208
|
+
length_scale_constraint_beta: Controls relaxation of the length scale
|
209
|
+
constraint near the zero level.
|
210
|
+
length_scale_constraint_weight: The weight of the length scale constraint in
|
211
|
+
the overall fabrication constraint peenalty.
|
212
|
+
curvature_constraint_weight: The weight of the curvature constraint.
|
213
|
+
fixed_pixel_constraint_weight: The weight of the fixed pixel constraint.
|
214
|
+
init_optimizer: The optimizer used in the initialization of the levelset
|
215
|
+
parameterization. At initialization, the latent parameters are optimized so
|
216
|
+
that the initial parameters match the binarized initial density.
|
217
|
+
init_steps: The number of optimization steps used in the initialization.
|
218
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
219
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
220
|
+
line_search_max_steps: The maximum number of steps in the line search.
|
221
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
222
|
+
for details.
|
223
|
+
gtol: Convergence criteria based on gradient.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
The `Optimizer` implementing the L-BFGS-B optimizer.
|
227
|
+
"""
|
228
|
+
return parameterized_lbfgsb(
|
229
|
+
density_parameterization=gaussian_levelset.gaussian_levelset(
|
230
|
+
length_scale_spacing_factor=length_scale_spacing_factor,
|
231
|
+
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
232
|
+
length_scale_constraint_factor=length_scale_constraint_factor,
|
233
|
+
smoothing_factor=smoothing_factor,
|
234
|
+
length_scale_constraint_beta=length_scale_constraint_beta,
|
235
|
+
length_scale_constraint_weight=length_scale_constraint_weight,
|
236
|
+
curvature_constraint_weight=curvature_constraint_weight,
|
237
|
+
fixed_pixel_constraint_weight=fixed_pixel_constraint_weight,
|
238
|
+
init_optimizer=init_optimizer,
|
239
|
+
init_steps=init_steps,
|
240
|
+
),
|
241
|
+
penalty=penalty,
|
184
242
|
maxcor=maxcor,
|
185
243
|
line_search_max_steps=line_search_max_steps,
|
186
244
|
ftol=ftol,
|
187
245
|
gtol=gtol,
|
188
|
-
transform_fn=transform_fn,
|
189
|
-
initialize_latent_fn=initialize_latent_fn,
|
190
246
|
)
|
191
247
|
|
192
248
|
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
ftol: float,
|
197
|
-
gtol: float,
|
198
|
-
transform_fn: Callable[[PyTree], PyTree],
|
199
|
-
initialize_latent_fn: Callable[[PyTree], PyTree],
|
200
|
-
) -> base.Optimizer:
|
201
|
-
"""Construct an latent parameter L-BFGS-B optimizer.
|
249
|
+
# -----------------------------------------------------------------------------
|
250
|
+
# Base parameterized L-BFGS-B optimizer.
|
251
|
+
# -----------------------------------------------------------------------------
|
202
252
|
|
203
|
-
The optimized parameters are termed latent parameters, from which the
|
204
|
-
actual parameters returned by the optimizer are obtained using the
|
205
|
-
`transform_fn`. In the simple case where this is just `lambda x: x` (i.e.
|
206
|
-
the identity), this is equivalent to the standard L-BFGS-B algorithm.
|
207
253
|
|
208
|
-
|
209
|
-
|
210
|
-
|
254
|
+
def parameterized_lbfgsb(
|
255
|
+
density_parameterization: Optional[parameterization_base.Density2DParameterization],
|
256
|
+
penalty: float,
|
257
|
+
maxcor: int = DEFAULT_MAXCOR,
|
258
|
+
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
259
|
+
ftol: float = DEFAULT_FTOL,
|
260
|
+
gtol: float = DEFAULT_GTOL,
|
261
|
+
) -> base.Optimizer:
|
262
|
+
"""Optimizer using L-BFGS-B optimizer with specified density parameterization.
|
263
|
+
|
264
|
+
This optimizer wraps scipy's implementation of the algorithm, and provides
|
265
|
+
a jax-style API to the scheme. The optimizer works with custom types such
|
266
|
+
as the `BoundedArray` to constrain the optimization variable.
|
211
267
|
|
212
268
|
Args:
|
213
|
-
|
214
|
-
the
|
269
|
+
density_parameterization: The parameterization to be used, or `None`. When no
|
270
|
+
parameterization is given, the direct pixel parameterization is used for
|
271
|
+
density arrays.
|
272
|
+
penalty: The weight of the scalar penalty formed from the constraints of the
|
273
|
+
parameterization.
|
274
|
+
maxcor: The maximum number of variable metric corrections used to define the
|
275
|
+
limited memory matrix, in the L-BFGS-B scheme.
|
215
276
|
line_search_max_steps: The maximum number of steps in the line search.
|
216
|
-
ftol:
|
217
|
-
|
218
|
-
gtol:
|
219
|
-
transform_fn: Function which transforms the internal latent parameters to
|
220
|
-
the parameters returned by the optimizer.
|
221
|
-
initialize_latent_fn: Function which computes the initial latent parameters
|
222
|
-
given the initial parameters.
|
277
|
+
ftol: Convergence criteria based on function values. See scipy documentation
|
278
|
+
for details.
|
279
|
+
gtol: Convergence criteria based on gradient.
|
223
280
|
|
224
281
|
Returns:
|
225
282
|
The `base.Optimizer`.
|
@@ -236,33 +293,73 @@ def transformed_lbfgsb(
|
|
236
293
|
f"{line_search_max_steps}"
|
237
294
|
)
|
238
295
|
|
296
|
+
if density_parameterization is None:
|
297
|
+
density_parameterization = pixel.pixel()
|
298
|
+
|
299
|
+
def _init_latents(params: PyTree) -> PyTree:
|
300
|
+
def _leaf_init_latents(leaf: Any) -> Any:
|
301
|
+
leaf = _clip(leaf)
|
302
|
+
if not _is_density(leaf) or density_parameterization is None:
|
303
|
+
return leaf
|
304
|
+
return density_parameterization.from_density(leaf)
|
305
|
+
|
306
|
+
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
307
|
+
|
308
|
+
def _params_from_latents(latent_params: PyTree) -> PyTree:
|
309
|
+
def _leaf_params_from_latents(leaf: Any) -> Any:
|
310
|
+
if not _is_parameterized_density(leaf) or density_parameterization is None:
|
311
|
+
return leaf
|
312
|
+
return density_parameterization.to_density(leaf)
|
313
|
+
|
314
|
+
return tree_util.tree_map(
|
315
|
+
_leaf_params_from_latents,
|
316
|
+
latent_params,
|
317
|
+
is_leaf=_is_parameterized_density,
|
318
|
+
)
|
319
|
+
|
320
|
+
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
321
|
+
def _constraint_loss_leaf(
|
322
|
+
params: parameterization_base.ParameterizedDensity2DArrayBase,
|
323
|
+
) -> jnp.ndarray:
|
324
|
+
constraints = density_parameterization.constraints(params)
|
325
|
+
constraints = tree_util.tree_map(
|
326
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0)),
|
327
|
+
constraints,
|
328
|
+
)
|
329
|
+
return jnp.sum(jnp.asarray(constraints))
|
330
|
+
|
331
|
+
losses = [0.0] + [
|
332
|
+
_constraint_loss_leaf(p)
|
333
|
+
for p in tree_util.tree_leaves(
|
334
|
+
latent_params, is_leaf=_is_parameterized_density
|
335
|
+
)
|
336
|
+
if _is_parameterized_density(p)
|
337
|
+
]
|
338
|
+
return penalty * jnp.sum(jnp.asarray(losses))
|
339
|
+
|
239
340
|
def init_fn(params: PyTree) -> LbfgsbState:
|
240
341
|
"""Initializes the optimization state."""
|
241
342
|
|
242
|
-
def
|
243
|
-
lower_bound = types.extract_lower_bound(
|
244
|
-
upper_bound = types.extract_upper_bound(
|
343
|
+
def _init_state_pure(latent_params: PyTree) -> Tuple[PyTree, JaxLbfgsbDict]:
|
344
|
+
lower_bound = types.extract_lower_bound(latent_params)
|
345
|
+
upper_bound = types.extract_upper_bound(latent_params)
|
245
346
|
scipy_lbfgsb_state = ScipyLbfgsbState.init(
|
246
|
-
x0=_to_numpy(
|
247
|
-
lower_bound=_bound_for_params(lower_bound,
|
248
|
-
upper_bound=_bound_for_params(upper_bound,
|
347
|
+
x0=_to_numpy(latent_params),
|
348
|
+
lower_bound=_bound_for_params(lower_bound, latent_params),
|
349
|
+
upper_bound=_bound_for_params(upper_bound, latent_params),
|
249
350
|
maxcor=maxcor,
|
250
351
|
line_search_max_steps=line_search_max_steps,
|
251
352
|
ftol=ftol,
|
252
353
|
gtol=gtol,
|
253
354
|
)
|
254
|
-
latent_params = _to_pytree(scipy_lbfgsb_state.x,
|
355
|
+
latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_params)
|
255
356
|
return latent_params, scipy_lbfgsb_state.to_jax()
|
256
357
|
|
257
|
-
(
|
258
|
-
|
259
|
-
|
260
|
-
) = jax.pure_callback(
|
261
|
-
_init_pure,
|
262
|
-
_example_state(params, maxcor),
|
263
|
-
initialize_latent_fn(params),
|
358
|
+
latent_params = _init_latents(params)
|
359
|
+
latent_params, jax_lbfgsb_state = jax.pure_callback(
|
360
|
+
_init_state_pure, _example_state(latent_params, maxcor), latent_params
|
264
361
|
)
|
265
|
-
return
|
362
|
+
return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
|
266
363
|
|
267
364
|
def params_fn(state: LbfgsbState) -> PyTree:
|
268
365
|
"""Returns the parameters for the given `state`."""
|
@@ -294,16 +391,35 @@ def transformed_lbfgsb(
|
|
294
391
|
return flat_latent_params, scipy_lbfgsb_state.to_jax()
|
295
392
|
|
296
393
|
_, latent_params, jax_lbfgsb_state = state
|
297
|
-
_, vjp_fn = jax.vjp(
|
394
|
+
_, vjp_fn = jax.vjp(_params_from_latents, latent_params)
|
298
395
|
(latent_grad,) = vjp_fn(grad)
|
396
|
+
|
397
|
+
if not (
|
398
|
+
tree_util.tree_structure(latent_grad)
|
399
|
+
== tree_util.tree_structure(latent_params) # type: ignore[operator]
|
400
|
+
):
|
401
|
+
raise ValueError(
|
402
|
+
f"Tree structure of `latent_grad` was different than expected, got \n"
|
403
|
+
f"{tree_util.tree_structure(latent_grad)} but expected \n"
|
404
|
+
f"{tree_util.tree_structure(latent_params)}."
|
405
|
+
)
|
406
|
+
|
407
|
+
(
|
408
|
+
constraint_loss_value,
|
409
|
+
constraint_loss_grad,
|
410
|
+
) = jax.value_and_grad(
|
411
|
+
_constraint_loss
|
412
|
+
)(latent_params)
|
413
|
+
value += constraint_loss_value
|
414
|
+
latent_grad = tree_util.tree_map(
|
415
|
+
lambda a, b: a + b, latent_grad, constraint_loss_grad
|
416
|
+
)
|
417
|
+
|
299
418
|
flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
|
300
419
|
latent_grad
|
301
420
|
) # type: ignore[no-untyped-call]
|
302
421
|
|
303
|
-
(
|
304
|
-
flat_latent_params,
|
305
|
-
jax_lbfgsb_state,
|
306
|
-
) = jax.pure_callback(
|
422
|
+
flat_latent_params, jax_lbfgsb_state = jax.pure_callback(
|
307
423
|
_update_pure,
|
308
424
|
(flat_latent_grad, jax_lbfgsb_state),
|
309
425
|
flat_latent_grad,
|
@@ -311,7 +427,7 @@ def transformed_lbfgsb(
|
|
311
427
|
jax_lbfgsb_state,
|
312
428
|
)
|
313
429
|
latent_params = unflatten_fn(flat_latent_params)
|
314
|
-
return
|
430
|
+
return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
|
315
431
|
|
316
432
|
return base.Optimizer(
|
317
433
|
init=init_fn,
|
@@ -335,6 +451,31 @@ def _is_density(leaf: Any) -> Any:
|
|
335
451
|
return isinstance(leaf, types.Density2DArray)
|
336
452
|
|
337
453
|
|
454
|
+
def _is_parameterized_density(leaf: Any) -> Any:
|
455
|
+
"""Return `True` if `leaf` is a parameterized density array."""
|
456
|
+
return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase)
|
457
|
+
|
458
|
+
|
459
|
+
def _is_custom_type(leaf: Any) -> bool:
|
460
|
+
"""Return `True` if `leaf` is a recognized custom type."""
|
461
|
+
return isinstance(leaf, (types.BoundedArray, types.Density2DArray))
|
462
|
+
|
463
|
+
|
464
|
+
def _clip(pytree: PyTree) -> PyTree:
|
465
|
+
"""Clips leaves on `pytree` to their bounds."""
|
466
|
+
|
467
|
+
def _clip_fn(leaf: Any) -> Any:
|
468
|
+
if not _is_custom_type(leaf):
|
469
|
+
return leaf
|
470
|
+
if leaf.lower_bound is None and leaf.upper_bound is None:
|
471
|
+
return leaf
|
472
|
+
return tree_util.tree_map(
|
473
|
+
lambda x: jnp.clip(x, leaf.lower_bound, leaf.upper_bound), leaf
|
474
|
+
)
|
475
|
+
|
476
|
+
return tree_util.tree_map(_clip_fn, pytree, is_leaf=_is_custom_type)
|
477
|
+
|
478
|
+
|
338
479
|
def _to_numpy(params: PyTree) -> NDArray:
|
339
480
|
"""Flattens a `params` pytree into a single rank-1 numpy array."""
|
340
481
|
x, _ = flatten_util.ravel_pytree(params) # type: ignore[no-untyped-call]
|