invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +7 -4
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -672
- invrs_opt-0.4.0.dist-info/LICENSE +0 -21
- invrs_opt-0.4.0.dist-info/METADATA +0 -75
- invrs_opt-0.4.0.dist-info/RECORD +0 -14
- /invrs_opt/{lbfgsb → optimizers}/__init__.py +0 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -3,8 +3,19 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.
|
6
|
+
__version__ = "v0.10.3"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt
|
10
|
-
|
9
|
+
from invrs_opt import parameterization as parameterization
|
10
|
+
|
11
|
+
from invrs_opt.optimizers.lbfgsb import (
|
12
|
+
density_lbfgsb as density_lbfgsb,
|
13
|
+
lbfgsb as lbfgsb,
|
14
|
+
levelset_lbfgsb as levelset_lbfgsb,
|
15
|
+
)
|
16
|
+
|
17
|
+
from invrs_opt.optimizers.wrapped_optax import (
|
18
|
+
density_wrapped_optax as density_wrapped_optax,
|
19
|
+
levelset_wrapped_optax as levelset_wrapped_optax,
|
20
|
+
wrapped_optax as wrapped_optax,
|
21
|
+
)
|
invrs_opt/experimental/client.py
CHANGED
@@ -4,15 +4,14 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import json
|
7
|
-
import requests
|
8
7
|
import time
|
9
8
|
from typing import Any, Dict, Optional
|
10
9
|
|
10
|
+
import requests
|
11
11
|
from totypes import json_utils
|
12
12
|
|
13
|
-
from invrs_opt import base
|
14
13
|
from invrs_opt.experimental import labels
|
15
|
-
|
14
|
+
from invrs_opt.optimizers import base
|
16
15
|
|
17
16
|
PyTree = Any
|
18
17
|
StateToken = str
|
@@ -133,7 +132,11 @@ def optimizer_client(
|
|
133
132
|
response = json.loads(get_response.text)
|
134
133
|
return json_utils.pytree_from_json(response[labels.PARAMS])
|
135
134
|
|
136
|
-
return base.Optimizer(
|
135
|
+
return base.Optimizer(
|
136
|
+
init=init_fn,
|
137
|
+
update=update_fn, # type: ignore[arg-type]
|
138
|
+
params=params_fn,
|
139
|
+
)
|
137
140
|
|
138
141
|
|
139
142
|
# -----------------------------------------------------------------------------
|
@@ -4,8 +4,13 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import dataclasses
|
7
|
+
import inspect
|
7
8
|
from typing import Any, Protocol
|
8
9
|
|
10
|
+
import jax.numpy as jnp
|
11
|
+
import optax # type: ignore[import-untyped]
|
12
|
+
from totypes import json_utils
|
13
|
+
|
9
14
|
PyTree = Any
|
10
15
|
|
11
16
|
|
@@ -30,7 +35,7 @@ class UpdateFn(Protocol):
|
|
30
35
|
self,
|
31
36
|
*,
|
32
37
|
grad: PyTree,
|
33
|
-
value:
|
38
|
+
value: jnp.ndarray,
|
34
39
|
params: PyTree,
|
35
40
|
state: PyTree,
|
36
41
|
) -> PyTree:
|
@@ -44,3 +49,13 @@ class Optimizer:
|
|
44
49
|
init: InitFn
|
45
50
|
params: ParamsFn
|
46
51
|
update: UpdateFn
|
52
|
+
|
53
|
+
|
54
|
+
# Register all optax state types for serialization.
|
55
|
+
optax_types = {}
|
56
|
+
for name, obj in inspect.getmembers(optax):
|
57
|
+
if name.endswith("State") and isinstance(obj, type):
|
58
|
+
optax_types[obj] = True
|
59
|
+
|
60
|
+
for obj in optax_types.keys():
|
61
|
+
json_utils.register_custom_type(obj)
|