invrs-opt 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +1 -1
- invrs_opt/optimizers/lbfgsb.py +21 -7
- invrs_opt/optimizers/wrapped_optax.py +19 -6
- invrs_opt/parameterization/base.py +8 -0
- invrs_opt/parameterization/filter_project.py +6 -0
- invrs_opt/parameterization/gaussian_levelset.py +7 -1
- invrs_opt/parameterization/pixel.py +5 -0
- {invrs_opt-0.7.2.dist-info → invrs_opt-0.8.0.dist-info}/METADATA +2 -2
- invrs_opt-0.8.0.dist-info/RECORD +20 -0
- {invrs_opt-0.7.2.dist-info → invrs_opt-0.8.0.dist-info}/WHEEL +1 -1
- invrs_opt-0.7.2.dist-info/RECORD +0 -20
- {invrs_opt-0.7.2.dist-info → invrs_opt-0.8.0.dist-info}/LICENSE +0 -0
- {invrs_opt-0.7.2.dist-info → invrs_opt-0.8.0.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
invrs_opt/optimizers/lbfgsb.py
CHANGED
@@ -29,7 +29,7 @@ NDArray = onp.ndarray[Any, Any]
|
|
29
29
|
PyTree = Any
|
30
30
|
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
31
31
|
JaxLbfgsbDict = Dict[str, jnp.ndarray]
|
32
|
-
LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]
|
32
|
+
LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
|
33
33
|
|
34
34
|
|
35
35
|
# Task message prefixes for the underlying L-BFGS-B implementation.
|
@@ -323,7 +323,7 @@ def parameterized_lbfgsb(
|
|
323
323
|
) -> jnp.ndarray:
|
324
324
|
constraints = density_parameterization.constraints(params)
|
325
325
|
constraints = tree_util.tree_map(
|
326
|
-
lambda x: jnp.sum(jnp.maximum(x, 0.0)),
|
326
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
327
327
|
constraints,
|
328
328
|
)
|
329
329
|
return jnp.sum(jnp.asarray(constraints))
|
@@ -337,6 +337,18 @@ def parameterized_lbfgsb(
|
|
337
337
|
]
|
338
338
|
return penalty * jnp.sum(jnp.asarray(losses))
|
339
339
|
|
340
|
+
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
341
|
+
def _update_leaf(leaf: Any) -> Any:
|
342
|
+
if not _is_parameterized_density(leaf):
|
343
|
+
return leaf
|
344
|
+
return density_parameterization.update(leaf, step)
|
345
|
+
|
346
|
+
return tree_util.tree_map(
|
347
|
+
_update_leaf,
|
348
|
+
latent_params,
|
349
|
+
is_leaf=_is_parameterized_density,
|
350
|
+
)
|
351
|
+
|
340
352
|
def init_fn(params: PyTree) -> LbfgsbState:
|
341
353
|
"""Initializes the optimization state."""
|
342
354
|
|
@@ -359,11 +371,11 @@ def parameterized_lbfgsb(
|
|
359
371
|
latent_params, jax_lbfgsb_state = jax.pure_callback(
|
360
372
|
_init_state_pure, _example_state(latent_params, maxcor), latent_params
|
361
373
|
)
|
362
|
-
return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
|
374
|
+
return 0, _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
|
363
375
|
|
364
376
|
def params_fn(state: LbfgsbState) -> PyTree:
|
365
377
|
"""Returns the parameters for the given `state`."""
|
366
|
-
params, _, _ = state
|
378
|
+
_, params, _, _ = state
|
367
379
|
return params
|
368
380
|
|
369
381
|
def update_fn(
|
@@ -390,7 +402,7 @@ def parameterized_lbfgsb(
|
|
390
402
|
flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
|
391
403
|
return flat_latent_params, scipy_lbfgsb_state.to_jax()
|
392
404
|
|
393
|
-
_, latent_params, jax_lbfgsb_state = state
|
405
|
+
step, _, latent_params, jax_lbfgsb_state = state
|
394
406
|
_, vjp_fn = jax.vjp(_params_from_latents, latent_params)
|
395
407
|
(latent_grad,) = vjp_fn(grad)
|
396
408
|
|
@@ -427,7 +439,9 @@ def parameterized_lbfgsb(
|
|
427
439
|
jax_lbfgsb_state,
|
428
440
|
)
|
429
441
|
latent_params = unflatten_fn(flat_latent_params)
|
430
|
-
|
442
|
+
latent_params = _update_parameterized_densities(latent_params, step)
|
443
|
+
params = _params_from_latents(latent_params)
|
444
|
+
return step + 1, params, latent_params, jax_lbfgsb_state
|
431
445
|
|
432
446
|
return base.Optimizer(
|
433
447
|
init=init_fn,
|
@@ -438,7 +452,7 @@ def parameterized_lbfgsb(
|
|
438
452
|
|
439
453
|
def is_converged(state: LbfgsbState) -> jnp.ndarray:
|
440
454
|
"""Returns `True` if the optimization has converged."""
|
441
|
-
return state[
|
455
|
+
return state[3]["converged"]
|
442
456
|
|
443
457
|
|
444
458
|
# ------------------------------------------------------------------------------
|
@@ -20,7 +20,7 @@ from invrs_opt.parameterization import (
|
|
20
20
|
)
|
21
21
|
|
22
22
|
PyTree = Any
|
23
|
-
WrappedOptaxState = Tuple[PyTree, PyTree, PyTree]
|
23
|
+
WrappedOptaxState = Tuple[int, PyTree, PyTree, PyTree]
|
24
24
|
|
25
25
|
|
26
26
|
def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
|
@@ -205,7 +205,7 @@ def parameterized_wrapped_optax(
|
|
205
205
|
) -> jnp.ndarray:
|
206
206
|
constraints = density_parameterization.constraints(params)
|
207
207
|
constraints = tree_util.tree_map(
|
208
|
-
lambda x: jnp.sum(jnp.maximum(x, 0.0)),
|
208
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
209
209
|
constraints,
|
210
210
|
)
|
211
211
|
return jnp.sum(jnp.asarray(constraints))
|
@@ -219,15 +219,27 @@ def parameterized_wrapped_optax(
|
|
219
219
|
]
|
220
220
|
return penalty * jnp.sum(jnp.asarray(losses))
|
221
221
|
|
222
|
+
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
223
|
+
def _update_leaf(leaf: Any) -> Any:
|
224
|
+
if not _is_parameterized_density(leaf):
|
225
|
+
return leaf
|
226
|
+
return density_parameterization.update(leaf, step)
|
227
|
+
|
228
|
+
return tree_util.tree_map(
|
229
|
+
_update_leaf,
|
230
|
+
latent_params,
|
231
|
+
is_leaf=_is_parameterized_density,
|
232
|
+
)
|
233
|
+
|
222
234
|
def init_fn(params: PyTree) -> WrappedOptaxState:
|
223
235
|
"""Initializes the optimization state."""
|
224
236
|
latent_params = _init_latents(params)
|
225
237
|
params = _params_from_latents(latent_params)
|
226
|
-
return params, latent_params, opt.init(latent_params)
|
238
|
+
return 0, params, latent_params, opt.init(latent_params)
|
227
239
|
|
228
240
|
def params_fn(state: WrappedOptaxState) -> PyTree:
|
229
241
|
"""Returns the parameters for the given `state`."""
|
230
|
-
params, _, _ = state
|
242
|
+
_, params, _, _ = state
|
231
243
|
return params
|
232
244
|
|
233
245
|
def update_fn(
|
@@ -240,7 +252,7 @@ def parameterized_wrapped_optax(
|
|
240
252
|
"""Updates the state."""
|
241
253
|
del value, params
|
242
254
|
|
243
|
-
_, latent_params, opt_state = state
|
255
|
+
step, _, latent_params, opt_state = state
|
244
256
|
_, vjp_fn = jax.vjp(_params_from_latents, latent_params)
|
245
257
|
(latent_grad,) = vjp_fn(grad)
|
246
258
|
|
@@ -264,8 +276,9 @@ def parameterized_wrapped_optax(
|
|
264
276
|
)
|
265
277
|
latent_params = optax.apply_updates(params=latent_params, updates=updates)
|
266
278
|
latent_params = _clip(latent_params)
|
279
|
+
latent_params = _update_parameterized_densities(latent_params, step)
|
267
280
|
params = _params_from_latents(latent_params)
|
268
|
-
return params, latent_params, opt_state
|
281
|
+
return step + 1, params, latent_params, opt_state
|
269
282
|
|
270
283
|
return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
|
271
284
|
|
@@ -44,6 +44,13 @@ class ConstraintsFn(Protocol):
|
|
44
44
|
...
|
45
45
|
|
46
46
|
|
47
|
+
class UpdateFn(Protocol):
|
48
|
+
"""Performs the required update of a parameterized density for the given step."""
|
49
|
+
|
50
|
+
def __call__(self, params: PyTree, step: int) -> PyTree:
|
51
|
+
...
|
52
|
+
|
53
|
+
|
47
54
|
@dataclasses.dataclass
|
48
55
|
class Density2DParameterization:
|
49
56
|
"""Stores `(from_density, to_density, constraints)` function triple."""
|
@@ -51,6 +58,7 @@ class Density2DParameterization:
|
|
51
58
|
from_density: FromDensityFn
|
52
59
|
to_density: ToDensityFn
|
53
60
|
constraints: ConstraintsFn
|
61
|
+
update: UpdateFn
|
54
62
|
|
55
63
|
|
56
64
|
@dataclasses.dataclass
|
@@ -85,8 +85,14 @@ def filter_project(beta: float) -> base.Density2DParameterization:
|
|
85
85
|
del params
|
86
86
|
return jnp.asarray(0.0)
|
87
87
|
|
88
|
+
def update_fn(params: FilterAndProjectParams, step: int) -> FilterAndProjectParams:
|
89
|
+
"""Perform updates to `params` required for the given `step`."""
|
90
|
+
del step
|
91
|
+
return params
|
92
|
+
|
88
93
|
return base.Density2DParameterization(
|
89
94
|
to_density=to_density_fn,
|
90
95
|
from_density=from_density_fn,
|
91
96
|
constraints=constraints_fn,
|
97
|
+
update=update_fn,
|
92
98
|
)
|
@@ -214,7 +214,7 @@ def gaussian_levelset(
|
|
214
214
|
return params, state
|
215
215
|
|
216
216
|
state = init_optimizer.init(params)
|
217
|
-
params,
|
217
|
+
params, _ = jax.lax.fori_loop(
|
218
218
|
0, init_steps, body_fun=step_fn, init_val=(params, state)
|
219
219
|
)
|
220
220
|
|
@@ -269,10 +269,16 @@ def gaussian_levelset(
|
|
269
269
|
)
|
270
270
|
return constraints / length_scale**2
|
271
271
|
|
272
|
+
def update_fn(params: GaussianLevelsetParams, step: int) -> GaussianLevelsetParams:
|
273
|
+
"""Perform updates to `params` required for the given `step`."""
|
274
|
+
del step
|
275
|
+
return params
|
276
|
+
|
272
277
|
return base.Density2DParameterization(
|
273
278
|
to_density=to_density_fn,
|
274
279
|
from_density=from_density_fn,
|
275
280
|
constraints=constraints_fn,
|
281
|
+
update=update_fn,
|
276
282
|
)
|
277
283
|
|
278
284
|
|
@@ -38,8 +38,13 @@ def pixel() -> base.Density2DParameterization:
|
|
38
38
|
del params
|
39
39
|
return jnp.asarray(0.0)
|
40
40
|
|
41
|
+
def update_fn(params: PixelParams, step: int) -> PixelParams:
|
42
|
+
del step
|
43
|
+
return params
|
44
|
+
|
41
45
|
return base.Density2DParameterization(
|
42
46
|
from_density=from_density_fn,
|
43
47
|
to_density=to_density_fn,
|
44
48
|
constraints=constraints_fn,
|
49
|
+
update=update_fn,
|
45
50
|
)
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.8.0
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
|
|
533
533
|
Requires-Dist: pytest-subtests; extra == "tests"
|
534
534
|
|
535
535
|
# invrs-opt - Optimization algorithms for inverse design
|
536
|
-
`v0.
|
536
|
+
`v0.8.0`
|
537
537
|
|
538
538
|
## Overview
|
539
539
|
|
@@ -0,0 +1,20 @@
|
|
1
|
+
invrs_opt/__init__.py,sha256=kTrg48iZu7i5OlH0Nqtfh_wBn3be9u2eZgJBRqUr6uQ,585
|
2
|
+
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
|
5
|
+
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
6
|
+
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
+
invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
|
8
|
+
invrs_opt/optimizers/lbfgsb.py,sha256=2NcyFllUM5CrjOZLEBocYAlCrkWt1fiRGJrxg05waog,35149
|
9
|
+
invrs_opt/optimizers/wrapped_optax.py,sha256=zPC2j_KkfF2RqeorxB38ovuUsg1SNwdxwhjA7gvOMC4,12387
|
10
|
+
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
invrs_opt/parameterization/base.py,sha256=k2VXZorMnkmaIWWY_G2CvS0NtxYojfvKlPccDGZyGNA,3779
|
12
|
+
invrs_opt/parameterization/filter_project.py,sha256=D2w8xrg34V8ysIbbr1RPvegM5WoLdz8QKUCGi80ieOI,3466
|
13
|
+
invrs_opt/parameterization/gaussian_levelset.py,sha256=DllvpkBpVxuDFQFu941f8v_Xh2D0m9Eh-JxBH5afNHU,25221
|
14
|
+
invrs_opt/parameterization/pixel.py,sha256=yWdGpf0x6Om_-7CYOcHTgNK9UF-XGAPM7GbOBNijnBw,1305
|
15
|
+
invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
|
16
|
+
invrs_opt-0.8.0.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
+
invrs_opt-0.8.0.dist-info/METADATA,sha256=WSzpZwByiBqOZfiwSGvaVaFa5TUAPj4b-mr2qZRGLq0,32633
|
18
|
+
invrs_opt-0.8.0.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
19
|
+
invrs_opt-0.8.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
+
invrs_opt-0.8.0.dist-info/RECORD,,
|
invrs_opt-0.7.2.dist-info/RECORD
DELETED
@@ -1,20 +0,0 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=_hiwGthyPapJUx8-nP7JrJYEFX3mJYTsdUh_X0MyhEQ,585
|
2
|
-
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
|
5
|
-
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
6
|
-
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
-
invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
|
8
|
-
invrs_opt/optimizers/lbfgsb.py,sha256=d7i02NZZ3yYdJg7wkERDMPpdBD3GaRPhmQexGrZPz_Y,34597
|
9
|
-
invrs_opt/optimizers/wrapped_optax.py,sha256=L836gwzWbwxPNWh8Y7PSgFHV4PZluSklnbh3BR5djlc,11859
|
10
|
-
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
invrs_opt/parameterization/base.py,sha256=QTnhOfMYbDchZOFzk9graryMd6rYHlyd7E2T1TtucB8,3570
|
12
|
-
invrs_opt/parameterization/filter_project.py,sha256=XCPqQ2ECv7DDTLRtVGJePfnjKYB2XndI6DssSr-4MZw,3239
|
13
|
-
invrs_opt/parameterization/gaussian_levelset.py,sha256=IGdQl3XMEHYNxNPSoKmUD7ZJu_fImIZTqqqmvilfmag,24998
|
14
|
-
invrs_opt/parameterization/pixel.py,sha256=4qCYDUCcFPr8W94whX5YFllrXgyZbPBVIJBf8m_Dv4k,1173
|
15
|
-
invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
|
16
|
-
invrs_opt-0.7.2.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
-
invrs_opt-0.7.2.dist-info/METADATA,sha256=jx_MxaNkrJRUBRubRlDIytDRznP_1P9sWRHrqihaKi8,32633
|
18
|
-
invrs_opt-0.7.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
|
19
|
-
invrs_opt-0.7.2.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
-
invrs_opt-0.7.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|