invrs-opt 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +1 -1
- invrs_opt/experimental/__init__.py +0 -0
- invrs_opt/experimental/client.py +152 -0
- invrs_opt/experimental/labels.py +23 -0
- invrs_opt/lbfgsb/lbfgsb.py +4 -2
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.4.0.dist-info}/METADATA +4 -2
- invrs_opt-0.4.0.dist-info/RECORD +14 -0
- invrs_opt-0.3.2.dist-info/RECORD +0 -11
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.4.0.dist-info}/LICENSE +0 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.4.0.dist-info}/WHEEL +0 -0
- {invrs_opt-0.3.2.dist-info → invrs_opt-0.4.0.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
File without changes
|
@@ -0,0 +1,152 @@
|
|
1
|
+
"""Defines basic client optimizers for use with an optimization service.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
import json
|
7
|
+
import requests
|
8
|
+
import time
|
9
|
+
from typing import Any, Dict, Optional
|
10
|
+
|
11
|
+
from totypes import json_utils
|
12
|
+
|
13
|
+
from invrs_opt import base
|
14
|
+
from invrs_opt.experimental import labels
|
15
|
+
|
16
|
+
|
17
|
+
PyTree = Any
|
18
|
+
StateToken = str
|
19
|
+
|
20
|
+
SESSION = None
|
21
|
+
SERVER_ADDRESS = None
|
22
|
+
|
23
|
+
|
24
|
+
def login(server_address: str) -> None:
|
25
|
+
"""Set the global server address and create a requests session."""
|
26
|
+
global SESSION
|
27
|
+
global SERVER_ADDRESS
|
28
|
+
SESSION = requests.Session()
|
29
|
+
SERVER_ADDRESS = server_address
|
30
|
+
|
31
|
+
|
32
|
+
def optimizer_client(
|
33
|
+
algorithm: str,
|
34
|
+
hparams: Dict[str, Any],
|
35
|
+
server_address: Optional[str],
|
36
|
+
session: Optional[requests.Session],
|
37
|
+
) -> base.Optimizer:
|
38
|
+
"""Generic optimizer class."""
|
39
|
+
|
40
|
+
if server_address is None:
|
41
|
+
if SERVER_ADDRESS is None:
|
42
|
+
raise ValueError(
|
43
|
+
"Argument `server_address` and the global `SERVER_ADDRESS` cannot "
|
44
|
+
"both be `None`. Use the `login` method to set the global, or "
|
45
|
+
"explicitly provide a value."
|
46
|
+
)
|
47
|
+
if session is None:
|
48
|
+
if SESSION is None:
|
49
|
+
raise ValueError(
|
50
|
+
"Argument `session` and the global `SESSION` cannot "
|
51
|
+
"both be `None`. Use the `login` method to set the global, or "
|
52
|
+
"explicitly provide a value."
|
53
|
+
)
|
54
|
+
session = SESSION
|
55
|
+
|
56
|
+
opt_config = {
|
57
|
+
labels.ALGORITHM: algorithm,
|
58
|
+
labels.HPARAMS: hparams,
|
59
|
+
}
|
60
|
+
|
61
|
+
def init_fn(params: PyTree) -> StateToken:
|
62
|
+
"""Handles 'init' requests."""
|
63
|
+
serialized_data = json_utils.json_from_pytree(
|
64
|
+
dict(opt_config=opt_config, data={"params": params})
|
65
|
+
)
|
66
|
+
post_response = session.post(
|
67
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
|
68
|
+
)
|
69
|
+
|
70
|
+
if not post_response.status_code == 200:
|
71
|
+
raise requests.RequestException(post_response.text)
|
72
|
+
response = json.loads(post_response.text)
|
73
|
+
new_state_token: str = response[labels.STATE_TOKEN]
|
74
|
+
return new_state_token
|
75
|
+
|
76
|
+
def update_fn(
|
77
|
+
*,
|
78
|
+
grad: PyTree,
|
79
|
+
value: float,
|
80
|
+
params: PyTree,
|
81
|
+
state: StateToken,
|
82
|
+
) -> StateToken:
|
83
|
+
"""Handles 'update' requests."""
|
84
|
+
state_token = state
|
85
|
+
del state
|
86
|
+
serialized_data = json_utils.json_from_pytree(
|
87
|
+
{
|
88
|
+
labels.OPT_CONFIG: opt_config,
|
89
|
+
labels.DATA: {
|
90
|
+
labels.PARAMS: params,
|
91
|
+
labels.VALUE: value,
|
92
|
+
labels.GRAD: grad,
|
93
|
+
labels.STATE_TOKEN: state_token,
|
94
|
+
},
|
95
|
+
}
|
96
|
+
)
|
97
|
+
post_response = session.post(
|
98
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
|
99
|
+
data=serialized_data,
|
100
|
+
)
|
101
|
+
|
102
|
+
if not post_response.status_code == 200:
|
103
|
+
raise requests.RequestException(post_response.text)
|
104
|
+
response = json.loads(post_response.text)
|
105
|
+
new_state_token: str = response[labels.STATE_TOKEN]
|
106
|
+
return new_state_token
|
107
|
+
|
108
|
+
def params_fn(
|
109
|
+
state: StateToken,
|
110
|
+
timeout: float = 60.0,
|
111
|
+
poll_interval: float = 0.1,
|
112
|
+
) -> PyTree:
|
113
|
+
"""Handles 'params' requests."""
|
114
|
+
state_token = state
|
115
|
+
del state
|
116
|
+
assert timeout >= poll_interval
|
117
|
+
start_time = time.time()
|
118
|
+
while time.time() < start_time + timeout:
|
119
|
+
get_response = session.get(
|
120
|
+
f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
|
121
|
+
)
|
122
|
+
if get_response.status_code == 200:
|
123
|
+
break
|
124
|
+
elif get_response.status_code == 404 and get_response.text.endswith(
|
125
|
+
labels.MESSAGE_STATE_NOT_READY.format(state_token)
|
126
|
+
):
|
127
|
+
time.sleep(poll_interval)
|
128
|
+
else:
|
129
|
+
raise requests.RequestException(get_response.text)
|
130
|
+
|
131
|
+
if not get_response.status_code == 200:
|
132
|
+
raise requests.Timeout("Timed out while waiting for params.")
|
133
|
+
response = json.loads(get_response.text)
|
134
|
+
return json_utils.pytree_from_json(response[labels.PARAMS])
|
135
|
+
|
136
|
+
return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
|
137
|
+
|
138
|
+
|
139
|
+
# -----------------------------------------------------------------------------
|
140
|
+
# Specific optimizers implemented here.
|
141
|
+
# -----------------------------------------------------------------------------
|
142
|
+
|
143
|
+
|
144
|
+
def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
|
145
|
+
"""Optimizer implementing the L-BFGS-B scheme."""
|
146
|
+
hparams = {
|
147
|
+
"maxcor": maxcor,
|
148
|
+
"line_search_max_steps": line_search_max_steps,
|
149
|
+
}
|
150
|
+
return optimizer_client(
|
151
|
+
algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
|
152
|
+
)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
"""Defines labels and messages used in the context of an optimization service.
|
2
|
+
|
3
|
+
Copyright (c) 2023 The INVRS-IO authors.
|
4
|
+
"""
|
5
|
+
|
6
|
+
VALUE = "value"
|
7
|
+
GRAD = "grad"
|
8
|
+
PARAMS = "params"
|
9
|
+
STATE_TOKEN = "state_token"
|
10
|
+
MESSAGE = "message"
|
11
|
+
|
12
|
+
OPT_CONFIG = "opt_config"
|
13
|
+
ALGORITHM = "algorithm"
|
14
|
+
HPARAMS = "hparams"
|
15
|
+
DATA = "data"
|
16
|
+
|
17
|
+
ROUTE_INIT = "init"
|
18
|
+
ROUTE_UPDATE = "update"
|
19
|
+
ROUTE_PARAMS = "params"
|
20
|
+
|
21
|
+
MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
|
22
|
+
MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
|
23
|
+
MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
|
invrs_opt/lbfgsb/lbfgsb.py
CHANGED
@@ -261,10 +261,12 @@ def transformed_lbfgsb(
|
|
261
261
|
flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
|
262
262
|
return flat_latent_params, scipy_lbfgsb_state.to_jax()
|
263
263
|
|
264
|
-
|
264
|
+
_, latent_params, jax_lbfgsb_state = state
|
265
265
|
_, vjp_fn = jax.vjp(transform_fn, latent_params)
|
266
266
|
(latent_grad,) = vjp_fn(grad)
|
267
|
-
flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
|
267
|
+
flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
|
268
|
+
latent_grad
|
269
|
+
) # type: ignore[no-untyped-call]
|
268
270
|
|
269
271
|
(
|
270
272
|
flat_latent_params,
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.4.0
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -33,8 +33,10 @@ License-File: LICENSE
|
|
33
33
|
Requires-Dist: jax
|
34
34
|
Requires-Dist: jaxlib
|
35
35
|
Requires-Dist: numpy
|
36
|
+
Requires-Dist: requests
|
36
37
|
Requires-Dist: scipy
|
37
38
|
Requires-Dist: totypes
|
39
|
+
Requires-Dist: types-requests
|
38
40
|
Provides-Extra: dev
|
39
41
|
Requires-Dist: bump-my-version ; extra == 'dev'
|
40
42
|
Requires-Dist: darglint ; extra == 'dev'
|
@@ -47,7 +49,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
|
|
47
49
|
Requires-Dist: pytest-subtests ; extra == 'tests'
|
48
50
|
|
49
51
|
# invrs-opt - Optimization algorithms for inverse design
|
50
|
-
`v0.
|
52
|
+
`v0.4.0`
|
51
53
|
|
52
54
|
## Overview
|
53
55
|
|
@@ -0,0 +1,14 @@
|
|
1
|
+
invrs_opt/__init__.py,sha256=sKLkaXTzj4zJf3pOcXI1uiHRM15tEEr2BhSL4RWQEns,309
|
2
|
+
invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
|
3
|
+
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
+
invrs_opt/experimental/client.py,sha256=td5o_YqqbcSypDrWCVrHGSoF8UxEdOLtKU0z9Dth9lA,4842
|
6
|
+
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
7
|
+
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
8
|
+
invrs_opt/lbfgsb/lbfgsb.py,sha256=YEBM7XcKj65QpEwO5Y8Mgmjud-h8k-1lF6UsEFgu6sM,25130
|
9
|
+
invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
|
10
|
+
invrs_opt-0.4.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
11
|
+
invrs_opt-0.4.0.dist-info/METADATA,sha256=abuCz0t6ZBbkxwL23I-Ef1m5IKMVJveOjvEl25SmfEw,3326
|
12
|
+
invrs_opt-0.4.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
13
|
+
invrs_opt-0.4.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
14
|
+
invrs_opt-0.4.0.dist-info/RECORD,,
|
invrs_opt-0.3.2.dist-info/RECORD
DELETED
@@ -1,11 +0,0 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=JG9h93Vq0AA2jX6zcqwxjA-LXOws8CJQeEbY9VQvQWQ,309
|
2
|
-
invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
|
3
|
-
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
|
-
invrs_opt/lbfgsb/lbfgsb.py,sha256=2X0GVQwMCj2beYHLiAdwoAnBIdJLPThH0Jjduz3GzHA,25080
|
6
|
-
invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
|
7
|
-
invrs_opt-0.3.2.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
|
8
|
-
invrs_opt-0.3.2.dist-info/METADATA,sha256=pwrRENMlb_s1bNAeROw-cJ-nIVuqR9815r-jPNM3cEg,3272
|
9
|
-
invrs_opt-0.3.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
|
10
|
-
invrs_opt-0.3.2.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
11
|
-
invrs_opt-0.3.2.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|