invrs-opt 0.3.1__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.3.1"
6
+ __version__ = "v0.4.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
File without changes
@@ -0,0 +1,152 @@
1
+ """Defines basic client optimizers for use with an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import json
7
+ import requests
8
+ import time
9
+ from typing import Any, Dict, Optional
10
+
11
+ from totypes import json_utils
12
+
13
+ from invrs_opt import base
14
+ from invrs_opt.experimental import labels
15
+
16
+
17
+ PyTree = Any
18
+ StateToken = str
19
+
20
+ SESSION = None
21
+ SERVER_ADDRESS = None
22
+
23
+
24
+ def login(server_address: str) -> None:
25
+ """Set the global server address and create a requests session."""
26
+ global SESSION
27
+ global SERVER_ADDRESS
28
+ SESSION = requests.Session()
29
+ SERVER_ADDRESS = server_address
30
+
31
+
32
+ def optimizer_client(
33
+ algorithm: str,
34
+ hparams: Dict[str, Any],
35
+ server_address: Optional[str],
36
+ session: Optional[requests.Session],
37
+ ) -> base.Optimizer:
38
+ """Generic optimizer class."""
39
+
40
+ if server_address is None:
41
+ if SERVER_ADDRESS is None:
42
+ raise ValueError(
43
+ "Argument `server_address` and the global `SERVER_ADDRESS` cannot "
44
+ "both be `None`. Use the `login` method to set the global, or "
45
+ "explicitly provide a value."
46
+ )
47
+ if session is None:
48
+ if SESSION is None:
49
+ raise ValueError(
50
+ "Argument `session` and the global `SESSION` cannot "
51
+ "both be `None`. Use the `login` method to set the global, or "
52
+ "explicitly provide a value."
53
+ )
54
+ session = SESSION
55
+
56
+ opt_config = {
57
+ labels.ALGORITHM: algorithm,
58
+ labels.HPARAMS: hparams,
59
+ }
60
+
61
+ def init_fn(params: PyTree) -> StateToken:
62
+ """Handles 'init' requests."""
63
+ serialized_data = json_utils.json_from_pytree(
64
+ dict(opt_config=opt_config, data={"params": params})
65
+ )
66
+ post_response = session.post(
67
+ f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
68
+ )
69
+
70
+ if not post_response.status_code == 200:
71
+ raise requests.RequestException(post_response.text)
72
+ response = json.loads(post_response.text)
73
+ new_state_token: str = response[labels.STATE_TOKEN]
74
+ return new_state_token
75
+
76
+ def update_fn(
77
+ *,
78
+ grad: PyTree,
79
+ value: float,
80
+ params: PyTree,
81
+ state: StateToken,
82
+ ) -> StateToken:
83
+ """Handles 'update' requests."""
84
+ state_token = state
85
+ del state
86
+ serialized_data = json_utils.json_from_pytree(
87
+ {
88
+ labels.OPT_CONFIG: opt_config,
89
+ labels.DATA: {
90
+ labels.PARAMS: params,
91
+ labels.VALUE: value,
92
+ labels.GRAD: grad,
93
+ labels.STATE_TOKEN: state_token,
94
+ },
95
+ }
96
+ )
97
+ post_response = session.post(
98
+ f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
99
+ data=serialized_data,
100
+ )
101
+
102
+ if not post_response.status_code == 200:
103
+ raise requests.RequestException(post_response.text)
104
+ response = json.loads(post_response.text)
105
+ new_state_token: str = response[labels.STATE_TOKEN]
106
+ return new_state_token
107
+
108
+ def params_fn(
109
+ state: StateToken,
110
+ timeout: float = 60.0,
111
+ poll_interval: float = 0.1,
112
+ ) -> PyTree:
113
+ """Handles 'params' requests."""
114
+ state_token = state
115
+ del state
116
+ assert timeout >= poll_interval
117
+ start_time = time.time()
118
+ while time.time() < start_time + timeout:
119
+ get_response = session.get(
120
+ f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
121
+ )
122
+ if get_response.status_code == 200:
123
+ break
124
+ elif get_response.status_code == 404 and get_response.text.endswith(
125
+ labels.MESSAGE_STATE_NOT_READY.format(state_token)
126
+ ):
127
+ time.sleep(poll_interval)
128
+ else:
129
+ raise requests.RequestException(get_response.text)
130
+
131
+ if not get_response.status_code == 200:
132
+ raise requests.Timeout("Timed out while waiting for params.")
133
+ response = json.loads(get_response.text)
134
+ return json_utils.pytree_from_json(response[labels.PARAMS])
135
+
136
+ return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
137
+
138
+
139
+ # -----------------------------------------------------------------------------
140
+ # Specific optimizers implemented here.
141
+ # -----------------------------------------------------------------------------
142
+
143
+
144
+ def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
145
+ """Optimizer implementing the L-BFGS-B scheme."""
146
+ hparams = {
147
+ "maxcor": maxcor,
148
+ "line_search_max_steps": line_search_max_steps,
149
+ }
150
+ return optimizer_client(
151
+ algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
152
+ )
@@ -0,0 +1,23 @@
1
+ """Defines labels and messages used in the context of an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ VALUE = "value"
7
+ GRAD = "grad"
8
+ PARAMS = "params"
9
+ STATE_TOKEN = "state_token"
10
+ MESSAGE = "message"
11
+
12
+ OPT_CONFIG = "opt_config"
13
+ ALGORITHM = "algorithm"
14
+ HPARAMS = "hparams"
15
+ DATA = "data"
16
+
17
+ ROUTE_INIT = "init"
18
+ ROUTE_UPDATE = "update"
19
+ ROUTE_PARAMS = "params"
20
+
21
+ MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
22
+ MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
23
+ MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
@@ -248,32 +248,37 @@ def transformed_lbfgsb(
248
248
  del params
249
249
 
250
250
  def _update_pure(
251
- latent_grad: PyTree,
251
+ flat_latent_grad: PyTree,
252
252
  value: jnp.ndarray,
253
253
  jax_lbfgsb_state: JaxLbfgsbDict,
254
254
  ) -> Tuple[PyTree, JaxLbfgsbDict]:
255
255
  assert onp.size(value) == 1
256
256
  scipy_lbfgsb_state = ScipyLbfgsbState.from_jax(jax_lbfgsb_state)
257
257
  scipy_lbfgsb_state.update(
258
- grad=_to_numpy(latent_grad), value=onp.asarray(value)
258
+ grad=onp.asarray(flat_latent_grad, dtype=onp.float64),
259
+ value=onp.asarray(value, dtype=onp.float64),
259
260
  )
260
- latent_params = _to_pytree(scipy_lbfgsb_state.x, latent_grad)
261
- return latent_params, scipy_lbfgsb_state.to_jax()
261
+ flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
262
+ return flat_latent_params, scipy_lbfgsb_state.to_jax()
262
263
 
263
- params, latent_params, jax_lbfgsb_state = state
264
+ _, latent_params, jax_lbfgsb_state = state
264
265
  _, vjp_fn = jax.vjp(transform_fn, latent_params)
265
266
  (latent_grad,) = vjp_fn(grad)
267
+ flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
268
+ latent_grad
269
+ ) # type: ignore[no-untyped-call]
266
270
 
