invrs-opt 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.3.2"
6
+ __version__ = "v0.4.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
File without changes
@@ -0,0 +1,152 @@
1
+ """Defines basic client optimizers for use with an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import json
7
+ import requests
8
+ import time
9
+ from typing import Any, Dict, Optional
10
+
11
+ from totypes import json_utils
12
+
13
+ from invrs_opt import base
14
+ from invrs_opt.experimental import labels
15
+
16
+
17
+ PyTree = Any
18
+ StateToken = str
19
+
20
+ SESSION = None
21
+ SERVER_ADDRESS = None
22
+
23
+
24
+ def login(server_address: str) -> None:
25
+ """Set the global server address and create a requests session."""
26
+ global SESSION
27
+ global SERVER_ADDRESS
28
+ SESSION = requests.Session()
29
+ SERVER_ADDRESS = server_address
30
+
31
+
32
+ def optimizer_client(
33
+ algorithm: str,
34
+ hparams: Dict[str, Any],
35
+ server_address: Optional[str],
36
+ session: Optional[requests.Session],
37
+ ) -> base.Optimizer:
38
+ """Generic optimizer class."""
39
+
40
+ if server_address is None:
41
+ if SERVER_ADDRESS is None:
42
+ raise ValueError(
43
+ "Argument `server_address` and the global `SERVER_ADDRESS` cannot "
44
+ "both be `None`. Use the `login` method to set the global, or "
45
+ "explicitly provide a value."
46
+ )
47
+ if session is None:
48
+ if SESSION is None:
49
+ raise ValueError(
50
+ "Argument `session` and the global `SESSION` cannot "
51
+ "both be `None`. Use the `login` method to set the global, or "
52
+ "explicitly provide a value."
53
+ )
54
+ session = SESSION
55
+
56
+ opt_config = {
57
+ labels.ALGORITHM: algorithm,
58
+ labels.HPARAMS: hparams,
59
+ }
60
+
61
+ def init_fn(params: PyTree) -> StateToken:
62
+ """Handles 'init' requests."""
63
+ serialized_data = json_utils.json_from_pytree(
64
+ dict(opt_config=opt_config, data={"params": params})
65
+ )
66
+ post_response = session.post(
67
+ f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
68
+ )
69
+
70
+ if not post_response.status_code == 200:
71
+ raise requests.RequestException(post_response.text)
72
+ response = json.loads(post_response.text)
73
+ new_state_token: str = response[labels.STATE_TOKEN]
74
+ return new_state_token
75
+
76
+ def update_fn(
77
+ *,
78
+ grad: PyTree,
79
+ value: float,
80
+ params: PyTree,
81
+ state: StateToken,
82
+ ) -> StateToken:
83
+ """Handles 'update' requests."""
84
+ state_token = state
85
+ del state
86
+ serialized_data = json_utils.json_from_pytree(
87
+ {
88
+ labels.OPT_CONFIG: opt_config,
89
+ labels.DATA: {
90
+ labels.PARAMS: params,
91
+ labels.VALUE: value,
92
+ labels.GRAD: grad,
93
+ labels.STATE_TOKEN: state_token,
94
+ },
95
+ }
96
+ )
97
+ post_response = session.post(
98
+ f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
99
+ data=serialized_data,
100
+ )
101
+
102
+ if not post_response.status_code == 200:
103
+ raise requests.RequestException(post_response.text)
104
+ response = json.loads(post_response.text)
105
+ new_state_token: str = response[labels.STATE_TOKEN]
106
+ return new_state_token
107
+
108
+ def params_fn(
109
+ state: StateToken,
110
+ timeout: float = 60.0,
111
+ poll_interval: float = 0.1,
112
+ ) -> PyTree:
113
+ """Handles 'params' requests."""
114
+ state_token = state
115
+ del state
116
+ assert timeout >= poll_interval
117
+ start_time = time.time()
118
+ while time.time() < start_time + timeout:
119
+ get_response = session.get(
120
+ f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
121
+ )
122
+ if get_response.status_code == 200:
123
+ break
124
+ elif get_response.status_code == 404 and get_response.text.endswith(
125
+ labels.MESSAGE_STATE_NOT_READY.format(state_token)
126
+ ):
127
+ time.sleep(poll_interval)
128
+ else:
129
+ raise requests.RequestException(get_response.text)
130
+
131
+ if not get_response.status_code == 200:
132
+ raise requests.Timeout("Timed out while waiting for params.")
133
+ response = json.loads(get_response.text)
134
+ return json_utils.pytree_from_json(response[labels.PARAMS])
135
+
136
+ return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
137
+
138
+
139
+ # -----------------------------------------------------------------------------
140
+ # Specific optimizers implemented here.
141
+ # -----------------------------------------------------------------------------
142
+
143
+
144
+ def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
145
+ """Optimizer implementing the L-BFGS-B scheme."""
146
+ hparams = {
147
+ "maxcor": maxcor,
148
+ "line_search_max_steps": line_search_max_steps,
149
+ }
150
+ return optimizer_client(
151
+ algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
152
+ )
@@ -0,0 +1,23 @@
1
+ """Defines labels and messages used in the context of an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ VALUE = "value"
7
+ GRAD = "grad"
8
+ PARAMS = "params"
9
+ STATE_TOKEN = "state_token"
10
+ MESSAGE = "message"
11
+
12
+ OPT_CONFIG = "opt_config"
13
+ ALGORITHM = "algorithm"
14
+ HPARAMS = "hparams"
15
+ DATA = "data"
16
+
17
+ ROUTE_INIT = "init"
18
+ ROUTE_UPDATE = "update"
19
+ ROUTE_PARAMS = "params"
20
+
21
+ MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
22
+ MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
23
+ MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
@@ -261,10 +261,12 @@ def transformed_lbfgsb(
261
261
  flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
262
262
  return flat_latent_params, scipy_lbfgsb_state.to_jax()
263
263
 
264
- params, latent_params, jax_lbfgsb_state = state
264
+ _, latent_params, jax_lbfgsb_state = state
265
265
  _, vjp_fn = jax.vjp(transform_fn, latent_params)
266
266
  (latent_grad,) = vjp_fn(grad)
267
- flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(latent_grad)
267
+ flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
268
+ latent_grad
269
+ ) # type: ignore[no-untyped-call]
268
270
 
