invrs-opt 0.7.2__py3-none-any.whl → 0.8.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.7.2"
6
+ __version__ = "v0.8.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -29,7 +29,7 @@ NDArray = onp.ndarray[Any, Any]
29
29
  PyTree = Any
30
30
  ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
31
31
  JaxLbfgsbDict = Dict[str, jnp.ndarray]
32
- LbfgsbState = Tuple[PyTree, PyTree, JaxLbfgsbDict]
32
+ LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
33
33
 
34
34
 
35
35
  # Task message prefixes for the underlying L-BFGS-B implementation.
@@ -323,7 +323,7 @@ def parameterized_lbfgsb(
323
323
  ) -> jnp.ndarray:
324
324
  constraints = density_parameterization.constraints(params)
325
325
  constraints = tree_util.tree_map(
326
- lambda x: jnp.sum(jnp.maximum(x, 0.0)),
326
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
327
327
  constraints,
328
328
  )
329
329
  return jnp.sum(jnp.asarray(constraints))
@@ -337,6 +337,18 @@ def parameterized_lbfgsb(
337
337
  ]
338
338
  return penalty * jnp.sum(jnp.asarray(losses))
339
339
 
340
+ def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
341
+ def _update_leaf(leaf: Any) -> Any:
342
+ if not _is_parameterized_density(leaf):
343
+ return leaf
344
+ return density_parameterization.update(leaf, step)
345
+
346
+ return tree_util.tree_map(
347
+ _update_leaf,
348
+ latent_params,
349
+ is_leaf=_is_parameterized_density,
350
+ )
351
+
340
352
  def init_fn(params: PyTree) -> LbfgsbState:
341
353
  """Initializes the optimization state."""
342
354
 
