invrs-opt 0.9.4__py3-none-any.whl → 0.10.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.9.4"
6
+ __version__ = "v0.10.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -132,7 +132,11 @@ def optimizer_client(
132
132
  response = json.loads(get_response.text)
133
133
  return json_utils.pytree_from_json(response[labels.PARAMS])
134
134
 
135
- return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
135
+ return base.Optimizer(
136
+ init=init_fn,
137
+ update=update_fn, # type: ignore[arg-type]
138
+ params=params_fn,
139
+ )
136
140
 
137
141
 
138
142
  # -----------------------------------------------------------------------------
@@ -7,6 +7,7 @@ import dataclasses
7
7
  import inspect
8
8
  from typing import Any, Protocol
9
9
 
10
+ import jax.numpy as jnp
10
11
  import optax # type: ignore[import-untyped]
11
12
  from totypes import json_utils
12
13
 
@@ -34,7 +35,7 @@ class UpdateFn(Protocol):
34
35
  self,
35
36
  *,
36
37
  grad: PyTree,
37
- value: float,
38
+ value: jnp.ndarray,
38
39
  params: PyTree,
39
40
  state: PyTree,
40
41
  ) -> PyTree:
@@ -335,7 +335,7 @@ def parameterized_lbfgsb(
335
335
  def update_fn(
336
336
  *,
337
337
  grad: PyTree,
338
- value: float,
338
+ value: jnp.ndarray,
339
339
  params: PyTree,
340
340
  state: LbfgsbState,
341
341
  ) -> LbfgsbState:
@@ -349,12 +349,14 @@ def parameterized_lbfgsb(
349
349
  ) -> Tuple[PyTree, NumpyLbfgsbDict]:
350
350
  assert onp.size(value) == 1
351
351
  scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
352
+ flat_latent_params = scipy_lbfgsb_state.x.copy()
352
353
  scipy_lbfgsb_state.update(
353
354
  grad=onp.array(flat_latent_grad, dtype=onp.float64),
354
355
  value=onp.array(value, dtype=onp.float64),
355
356
  )
356
- flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
357
- return flat_latent_params, scipy_lbfgsb_state.to_dict()
357
+ updated_flat_latent_params = scipy_lbfgsb_state.x
358
+ flat_latent_updates = updated_flat_latent_params - flat_latent_params
359
+ return flat_latent_updates, scipy_lbfgsb_state.to_dict()
358
360
 
359
361
  step, _, latent_params, jax_lbfgsb_state = state
360
362
  metadata, latents = param_base.partition_density_metadata(latent_params)
@@ -395,16 +397,21 @@ def parameterized_lbfgsb(
395
397
  latents_grad
396
398
  ) # type: ignore[no-untyped-call]
397
399
 
398
- flat_latents, jax_lbfgsb_state = jax.pure_callback(
400
+ flat_latent_updates, jax_lbfgsb_state = jax.pure_callback(
399
401
  _update_pure,
400
402
  (flat_latents_grad, jax_lbfgsb_state),
401
403
  flat_latents_grad,
402
404
  value,
403
405
  jax_lbfgsb_state,
404
406
  )
405
- latents = unflatten_fn(flat_latents)
406
- latent_params = param_base.combine_density_metadata(metadata, latents)
407
- latent_params = _update_parameterized_densities(latent_params, step)
407
+ latent_updates = unflatten_fn(flat_latent_updates)
408
+ latent_params = _apply_updates(
409
+ params=latent_params,
410
+ updates=param_base.combine_density_metadata(metadata, latent_updates),
411
+ value=value,
412
+ step=step,
413
+ )
414
+ latent_params = _clip(latent_params)
408
415
  params = _params_from_latent_params(latent_params)
409
416
  return step + 1, params, latent_params, jax_lbfgsb_state
410
417
 
@@ -433,15 +440,24 @@ def parameterized_lbfgsb(
433
440
  is_leaf=_is_parameterized_density,
434
441
  )
435
442
 
436
- def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
437
- def _update_leaf(leaf: Any) -> Any:
438
- if not _is_parameterized_density(leaf):
439
- return leaf
440
- return density_parameterization.update(leaf, step)
443
+ def _apply_updates(
444
+ params: PyTree,
445
+ updates: PyTree,
446
+ value: jnp.ndarray,
447
+ step: int,
448
+ ) -> PyTree:
449
+ def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
450
+ if _is_parameterized_density(leaf):
451
+ return density_parameterization.update(
452
+ params=leaf, updates=update, value=value, step=step
453
+ )
454
+ else:
455
+ return optax.apply_updates(params=leaf, updates=update)
441
456
 
442
457
  return tree_util.tree_map(
443
- _update_leaf,
444
- latent_params,
458
+ _leaf_apply_updates,
459
+ updates,
460
+ params,
445
461
  is_leaf=_is_parameterized_density,
446
462
  )
447
463
 
@@ -197,12 +197,12 @@ def parameterized_wrapped_optax(
197
197
  def update_fn(
198
198
  *,
199
199
  grad: PyTree,
200
- value: float,
200
+ value: jnp.ndarray,
201
201
  params: PyTree,
202
202
  state: WrappedOptaxState,
203
203
  ) -> WrappedOptaxState:
204
204
  """Updates the state."""
205
- del value, params
205
+ del params
206
206
 
207
207
  step, params, latent_params, opt_state = state
208
208
  metadata, latents = param_base.partition_density_metadata(latent_params)
@@ -233,12 +233,14 @@ def parameterized_wrapped_optax(
233
233
  lambda a, b: a + b, latents_grad, constraint_loss_grad
234
234
  )
235
235
 
236
- updates, opt_state = opt.update(latents_grad, state=opt_state, params=latents)
237
- latents = optax.apply_updates(params=latents, updates=updates)
238
-
239
- latent_params = param_base.combine_density_metadata(metadata, latents)
236
+ latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents)
237
+ latent_params = _apply_updates(
238
+ params=latent_params,
239
+ updates=param_base.combine_density_metadata(metadata, latent_updates),
240
+ value=value,
241
+ step=step,
242
+ )
240
243
  latent_params = _clip(latent_params)
241
- latent_params = _update_parameterized_densities(latent_params, step + 1)
242
244
  params = _params_from_latent_params(latent_params)
243
245
  return (step + 1, params, latent_params, opt_state)
244
246
 
@@ -267,15 +269,24 @@ def parameterized_wrapped_optax(
267
269
  is_leaf=_is_parameterized_density,
268
270
  )
269
271
 
270
- def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
271
- def _update_leaf(leaf: Any) -> Any:
272
- if not _is_parameterized_density(leaf):
273
- return leaf
274
- return density_parameterization.update(leaf, step)
272
+ def _apply_updates(
273
+ params: PyTree,
274
+ updates: PyTree,
275
+ value: jnp.ndarray,
276
+ step: int,
277
+ ) -> PyTree:
278
+ def _leaf_apply_updates(update: Any, leaf: Any) -> Any:
279
+ if _is_parameterized_density(leaf):
280
+ return density_parameterization.update(
281
+ params=leaf, updates=update, value=value, step=step
282
+ )
283
+ else:
284
+ return optax.apply_updates(params=leaf, updates=update)
275
285
 
276
286
  return tree_util.tree_map(
277
- _update_leaf,
278
- latent_params,
287
+ _leaf_apply_updates,
288
+ updates,
289
+ params,
279
290
  is_leaf=_is_parameterized_density,
280
291
  )
281
292
 
@@ -97,7 +97,13 @@ class ConstraintsFn(Protocol):
97
97
  class UpdateFn(Protocol):
98
98
  """Performs the required update of a parameterized density for the given step."""
99
99
 
100
- def __call__(self, params: PyTree, step: int) -> PyTree:
100
+ def __call__(
101
+ self,
102
+ params: PyTree,
103
+ updates: PyTree,
104
+ value: jnp.ndarray,
105
+ step: int,
106
+ ) -> PyTree:
101
107
  ...
102
108
 
103
109
 
@@ -115,10 +115,20 @@ def filter_project(beta: float) -> base.Density2DParameterization:
115
115
  del params
116
116
  return jnp.asarray(0.0)
117
117
 
118
- def update_fn(params: FilterProjectParams, step: int) -> FilterProjectParams:
118
+ def update_fn(
119
+ params: FilterProjectParams,
120
+ updates: FilterProjectParams,
121
+ value: jnp.ndarray,
122
+ step: int,
123
+ ) -> FilterProjectParams:
119
124
  """Perform updates to `params` required for the given `step`."""
120
- del step
121
- return params
125
+ del step, value
126
+ return FilterProjectParams(
127
+ latents=tree_util.tree_map(
128
+ lambda a, b: a + b, params.latents, updates.latents
129
+ ),
130
+ metadata=params.metadata,
131
+ )
122
132
 
123
133
  return base.Density2DParameterization(
124
134
  to_density=to_density_fn,
@@ -229,10 +229,20 @@ def gaussian_levelset(
229
229
  pad_pixels=pad_pixels,
230
230
  )
231
231
 
232
- def update_fn(params: GaussianLevelsetParams, step: int) -> GaussianLevelsetParams:
232
+ def update_fn(
233
+ params: GaussianLevelsetParams,
234
+ updates: GaussianLevelsetParams,
235
+ value: jnp.ndarray,
236
+ step: int,
237
+ ) -> GaussianLevelsetParams:
233
238
  """Perform updates to `params` required for the given `step`."""
234
- del step
235
- return params
239
+ del step, value
240
+ return GaussianLevelsetParams(
241
+ latents=tree_util.tree_map(
242
+ lambda a, b: a + b, params.latents, updates.latents
243
+ ),
244
+ metadata=params.metadata,
245
+ )
236
246
 
237
247
  return base.Density2DParameterization(
238
248
  to_density=to_density_fn,
@@ -509,11 +519,13 @@ def _levelset_constraints(
509
519
  )
510
520
 
511
521
  d = minimum_length_scale * length_scale_constraint_factor
512
- length_scale_constraint = (
513
- jnp.abs(phi_vv) / (jnp.pi / d * jnp.abs(phi) + beta * phi_v) - jnp.pi / d
514
- )
522
+ denom = jnp.pi / d * jnp.abs(phi) + beta * phi_v
523
+ denom_safe = jnp.where(jnp.isclose(phi_vv, 0.0), 1.0, denom)
524
+ length_scale_constraint = jnp.abs(phi_vv) / denom_safe - jnp.pi / d
525
+
526
+ curvature_denom_safe = jnp.where(jnp.isclose(phi_v, 0.0), 1.0, phi)
515
527
  curvature_constraint = (
516
- jnp.abs(inverse_radius * jnp.arctan(phi_v / phi)) - jnp.pi / d
528
+ jnp.abs(inverse_radius * jnp.arctan(phi_v / curvature_denom_safe)) - jnp.pi / d
517
529
  )
518
530
 
519
531
  # Downsample so that constraints shape matches the density shape.
@@ -52,9 +52,20 @@ def pixel() -> base.Density2DParameterization:
52
52
  del params
53
53
  return jnp.asarray(0.0)
54
54
 
55
- def update_fn(params: PixelParams, step: int) -> PixelParams:
56
- del step
57
- return params
55
+ def update_fn(
56
+ params: PixelParams,
57
+ updates: PixelParams,
58
+ value: jnp.ndarray,
59
+ step: int,
60
+ ) -> PixelParams:
61
+ """Perform updates to `params` required for the given `step`."""
62
+ del step, value
63
+ return PixelParams(
64
+ latents=tree_util.tree_map(
65
+ lambda a, b: a + b, params.latents, updates.latents
66
+ ),
67
+ metadata=params.metadata,
68
+ )
58
69
 
59
70
  return base.Density2DParameterization(
60
71
  from_density=from_density_fn,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.9.4
3
+ Version: 0.10.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
533
533
  Requires-Dist: pytest-subtests; extra == "tests"
534
534
 
535
535
  # invrs-opt - Optimization algorithms for inverse design
536
- `v0.9.4`
536
+ `v0.10.0`
537
537
 
538
538
  ## Overview
539
539
 
@@ -0,0 +1,20 @@
1
+ invrs_opt/__init__.py,sha256=8rVngIALR7klCVPflhoqcaf745snBRiTOkwQuq6xYvE,586
2
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/experimental/client.py,sha256=tbtH13FrA65XmTZfTO71CxJ78jeAEj3Zf85R-MTwbiU,4909
5
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
+ invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ invrs_opt/optimizers/base.py,sha256=uFfkN2LwWzAtwh6ktwWNy2iHNOY-sW3JzI46iSFkgok,1306
8
+ invrs_opt/optimizers/lbfgsb.py,sha256=C85Ejvq9kJO6BorrLwxeaHLzMCQwFfJ7227JDHACCd0,36840
9
+ invrs_opt/optimizers/wrapped_optax.py,sha256=781-8v_TlHsGaQF9Se9_iOEvtOLOr-BesTLudYarSlg,13685
10
+ invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ invrs_opt/parameterization/base.py,sha256=jSwrEO86lGkYQG5gWsHvcIMWpZnnbdiKpn--2qaU02g,5362
12
+ invrs_opt/parameterization/filter_project.py,sha256=XL3HTEBLrF-q_75TjhOWLNdfUOSEEjKcoM7Qj844QpQ,4590
13
+ invrs_opt/parameterization/gaussian_levelset.py,sha256=-6foekLTFoZDtMKuoMEvdxMJt0_zTxrKNJo0Vn-Rv80,26073
14
+ invrs_opt/parameterization/pixel.py,sha256=YWkyBhfYtzI8cQ-M90PAZqRAbabwVaUh0UiYIGegQHI,1955
15
+ invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
+ invrs_opt-0.10.0.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.10.0.dist-info/METADATA,sha256=LN2Csi1bX4q2iagAKnwNl-jXZxKQmkxtX-fYq9QX7hs,32643
18
+ invrs_opt-0.10.0.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
19
+ invrs_opt-0.10.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.10.0.dist-info/RECORD,,
@@ -1,20 +0,0 @@
1
- invrs_opt/__init__.py,sha256=bVIs0NxPxNuRUOypBIE68qx-SA1lFCmqmo5cqvRpxmU,585
2
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
5
- invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
- invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
- invrs_opt/optimizers/lbfgsb.py,sha256=9DrmyCj4Ny04NFVRUTz3pJypbw_j5Gw4wpKfe0WKEv4,36336
9
- invrs_opt/optimizers/wrapped_optax.py,sha256=VXdCteT2kumqhP81l3p6QiEqwBffoUuJ3UjrAyX5ToA,13468
10
- invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- invrs_opt/parameterization/base.py,sha256=BObzbz6efT2nBjib0_5BSdkCmFi2f0mcZ9VJYpDzO6Q,5278
12
- invrs_opt/parameterization/filter_project.py,sha256=7Jb8JVENmBTdx3-XmI-VRm4aMjxg_Wtin8tMKKKxWvQ,4309
13
- invrs_opt/parameterization/gaussian_levelset.py,sha256=Ka4hW_OLxUaIPHQsyIOlryG7i1mC-LTIgiQQhCPwHPk,25626
14
- invrs_opt/parameterization/pixel.py,sha256=AwC4GBNNOysdICvYHv_D2tZdqJmYiRzOUZNq_-R9Z70,1617
15
- invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
- invrs_opt-0.9.4.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.9.4.dist-info/METADATA,sha256=SOt6aECrbTQNI0QuznoPrWKzUdMeahaST0G4Q7sccxE,32641
18
- invrs_opt-0.9.4.dist-info/WHEEL,sha256=OVMc5UfuAQiSplgO0_WdW7vXVGAt9Hdd6qtN4HotdyA,91
19
- invrs_opt-0.9.4.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.9.4.dist-info/RECORD,,