invrs-opt 0.10.6__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.10.6"
6
+ __version__ = "v0.11.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -16,7 +16,7 @@ from jax import flatten_util, tree_util
16
16
  from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
17
17
  _lbfgsb as scipy_lbfgsb,
18
18
  )
19
- from totypes import types
19
+ from totypes import json_utils, types
20
20
 
21
21
  from invrs_opt.optimizers import base
22
22
  from invrs_opt.parameterization import (
@@ -31,7 +31,26 @@ PyTree = Any
31
31
  ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
32
32
  NumpyLbfgsbDict = Dict[str, NDArray]
33
33
  JaxLbfgsbDict = Dict[str, jnp.ndarray]
34
- LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
34
+
35
+
36
+ @dataclasses.dataclass
37
+ class LbfgsbState:
38
+ """Stores the state of the L-BFGS-B optimizer."""
39
+
40
+ step: int
41
+ params: PyTree
42
+ latent_params: PyTree
43
+ opt_state: JaxLbfgsbDict
44
+
45
+
46
+ tree_util.register_dataclass(
47
+ LbfgsbState,
48
+ data_fields=["step", "params", "latent_params", "opt_state"],
49
+ meta_fields=[],
50
+ )
51
+
52
+
53
+ json_utils.register_custom_type(LbfgsbState)
35
54
 
36
55
 
37
56
  # Task message prefixes for the underlying L-BFGS-B implementation.
@@ -327,17 +346,16 @@ def parameterized_lbfgsb(
327
346
  latents,
328
347
  )
329
348
  latent_params = param_base.combine_density_metadata(metadata, latents)
330
- return (
331
- 0, # step
332
- _params_from_latent_params(latent_params), # params
333
- latent_params, # latent params
334
- jax_lbfgsb_state, # opt state
349
+ return LbfgsbState(
350
+ step=0,
351
+ params=_params_from_latent_params(latent_params),
352
+ latent_params=latent_params,
353
+ opt_state=jax_lbfgsb_state,
335
354
  )
336
355
 
337
356
  def params_fn(state: LbfgsbState) -> PyTree:
338
357
  """Returns the parameters for the given `state`."""
339
- _, params, _, _ = state
340
- return params
358
+ return state.params
341
359
 
342
360
  def update_fn(
343
361
  *,
@@ -366,8 +384,7 @@ def parameterized_lbfgsb(
366
384
  flat_latent_updates = updated_flat_latent_params - flat_latent_params
367
385
  return flat_latent_updates, scipy_lbfgsb_state.to_dict()
368
386
 
369
- step, _, latent_params, jax_lbfgsb_state = state
370
- metadata, latents = param_base.partition_density_metadata(latent_params)
387
+ metadata, latents = param_base.partition_density_metadata(state.latent_params)
371
388
 
372
389
  def _params_from_latents(latents: PyTree) -> PyTree:
373
390
  latent_params = param_base.combine_density_metadata(metadata, latents)
@@ -380,10 +397,9 @@ def parameterized_lbfgsb(
380
397
  _, vjp_fn = jax.vjp(_params_from_latents, latents)
381
398
  (latents_grad,) = vjp_fn(grad)
382
399
 
383
- if not (
384
- tree_util.tree_structure(latents_grad)
385
- == tree_util.tree_structure(latents) # type: ignore[operator]
386
- ):
400
+ treedef = tree_util.tree_structure(latents_grad)
401
+ expected_treedef = tree_util.tree_structure(latents)
402
+ if not treedef == expected_treedef: # type: ignore[operator]
387
403
  raise ValueError(
388
404
  f"Tree structure of `latents_grad` was different than expected, got \n"
389
405
  f"{tree_util.tree_structure(latents_grad)} but expected \n"
@@ -405,23 +421,28 @@ def parameterized_lbfgsb(
405
421
  latents_grad
406
422
  ) # type: ignore[no-untyped-call]
407
423
 
408
- flat_latent_updates, jax_lbfgsb_state = callback_sequential(
424
+ flat_latent_updates, opt_state = callback_sequential(
409
425
  _update_pure,
410
- (flat_latents_grad, jax_lbfgsb_state),
426
+ (flat_latents_grad, state.opt_state),
411
427
  flat_latents_grad,
412
428
  value,
413
- jax_lbfgsb_state,
429
+ state.opt_state,
414
430
  )
415
431
  latent_updates = unflatten_fn(flat_latent_updates)
416
432
  latent_params = _apply_updates(
417
- params=latent_params,
433
+ params=state.latent_params,
418
434
  updates=param_base.combine_density_metadata(metadata, latent_updates),
419
435
  value=value,
420
- step=step,
436
+ step=state.step,
421
437
  )
422
438
  latent_params = _clip(latent_params)
423
439
  params = _params_from_latent_params(latent_params)
424
- return step + 1, params, latent_params, jax_lbfgsb_state
440
+ return LbfgsbState(
441
+ step=state.step + 1,
442
+ params=params,
443
+ latent_params=latent_params,
444
+ opt_state=opt_state,
445
+ )
425
446
 
426
447
  # -------------------------------------------------------------------------
427
448
  # Functions related to the density parameterization.
@@ -502,7 +523,7 @@ def parameterized_lbfgsb(
502
523
 
503
524
  def is_converged(state: LbfgsbState) -> jnp.ndarray:
504
525
  """Returns `True` if the optimization has converged."""
505
- return state[3]["converged"]
526
+ return state.opt_state["converged"]
506
527
 
507
528
 
508
529
  # ------------------------------------------------------------------------------
@@ -3,13 +3,14 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- from typing import Any, Optional, Tuple
6
+ import dataclasses
7
+ from typing import Any, Optional
7
8
 
8
9
  import jax
9
10
  import jax.numpy as jnp
10
11
  import optax # type: ignore[import-untyped]
11
12
  from jax import tree_util
12
- from totypes import types
13
+ from totypes import json_utils, types
13
14
 
14
15
  from invrs_opt.optimizers import base
15
16
  from invrs_opt.parameterization import (
@@ -20,7 +21,26 @@ from invrs_opt.parameterization import (
20
21
  )
21
22
 
22
23
  PyTree = Any
23
- WrappedOptaxState = Tuple[int, PyTree, PyTree, PyTree]
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class WrappedOptaxState:
28
+ """Stores the state of a wrapped optax optimizer."""
29
+
30
+ step: int
31
+ params: PyTree
32
+ latent_params: PyTree
33
+ opt_state: PyTree
34
+
35
+
36
+ tree_util.register_dataclass(
37
+ WrappedOptaxState,
38
+ data_fields=["step", "params", "latent_params", "opt_state"],
39
+ meta_fields=[],
40
+ )
41
+
42
+
43
+ json_utils.register_custom_type(WrappedOptaxState)
24
44
 
25
45
 
26
46
  def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
@@ -182,17 +202,16 @@ def parameterized_wrapped_optax(
182
202
  """Initializes the optimization state."""
183
203
  latent_params = _init_latents(params)
184
204
  _, latents = param_base.partition_density_metadata(latent_params)
185
- return (
186
- 0, # step
187
- _params_from_latent_params(latent_params), # params
188
- latent_params, # latent params
189
- opt.init(latents), # opt state
205
+ return WrappedOptaxState(
206
+ step=0,
207
+ params=_params_from_latent_params(latent_params),
208
+ latent_params=latent_params,
209
+ opt_state=opt.init(latents),
190
210
  )
191
211
 
192
212
  def params_fn(state: WrappedOptaxState) -> PyTree:
193
213
  """Returns the parameters for the given `state`."""
194
- _, params, _, _ = state
195
- return params
214
+ return state.params
196
215
 
197
216
  def update_fn(
198
217
  *,
@@ -204,8 +223,7 @@ def parameterized_wrapped_optax(
204
223
  """Updates the state."""
205
224
  del params
206
225
 
207
- step, params, latent_params, opt_state = state
208
- metadata, latents = param_base.partition_density_metadata(latent_params)
226
+ metadata, latents = param_base.partition_density_metadata(state.latent_params)
209
227
 
210
228
  def _params_from_latents(latents: PyTree) -> PyTree:
211
229
  latent_params = param_base.combine_density_metadata(metadata, latents)
@@ -218,10 +236,9 @@ def parameterized_wrapped_optax(
218
236
  _, vjp_fn = jax.vjp(_params_from_latents, latents)
219
237
  (latents_grad,) = vjp_fn(grad)
220
238
 
221
- if not (
222
- tree_util.tree_structure(latents_grad)
223
- == tree_util.tree_structure(latents) # type: ignore[operator]
224
- ):
239
+ treedef = tree_util.tree_structure(latents_grad)
240
+ expected_treedef = tree_util.tree_structure(latents)
241
+ if not treedef == expected_treedef: # type: ignore[operator]
225
242
  raise ValueError(
226
243
  f"Tree structure of `latents_grad` was different than expected, got \n"
227
244
  f"{tree_util.tree_structure(latents_grad)} but expected \n"
@@ -233,16 +250,23 @@ def parameterized_wrapped_optax(
233
250
  lambda a, b: a + b, latents_grad, constraint_loss_grad
234
251
  )
235
252
 
236
- latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents)
253
+ latent_updates, opt_state = opt.update(
254
+ latents_grad, state.opt_state, params=latents
255
+ )
237
256
  latent_params = _apply_updates(
238
- params=latent_params,
257
+ params=state.latent_params,
239
258
  updates=param_base.combine_density_metadata(metadata, latent_updates),
240
259
  value=value,
241
- step=step,
260
+ step=state.step,
242
261
  )
243
262
  latent_params = _clip(latent_params)
244
263
  params = _params_from_latent_params(latent_params)
245
- return (step + 1, params, latent_params, opt_state)
264
+ return WrappedOptaxState(
265
+ step=state.step + 1,
266
+ params=params,
267
+ latent_params=latent_params,
268
+ opt_state=opt_state,
269
+ )
246
270
 
247
271
  # -------------------------------------------------------------------------
248
272
  # Functions related to the density parameterization.
@@ -124,6 +124,22 @@ class Density2DMetadata:
124
124
  self.periodic = tuple(self.periodic)
125
125
  self.symmetries = tuple(self.symmetries)
126
126
 
127
+ def __eq__(self, other: Any) -> bool:
128
+ if not isinstance(other, Density2DMetadata):
129
+ return False
130
+ if not (
131
+ self.lower_bound == other.lower_bound
132
+ and self.upper_bound == other.upper_bound
133
+ and _arrays_equal_or_both_none(self.fixed_solid, other.fixed_solid)
134
+ and _arrays_equal_or_both_none(self.fixed_void, other.fixed_void)
135
+ and self.minimum_width == other.minimum_width
136
+ and self.minimum_spacing == other.minimum_spacing
137
+ and self.periodic == other.periodic
138
+ and self.symmetries == other.symmetries
139
+ ):
140
+ return False
141
+ return True
142
+
127
143
  @classmethod
128
144
  def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
129
145
  density_metadata_dict = dataclasses.asdict(density)
@@ -131,6 +147,21 @@ class Density2DMetadata:
131
147
  return Density2DMetadata(**density_metadata_dict)
132
148
 
133
149
 
150
+ def _arrays_equal_or_both_none(a: Optional[Array], b: Optional[Array]) -> bool:
151
+ """Return `True` if `a` and `b` are equal arrays or both `None`."""
152
+ if (a is None, b is None) not in ((True, True), (False, False)):
153
+ return False
154
+ if a is None and b is None:
155
+ return True
156
+ assert isinstance(a, onp.ndarray)
157
+ assert isinstance(b, onp.ndarray)
158
+ if a.dtype != b.dtype:
159
+ return False
160
+ if a.shape != b.shape:
161
+ return False
162
+ return bool(onp.all(a == b))
163
+
164
+
134
165
  def _flatten_density_2d_metadata(
135
166
  metadata: Density2DMetadata,
136
167
  ) -> Tuple[
@@ -388,7 +388,7 @@ def _phi_from_params(
388
388
  assert array.shape[-2] % s_factor == 0
389
389
  assert array.shape[-1] % s_factor == 0
390
390
  array = symmetry.symmetrize(array, tuple(example_density.symmetries))
391
- return array
391
+ return jnp.asarray(array)
392
392
 
393
393
 
394
394
  # -----------------------------------------------------------------------------
@@ -1,6 +1,6 @@
1
- Metadata-Version: 2.2
1
+ Metadata-Version: 2.4
2
2
  Name: invrs_opt
3
- Version: 0.10.6
3
+ Version: 0.11.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -531,6 +531,7 @@ Requires-Dist: bump-my-version; extra == "dev"
531
531
  Requires-Dist: darglint; extra == "dev"
532
532
  Requires-Dist: mypy; extra == "dev"
533
533
  Requires-Dist: pre-commit; extra == "dev"
534
+ Dynamic: license-file
534
535
 
535
536
  # invrs-opt - Optimization algorithms for inverse design
536
537
  ![Continuous integration](https://github.com/invrs-io/opt/actions/workflows/build-ci.yml/badge.svg)
@@ -1,20 +1,20 @@
1
- invrs_opt/__init__.py,sha256=4uq04SwbUcX0hVAhaX0-GJ4R4tRJAAXfIzzfU-VcalA,586
1
+ invrs_opt/__init__.py,sha256=CQM3bUeeV9fgJgAee7NFy1B2HJyw2uhNSzoKynmVyOQ,586
2
2
  invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  invrs_opt/experimental/client.py,sha256=tbtH13FrA65XmTZfTO71CxJ78jeAEj3Zf85R-MTwbiU,4909
5
5
  invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
6
  invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  invrs_opt/optimizers/base.py,sha256=uFfkN2LwWzAtwh6ktwWNy2iHNOY-sW3JzI46iSFkgok,1306
8
- invrs_opt/optimizers/lbfgsb.py,sha256=WP6ouVtLaXSwJBh7CSzWR7rnRdHZuSmOr57TKF4UxMg,36659
9
- invrs_opt/optimizers/wrapped_optax.py,sha256=781-8v_TlHsGaQF9Se9_iOEvtOLOr-BesTLudYarSlg,13685
8
+ invrs_opt/optimizers/lbfgsb.py,sha256=QB8lD02sMr-2V0d_k4UB8Y7SOlNektp0vzwrJId0u44,37059
9
+ invrs_opt/optimizers/wrapped_optax.py,sha256=hftrGCpg4kVv2NarZAVHLg6Gkk87zcc4_yyfvSvQHTo,14161
10
10
  invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- invrs_opt/parameterization/base.py,sha256=jSwrEO86lGkYQG5gWsHvcIMWpZnnbdiKpn--2qaU02g,5362
11
+ invrs_opt/parameterization/base.py,sha256=GFubMydcww6MXlGNkxHGMH_sOCyj9M5R-iwZYTqCo2I,6539
12
12
  invrs_opt/parameterization/filter_project.py,sha256=XL3HTEBLrF-q_75TjhOWLNdfUOSEEjKcoM7Qj844QpQ,4590
13
- invrs_opt/parameterization/gaussian_levelset.py,sha256=PDvjdgBzklRTCUoBpo4ZMcmXeTkn0BpZEzQj7ojtYGE,24813
13
+ invrs_opt/parameterization/gaussian_levelset.py,sha256=bmVU1We92zPpNIJI8sCq3OCeHZw6emZl86unJYwnWbc,24826
14
14
  invrs_opt/parameterization/pixel.py,sha256=YWkyBhfYtzI8cQ-M90PAZqRAbabwVaUh0UiYIGegQHI,1955
15
15
  invrs_opt/parameterization/transforms.py,sha256=mqDKuAg4wpSL9kh0oYKxtSoH0mHOQeKG1RND2fJSYaU,9441
16
- invrs_opt-0.10.6.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.10.6.dist-info/METADATA,sha256=40XUQ7i3S4nYRGUqkEVOwKrJwGtpvdwmpD9ojFpAeAM,32816
18
- invrs_opt-0.10.6.dist-info/WHEEL,sha256=nn6H5-ilmfVryoAQl3ZQ2l8SH5imPWFpm1A5FgEuFV4,91
19
- invrs_opt-0.10.6.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.10.6.dist-info/RECORD,,
16
+ invrs_opt-0.11.0.dist-info/licenses/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.11.0.dist-info/METADATA,sha256=M6IzHHGIctk49mBxnswyKcvE3BikRoqZz_19UDOqB9c,32838
18
+ invrs_opt-0.11.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
+ invrs_opt-0.11.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.11.0.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (75.8.1)
2
+ Generator: setuptools (80.9.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5