invrs-opt 0.4.0__py3-none-any.whl → 0.10.4__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,19 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.4.0"
6
+ __version__ = "v0.10.4"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
9
+ from invrs_opt import parameterization as parameterization
10
+
11
+ from invrs_opt.optimizers.lbfgsb import (
12
+ density_lbfgsb as density_lbfgsb,
13
+ lbfgsb as lbfgsb,
14
+ levelset_lbfgsb as levelset_lbfgsb,
15
+ )
16
+
17
+ from invrs_opt.optimizers.wrapped_optax import (
18
+ density_wrapped_optax as density_wrapped_optax,
19
+ levelset_wrapped_optax as levelset_wrapped_optax,
20
+ wrapped_optax as wrapped_optax,
21
+ )
@@ -4,15 +4,14 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import json
7
- import requests
8
7
  import time
9
8
  from typing import Any, Dict, Optional
10
9
 
10
+ import requests
11
11
  from totypes import json_utils
12
12
 
13
- from invrs_opt import base
14
13
  from invrs_opt.experimental import labels
15
-
14
+ from invrs_opt.optimizers import base
16
15
 
17
16
  PyTree = Any
18
17
  StateToken = str
@@ -133,7 +132,11 @@ def optimizer_client(
133
132
  response = json.loads(get_response.text)
134
133
  return json_utils.pytree_from_json(response[labels.PARAMS])
135
134
 
136
- return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
135
+ return base.Optimizer(
136
+ init=init_fn,
137
+ update=update_fn, # type: ignore[arg-type]
138
+ params=params_fn,
139
+ )
137
140
 
138
141
 
139
142
  # -----------------------------------------------------------------------------
@@ -4,8 +4,13 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import dataclasses
7
+ import inspect
7
8
  from typing import Any, Protocol
8
9
 
10
+ import jax.numpy as jnp
11
+ import optax # type: ignore[import-untyped]
12
+ from totypes import json_utils
13
+
9
14
  PyTree = Any
10
15
 
11
16
 
@@ -30,7 +35,7 @@ class UpdateFn(Protocol):
30
35
  self,
31
36
  *,
32
37
  grad: PyTree,
33
- value: float,
38
+ value: jnp.ndarray,
34
39
  params: PyTree,
35
40
  state: PyTree,
36
41
  ) -> PyTree:
@@ -44,3 +49,13 @@ class Optimizer:
44
49
  init: InitFn
45
50
  params: ParamsFn
46
51
  update: UpdateFn
52
+
53
+
54
+ # Register all optax state types for serialization.
55
+ optax_types = {}
56
+ for name, obj in inspect.getmembers(optax):
57
+ if name.endswith("State") and isinstance(obj, type):
58
+ optax_types[obj] = True
59
+
60
+ for obj in optax_types.keys():
61
+ json_utils.register_custom_type(obj)