invrs-opt 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.3.1"
6
+ __version__ = "v0.4.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
File without changes
@@ -0,0 +1,152 @@
1
+ """Defines basic client optimizers for use with an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import json
7
+ import requests
8
+ import time
9
+ from typing import Any, Dict, Optional
10
+
11
+ from totypes import json_utils
12
+
13
+ from invrs_opt import base
14
+ from invrs_opt.experimental import labels
15
+
16
+
17
+ PyTree = Any
18
+ StateToken = str
19
+
20
+ SESSION = None
21
+ SERVER_ADDRESS = None
22
+
23
+
24
+ def login(server_address: str) -> None:
25
+ """Set the global server address and create a requests session."""
26
+ global SESSION
27
+ global SERVER_ADDRESS
28
+ SESSION = requests.Session()
29
+ SERVER_ADDRESS = server_address
30
+
31
+
32
+ def optimizer_client(
33
+ algorithm: str,
34
+ hparams: Dict[str, Any],
35
+ server_address: Optional[str],
36
+ session: Optional[requests.Session],
37
+ ) -> base.Optimizer:
38
+ """Generic optimizer class."""
39
+
40
+ if server_address is None:
41
+ if SERVER_ADDRESS is None:
42
+ raise ValueError(
43
+ "Argument `server_address` and the global `SERVER_ADDRESS` cannot "
44
+ "both be `None`. Use the `login` method to set the global, or "
45
+ "explicitly provide a value."
46
+ )
47
+ if session is None:
48
+ if SESSION is None:
49
+ raise ValueError(
50
+ "Argument `session` and the global `SESSION` cannot "
51
+ "both be `None`. Use the `login` method to set the global, or "
52
+ "explicitly provide a value."
53
+ )
54
+ session = SESSION
55
+
56
+ opt_config = {
57
+ labels.ALGORITHM: algorithm,
58
+ labels.HPARAMS: hparams,
59
+ }
60
+
61
+ def init_fn(params: PyTree) -> StateToken:
62
+ """Handles 'init' requests."""
63
+ serialized_data = json_utils.json_from_pytree(
64
+ dict(opt_config=opt_config, data={"params": params})
65
+ )
66
+ post_response = session.post(
67
+ f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
68
+ )
69
+
70
+ if not post_response.status_code == 200:
71
+ raise requests.RequestException(post_response.text)
72
+ response = json.loads(post_response.text)
73
+ new_state_token: str = response[labels.STATE_TOKEN]
74
+ return new_state_token
75
+
76
+ def update_fn(
77
+ *,
78
+ grad: PyTree,
79
+ value: float,
80
+ params: PyTree,
81
+ state: StateToken,
82
+ ) -> StateToken:
83
+ """Handles 'update' requests."""
84
+ state_token = state
85
+ del state
86
+ serialized_data = json_utils.json_from_pytree(
87
+ {
88
+ labels.OPT_CONFIG: opt_config,
89
+ labels.DATA: {
90
+ labels.PARAMS: params,
91
+ labels.VALUE: value,
92
+ labels.GRAD: grad,
93
+ labels.STATE_TOKEN: state_token,
94
+ },
95
+ }
96
+ )
97
+ post_response = session.post(
98
+ f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
99
+ data=serialized_data,
100
+ )
101
+
102
+ if not post_response.status_code == 200:
103
+ raise requests.RequestException(post_response.text)
104
+ response = json.loads(post_response.text)
105
+ new_state_token: str = response[labels.STATE_TOKEN]
106
+ return new_state_token
107
+
108
+ def params_fn(
109
+ state: StateToken,
110
+ timeout: float = 60.0,
111
+ poll_interval: float = 0.1,
112
+ ) -> PyTree:
113
+ """Handles 'params' requests."""
114
+ state_token = state
115
+ del state
116
+ assert timeout >= poll_interval
117
+ start_time = time.time()
118
+ while time.time() < start_time + timeout:
119
+ get_response = session.get(
120
+ f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
121
+ )
122
+ if get_response.status_code == 200:
123
+ break
124
+ elif get_response.status_code == 404 and get_response.text.endswith(
125
+ labels.MESSAGE_STATE_NOT_READY.format(state_token)
126
+ ):
127
+ time.sleep(poll_interval)
128
+ else:
129
+ raise requests.RequestException(get_response.text)
130
+
131
+ if not get_response.status_code == 200:
132
+ raise requests.Timeout("Timed out while waiting for params.")
133
+ response = json.loads(get_response.text)
134
+ return json_utils.pytree_from_json(response[labels.PARAMS])
135
+
136
+ return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
137
+
138
+
139
+ # -----------------------------------------------------------------------------
140
+ # Specific optimizers implemented here.
141
+ # -----------------------------------------------------------------------------
142
+
143
+
144
+ def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
145
+ """Optimizer implementing the L-BFGS-B scheme."""
146
+ hparams = {
147
+ "maxcor": maxcor,
148
+ "line_search_max_steps": line_search_max_steps,
149
+ }
150
+ return optimizer_client(
151
+ algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
152
+ )
@@ -0,0 +1,23 @@
1
+ """Defines labels and messages used in the context of an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ VALUE = "value"
7
+ GRAD = "grad"
8
+ PARAMS = "params"
9
+ STATE_TOKEN = "state_token"
10
+ MESSAGE = "message"
11
+
12
+ OPT_CONFIG = "opt_config"
13
+ ALGORITHM = "algorithm"
14
+ HPARAMS = "hparams"
15
+ DATA = "data"
16
+
17
+ ROUTE_INIT = "init"
18
+ ROUTE_UPDATE = "update"
19
+ ROUTE_PARAMS = "params"
20
+
21
+ MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
22
+ MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
23
+ MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
@@ -248,32 +248,37 @@ def transformed_lbfgsb(
248
248
  del params
249
249
 
250
250
  def _update_pure(
251
- latent_grad: PyTree,
251
+ flat_latent_grad: PyTree,
252
252
  value: jnp.ndarray,
253
253
  jax_lbfgsb_state: JaxLbfgsbDict,
254
254
  ) -> Tuple[PyTree, JaxLbfgsbDict]:
255
255
  assert onp.size(value) == 1
256
256
  scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
257
257
  scipy_lbfgsb_state.update(
258
- grad=_to_numpy(latent_grad), value=onp.asarray(value)
258
+ grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
259
+ value=onp.asarray(value, dtype=onp.float64),
259
260
  )
260
- latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_grad)
261
- return latent_params, scipy_lbfgsb_state.to_jax()
261
+ flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
262
+ return flat_latent_params, scipy_lbfgsb_state.to_jax()
262
263
 
