invrs-opt 0.10.6__py3-none-any.whl → 0.11.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +1 -1
- invrs_opt/optimizers/lbfgsb.py +43 -22
- invrs_opt/optimizers/wrapped_optax.py +44 -20
- invrs_opt/parameterization/base.py +31 -0
- invrs_opt/parameterization/gaussian_levelset.py +1 -1
- {invrs_opt-0.10.6.dist-info → invrs_opt-0.11.0.dist-info}/METADATA +3 -2
- {invrs_opt-0.10.6.dist-info → invrs_opt-0.11.0.dist-info}/RECORD +10 -10
- {invrs_opt-0.10.6.dist-info → invrs_opt-0.11.0.dist-info}/WHEEL +1 -1
- {invrs_opt-0.10.6.dist-info → invrs_opt-0.11.0.dist-info/licenses}/LICENSE +0 -0
- {invrs_opt-0.10.6.dist-info → invrs_opt-0.11.0.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
invrs_opt/optimizers/lbfgsb.py
CHANGED
|
@@ -16,7 +16,7 @@ from jax import flatten_util, tree_util
|
|
|
16
16
|
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
|
17
17
|
_lbfgsb as scipy_lbfgsb,
|
|
18
18
|
)
|
|
19
|
-
from totypes import types
|
|
19
|
+
from totypes import json_utils, types
|
|
20
20
|
|
|
21
21
|
from invrs_opt.optimizers import base
|
|
22
22
|
from invrs_opt.parameterization import (
|
|
@@ -31,7 +31,26 @@ PyTree = Any
|
|
|
31
31
|
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
|
32
32
|
NumpyLbfgsbDict = Dict[str, NDArray]
|
|
33
33
|
JaxLbfgsbDict = Dict[str, jnp.ndarray]
|
|
34
|
-
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclasses.dataclass
|
|
37
|
+
class LbfgsbState:
|
|
38
|
+
"""Stores the state of the L-BFGS-B optimizer."""
|
|
39
|
+
|
|
40
|
+
step: int
|
|
41
|
+
params: PyTree
|
|
42
|
+
latent_params: PyTree
|
|
43
|
+
opt_state: JaxLbfgsbDict
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
tree_util.register_dataclass(
|
|
47
|
+
LbfgsbState,
|
|
48
|
+
data_fields=["step", "params", "latent_params", "opt_state"],
|
|
49
|
+
meta_fields=[],
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
json_utils.register_custom_type(LbfgsbState)
|
|
35
54
|
|
|
36
55
|
|
|
37
56
|
# Task message prefixes for the underlying L-BFGS-B implementation.
|
|
@@ -327,17 +346,16 @@ def parameterized_lbfgsb(
|
|
|
327
346
|
latents,
|
|
328
347
|
)
|
|
329
348
|
latent_params = param_base.combine_density_metadata(metadata, latents)
|
|
330
|
-
return (
|
|
331
|
-
0,
|
|
332
|
-
_params_from_latent_params(latent_params),
|
|
333
|
-
latent_params,
|
|
334
|
-
jax_lbfgsb_state,
|
|
349
|
+
return LbfgsbState(
|
|
350
|
+
step=0,
|
|
351
|
+
params=_params_from_latent_params(latent_params),
|
|
352
|
+
latent_params=latent_params,
|
|
353
|
+
opt_state=jax_lbfgsb_state,
|
|
335
354
|
)
|
|
336
355
|
|
|
337
356
|
def params_fn(state: LbfgsbState) -> PyTree:
|
|
338
357
|
"""Returns the parameters for the given `state`."""
|
|
339
|
-
|
|
340
|
-
return params
|
|
358
|
+
return state.params
|
|
341
359
|
|
|
342
360
|
def update_fn(
|
|
343
361
|
*,
|
|
@@ -366,8 +384,7 @@ def parameterized_lbfgsb(
|
|
|
366
384
|
flat_latent_updates = updated_flat_latent_params - flat_latent_params
|
|
367
385
|
return flat_latent_updates, scipy_lbfgsb_state.to_dict()
|
|
368
386
|
|
|
369
|
-
|
|
370
|
-
metadata, latents = param_base.partition_density_metadata(latent_params)
|
|
387
|
+
metadata, latents = param_base.partition_density_metadata(state.latent_params)
|
|
371
388
|
|
|
372
389
|
def _params_from_latents(latents: PyTree) -> PyTree:
|
|
373
390
|
latent_params = param_base.combine_density_metadata(metadata, latents)
|
|
@@ -380,10 +397,9 @@ def parameterized_lbfgsb(
|
|
|
380
397
|
_, vjp_fn = jax.vjp(_params_from_latents, latents)
|
|
381
398
|
(latents_grad,) = vjp_fn(grad)
|
|
382
399
|
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
):
|
|
400
|
+
treedef = tree_util.tree_structure(latents_grad)
|
|
401
|
+
expected_treedef = tree_util.tree_structure(latents)
|
|
402
|
+
if not treedef == expected_treedef: # type: ignore[operator]
|
|
387
403
|
raise ValueError(
|
|
388
404
|
f"Tree structure of `latents_grad` was different than expected, got \n"
|
|
389
405
|
f"{tree_util.tree_structure(latents_grad)} but expected \n"
|
|
@@ -405,23 +421,28 @@ def parameterized_lbfgsb(
|
|
|
405
421
|
latents_grad
|
|
406
422
|
) # type: ignore[no-untyped-call]
|
|
407
423
|
|
|
408
|
-
flat_latent_updates,
|
|
424
|
+
flat_latent_updates, opt_state = callback_sequential(
|
|
409
425
|
_update_pure,
|
|
410
|
-
(flat_latents_grad,
|
|
426
|
+
(flat_latents_grad, state.opt_state),
|
|
411
427
|
flat_latents_grad,
|
|
412
428
|
value,
|
|
413
|
-
|
|
429
|
+
state.opt_state,
|
|
414
430
|
)
|
|
415
431
|
latent_updates = unflatten_fn(flat_latent_updates)
|
|
416
432
|
latent_params = _apply_updates(
|
|
417
|
-
params=latent_params,
|
|
433
|
+
params=state.latent_params,
|
|
418
434
|
updates=param_base.combine_density_metadata(metadata, latent_updates),
|
|
419
435
|
value=value,
|
|
420
|
-
step=step,
|
|
436
|
+
step=state.step,
|
|
421
437
|
)
|
|
422
438
|
latent_params = _clip(latent_params)
|
|
423
439
|
params = _params_from_latent_params(latent_params)
|
|
424
|
-
return
|
|
440
|
+
return LbfgsbState(
|
|
441
|
+
step=state.step + 1,
|
|
442
|
+
params=params,
|
|
443
|
+
latent_params=latent_params,
|
|
444
|
+
opt_state=opt_state,
|
|
445
|
+
)
|
|
425
446
|
|
|
426
447
|
# -------------------------------------------------------------------------
|
|
427
448
|
# Functions related to the density parameterization.
|
|
@@ -502,7 +523,7 @@ def parameterized_lbfgsb(
|
|
|
502
523
|
|
|
503
524
|
def is_converged(state: LbfgsbState) -> jnp.ndarray:
|
|
504
525
|
"""Returns `True` if the optimization has converged."""
|
|
505
|
-
return state[
|
|
526
|
+
return state.opt_state["converged"]
|
|
506
527
|
|
|
507
528
|
|
|
508
529
|
# ------------------------------------------------------------------------------
|
|
@@ -3,13 +3,14 @@
|
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import dataclasses
|
|
7
|
+
from typing import Any, Optional
|
|
7
8
|
|
|
8
9
|
import jax
|
|
9
10
|
import jax.numpy as jnp
|
|
10
11
|
import optax # type: ignore[import-untyped]
|
|
11
12
|
from jax import tree_util
|
|
12
|
-
from totypes import types
|
|
13
|
+
from totypes import json_utils, types
|
|
13
14
|
|
|
14
15
|
from invrs_opt.optimizers import base
|
|
15
16
|
from invrs_opt.parameterization import (
|
|
@@ -20,7 +21,26 @@ from invrs_opt.parameterization import (
|
|
|
20
21
|
)
|
|
21
22
|
|
|
22
23
|
PyTree = Any
|
|
23
|
-
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclasses.dataclass
|
|
27
|
+
class WrappedOptaxState:
|
|
28
|
+
"""Stores the state of a wrapped optax optimizer."""
|
|
29
|
+
|
|
30
|
+
step: int
|
|
31
|
+
params: PyTree
|
|
32
|
+
latent_params: PyTree
|
|
33
|
+
opt_state: PyTree
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
tree_util.register_dataclass(
|
|
37
|
+
WrappedOptaxState,
|
|
38
|
+
data_fields=["step", "params", "latent_params", "opt_state"],
|
|
39
|
+
meta_fields=[],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
json_utils.register_custom_type(WrappedOptaxState)
|
|
24
44
|
|
|
25
45
|
|
|
26
46
|
def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
|
|
@@ -182,17 +202,16 @@ def parameterized_wrapped_optax(
|
|
|
182
202
|
"""Initializes the optimization state."""
|
|
183
203
|
latent_params = _init_latents(params)
|
|
184
204
|
_, latents = param_base.partition_density_metadata(latent_params)
|
|
185
|
-
return (
|
|
186
|
-
0,
|
|
187
|
-
_params_from_latent_params(latent_params),
|
|
188
|
-
latent_params,
|
|
189
|
-
opt.init(latents),
|
|
205
|
+
return WrappedOptaxState(
|
|
206
|
+
step=0,
|
|
207
|
+
params=_params_from_latent_params(latent_params),
|
|
208
|
+
latent_params=latent_params,
|
|
209
|
+
opt_state=opt.init(latents),
|
|
190
210
|
)
|
|
191
211
|
|
|
192
212
|
def params_fn(state: WrappedOptaxState) -> PyTree:
|
|
193
213
|
"""Returns the parameters for the given `state`."""
|
|
194
|
-
|
|
195
|
-
return params
|
|
214
|
+
return state.params
|
|
196
215
|
|
|
197
216
|
def update_fn(
|
|
198
217
|
*,
|
|
@@ -204,8 +223,7 @@ def parameterized_wrapped_optax(
|
|
|
204
223
|
"""Updates the state."""
|
|
205
224
|
del params
|
|
206
225
|
|
|
207
|
-
|
|
208
|
-
metadata, latents = param_base.partition_density_metadata(latent_params)
|
|
226
|
+
metadata, latents = param_base.partition_density_metadata(state.latent_params)
|
|
209
227
|
|
|
210
228
|
def _params_from_latents(latents: PyTree) -> PyTree:
|
|
211
229
|
latent_params = param_base.combine_density_metadata(metadata, latents)
|
|
@@ -218,10 +236,9 @@ def parameterized_wrapped_optax(
|
|
|
218
236
|
_, vjp_fn = jax.vjp(_params_from_latents, latents)
|
|
219
237
|
(latents_grad,) = vjp_fn(grad)
|
|
220
238
|
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
):
|
|
239
|
+
treedef = tree_util.tree_structure(latents_grad)
|
|
240
|
+
expected_treedef = tree_util.tree_structure(latents)
|
|
241
|
+
if not treedef == expected_treedef: # type: ignore[operator]
|
|
225
242
|
raise ValueError(
|
|
226
243
|
f"Tree structure of `latents_grad` was different than expected, got \n"
|
|
227
244
|
f"{tree_util.tree_structure(latents_grad)} but expected \n"
|
|
@@ -233,16 +250,23 @@ def parameterized_wrapped_optax(
|
|
|
233
250
|
lambda a, b: a + b, latents_grad, constraint_loss_grad
|
|
234
251
|
)
|
|
235
252
|
|
|
236
|
-
latent_updates, opt_state = opt.update(
|
|
253
|
+
latent_updates, opt_state = opt.update(
|
|
254
|
+
latents_grad, state.opt_state, params=latents
|
|
255
|
+
)
|
|
237
256
|
latent_params = _apply_updates(
|
|
238
|
-
params=latent_params,
|
|
257
|
+
params=state.latent_params,
|
|
239
258
|
updates=param_base.combine_density_metadata(metadata, latent_updates),
|
|
240
259
|
value=value,
|
|
241
|
-
step=step,
|
|
260
|
+
step=state.step,
|
|
242
261
|
)
|
|
243
262
|
latent_params = _clip(latent_params)
|
|
244
263
|
params = _params_from_latent_params(latent_params)
|
|
245
|
-
return (
|
|
264
|
+
return WrappedOptaxState(
|
|
265
|
+
step=state.step + 1,
|
|
266
|
+
params=params,
|
|
267
|
+
latent_params=latent_params,
|
|
268
|
+
opt_state=opt_state,
|
|
269
|
+
)
|
|
246
270
|
|
|
247
271
|
# -------------------------------------------------------------------------
|
|
248
272
|
# Functions related to the density parameterization.
|
|
@@ -124,6 +124,22 @@ class Density2DMetadata:
|
|
|
124
124
|
self.periodic = tuple(self.periodic)
|
|
125
125
|
self.symmetries = tuple(self.symmetries)
|
|
126
126
|
|
|
127
|
+
def __eq__(self, other: Any) -> bool:
|
|
128
|
+
if not isinstance(other, Density2DMetadata):
|
|
129
|
+
return False
|
|
130
|
+
if not (
|
|
131
|
+
self.lower_bound == other.lower_bound
|
|
132
|
+
and self.upper_bound == other.upper_bound
|
|
133
|
+
and _arrays_equal_or_both_none(self.fixed_solid, other.fixed_solid)
|
|
134
|
+
and _arrays_equal_or_both_none(self.fixed_void, other.fixed_void)
|
|
135
|
+
and self.minimum_width == other.minimum_width
|
|
136
|
+
and self.minimum_spacing == other.minimum_spacing
|
|
137
|
+
and self.periodic == other.periodic
|
|
138
|
+
and self.symmetries == other.symmetries
|
|
139
|
+
):
|
|
140
|
+
return False
|
|
141
|
+
return True
|
|
142
|
+
|
|
127
143
|
@classmethod
|
|
128
144
|
def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
|
|
129
145
|
density_metadata_dict = dataclasses.asdict(density)
|
|
@@ -131,6 +147,21 @@ class Density2DMetadata:
|
|
|
131
147
|
return Density2DMetadata(**density_metadata_dict)
|
|
132
148
|
|
|
133
149
|
|
|
150
|
+
def _arrays_equal_or_both_none(a: Optional[Array], b: Optional[Array]) -> bool:
|
|
151
|
+
"""Return `True` if `a` and `b` are equal arrays or both `None`."""
|
|
152
|
+
if (a is None, b is None) not in ((True, True), (False, False)):
|
|
153
|
+
return False
|
|
154
|
+
if a is None and b is None:
|
|
155
|
+
return True
|
|
156
|
+
assert isinstance(a, onp.ndarray)
|
|
157
|
+
assert isinstance(b, onp.ndarray)
|
|
158
|
+
if a.dtype != b.dtype:
|
|
159
|
+
return False
|
|
160
|
+
if a.shape != b.shape:
|
|
161
|
+
return False
|
|
162
|
+
return bool(onp.all(a == b))
|
|
163
|
+
|
|
164
|
+
|
|
134
165
|
def _flatten_density_2d_metadata(
|
|
135
166
|
metadata: Density2DMetadata,
|
|
136
167
|
) -> Tuple[
|
|
@@ -388,7 +388,7 @@ def _phi_from_params(
|
|
|
388
388
|
assert array.shape[-2] % s_factor == 0
|
|
389
389
|
assert array.shape[-1] % s_factor == 0
|
|
390
390
|
array = symmetry.symmetrize(array, tuple(example_density.symmetries))
|
|
391
|
-
return array
|
|
391
|
+
return jnp.asarray(array)
|
|
392
392
|
|
|
393
393
|
|
|
394
394
|
# -----------------------------------------------------------------------------
|
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: invrs_opt
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.11.0
|
|
4
4
|
Summary: Algorithms for inverse design
|
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
|
@@ -531,6 +531,7 @@ Requires-Dist: bump-my-version; extra == "dev"
|
|
|
531
531
|
Requires-Dist: darglint; extra == "dev"
|
|
532
532
|
Requires-Dist: mypy; extra == "dev"
|
|
533
533
|
Requires-Dist: pre-commit; extra == "dev"
|
|
534
|
+
Dynamic: license-file
|
|
534
535
|
|
|
535
536
|
# invrs-opt - Optimization algorithms for inverse design
|
|
536
537
|

|
|
@@ -1,20 +1,20 @@
|
|
|
1
|
-
invrs_opt/__init__.py,sha256=
|
|
1
|
+
invrs_opt/__init__.py,sha256=CQM3bUeeV9fgJgAee7NFy1B2HJyw2uhNSzoKynmVyOQ,586
|
|
2
2
|
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
3
3
|
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
4
|
invrs_opt/experimental/client.py,sha256=tbtH13FrA65XmTZfTO71CxJ78jeAEj3Zf85R-MTwbiU,4909
|
|
5
5
|
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
|
6
6
|
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
7
|
invrs_opt/optimizers/base.py,sha256=uFfkN2LwWzAtwh6ktwWNy2iHNOY-sW3JzI46iSFkgok,1306
|
|
8
|
-
invrs_opt/optimizers/lbfgsb.py,sha256=
|
|
9
|
-
invrs_opt/optimizers/wrapped_optax.py,sha256=
|
|
8
|
+
invrs_opt/optimizers/lbfgsb.py,sha256=QB8lD02sMr-2V0d_k4UB8Y7SOlNektp0vzwrJId0u44,37059
|
|
9
|
+
invrs_opt/optimizers/wrapped_optax.py,sha256=hftrGCpg4kVv2NarZAVHLg6Gkk87zcc4_yyfvSvQHTo,14161
|
|
10
10
|
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
invrs_opt/parameterization/base.py,sha256=
|
|
11
|
+
invrs_opt/parameterization/base.py,sha256=GFubMydcww6MXlGNkxHGMH_sOCyj9M5R-iwZYTqCo2I,6539
|
|
12
12
|
invrs_opt/parameterization/filter_project.py,sha256=XL3HTEBLrF-q_75TjhOWLNdfUOSEEjKcoM7Qj844QpQ,4590
|
|
13
|
-
invrs_opt/parameterization/gaussian_levelset.py,sha256=
|
|
13
|
+
invrs_opt/parameterization/gaussian_levelset.py,sha256=bmVU1We92zPpNIJI8sCq3OCeHZw6emZl86unJYwnWbc,24826
|
|
14
14
|
invrs_opt/parameterization/pixel.py,sha256=YWkyBhfYtzI8cQ-M90PAZqRAbabwVaUh0UiYIGegQHI,1955
|
|
15
15
|
invrs_opt/parameterization/transforms.py,sha256=mqDKuAg4wpSL9kh0oYKxtSoH0mHOQeKG1RND2fJSYaU,9441
|
|
16
|
-
invrs_opt-0.
|
|
17
|
-
invrs_opt-0.
|
|
18
|
-
invrs_opt-0.
|
|
19
|
-
invrs_opt-0.
|
|
20
|
-
invrs_opt-0.
|
|
16
|
+
invrs_opt-0.11.0.dist-info/licenses/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
|
17
|
+
invrs_opt-0.11.0.dist-info/METADATA,sha256=M6IzHHGIctk49mBxnswyKcvE3BikRoqZz_19UDOqB9c,32838
|
|
18
|
+
invrs_opt-0.11.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
19
|
+
invrs_opt-0.11.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
|
20
|
+
invrs_opt-0.11.0.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|