invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- invrs_opt/__init__.py +14 -3
- invrs_opt/experimental/client.py +155 -0
- invrs_opt/experimental/labels.py +23 -0
- invrs_opt/optimizers/__init__.py +0 -0
- invrs_opt/{base.py → optimizers/base.py} +16 -1
- invrs_opt/optimizers/lbfgsb.py +939 -0
- invrs_opt/optimizers/wrapped_optax.py +347 -0
- invrs_opt/parameterization/__init__.py +0 -0
- invrs_opt/parameterization/base.py +208 -0
- invrs_opt/parameterization/filter_project.py +138 -0
- invrs_opt/parameterization/gaussian_levelset.py +671 -0
- invrs_opt/parameterization/pixel.py +75 -0
- invrs_opt/{lbfgsb/transform.py → parameterization/transforms.py} +76 -11
- invrs_opt-0.10.3.dist-info/LICENSE +504 -0
- invrs_opt-0.10.3.dist-info/METADATA +560 -0
- invrs_opt-0.10.3.dist-info/RECORD +20 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/WHEEL +1 -1
- invrs_opt/lbfgsb/lbfgsb.py +0 -670
- invrs_opt-0.3.2.dist-info/LICENSE +0 -21
- invrs_opt-0.3.2.dist-info/METADATA +0 -73
- invrs_opt-0.3.2.dist-info/RECORD +0 -11
- /invrs_opt/{lbfgsb → experimental}/__init__.py +0 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.10.3.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
@@ -3,8 +3,19 @@
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
4
4
|
"""
|
5
5
|
|
6
|
-
__version__ = "v0.3
|
6
|
+
__version__ = "v0.10.3"
|
7
7
|
__author__ = "Martin F. Schubert <mfschubert@gmail.com>"
|
8
8
|
|
9
|
-
from invrs_opt
|
10
|
-
|
9
|
+
from invrs_opt import parameterization as parameterization
|
10
|
+
|
11
|
+
from invrs_opt.optimizers.lbfgsb import (
|
12
|
+
density_lbfgsb as density_lbfgsb,
|
13
|
+
lbfgsb as lbfgsb,
|
14
|
+
levelset_lbfgsb as levelset_lbfgsb,
|
15
|
+
)
|
16
|
+
|
17
|
+
from invrs_opt.optimizers.wrapped_optax import (
|
18
|
+
density_wrapped_optax as density_wrapped_optax,
|
19
|
+
levelset_wrapped_optax as levelset_wrapped_optax,
|
20
|
+
wrapped_optax as wrapped_optax,
|
21
|
+
)
|
@@ -0,0 +1,155 @@
|
|
1
|
+
"""Defines basic client optimizers for use with an optimization service.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import json
|
7
|
+
import time
|
8
|
+
from typing import Any, Dict, Optional
|
9
|
+
|
10
|
+
import requests
|
11
|
+
from totypes import json_utils
|
12
|
+
|
13
|
+
from invrs_opt.experimental import labels
|
14
|
+
from invrs_opt.optimizers import base
|
15
|
+
|
16
|
+
PyTree = Any
|
17
|
+
StateToken = str
|
18
|
+
|
19
|
+
SESSION = None
|
20
|
+
SERVER_ADDRESS = None
|
21
|
+
|
22
|
+
|
23
|
+
def login(server_address: str) -> None:
|
24
|
+
"""Set the global server address and create a requests session."""
|
25
|
+
global SESSION
|
26
|
+
global SERVER_ADDRESS
|
27
|
+
SESSION = requests.Session()
|
28
|
+
SERVER_ADDRESS = server_address
|
29
|
+
|
30
|
+
|
31
|
+
def optimizer_client(
|
32
|
+
algorithm: str,
|
33
|
+
hparams: Dict[str, Any],
|
34
|
+
server_address: Optional[str],
|
35
|
+
session: Optional[requests.Session],
|
36
|
+
) -> base.Optimizer:
|
37
|
+
"""Generic optimizer class."""
|
38
|
+
|
39
|
+
if server_address is None:
|
40
|
+
if SERVER_ADDRESS is None:
|
41
|
+
raise ValueError(
|
42
|
+
"Argument `server_address` and the global `SERVER_ADDRESS` cannot "
|
43
|
+
"both be `None`. Use the `login` method to set the global, or "
|
44
|
+
"explicitly provide a value."
|
45
|
+
)
|
46
|
+
if session is None:
|
47
|
+
if SESSION is None:
|
48
|
+
raise ValueError(
|
49
|
+
"Argument `session` and the global `SESSION` cannot "
|
50
|
+
"both be `None`. Use the `login` method to set the global, or "
|
51
|
+
"explicitly provide a value."
|
52
|
+
)
|
53
|
+
session = SESSION
|
54
|
+
|
55
|
+
opt_config = {
|
56
|
+
labels.ALGORITHM: algorithm,
|
57
|
+
labels.HPARAMS: hparams,
|
58
|
+
}
|
59
|
+
|
60
|
+
def init_fn(params: PyTree) -> StateToken:
|
61
|
+
"""Handles 'init' requests."""
|
62
|
+
serialized_data = json_utils.json_from_pytree(
|
63
|
+
dict(opt_config=opt_config, data={"params": params})
|
64
|
+
)
|
65
|
+
post_response = session.post(
|
66
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
|
67
|
+
)
|
68
|
+
|
69
|
+
if not post_response.status_code == 200:
|
70
|
+
raise requests.RequestException(post_response.text)
|
71
|
+
response = json.loads(post_response.text)
|
72
|
+
new_state_token: str = response[labels.STATE_TOKEN]
|
73
|
+
return new_state_token
|
74
|
+
|
75
|
+
def update_fn(
|
76
|
+
*,
|
77
|
+
grad: PyTree,
|
78
|
+
value: float,
|
79
|
+
params: PyTree,
|
80
|
+
state: StateToken,
|
81
|
+
) -> StateToken:
|
82
|
+
"""Handles 'update' requests."""
|
83
|
+
state_token = state
|
84
|
+
del state
|
85
|
+
serialized_data = json_utils.json_from_pytree(
|
86
|
+
{
|
87
|
+
labels.OPT_CONFIG: opt_config,
|
88
|
+
labels.DATA: {
|
89
|
+
labels.PARAMS: params,
|
90
|
+
labels.VALUE: value,
|
91
|
+
labels.GRAD: grad,
|
92
|
+
labels.STATE_TOKEN: state_token,
|
93
|
+
},
|
94
|
+
}
|
95
|
+
)
|
96
|
+
post_response = session.post(
|
97
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
|
98
|
+
data=serialized_data,
|
99
|
+
)
|
100
|
+
|
101
|
+
if not post_response.status_code == 200:
|
102
|
+
raise requests.RequestException(post_response.text)
|
103
|
+
response = json.loads(post_response.text)
|
104
|
+
new_state_token: str = response[labels.STATE_TOKEN]
|
105
|
+
return new_state_token
|
106
|
+
|
107
|
+
def params_fn(
|
108
|
+
state: StateToken,
|
109
|
+
timeout: float = 60.0,
|
110
|
+
poll_interval: float = 0.1,
|
111
|
+
) -> PyTree:
|
112
|
+
"""Handles 'params' requests."""
|
113
|
+
state_token = state
|
114
|
+
del state
|
115
|
+
assert timeout >= poll_interval
|
116
|
+
start_time = time.time()
|
117
|
+
while time.time() < start_time + timeout:
|
118
|
+
get_response = session.get(
|
119
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
|
120
|
+
)
|
121
|
+
if get_response.status_code == 200:
|
122
|
+
break
|
123
|
+
elif get_response.status_code == 404 and get_response.text.endswith(
|
124
|
+
labels.MESSAGE_STATE_NOT_READY.format(state_token)
|
125
|
+
):
|
126
|
+
time.sleep(poll_interval)
|
127
|
+
else:
|
128
|
+
raise requests.RequestException(get_response.text)
|
129
|
+
|
130
|
+
if not get_response.status_code == 200:
|
131
|
+
raise requests.Timeout("Timed out while waiting for params.")
|
132
|
+
response = json.loads(get_response.text)
|
133
|
+
return json_utils.pytree_from_json(response[labels.PARAMS])
|
134
|
+
|
135
|
+
return base.Optimizer(
|
136
|
+
init=init_fn,
|
137
|
+
update=update_fn, # type: ignore[arg-type]
|
138
|
+
params=params_fn,
|
139
|
+
)
|
140
|
+
|
141
|
+
|
142
|
+
# -----------------------------------------------------------------------------
|
143
|
+
# Specific optimizers implemented here.
|
144
|
+
# -----------------------------------------------------------------------------
|
145
|
+
|
146
|
+
|
147
|
+
def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
|
148
|
+
"""Optimizer implementing the L-BFGS-B scheme."""
|
149
|
+
hparams = {
|
150
|
+
"maxcor": maxcor,
|
151
|
+
"line_search_max_steps": line_search_max_steps,
|
152
|
+
}
|
153
|
+
return optimizer_client(
|
154
|
+
algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
|
155
|
+
)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""Defines labels and messages used in the context of an optimization service.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
VALUE = "value"
|
7
|
+
GRAD = "grad"
|
8
|
+
PARAMS = "params"
|
9
|
+
STATE_TOKEN = "state_token"
|
10
|
+
MESSAGE = "message"
|
11
|
+
|
12
|
+
OPT_CONFIG = "opt_config"
|
13
|
+
ALGORITHM = "algorithm"
|
14
|
+
HPARAMS = "hparams"
|
15
|
+
DATA = "data"
|
16
|
+
|
17
|
+
ROUTE_INIT = "init"
|
18
|
+
ROUTE_UPDATE = "update"
|
19
|
+
ROUTE_PARAMS = "params"
|
20
|
+
|
21
|
+
MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
|
22
|
+
MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
|
23
|
+
MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
|
File without changes
|
@@ -4,8 +4,13 @@ Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
5
5
|
|
6
6
|
import dataclasses
|
7
|
+
import inspect
|
7
8
|
from typing import Any, Protocol
|
8
9
|
|
10
|
+
import jax.numpy as jnp
|
11
|
+
import optax # type: ignore[import-untyped]
|
12
|
+
from totypes import json_utils
|
13
|
+
|
9
14
|
PyTree = Any
|
10
15
|
|
11
16
|
|
@@ -30,7 +35,7 @@ class UpdateFn(Protocol):
|
|
30
35
|
self,
|
31
36
|
*,
|
32
37
|
grad: PyTree,
|
33
|
-
value:
|
38
|
+
value: jnp.ndarray,
|
34
39
|
params: PyTree,
|
35
40
|
state: PyTree,
|
36
41
|
) -> PyTree:
|
@@ -44,3 +49,13 @@ class Optimizer:
|
|
44
49
|
init: InitFn
|
45
50
|
params: ParamsFn
|
46
51
|
update: UpdateFn
|
52
|
+
|
53
|
+
|
54
|
+
# Register all optax state types for serialization.
|
55
|
+
optax_types = {}
|
56
|
+
for name, obj in inspect.getmembers(optax):
|
57
|
+
if name.endswith("State") and isinstance(obj, type):
|
58
|
+
optax_types[obj] = True
|
59
|
+
|
60
|
+
for obj in optax_types.keys():
|
61
|
+
json_utils.register_custom_type(obj)
|