invrs-opt 0.10.7__py3-none-any.whl → 0.11.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.10.7"
6
+ __version__ = "v0.11.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -16,7 +16,7 @@ from jax import flatten_util, tree_util
16
16
  from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
17
17
  _lbfgsb as scipy_lbfgsb,
18
18
  )
19
- from totypes import types
19
+ from totypes import json_utils, types
20
20
 
21
21
  from invrs_opt.optimizers import base
22
22
  from invrs_opt.parameterization import (
@@ -31,7 +31,26 @@ PyTree = Any
31
31
  ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
32
32
  NumpyLbfgsbDict = Dict[str, NDArray]
33
33
  JaxLbfgsbDict = Dict[str, jnp.ndarray]
34
- LbfgsbState = Tuple[int, PyTree, PyTree, JaxLbfgsbDict]
34
+
35
+
36
+ @dataclasses.dataclass
37
+ class LbfgsbState:
38
+ """Stores the state of the L-BFGS-B optimizer."""
39
+
40
+ step: int
41
+ params: PyTree
42
+ latent_params: PyTree
43
+ opt_state: JaxLbfgsbDict
44
+
45
+
46
+ tree_util.register_dataclass(
47
+ LbfgsbState,
48
+ data_fields=["step", "params", "latent_params", "opt_state"],
49
+ meta_fields=[],
50
+ )
51
+
52
+
53
+ json_utils.register_custom_type(LbfgsbState)
35
54
 
36
55
 
37
56
  # Task message prefixes for the underlying L-BFGS-B implementation.
@@ -327,17 +346,16 @@ def parameterized_lbfgsb(
327
346
  latents,
328
347
  )
329
348
  latent_params = param_base.combine_density_metadata(metadata, latents)
330
- return (
331
- 0, # step
332
- _params_from_latent_params(latent_params), # params
333
- latent_params, # latent params
334
- jax_lbfgsb_state, # opt state
349
+ return LbfgsbState(
350
+ step=0,
351
+ params=_params_from_latent_params(latent_params),
352
+ latent_params=latent_params,
353
+ opt_state=jax_lbfgsb_state,
335
354
  )
336
355
 
337
356
  def params_fn(state: LbfgsbState) -> PyTree:
338
357
  """Returns the parameters for the given `state`."""
339
- _, params, _, _ = state
340
- return params
358
+ return state.params
341
359
 
342
360
  def update_fn(
343
361
  *,
@@ -366,8 +384,7 @@ def parameterized_lbfgsb(
366
384
  flat_latent_updates = updated_flat_latent_params - flat_latent_params
367
385
  return flat_latent_updates, scipy_lbfgsb_state.to_dict()
368
386
 
369
- step, _, latent_params, jax_lbfgsb_state = state
370
- metadata, latents = param_base.partition_density_metadata(latent_params)
387
+ metadata, latents = param_base.partition_density_metadata(state.latent_params)
371
388
 
372
389
  def _params_from_latents(latents: PyTree) -> PyTree:
373
390
  latent_params = param_base.combine_density_metadata(metadata, latents)
@@ -404,23 +421,28 @@ def parameterized_lbfgsb(
404
421
  latents_grad
405
422
  ) # type: ignore[no-untyped-call]
406
423
 
407
- flat_latent_updates, jax_lbfgsb_state = callback_sequential(
424
+ flat_latent_updates, opt_state = callback_sequential(
408
425
  _update_pure,
409
- (flat_latents_grad, jax_lbfgsb_state),
426
+ (flat_latents_grad, state.opt_state),
410
427
  flat_latents_grad,
411
428
  value,
412
- jax_lbfgsb_state,
429
+ state.opt_state,
413
430
  )
414
431
  latent_updates = unflatten_fn(flat_latent_updates)
415
432
  latent_params = _apply_updates(
416
- params=latent_params,
433
+ params=state.latent_params,
417
434
  updates=param_base.combine_density_metadata(metadata, latent_updates),
418
435
  value=value,
419
- step=step,
436
+ step=state.step,
420
437
  )
421
438
  latent_params = _clip(latent_params)
422
439
  params = _params_from_latent_params(latent_params)
423
- return step + 1, params, latent_params, jax_lbfgsb_state
440
+ return LbfgsbState(
441
+ step=state.step + 1,
442
+ params=params,
443
+ latent_params=latent_params,
444
+ opt_state=opt_state,
445
+ )
424
446
 
425
447
  # -------------------------------------------------------------------------
426
448
  # Functions related to the density parameterization.
@@ -501,7 +523,7 @@ def parameterized_lbfgsb(
501
523
 
502
524
  def is_converged(state: LbfgsbState) -> jnp.ndarray:
503
525
  """Returns `True` if the optimization has converged."""
504
- return state[3]["converged"]
526
+ return state.opt_state["converged"]
505
527
 
506
528
 
507
529
  # ------------------------------------------------------------------------------
@@ -3,13 +3,14 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- from typing import Any, Optional, Tuple
6
+ import dataclasses
7
+ from typing import Any, Optional
7
8
 
8
9
  import jax
9
10
  import jax.numpy as jnp
10
11
  import optax # type: ignore[import-untyped]
11
12
  from jax import tree_util
12
- from totypes import types
13
+ from totypes import json_utils, types
13
14
 
14
15
  from invrs_opt.optimizers import base
15
16
  from invrs_opt.parameterization import (
@@ -20,7 +21,26 @@ from invrs_opt.parameterization import (
20
21
  )
21
22
 
22
23
  PyTree = Any
23
- WrappedOptaxState = Tuple[int, PyTree, PyTree, PyTree]
24
+
25
+
26
+ @dataclasses.dataclass
27
+ class WrappedOptaxState:
28
+ """Stores the state of a wrapped optax optimizer."""
29
+
30
+ step: int
31
+ params: PyTree
32
+ latent_params: PyTree
33
+ opt_state: PyTree
34
+
35
+
36
+ tree_util.register_dataclass(
37
+ WrappedOptaxState,
38
+ data_fields=["step", "params", "latent_params", "opt_state"],
39
+ meta_fields=[],
40
+ )
41
+
42
+
43
+ json_utils.register_custom_type(WrappedOptaxState)
24
44
 
25
45
 
26
46
  def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
@@ -182,17 +202,16 @@ def parameterized_wrapped_optax(
182
202
  """Initializes the optimization state."""
183
203
  latent_params = _init_latents(params)
184
204
  _, latents = param_base.partition_density_metadata(latent_params)
185
- return (
186
- 0, # step
187
- _params_from_latent_params(latent_params), # params
188
- latent_params, # latent params
189
- opt.init(latents), # opt state
205
+ return WrappedOptaxState(
206
+ step=0,
207
+ params=_params_from_latent_params(latent_params),
208
+ latent_params=latent_params,
209
+ opt_state=opt.init(latents),
190
210
  )
191
211
 
192
212
  def params_fn(state: WrappedOptaxState) -> PyTree:
193
213
  """Returns the parameters for the given `state`."""
194
- _, params, _, _ = state
195
- return params
214
+ return state.params
196
215
 
197
216
  def update_fn(
198
217
  *,
@@ -204,8 +223,7 @@ def parameterized_wrapped_optax(
204
223
  """Updates the state."""
205
224
  del params
206
225
 
207
- step, params, latent_params, opt_state = state
208
- metadata, latents = param_base.partition_density_metadata(latent_params)
226
+ metadata, latents = param_base.partition_density_metadata(state.latent_params)
209
227
 
210
228
  def _params_from_latents(latents: PyTree) -> PyTree:
211
229
  latent_params = param_base.combine_density_metadata(metadata, latents)
@@ -232,16 +250,23 @@ def parameterized_wrapped_optax(
232
250
  lambda a, b: a + b, latents_grad, constraint_loss_grad
233
251
  )
234
252
 
235
- latent_updates, opt_state = opt.update(latents_grad, opt_state, params=latents)
253
+ latent_updates, opt_state = opt.update(
254
+ latents_grad, state.opt_state, params=latents
255
+ )
236
256
  latent_params = _apply_updates(
237
- params=latent_params,
257
+ params=state.latent_params,
238
258
  updates=param_base.combine_density_metadata(metadata, latent_updates),
239
259
  value=value,
240
- step=step,
260
+ step=state.step,
241
261
  )
242
262
  latent_params = _clip(latent_params)
243
263
  params = _params_from_latent_params(latent_params)
244
- return (step + 1, params, latent_params, opt_state)
264
+ return WrappedOptaxState(
265
+ step=state.step + 1,
266
+ params=params,
267
+ latent_params=latent_params,
268
+ opt_state=opt_state,
269
+ )
245
270
 
246
271
  # -------------------------------------------------------------------------
247
272
  # Functions related to the density parameterization.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: invrs_opt
3
- Version: 0.10.7
3
+ Version: 0.11.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -1,20 +1,20 @@
1
- invrs_opt/__init__.py,sha256=4ZhV7JC3OdPDzBPTgWsWk8NNPrlg_Hthf7M4BTmUQ4g,586
1
+ invrs_opt/__init__.py,sha256=CQM3bUeeV9fgJgAee7NFy1B2HJyw2uhNSzoKynmVyOQ,586
2
2
  invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
4
  invrs_opt/experimental/client.py,sha256=tbtH13FrA65XmTZfTO71CxJ78jeAEj3Zf85R-MTwbiU,4909
5
5
  invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
6
  invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  invrs_opt/optimizers/base.py,sha256=uFfkN2LwWzAtwh6ktwWNy2iHNOY-sW3JzI46iSFkgok,1306
8
- invrs_opt/optimizers/lbfgsb.py,sha256=uyeVoTq1ZtCLILw5Rv0ALNRiEXMxhEqdVElukmeuoZ4,36693
9
- invrs_opt/optimizers/wrapped_optax.py,sha256=v-pezJDDdtUckh-jL33Wj-H6a_a3xqrxkE9tOBWcCNI,13719
8
+ invrs_opt/optimizers/lbfgsb.py,sha256=QB8lD02sMr-2V0d_k4UB8Y7SOlNektp0vzwrJId0u44,37059
9
+ invrs_opt/optimizers/wrapped_optax.py,sha256=hftrGCpg4kVv2NarZAVHLg6Gkk87zcc4_yyfvSvQHTo,14161
10
10
  invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
11
  invrs_opt/parameterization/base.py,sha256=GFubMydcww6MXlGNkxHGMH_sOCyj9M5R-iwZYTqCo2I,6539
12
12
  invrs_opt/parameterization/filter_project.py,sha256=XL3HTEBLrF-q_75TjhOWLNdfUOSEEjKcoM7Qj844QpQ,4590
13
13
  invrs_opt/parameterization/gaussian_levelset.py,sha256=bmVU1We92zPpNIJI8sCq3OCeHZw6emZl86unJYwnWbc,24826
14
14
  invrs_opt/parameterization/pixel.py,sha256=YWkyBhfYtzI8cQ-M90PAZqRAbabwVaUh0UiYIGegQHI,1955
15
15
  invrs_opt/parameterization/transforms.py,sha256=mqDKuAg4wpSL9kh0oYKxtSoH0mHOQeKG1RND2fJSYaU,9441
16
- invrs_opt-0.10.7.dist-info/licenses/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.10.7.dist-info/METADATA,sha256=2uHpZfOyDdaMlLtnUoYuj6rGcG-T3KYchXU-oORyRQY,32838
18
- invrs_opt-0.10.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
- invrs_opt-0.10.7.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.10.7.dist-info/RECORD,,
16
+ invrs_opt-0.11.0.dist-info/licenses/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.11.0.dist-info/METADATA,sha256=M6IzHHGIctk49mBxnswyKcvE3BikRoqZz_19UDOqB9c,32838
18
+ invrs_opt-0.11.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
19
+ invrs_opt-0.11.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.11.0.dist-info/RECORD,,