263
- params, latent_params, jax_lbfgsb_state = state
264
+ _, latent_params, jax_lbfgsb_state = state
264
265
  _, vjp_fn = jax.vjp(transform_fn, latent_params)
265
266
  (latent_grad,) = vjp_fn(grad)
267
+ flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
268
+ latent_grad
269
+ ) # type: ignore[no-untyped-call]
266
270
 
267
271
  (
268
- latent_params,
272
+ flat_latent_params,
269
273
  jax_lbfgsb_state,
270
274
  ) = jax.pure_callback( # type: ignore[attr-defined]
271
275
  _update_pure,
272
- (latent_params, jax_lbfgsb_state),
273
- latent_grad,
276
+ (flat_latent_grad, jax_lbfgsb_state),
277
+ flat_latent_grad,
274
278
  value,
275
279
  jax_lbfgsb_state,
276
280
  )
281
+ latent_params = unflatten_fn(flat_latent_params)
277
282
  return transform_fn(latent_params), latent_params, jax_lbfgsb_state
278
283
 
279
284
  return base.Optimizer(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.3.1
3
+ Version: 0.4.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -33,8 +33,10 @@ License-File: LICENSE
33
33
  Requires-Dist: jax
34
34
  Requires-Dist: jaxlib
35
35
  Requires-Dist: numpy
36
+ Requires-Dist: requests
36
37
  Requires-Dist: scipy
37
38
  Requires-Dist: totypes
39
+ Requires-Dist: types-requests
38
40
  Provides-Extra: dev
39
41
  Requires-Dist: bump-my-version ; extra == 'dev'
40
42
  Requires-Dist: darglint ; extra == 'dev'
@@ -47,7 +49,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
47
49
  Requires-Dist: pytest-subtests ; extra == 'tests'
48
50
 
49
51
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.3.1`
52
+ `v0.4.0`
51
53
 
52
54
  ## Overview
53
55
 
@@ -0,0 +1,14 @@
1
+ invrs_opt/__init__.py,sha256=sKLkaXTzj4zJf3pOcXI1uiHRM15tEEr2BhSL4RWQEns,309
2
+ invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ invrs_opt/experimental/client.py,sha256=td5o_YqqbcSypDrWCVrHGSoF8UxEdOLtKU0z9Dth9lA,4842
6
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
7
+ invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ invrs_opt/lbfgsb/lbfgsb.py,sha256=YEBM7XcKj65QpEwO5Y8Mgmjud-h8k-1lF6UsEFgu6sM,25130
9
+ invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
10
+ invrs_opt-0.4.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
11
+ invrs_opt-0.4.0.dist-info/METADATA,sha256=abuCz0t6ZBbkxwL23I-Ef1m5IKMVJveOjvEl25SmfEw,3326
12
+ invrs_opt-0.4.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
13
+ invrs_opt-0.4.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
14
+ invrs_opt-0.4.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- invrs_opt/__init__.py,sha256=F8FYw0x4opoCTILqxc9maMG08H6pl70KmGAuYTELLlM,309
2
- invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- invrs_opt/lbfgsb/lbfgsb.py,sha256=BZoLGUFjZEXGupVVLkbwd5Pa8nRL4_HtEmU7SyOqwvw,24865
6
- invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
7
- invrs_opt-0.3.1.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
8
- invrs_opt-0.3.1.dist-info/METADATA,sha256=A30Gz7fLSuuUcAl7x72f-Uu6Zuv8faD7O2AqVVNs67E,3272
9
- invrs_opt-0.3.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
10
- invrs_opt-0.3.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
11
- invrs_opt-0.3.1.dist-info/RECORD,,