269
271
  (
270
272
  flat_latent_params,
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -33,8 +33,10 @@ License-File: LICENSE
33
33
  Requires-Dist: jax
34
34
  Requires-Dist: jaxlib
35
35
  Requires-Dist: numpy
36
+ Requires-Dist: requests
36
37
  Requires-Dist: scipy
37
38
  Requires-Dist: totypes
39
+ Requires-Dist: types-requests
38
40
  Provides-Extra: dev
39
41
  Requires-Dist: bump-my-version ; extra == 'dev'
40
42
  Requires-Dist: darglint ; extra == 'dev'
@@ -47,7 +49,7 @@ Requires-Dist: pytest-cov ; extra == 'tests'
47
49
  Requires-Dist: pytest-subtests ; extra == 'tests'
48
50
 
49
51
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.3.2`
52
+ `v0.4.0`
51
53
 
52
54
  ## Overview
53
55
 
@@ -0,0 +1,14 @@
1
+ invrs_opt/__init__.py,sha256=sKLkaXTzj4zJf3pOcXI1uiHRM15tEEr2BhSL4RWQEns,309
2
+ invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
+ invrs_opt/experimental/client.py,sha256=td5o_YqqbcSypDrWCVrHGSoF8UxEdOLtKU0z9Dth9lA,4842
6
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
7
+ invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ invrs_opt/lbfgsb/lbfgsb.py,sha256=YEBM7XcKj65QpEwO5Y8Mgmjud-h8k-1lF6UsEFgu6sM,25130
9
+ invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
10
+ invrs_opt-0.4.0.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
11
+ invrs_opt-0.4.0.dist-info/METADATA,sha256=abuCz0t6ZBbkxwL23I-Ef1m5IKMVJveOjvEl25SmfEw,3326
12
+ invrs_opt-0.4.0.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
13
+ invrs_opt-0.4.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
14
+ invrs_opt-0.4.0.dist-info/RECORD,,
@@ -1,11 +0,0 @@
1
- invrs_opt/__init__.py,sha256=JG9h93Vq0AA2jX6zcqwxjA-LXOws8CJQeEbY9VQvQWQ,309
2
- invrs_opt/base.py,sha256=dSX9QkMPzI8ROxy2cFNmMwk_89eQbk0rw94CzvLPQoY,907
3
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/lbfgsb/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
- invrs_opt/lbfgsb/lbfgsb.py,sha256=2X0GVQwMCj2beYHLiAdwoAnBIdJLPThH0Jjduz3GzHA,25080
6
- invrs_opt/lbfgsb/transform.py,sha256=a_Saj9Wq4lvqCJBrg5L2Z9DZ2NVs1xqrHLqha90a9Ws,5971
7
- invrs_opt-0.3.2.dist-info/LICENSE,sha256=ty6jHPvpyjHy6dbhnu6aDSY05bbl2jQTjnq9u1sBCfM,1078
8
- invrs_opt-0.3.2.dist-info/METADATA,sha256=pwrRENMlb_s1bNAeROw-cJ-nIVuqR9815r-jPNM3cEg,3272
9
- invrs_opt-0.3.2.dist-info/WHEEL,sha256=oiQVh_5PnQM0E3gPdiz09WCNmwiHDMaGer_elqB3coM,92
10
- invrs_opt-0.3.2.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
11
- invrs_opt-0.3.2.dist-info/RECORD,,