invrs-opt 0.3.2__tar.gz → 0.4.0__tar.gz

Sign up to get free protection for your applications and to get access to all the features.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -33,8 +33,10 @@ License-File: LICENSE
33
33
  Requires-Dist: jax
34
34
  Requires-Dist: jaxlib
35
35
  Requires-Dist: numpy
36
+ Requires-Dist: requests
36
37
  Requires-Dist: scipy
37
38
  Requires-Dist: totypes
39
+ Requires-Dist: types-requests
38
40
  Provides-Extra: tests
39
41
  Requires-Dist: parameterized; extra == "tests"
40
42
  Requires-Dist: pytest; extra == "tests"
@@ -47,7 +49,7 @@ Requires-Dist: mypy; extra == "dev"
47
49
  Requires-Dist: pre-commit; extra == "dev"
48
50
 
49
51
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.3.2`
52
+ `v0.4.0`
51
53
 
52
54
  ## Overview
53
55
 
@@ -1,5 +1,5 @@
1
1
  # invrs-opt - Optimization algorithms for inverse design
2
- `v0.3.2`
2
+ `v0.4.0`
3
3
 
4
4
  ## Overview
5
5
 
@@ -1,7 +1,7 @@
1
1
  [project]
2
2
 
3
3
  name = "invrs_opt"
4
- version = "v0.3.2"
4
+ version = "v0.4.0"
5
5
  description = "Algorithms for inverse design"
6
6
  keywords = ["topology", "optimization", "jax", "inverse design"]
7
7
  readme = "README.md"
@@ -19,8 +19,10 @@ dependencies = [
19
19
  "jax",
20
20
  "jaxlib",
21
21
  "numpy",
22
+ "requests",
22
23
  "scipy",
23
24
  "totypes",
25
+ "types-requests",
24
26
  ]
25
27
 
26
28
  [project.optional-dependencies]
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.3.2"
6
+ __version__ = "v0.4.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
@@ -0,0 +1,152 @@
1
+ """Defines basic client optimizers for use with an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import json
7
+ import requests
8
+ import time
9
+ from typing import Any, Dict, Optional
10
+
11
+ from totypes import json_utils
12
+
13
+ from invrs_opt import base
14
+ from invrs_opt.experimental import labels
15
+
16
+
17
+ PyTree = Any
18
+ StateToken = str
19
+
20
+ SESSION = None
21
+ SERVER_ADDRESS = None
22
+
23
+
24
+ def login(server_address: str) -> None:
25
+ """Set the global server address and create a requests session."""
26
+ global SESSION
27
+ global SERVER_ADDRESS
28
+ SESSION = requests.Session()
29
+ SERVER_ADDRESS = server_address
30
+
31
+
32
+ def optimizer_client(
33
+ algorithm: str,
34
+ hparams: Dict[str, Any],
35
+ server_address: Optional[str],
36
+ session: Optional[requests.Session],
37
+ ) -> base.Optimizer:
38
+ """Generic optimizer class."""
39
+
40
+ if server_address is None:
41
+ if SERVER_ADDRESS is None:
42
+ raise ValueError(
43
+ "Argument `server_address` and the global `SERVER_ADDRESS` cannot "
44
+ "both be `None`. Use the `login` method to set the global, or "
45
+ "explicitly provide a value."
46
+ )
47
+ if session is None:
48
+ if SESSION is None:
49
+ raise ValueError(
50
+ "Argument `session` and the global `SESSION` cannot "
51
+ "both be `None`. Use the `login` method to set the global, or "
52
+ "explicitly provide a value."
53
+ )
54
+ session = SESSION
55
+
56
+ opt_config = {
57
+ labels.ALGORITHM: algorithm,
58
+ labels.HPARAMS: hparams,
59
+ }
60
+
61
+ def init_fn(params: PyTree) -> StateToken:
62
+ """Handles 'init' requests."""
63
+ serialized_data = json_utils.json_from_pytree(
64
+ dict(opt_config=opt_config, data={"params": params})
65
+ )
66
+ post_response = session.post(
67
+ f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
68
+ )
69
+
70
+ if not post_response.status_code == 200:
71
+ raise requests.RequestException(post_response.text)
72
+ response = json.loads(post_response.text)
73
+ new_state_token: str = response[labels.STATE_TOKEN]
74
+ return new_state_token
75
+
76
+ def update_fn(
77
+ *,
78
+ grad: PyTree,
79
+ value: float,
80
+ params: PyTree,
81
+ state: StateToken,
82
+ ) -> StateToken:
83
+ """Handles 'update' requests."""
84
+ state_token = state
85
+ del state
86
+ serialized_data = json_utils.json_from_pytree(
87
+ {
88
+ labels.OPT_CONFIG: opt_config,
89
+ labels.DATA: {
90
+ labels.PARAMS: params,
91
+ labels.VALUE: value,
92
+ labels.GRAD: grad,
93
+ labels.STATE_TOKEN: state_token,
94
+ },
95
+ }
96
+ )
97
+ post_response = session.post(
98
+ f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
99
+ data=serialized_data,
100
+ )
101
+
102
+ if not post_response.status_code == 200:
103
+ raise requests.RequestException(post_response.text)
104
+ response = json.loads(post_response.text)
105
+ new_state_token: str = response[labels.STATE_TOKEN]
106
+ return new_state_token
107
+
108
+ def params_fn(
109
+ state: StateToken,
110
+ timeout: float = 60.0,
111
+ poll_interval: float = 0.1,
112
+ ) -> PyTree:
113
+ """Handles 'params' requests."""
114
+ state_token = state
115
+ del state
116
+ assert timeout >= poll_interval
117
+ start_time = time.time()
118
+ while time.time() < start_time + timeout:
119
+ get_response = session.get(
120
+ f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
121
+ )
122
+ if get_response.status_code == 200:
123
+ break
124
+ elif get_response.status_code == 404 and get_response.text.endswith(
125
+ labels.MESSAGE_STATE_NOT_READY.format(state_token)
126
+ ):
127
+ time.sleep(poll_interval)
128
+ else:
129
+ raise requests.RequestException(get_response.text)
130
+
131
+ if not get_response.status_code == 200:
132
+ raise requests.Timeout("Timed out while waiting for params.")
133
+ response = json.loads(get_response.text)
134
+ return json_utils.pytree_from_json(response[labels.PARAMS])
135
+
136
+ return base.Optimizer(init=init_fn, update=update_fn, params=params_fn)
137
+
138
+
139
+ # -----------------------------------------------------------------------------
140
+ # Specific optimizers implemented here.
141
+ # -----------------------------------------------------------------------------
142
+
143
+
144
+ def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
145
+ """Optimizer implementing the L-BFGS-B scheme."""
146
+ hparams = {
147
+ "maxcor": maxcor,
148
+ "line_search_max_steps": line_search_max_steps,
149
+ }
150
+ return optimizer_client(
151
+ algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
152
+ )
@@ -0,0 +1,23 @@
1
+ """Defines labels and messages used in the context of an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ VALUE = "value"
7
+ GRAD = "grad"
8
+ PARAMS = "params"
9
+ STATE_TOKEN = "state_token"
10
+ MESSAGE = "message"
11
+
12
+ OPT_CONFIG = "opt_config"
13
+ ALGORITHM = "algorithm"
14
+ HPARAMS = "hparams"
15
+ DATA = "data"
16
+
17
+ ROUTE_INIT = "init"
18
+ ROUTE_UPDATE = "update"
19
+ ROUTE_PARAMS = "params"
20
+
21
+ MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
22
+ MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
23
+ MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
@@ -261,10 +261,12 @@ def transformed_lbfgsb(
261
261
  flat_latent_params = jnp.asarray(scipy_lbfgsb_state.x)
262
262
  return flat_latent_params, scipy_lbfgsb_state.to_jax()
263
263
 
264
- params, latent_params, jax_lbfgsb_state = state
264
+ _, latent_params, jax_lbfgsb_state = state
265
265
  _, vjp_fn = jax.vjp(transform_fn, latent_params)
266
266
  (latent_grad,) = vjp_fn(grad)
267
- flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(latent_grad)
267
+ flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
268
+ latent_grad
269
+ ) # type: ignore[no-untyped-call]
268
270
 
269
271
  (
270
272
  flat_latent_params,
File without changes
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.3.2
3
+ Version: 0.4.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -33,8 +33,10 @@ License-File: LICENSE
33
33
  Requires-Dist: jax
34
34
  Requires-Dist: jaxlib
35
35
  Requires-Dist: numpy
36
+ Requires-Dist: requests
36
37
  Requires-Dist: scipy
37
38
  Requires-Dist: totypes
39
+ Requires-Dist: types-requests
38
40
  Provides-Extra: tests
39
41
  Requires-Dist: parameterized; extra == "tests"
40
42
  Requires-Dist: pytest; extra == "tests"
@@ -47,7 +49,7 @@ Requires-Dist: mypy; extra == "dev"
47
49
  Requires-Dist: pre-commit; extra == "dev"
48
50
 
49
51
  # invrs-opt - Optimization algorithms for inverse design
50
- `v0.3.2`
52
+ `v0.4.0`
51
53
 
52
54
  ## Overview
53
55
 
@@ -9,6 +9,9 @@ src/invrs_opt.egg-info/SOURCES.txt
9
9
  src/invrs_opt.egg-info/dependency_links.txt
10
10
  src/invrs_opt.egg-info/requires.txt
11
11
  src/invrs_opt.egg-info/top_level.txt
12
+ src/invrs_opt/experimental/__init__.py
13
+ src/invrs_opt/experimental/client.py
14
+ src/invrs_opt/experimental/labels.py
12
15
  src/invrs_opt/lbfgsb/__init__.py
13
16
  src/invrs_opt/lbfgsb/lbfgsb.py
14
17
  src/invrs_opt/lbfgsb/transform.py
@@ -1,8 +1,10 @@
1
1
  jax
2
2
  jaxlib
3
3
  numpy
4
+ requests
4
5
  scipy
5
6
  totypes
7
+ types-requests
6
8
 
7
9
  [dev]
8
10
  bump-my-version
File without changes
File without changes
File without changes