invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,19 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.4.0"
6
+ __version__ = "v0.10.3"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
9
+ from invrs_opt import parameterization as parameterization
10
+
11
+ from invrs_opt.optimizers.lbfgsb import (
12
+ density_lbfgsb as density_lbfgsb,
13
+ lbfgsb as lbfgsb,
14
+ levelset_lbfgsb as levelset_lbfgsb,
15
+ )
16
+
17
+ from invrs_opt.optimizers.wrapped_optax import (
18
+ density_wrapped_optax as density_wrapped_optax,
19
+ levelset_wrapped_optax as levelset_wrapped_optax,
20
+ wrapped_optax as wrapped_optax,
21
+ )
@@ -4,15 +4,14 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import json
7
- import requests
8
7
  import time
9
8
  from typing import Any, Dict, Optional
10
9
 
10
+ import requests
11
11
  from totypes import json_utils
12
12
 
13
- from invrs_opt import base
14
13
  from invrs_opt.experimental import labels
15
-
14
+ from invrs_opt.optimizers import base
16
15
 
17
16
  PyTree = Any
18
17
  StateToken = str
@@ -133,7 +132,11 @@ def optimizer_client(
133
132
  response = json.loads(get_response.text)
134
133
  return json_utils.pytree_from_json(response[labels.PARAMS])
135
134
 
136
- return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
135
+ return base.Optimizer(
136
+ init=init_fn,
137
+ update=update_fn, # type: ignore[arg-type]
138
+ params=params_fn,
139
+ )
137
140
 
138
141
 
139
142
  # -----------------------------------------------------------------------------
@@ -4,8 +4,13 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import dataclasses
7
+ import inspect
7
8
  from typing import Any, Protocol
8
9
 
10
+ import jax.numpy as jnp
11
+ import optax # type: ignore[import-untyped]
12
+ from totypes import json_utils
13
+
9
14
  PyTree = Any
10
15
 
11
16
 
@@ -30,7 +35,7 @@ class UpdateFn(Protocol):
30
35
  self,
31
36
  *,
32
37
  grad: PyTree,
33
- value: float,
38
+ value: jnp.ndarray,
34
39
  params: PyTree,
35
40
  state: PyTree,
36
41
  ) -> PyTree:
@@ -44,3 +49,13 @@ class Optimizer:
44
49
  init: InitFn
45
50
  params: ParamsFn
46
51
  update: UpdateFn
52
+
53
+
54
+ # Register all optax state types for serialization.
55
+ optax_types = {}
56
+ for name, obj in inspect.getmembers(optax):
57
+ if name.endswith("State") and isinstance(obj, type):
58
+ optax_types[obj] = True
59
+
60
+ for obj in optax_types.keys():
61
+ json_utils.register_custom_type(obj)