invrs-opt 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- invrs_opt/__init__.py +1 -1
- invrs_opt/experimental/__init__.py +0 -0
- invrs_opt/experimental/client.py +152 -0
- invrs_opt/experimental/labels.py +23 -0
- invrs_opt/lbfgsb/lbfgsb.py +13 -8
- {invrs_opt-0.3.1.dist-info → invrs_opt-0.4.0.dist-info}/METADATA +4 -2
- invrs_opt-0.4.0.dist-info/RECORD +14 -0
- invrs_opt-0.3.1.dist-info/RECORD +0 -11
- {invrs_opt-0.3.1.dist-info → invrs_opt-0.4.0.dist-info}/LICENSE +0 -0
- {invrs_opt-0.3.1.dist-info → invrs_opt-0.4.0.dist-info}/WHEEL +0 -0
- {invrs_opt-0.3.1.dist-info → invrs_opt-0.4.0.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
File without changes
|
@@ -0,0 +1,152 @@
|
|
1
|
+
"""Defines basic client optimizers for use with an optimization service.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import json
|
7
|
+
import requests
|
8
|
+
import time
|
9
|
+
from typing import Any, Dict, Optional
|
10
|
+
|
11
|
+
from totypes import json_utils
|
12
|
+
|
13
|
+
from invrs_opt import base
|
14
|
+
from invrs_opt.experimental import labels
|
15
|
+
|
16
|
+
|
17
|
+
PyTree = Any
|
18
|
+
StateToken = str
|
19
|
+
|
20
|
+
SESSION = None
|
21
|
+
SERVER_ADDRESS = None
|
22
|
+
|
23
|
+
|
24
|
+
def login(server_address: str) -> None:
|
25
|
+
"""Set the global server address and create a requests session."""
|
26
|
+
global SESSION
|
27
|
+
global SERVER_ADDRESS
|
28
|
+
SESSION = requests.Session()
|
29
|
+
SERVER_ADDRESS = server_address
|
30
|
+
|
31
|
+
|
32
|
+
def optimizer_client(
|
33
|
+
algorithm: str,
|
34
|
+
hparams: Dict[str, Any],
|
35
|
+
server_address: Optional[str],
|
36
|
+
session: Optional[requests.Session],
|
37
|
+
) -> base.Optimizer:
|
38
|
+
"""Generic optimizer class."""
|
39
|
+
|
40
|
+
if server_address is None:
|
41
|
+
if SERVER_ADDRESS is None:
|
42
|
+
raise ValueError(
|
43
|
+
"Argument `server_address` and the global `SERVER_ADDRESS` cannot "
|
44
|
+
"both be `None`. Use the `login` method to set the global, or "
|
45
|
+
"explicitly provide a value."
|
46
|
+
)
|
47
|
+
if session is None:
|
48
|
+
if SESSION is None:
|
49
|
+
raise ValueError(
|
50
|
+
"Argument `session` and the global `SESSION` cannot "
|
51
|
+
"both be `None`. Use the `login` method to set the global, or "
|
52
|
+
"explicitly provide a value."
|
53
|
+
)
|
54
|
+
session = SESSION
|
55
|
+
|
56
|
+
opt_config = {
|
57
|
+
labels.ALGORITHM: algorithm,
|
58
|
+
labels.HPARAMS: hparams,
|
59
|
+
}
|
60
|
+
|
61
|
+
def init_fn(params: PyTree) -> StateToken:
|
62
|
+
"""Handles 'init' requests."""
|
63
|
+
serialized_data = json_utils.json_from_pytree(
|
64
|
+
dict(opt_config=opt_config, data={"params": params})
|
65
|
+
)
|
66
|
+
post_response = session.post(
|
67
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
|
68
|
+
)
|
69
|
+
|
70
|
+
if not post_response.status_code == 200:
|
71
|
+
raise requests.RequestException(post_response.text)
|
72
|
+
response = json.loads(post_response.text)
|
73
|
+
new_state_token: str = response[labels.STATE_TOKEN]
|
74
|
+
return new_state_token
|
75
|
+
|
76
|
+
def update_fn(
|
77
|
+
*,
|
78
|
+
grad: PyTree,
|
79
|
+
value: float,
|
80
|
+
params: PyTree,
|
81
|
+
state: StateToken,
|
82
|
+
) -> StateToken:
|
83
|
+
"""Handles 'update' requests."""
|
84
|
+
state_token = state
|
85
|
+
del state
|
86
|
+
serialized_data = json_utils.json_from_pytree(
|
87
|
+
{
|
88
|
+
labels.OPT_CONFIG: opt_config,
|
89
|
+
labels.DATA: {
|
90
|
+
labels.PARAMS: params,
|
91
|
+
labels.VALUE: value,
|
92
|
+
labels.GRAD: grad,
|
93
|
+
labels.STATE_TOKEN: state_token,
|
94
|
+
},
|
95
|
+
}
|
96
|
+
)
|
97
|
+
post_response = session.post(
|
98
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
|
99
|
+
data=serialized_data,
|
100
|
+
)
|
101
|
+
|
102
|
+
if not post_response.status_code == 200:
|
103
|
+
raise requests.RequestException(post_response.text)
|
104
|
+
response = json.loads(post_response.text)
|
105
|
+
new_state_token: str = response[labels.STATE_TOKEN]
|
106
|
+
return new_state_token
|
107
|
+
|
108
|
+
def params_fn(
|
109
|
+
state: StateToken,
|
110
|
+
timeout: float = 60.0,
|
111
|
+
poll_interval: float = 0.1,
|
112
|
+
) -> PyTree:
|
113
|
+
"""Handles 'params' requests."""
|
114
|
+
state_token = state
|
115
|
+
del state
|
116
|
+
assert timeout >= poll_interval
|
117
|
+
start_time = time.time()
|
118
|
+
while time.time() < start_time + timeout:
|
119
|
+
get_response = session.get(
|
120
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
|
121
|
+
)
|
122
|
+
if get_response.status_code == 200:
|
123
|
+
break
|
124
|
+
elif get_response.status_code == 404 and get_response.text.endswith(
|
125
|
+
labels.MESSAGE_STATE_NOT_READY.format(state_token)
|
126
|
+
):
|
127
|
+
time.sleep(poll_interval)
|
128
|
+
else:
|
129
|
+
raise requests.RequestException(get_response.text)
|
130
|
+
|
131
|
+
if not get_response.status_code == 200:
|
132
|
+
raise requests.Timeout("Timed out while waiting for params.")
|
133
|
+
response = json.loads(get_response.text)
|
134
|
+
return json_utils.pytree_from_json(response[labels.PARAMS])
|
135
|
+
|
136
|
+
return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
|
137
|
+
|
138
|
+
|
139
|
+
# -----------------------------------------------------------------------------
|
140
|
+
# Specific optimizers implemented here.
|
141
|
+
# -----------------------------------------------------------------------------
|
142
|
+
|
143
|
+
|
144
|
+
def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
|
145
|
+
"""Optimizer implementing the L-BFGS-B scheme."""
|
146
|
+
hparams = {
|
147
|
+
"maxcor": maxcor,
|
148
|
+
"line_search_max_steps": line_search_max_steps,
|
149
|
+
}
|
150
|
+
return optimizer_client(
|
151
|
+
algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
|
152
|
+
)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""Defines labels and messages used in the context of an optimization service.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
VALUE = "value"
|
7
|
+
GRAD = "grad"
|
8
|
+
PARAMS = "params"
|
9
|
+
STATE_TOKEN = "state_token"
|
10
|
+
MESSAGE = "message"
|
11
|
+
|
12
|
+
OPT_CONFIG = "opt_config"
|
13
|
+
ALGORITHM = "algorithm"
|
14
|
+
HPARAMS = "hparams"
|
15
|
+
DATA = "data"
|
16
|
+
|
17
|
+
ROUTE_INIT = "init"
|
18
|
+
ROUTE_UPDATE = "update"
|
19
|
+
ROUTE_PARAMS = "params"
|
20
|
+
|
21
|
+
MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
|
22
|
+
MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
|
23
|
+
MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
|
invrs_opt/lbfgsb/lbfgsb.py
CHANGED
@@ -248,32 +248,37 @@ def transformed_lbfgsb(
|
|
248
248
|
del params
|
249
249
|
|
250
250
|
def _update_pure(
|
251
|
-
|
251
|
+
flat_latent_grad: PyTree,
|
252
252
|
value: jnp.ndarray,
|
253
253
|
jax_lbfgsb_state: JaxLbfgsbDict,
|
254
254
|
) -> Tuple[PyTree, JaxLbfgsbDict]:
|
255
255
|
assert onp.size(value) == 1
|
256
256
|
scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
|
257
257
|
scipy_lbfgsb_state.update(
|
258
|
-
grad=
|
258
|
+
grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
|
259
|
+
value=onp.asarray(value, dtype=onp.float64),
|
259
260
|
)
|
260
|
-
|
261
|
-
return
|
261
|
+
flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
|
262
|
+
return flat_latent_params, scipy_lbfgsb_state.to_jax()
|
262
263
|
|
263
|
-
|
264
|
+
_, latent_params, jax_lbfgsb_state = state
|
264
265
|
_, vjp_fn = jax.vjp(transform_fn, latent_params)
|
265
266
|
(latent_grad,) = vjp_fn(grad)
|
267
|
+
flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
|
268
|
+
latent_grad
|
269
|
+
) # type: ignore[no-untyped-call]
|
266
270
|
|
267
271
|
(
|
268
|
-
|
272
|
+
flat_latent_params,
|
269
273
|
jax_lbfgsb_state,
|
270
274
|
) = jax.pure_callback( # type: ignore[attr-defined]
|
271
275
|
_update_pure,
|
272
|
-
(
|
273
|
-
|
276
|
+
(flat_latent_grad, jax_lbfgsb_state),
|
277
|
+
flat_latent_grad,
|
274
278
|
value,
|
275
279
|
jax_lbfgsb_state,
|
276
280
|
)
|
281
|
+
latent_params = unflatten_fn(flat_latent_params)
|
277
282
|
return transform_fn(latent_params), latent_params, jax_lbfgsb_state
|
278
283
|
|
279
284
|
return base.Optimizer(
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -33,8 +33,10 @@ License-File: LICENSE
|
|
33
33
|
Requires-Dist: jax
|
34
34
|
Requires-Dist: jaxlib
|
35
35
|
Requires-Dist: numpy
|
36
|
+
Requires-Dist: requests
|
36
37
|
Requires-Dist: scipy
|
37
38
|
Requires-Dist: totypes
|
39
|
+
Requires-Dist: types-requests
|
38
40
|
Provides-Extra: dev
|
39
41
|
Requires-Dist: bump-my-version ; extra == 'dev'
|
40
42
|
Requires-Dist: darglint ; extra == 'dev'
|
@@ -47,7 +49,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
|
|
47
49
|
Requires-Dist: pytest-subtests ; extra == 'tests'
|
48
50
|
|
49
51
|
# invrs-opt - Optimization algorithms for inverse design
|
50
|
-
`v0.
|
52
|
+
`v0.4.0`
|
51
53
|
|
52
54
|
## Overview
|
53
55
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
invrs_opt/__init__.py,sha256=sKLkaXTzj4zJf3pOcXI1uiHRM15tEEr2BhSL4RWQEns,309
|
2
|
+
invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
|
3
|
+
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
+
invrs_opt/experimental/client.py,sha256=td5o_YqqbcSypDrWCVrHGSoF8UxEdOLtKU0z9Dth9lA,4842
|
6
|
+
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
7
|
+
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
+
invrs_opt/lbfgsb/lbfgsb.py,sha256=YEBM7XcKj65QpEwO5Y8Mgmjud-h8k-1lF6UsEFgu6sM,25130
|
9
|
+
invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
|
10
|
+
invrs_opt-0.4.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
11
|
+
invrs_opt-0.4.0.dist-info/METADATA,sha256=abuCz0t6ZBbkxwL23I-Ef1m5IKMVJveOjvEl25SmfEw,3326
|
12
|
+
invrs_opt-0.4.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
13
|
+
invrs_opt-0.4.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
14
|
+
invrs_opt-0.4.0.dist-info/RECORD,,
|
invrs_opt-0.3.1.dist-info/RECORD
DELETED
@@ -1,11 +0,0 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=F8FYw0x4opoCTILqxc9maMG08H6pl70KmGAuYTELLlM,309
|
2
|
-
invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
|
3
|
-
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
invrs_opt/lbfgsb/lbfgsb.py,sha256=BZoLGUFjZEXGupVVLkbwd5Pa8nRL4_HtEmU7SyOqwvw,24865
|
6
|
-
invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
|
7
|
-
invrs_opt-0.3.1.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
8
|
-
invrs_opt-0.3.1.dist-info/METADATA,sha256=A30Gz7fLSuuUcAl7x72f-Uu6Zuv8faD7O2AqVVNs67E,3272
|
9
|
-
invrs_opt-0.3.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
10
|
-
invrs_opt-0.3.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
11
|
-
invrs_opt-0.3.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|