invrs-opt 0.4.0__py3-none-any.whl → 0.10.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +7 -4
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -672
- invrs_opt-0.4.0.dist-info/LICENSE +0 -21
- invrs_opt-0.4.0.dist-info/METADATA +0 -75
- invrs_opt-0.4.0.dist-info/RECORD +0 -14
- /invrs_opt/{lbfgsb → optimizers}/__init__.py +0 -0
- {invrs_opt-0.4.0.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -3,8 +3,19 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.
|
6
|
+
__version__ = "v0.10.3"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt
|
10
|
-
|
9
|
+
from invrs_opt import parameterization as parameterization
|
10
|
+
|
11
|
+
from invrs_opt.optimizers.lbfgsb import (
|
12
|
+
density_lbfgsb as density_lbfgsb,
|
13
|
+
lbfgsb as lbfgsb,
|
14
|
+
levelset_lbfgsb as levelset_lbfgsb,
|
15
|
+
)
|
16
|
+
|
17
|
+
from invrs_opt.optimizers.wrapped_optax import (
|
18
|
+
density_wrapped_optax as density_wrapped_optax,
|
19
|
+
levelset_wrapped_optax as levelset_wrapped_optax,
|
20
|
+
wrapped_optax as wrapped_optax,
|
21
|
+
)
|
invrs_opt/experimental/client.py
CHANGED
@@ -4,15 +4,14 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import json
|
7
|
-
import requests
|
8
7
|
import time
|
9
8
|
from typing import Any, Dict, Optional
|
10
9
|
|
10
|
+
import requests
|
11
11
|
from totypes import json_utils
|
12
12
|
|
13
|
-
from invrs_opt import base
|
14
13
|
from invrs_opt.experimental import labels
|
15
|
-
|
14
|
+
from invrs_opt.optimizers import base
|
16
15
|
|
17
16
|
PyTree = Any
|
18
17
|
StateToken = str
|
@@ -133,7 +132,11 @@ def optimizer_client(
|
|
133
132
|
response = json.loads(get_response.text)
|
134
133
|
return json_utils.pytree_from_json(response[labels.PARAMS])
|
135
134
|
|
136
|
-
return base.Optimizer(
|
135
|
+
return base.Optimizer(
|
136
|
+
init=init_fn,
|
137
|
+
update=update_fn, # type: ignore[arg-type]
|
138
|
+
params=params_fn,
|
139
|
+
)
|
137
140
|
|
138
141
|
|
139
142
|
# -----------------------------------------------------------------------------
|
@@ -4,8 +4,13 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import dataclasses
|
7
|
+
import inspect
|
7
8
|
from typing import Any, Protocol
|
8
9
|
|
10
|
+
import jax.numpy as jnp
|
11
|
+
import optax # type: ignore[import-untyped]
|
12
|
+
from totypes import json_utils
|
13
|
+
|
9
14
|
PyTree = Any
|
10
15
|
|
11
16
|
|
@@ -30,7 +35,7 @@ class UpdateFn(Protocol):
|
|
30
35
|
self,
|
31
36
|
*,
|
32
37
|
grad: PyTree,
|
33
|
-
value:
|
38
|
+
value: jnp.ndarray,
|
34
39
|
params: PyTree,
|
35
40
|
state: PyTree,
|
36
41
|
) -> PyTree:
|
@@ -44,3 +49,13 @@ class Optimizer:
|
|
44
49
|
init: InitFn
|
45
50
|
params: ParamsFn
|
46
51
|
update: UpdateFn
|
52
|
+
|
53
|
+
|
54
|
+
# Register all optax state types for serialization.
|
55
|
+
optax_types = {}
|
56
|
+
for name, obj in inspect.getmembers(optax):
|
57
|
+
if name.endswith("State") and isinstance(obj, type):
|
58
|
+
optax_types[obj] = True
|
59
|
+
|
60
|
+
for obj in optax_types.keys():
|
61
|
+
json_utils.register_custom_type(obj)
|