@@ -359,11 +371,11 @@ def parameterized_lbfgsb(
359
371
  latent_params, jax_lbfgsb_state = jax.pure_callback(
360
372
  _init_state_pure, _example_state(latent_params, maxcor), latent_params
361
373
  )
362
- return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
374
+ return 0, _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
363
375
 
364
376
  def params_fn(state: LbfgsbState) -> PyTree:
365
377
  """Returns the parameters for the given `state`."""
366
- params, _, _ = state
378
+ _, params, _, _ = state
367
379
  return params
368
380
 
369
381
  def update_fn(
@@ -390,7 +402,7 @@ def parameterized_lbfgsb(
390
402
  flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
391
403
  return flat_latent_params, scipy_lbfgsb_state.to_jax()
392
404
 
393
- _, latent_params, jax_lbfgsb_state = state
405
+ step, _, latent_params, jax_lbfgsb_state = state
394
406
  _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
395
407
  (latent_grad,) = vjp_fn(grad)
396
408
 
@@ -427,7 +439,9 @@ def parameterized_lbfgsb(
427
439
  jax_lbfgsb_state,
428
440
  )
429
441
  latent_params = unflatten_fn(flat_latent_params)
430
- return _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
442
+ latent_params = _update_parameterized_densities(latent_params, step)
443
+ params = _params_from_latents(latent_params)
444
+ return step + 1, params, latent_params, jax_lbfgsb_state
431
445
 
432
446
  return base.Optimizer(
433
447
  init=init_fn,
@@ -438,7 +452,7 @@ def parameterized_lbfgsb(
438
452
 
439
453
  def is_converged(state: LbfgsbState) -> jnp.ndarray:
440
454
  """Returns `True` if the optimization has converged."""
441
- return state[2]["converged"]
455
+ return state[3]["converged"]
442
456
 
443
457
 
444
458
  # ------------------------------------------------------------------------------
@@ -20,7 +20,7 @@ from invrs_opt.parameterization import (
20
20
  )
21
21
 
22
22
  PyTree = Any
23
- WrappedOptaxState = Tuple[PyTree, PyTree, PyTree]
23
+ WrappedOptaxState = Tuple[int, PyTree, PyTree, PyTree]
24
24
 
25
25
 
26
26
  def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
@@ -205,7 +205,7 @@ def parameterized_wrapped_optax(
205
205
  ) -> jnp.ndarray:
206
206
  constraints = density_parameterization.constraints(params)
207
207
  constraints = tree_util.tree_map(
208
- lambda x: jnp.sum(jnp.maximum(x, 0.0)),
208
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
209
209
  constraints,
210
210
  )
211
211
  return jnp.sum(jnp.asarray(constraints))
@@ -219,15 +219,27 @@ def parameterized_wrapped_optax(
219
219
  ]
220
220
  return penalty * jnp.sum(jnp.asarray(losses))
221
221
 
222
+ def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
223
+ def _update_leaf(leaf: Any) -> Any:
224
+ if not _is_parameterized_density(leaf):
225
+ return leaf
226
+ return density_parameterization.update(leaf, step)
227
+
228
+ return tree_util.tree_map(
229
+ _update_leaf,
230
+ latent_params,
231
+ is_leaf=_is_parameterized_density,
232
+ )
233
+
222
234
  def init_fn(params: PyTree) -> WrappedOptaxState:
223
235
  """Initializes the optimization state."""
224
236
  latent_params = _init_latents(params)
225
237
  params = _params_from_latents(latent_params)
226
- return params, latent_params, opt.init(latent_params)
238
+ return 0, params, latent_params, opt.init(latent_params)
227
239
 
228
240
  def params_fn(state: WrappedOptaxState) -> PyTree:
229
241
  """Returns the parameters for the given `state`."""
230
- params, _, _ = state
242
+ _, params, _, _ = state
231
243
  return params
232
244
 
233
245
  def update_fn(
@@ -240,7 +252,7 @@ def parameterized_wrapped_optax(
240
252
  """Updates the state."""
241
253
  del value, params
242
254
 
243
- _, latent_params, opt_state = state
255
+ step, _, latent_params, opt_state = state
244
256
  _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
245
257
  (latent_grad,) = vjp_fn(grad)
246
258
 
@@ -264,8 +276,9 @@ def parameterized_wrapped_optax(
264
276
  )
265
277
  latent_params = optax.apply_updates(params=latent_params, updates=updates)
266
278
  latent_params = _clip(latent_params)
279
+ latent_params = _update_parameterized_densities(latent_params, step)
267
280
  params = _params_from_latents(latent_params)
268
- return params, latent_params, opt_state
281
+ return step + 1, params, latent_params, opt_state
269
282
 
270
283
  return base.Optimizer(init=init_fn, params=params_fn, update=update_fn)
271
284
 
@@ -44,6 +44,13 @@ class ConstraintsFn(Protocol):
44
44
  ...
45
45
 
46
46
 
47
+ class UpdateFn(Protocol):
48
+ """Performs the required update of a parameterized density for the given step."""
49
+
50
+ def __call__(self, params: PyTree, step: int) -> PyTree:
51
+ ...
52
+
53
+
47
54
  @dataclasses.dataclass
48
55
  class Density2DParameterization:
49
56
  """Stores `(from_density, to_density, constraints)` function triple."""
@@ -51,6 +58,7 @@ class Density2DParameterization:
51
58
  from_density: FromDensityFn
52
59
  to_density: ToDensityFn
53
60
  constraints: ConstraintsFn
61
+ update: UpdateFn
54
62
 
55
63
 
56
64
  @dataclasses.dataclass
@@ -85,8 +85,14 @@ def filter_project(beta: float) -> base.Density2DParameterization:
85
85
  del params
86
86
  return jnp.asarray(0.0)
87
87
 
88
+ def update_fn(params: FilterAndProjectParams, step: int) -> FilterAndProjectParams:
89
+ """Perform updates to `params` required for the given `step`."""
90
+ del step
91
+ return params
92
+
88
93
  return base.Density2DParameterization(
89
94
  to_density=to_density_fn,
90
95
  from_density=from_density_fn,
91
96
  constraints=constraints_fn,
97
+ update=update_fn,
92
98
  )
@@ -214,7 +214,7 @@ def gaussian_levelset(
214
214
  return params, state
215
215
 
216
216
  state = init_optimizer.init(params)
217
- params, state = jax.lax.fori_loop(
217
+ params, _ = jax.lax.fori_loop(
218
218
  0, init_steps, body_fun=step_fn, init_val=(params, state)
219
219
  )
220
220
 
@@ -269,10 +269,16 @@ def gaussian_levelset(
269
269
  )
270
270
  return constraints / length_scale**2
271
271
 
272
+ def update_fn(params: GaussianLevelsetParams, step: int) -> GaussianLevelsetParams:
273
+ """Perform updates to `params` required for the given `step`."""
274
+ del step
275
+ return params
276
+
272
277
  return base.Density2DParameterization(
273
278
  to_density=to_density_fn,
274
279
  from_density=from_density_fn,
275
280
  constraints=constraints_fn,
281
+ update=update_fn,
276
282
  )
277
283
 
278
284
 
@@ -38,8 +38,13 @@ def pixel() -> base.Density2DParameterization:
38
38
  del params
39
39
  return jnp.asarray(0.0)
40
40
 
41
+ def update_fn(params: PixelParams, step: int) -> PixelParams:
42
+ del step
43
+ return params
44
+
41
45
  return base.Density2DParameterization(
42
46
  from_density=from_density_fn,
43
47
  to_density=to_density_fn,
44
48
  constraints=constraints_fn,
49
+ update=update_fn,
45
50
  )
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.7.2
3
+ Version: 0.8.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
533
533
  Requires-Dist: pytest-subtests; extra == "tests"
534
534
 
535
535
  # invrs-opt - Optimization algorithms for inverse design
536
- `v0.7.2`
536
+ `v0.8.0`
537
537
 
538
538
  ## Overview
539
539
 
@@ -0,0 +1,20 @@
1
+ invrs_opt/__init__.py,sha256=kTrg48iZu7i5OlH0Nqtfh_wBn3be9u2eZgJBRqUr6uQ,585
2
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
5
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
+ invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
+ invrs_opt/optimizers/lbfgsb.py,sha256=2NcyFllUM5CrjOZLEBocYAlCrkWt1fiRGJrxg05waog,35149
9
+ invrs_opt/optimizers/wrapped_optax.py,sha256=zPC2j_KkfF2RqeorxB38ovuUsg1SNwdxwhjA7gvOMC4,12387
10
+ invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ invrs_opt/parameterization/base.py,sha256=k2VXZorMnkmaIWWY_G2CvS0NtxYojfvKlPccDGZyGNA,3779
12
+ invrs_opt/parameterization/filter_project.py,sha256=D2w8xrg34V8ysIbbr1RPvegM5WoLdz8QKUCGi80ieOI,3466
13
+ invrs_opt/parameterization/gaussian_levelset.py,sha256=DllvpkBpVxuDFQFu941f8v_Xh2D0m9Eh-JxBH5afNHU,25221
14
+ invrs_opt/parameterization/pixel.py,sha256=yWdGpf0x6Om_-7CYOcHTgNK9UF-XGAPM7GbOBNijnBw,1305
15
+ invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
+ invrs_opt-0.8.0.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.8.0.dist-info/METADATA,sha256=WSzpZwByiBqOZfiwSGvaVaFa5TUAPj4b-mr2qZRGLq0,32633
18
+ invrs_opt-0.8.0.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
19
+ invrs_opt-0.8.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.8.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (72.1.0)
2
+ Generator: setuptools (72.2.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5
 
@@ -1,20 +0,0 @@
1
- invrs_opt/__init__.py,sha256=_hiwGthyPapJUx8-nP7JrJYEFX3mJYTsdUh_X0MyhEQ,585
2
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
5
- invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
- invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
- invrs_opt/optimizers/lbfgsb.py,sha256=d7i02NZZ3yYdJg7wkERDMPpdBD3GaRPhmQexGrZPz_Y,34597
9
- invrs_opt/optimizers/wrapped_optax.py,sha256=L836gwzWbwxPNWh8Y7PSgFHV4PZluSklnbh3BR5djlc,11859
10
- invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- invrs_opt/parameterization/base.py,sha256=QTnhOfMYbDchZOFzk9graryMd6rYHlyd7E2T1TtucB8,3570
12
- invrs_opt/parameterization/filter_project.py,sha256=XCPqQ2ECv7DDTLRtVGJePfnjKYB2XndI6DssSr-4MZw,3239
13
- invrs_opt/parameterization/gaussian_levelset.py,sha256=IGdQl3XMEHYNxNPSoKmUD7ZJu_fImIZTqqqmvilfmag,24998
14
- invrs_opt/parameterization/pixel.py,sha256=4qCYDUCcFPr8W94whX5YFllrXgyZbPBVIJBf8m_Dv4k,1173
15
- invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
- invrs_opt-0.7.2.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.7.2.dist-info/METADATA,sha256=jx_MxaNkrJRUBRubRlDIytDRznP_1P9sWRHrqihaKi8,32633
18
- invrs_opt-0.7.2.dist-info/WHEEL,sha256=R0nc6qTxuoLk7ShA2_Y-UWkN8ZdfDBG2B6Eqpz2WXbs,91
19
- invrs_opt-0.7.2.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.7.2.dist-info/RECORD,,