267
271
  (
268
- latent_params,
272
+ flat_latent_params,
269
273
  jax_lbfgsb_state,
270
274
  ) = jax.pure_callback( # type: ignore[attr-defined]
271
275
  _update_pure,
272
- (latent_params, jax_lbfgsb_state),
273
- latent_grad,
276
+ (flat_latent_grad, jax_lbfgsb_state),
277
+ flat_latent_grad,
274
278
  value,
275
279
  jax_lbfgsb_state,
276
280
  )
281
+ latent_params = unflatten_fn(flat_latent_params)
277
282
  return transform_fn(latent_params), latent_params, jax_lbfgsb_state
278
283
 
279
284
  return base.Optimizer(
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.3.1
3
+ Version: 0.4.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -33,8 +33,10 @@ License-File: LICENSE
33
33
  Requires-Dist: jax
34
34
  Requires-Dist: jaxlib
35
35
  Requires-Dist: numpy
36
+ Requires-Dist: requests
36
37
  Requires-Dist: scipy
37
38
  Requires-Dist: totypes
39
+ Requires-Dist: types-requests
38
40
  Provides-Extra: dev
39
41
  Requires-Dist: bump-my-version ; extra == 'dev'
40
42
  Requires-Dist: darglint ; extra == 'dev'
@@ -47,7 +49,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
47
49
  Requires-Dist: pytest-subtests ; extra == 'tests'
48
50
 
49
51
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.3.1`
52
+ `v0.4.0`
51
53
 
52
54
  ## Overview
53
55
 
@@ -0,0 +1,14 @@
1
+ invrs_opt/__init__.py,sha256=sKLkaXTzj4zJf3pOcXI1uiHRM15tEEr2BhSL4RWQEns,309
2
+ invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ invrs_opt/experimental/client.py,sha256=td5o_YqqbcSypDrWCVrHGSoF8UxEdOLtKU0z9Dth9lA,4842
6
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
7
+ invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ invrs_opt/lbfgsb/lbfgsb.py,sha256=YEBM7XcKj65QpEwO5Y8Mgmjud-h8k-1lF6UsEFgu6sM,25130
9
+ invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
10
+ invrs_opt-0.4.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
11
+ invrs_opt-0.4.0.dist-info/METADATA,sha256=abuCz0t6ZBbkxwL23I-Ef1m5IKMVJveOjvEl25SmfEw,3326
12
+ invrs_opt-0.4.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
13
+ invrs_opt-0.4.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
14
+ invrs_opt-0.4.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- invrs_opt/__init__.py,sha256=F8FYw0x4opoCTILqxc9maMG08H6pl70KmGAuYTELLlM,309
2
- invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- invrs_opt/lbfgsb/lbfgsb.py,sha256=BZoLGUFjZEXGupVVLkbwd5Pa8nRL4_HtEmU7SyOqwvw,24865
6
- invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
7
- invrs_opt-0.3.1.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
8
- invrs_opt-0.3.1.dist-info/METADATA,sha256=A30Gz7fLSuuUcAl7x72f-Uu6Zuv8faD7O2AqVVNs67E,3272
9
- invrs_opt-0.3.1.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
10
- invrs_opt-0.3.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
11
- invrs_opt-0.3.1.dist-info/RECORD,,