invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +155 -0
- invrs_opt/experimental/labels.py +23 -0
- invrs_opt/optimizers/__init__.py +0 -0
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -670
- invrs_opt-0.3.2.dist-info/LICENSE +0 -21
- invrs_opt-0.3.2.dist-info/METADATA +0 -73
- invrs_opt-0.3.2.dist-info/RECORD +0 -11
- /invrs_opt/{lbfgsb → experimental}/__init__.py +0 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -3,8 +3,19 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.3
|
6
|
+
__version__ = "v0.10.3"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt
|
10
|
-
|
9
|
+
from invrs_opt import parameterization as parameterization
|
10
|
+
|
11
|
+
from invrs_opt.optimizers.lbfgsb import (
|
12
|
+
density_lbfgsb as density_lbfgsb,
|
13
|
+
lbfgsb as lbfgsb,
|
14
|
+
levelset_lbfgsb as levelset_lbfgsb,
|
15
|
+
)
|
16
|
+
|
17
|
+
from invrs_opt.optimizers.wrapped_optax import (
|
18
|
+
density_wrapped_optax as density_wrapped_optax,
|
19
|
+
levelset_wrapped_optax as levelset_wrapped_optax,
|
20
|
+
wrapped_optax as wrapped_optax,
|
21
|
+
)
|
@@ -0,0 +1,155 @@
|
|
1
|
+
"""Defines basic client optimizers for use with an optimization service.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import json
|
7
|
+
import time
|
8
|
+
from typing import Any, Dict, Optional
|
9
|
+
|
10
|
+
import requests
|
11
|
+
from totypes import json_utils
|
12
|
+
|
13
|
+
from invrs_opt.experimental import labels
|
14
|
+
from invrs_opt.optimizers import base
|
15
|
+
|
16
|
+
PyTree = Any
|
17
|
+
StateToken = str
|
18
|
+
|
19
|
+
SESSION = None
|
20
|
+
SERVER_ADDRESS = None
|
21
|
+
|
22
|
+
|
23
|
+
def login(server_address: str) -> None:
|
24
|
+
"""Set the global server address and create a requests session."""
|
25
|
+
global SESSION
|
26
|
+
global SERVER_ADDRESS
|
27
|
+
SESSION = requests.Session()
|
28
|
+
SERVER_ADDRESS = server_address
|
29
|
+
|
30
|
+
|
31
|
+
def optimizer_client(
|
32
|
+
algorithm: str,
|
33
|
+
hparams: Dict[str, Any],
|
34
|
+
server_address: Optional[str],
|
35
|
+
session: Optional[requests.Session],
|
36
|
+
) -> base.Optimizer:
|
37
|
+
"""Generic optimizer class."""
|
38
|
+
|
39
|
+
if server_address is None:
|
40
|
+
if SERVER_ADDRESS is None:
|
41
|
+
raise ValueError(
|
42
|
+
"Argument `server_address` and the global `SERVER_ADDRESS` cannot "
|
43
|
+
"both be `None`. Use the `login` method to set the global, or "
|
44
|
+
"explicitly provide a value."
|
45
|
+
)
|
46
|
+
if session is None:
|
47
|
+
if SESSION is None:
|
48
|
+
raise ValueError(
|
49
|
+
"Argument `session` and the global `SESSION` cannot "
|
50
|
+
"both be `None`. Use the `login` method to set the global, or "
|
51
|
+
"explicitly provide a value."
|
52
|
+
)
|
53
|
+
session = SESSION
|
54
|
+
|
55
|
+
opt_config = {
|
56
|
+
labels.ALGORITHM: algorithm,
|
57
|
+
labels.HPARAMS: hparams,
|
58
|
+
}
|
59
|
+
|
60
|
+
def init_fn(params: PyTree) -> StateToken:
|
61
|
+
"""Handles 'init' requests."""
|
62
|
+
serialized_data = json_utils.json_from_pytree(
|
63
|
+
dict(opt_config=opt_config, data={"params": params})
|
64
|
+
)
|
65
|
+
post_response = session.post(
|
66
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
|
67
|
+
)
|
68
|
+
|
69
|
+
if not post_response.status_code == 200:
|
70
|
+
raise requests.RequestException(post_response.text)
|
71
|
+
response = json.loads(post_response.text)
|
72
|
+
new_state_token: str = response[labels.STATE_TOKEN]
|
73
|
+
return new_state_token
|
74
|
+
|
75
|
+
def update_fn(
|
76
|
+
*,
|
77
|
+
grad: PyTree,
|
78
|
+
value: float,
|
79
|
+
params: PyTree,
|
80
|
+
state: StateToken,
|
81
|
+
) -> StateToken:
|
82
|
+
"""Handles 'update' requests."""
|
83
|
+
state_token = state
|
84
|
+
del state
|
85
|
+
serialized_data = json_utils.json_from_pytree(
|
86
|
+
{
|
87
|
+
labels.OPT_CONFIG: opt_config,
|
88
|
+
labels.DATA: {
|
89
|
+
labels.PARAMS: params,
|
90
|
+
labels.VALUE: value,
|
91
|
+
labels.GRAD: grad,
|
92
|
+
labels.STATE_TOKEN: state_token,
|
93
|
+
},
|
94
|
+
}
|
95
|
+
)
|
96
|
+
post_response = session.post(
|
97
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
|
98
|
+
data=serialized_data,
|
99
|
+
)
|
100
|
+
|
101
|
+
if not post_response.status_code == 200:
|
102
|
+
raise requests.RequestException(post_response.text)
|
103
|
+
response = json.loads(post_response.text)
|
104
|
+
new_state_token: str = response[labels.STATE_TOKEN]
|
105
|
+
return new_state_token
|
106
|
+
|
107
|
+
def params_fn(
|
108
|
+
state: StateToken,
|
109
|
+
timeout: float = 60.0,
|
110
|
+
poll_interval: float = 0.1,
|
111
|
+
) -> PyTree:
|
112
|
+
"""Handles 'params' requests."""
|
113
|
+
state_token = state
|
114
|
+
del state
|
115
|
+
assert timeout >= poll_interval
|
116
|
+
start_time = time.time()
|
117
|
+
while time.time() < start_time + timeout:
|
118
|
+
get_response = session.get(
|
119
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
|
120
|
+
)
|
121
|
+
if get_response.status_code == 200:
|
122
|
+
break
|
123
|
+
elif get_response.status_code == 404 and get_response.text.endswith(
|
124
|
+
labels.MESSAGE_STATE_NOT_READY.format(state_token)
|
125
|
+
):
|
126
|
+
time.sleep(poll_interval)
|
127
|
+
else:
|
128
|
+
raise requests.RequestException(get_response.text)
|
129
|
+
|
130
|
+
if not get_response.status_code == 200:
|
131
|
+
raise requests.Timeout("Timed out while waiting for params.")
|
132
|
+
response = json.loads(get_response.text)
|
133
|
+
return json_utils.pytree_from_json(response[labels.PARAMS])
|
134
|
+
|
135
|
+
return base.Optimizer(
|
136
|
+
init=init_fn,
|
137
|
+
update=update_fn, # type: ignore[arg-type]
|
138
|
+
params=params_fn,
|
139
|
+
)
|
140
|
+
|
141
|
+
|
142
|
+
# -----------------------------------------------------------------------------
|
143
|
+
# Specific optimizers implemented here.
|
144
|
+
# -----------------------------------------------------------------------------
|
145
|
+
|
146
|
+
|
147
|
+
def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
|
148
|
+
"""Optimizer implementing the L-BFGS-B scheme."""
|
149
|
+
hparams = {
|
150
|
+
"maxcor": maxcor,
|
151
|
+
"line_search_max_steps": line_search_max_steps,
|
152
|
+
}
|
153
|
+
return optimizer_client(
|
154
|
+
algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
|
155
|
+
)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""Defines labels and messages used in the context of an optimization service.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
VALUE = "value"
|
7
|
+
GRAD = "grad"
|
8
|
+
PARAMS = "params"
|
9
|
+
STATE_TOKEN = "state_token"
|
10
|
+
MESSAGE = "message"
|
11
|
+
|
12
|
+
OPT_CONFIG = "opt_config"
|
13
|
+
ALGORITHM = "algorithm"
|
14
|
+
HPARAMS = "hparams"
|
15
|
+
DATA = "data"
|
16
|
+
|
17
|
+
ROUTE_INIT = "init"
|
18
|
+
ROUTE_UPDATE = "update"
|
19
|
+
ROUTE_PARAMS = "params"
|
20
|
+
|
21
|
+
MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
|
22
|
+
MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
|
23
|
+
MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
|
File without changes
|
@@ -4,8 +4,13 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import dataclasses
|
7
|
+
import inspect
|
7
8
|
from typing import Any, Protocol
|
8
9
|
|
10
|
+
import jax.numpy as jnp
|
11
|
+
import optax # type: ignore[import-untyped]
|
12
|
+
from totypes import json_utils
|
13
|
+
|
9
14
|
PyTree = Any
|
10
15
|
|
11
16
|
|
@@ -30,7 +35,7 @@ class UpdateFn(Protocol):
|
|
30
35
|
self,
|
31
36
|
*,
|
32
37
|
grad: PyTree,
|
33
|
-
value:
|
38
|
+
value: jnp.ndarray,
|
34
39
|
params: PyTree,
|
35
40
|
state: PyTree,
|
36
41
|
) -> PyTree:
|
@@ -44,3 +49,13 @@ class Optimizer:
|
|
44
49
|
init: InitFn
|
45
50
|
params: ParamsFn
|
46
51
|
update: UpdateFn
|
52
|
+
|
53
|
+
|
54
|
+
# Register all optax state types for serialization.
|
55
|
+
optax_types = {}
|
56
|
+
for name, obj in inspect.getmembers(optax):
|
57
|
+
if name.endswith("State") and isinstance(obj, type):
|
58
|
+
optax_types[obj] = True
|
59
|
+
|
60
|
+
for obj in optax_types.keys():
|
61
|
+
json_utils.register_custom_type(obj)
|