invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,19 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.3.2"
6
+ __version__ = "v0.10.3"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
9
+ from invrs_opt import parameterization as parameterization
10
+
11
+ from invrs_opt.optimizers.lbfgsb import (
12
+ density_lbfgsb as density_lbfgsb,
13
+ lbfgsb as lbfgsb,
14
+ levelset_lbfgsb as levelset_lbfgsb,
15
+ )
16
+
17
+ from invrs_opt.optimizers.wrapped_optax import (
18
+ density_wrapped_optax as density_wrapped_optax,
19
+ levelset_wrapped_optax as levelset_wrapped_optax,
20
+ wrapped_optax as wrapped_optax,
21
+ )
@@ -0,0 +1,155 @@
1
+ """Defines basic client optimizers for use with an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import json
7
+ import time
8
+ from typing import Any, Dict, Optional
9
+
10
+ import requests
11
+ from totypes import json_utils
12
+
13
+ from invrs_opt.experimental import labels
14
+ from invrs_opt.optimizers import base
15
+
16
+ PyTree = Any
17
+ StateToken = str
18
+
19
+ SESSION = None
20
+ SERVER_ADDRESS = None
21
+
22
+
23
+ def login(server_address: str) -> None:
24
+ """Set the global server address and create a requests session."""
25
+ global SESSION
26
+ global SERVER_ADDRESS
27
+ SESSION = requests.Session()
28
+ SERVER_ADDRESS = server_address
29
+
30
+
31
+ def optimizer_client(
32
+ algorithm: str,
33
+ hparams: Dict[str, Any],
34
+ server_address: Optional[str],
35
+ session: Optional[requests.Session],
36
+ ) -> base.Optimizer:
37
+ """Generic optimizer class."""
38
+
39
+ if server_address is None:
40
+ if SERVER_ADDRESS is None:
41
+ raise ValueError(
42
+ "Argument `server_address` and the global `SERVER_ADDRESS` cannot "
43
+ "both be `None`. Use the `login` method to set the global, or "
44
+ "explicitly provide a value."
45
+ )
46
+ if session is None:
47
+ if SESSION is None:
48
+ raise ValueError(
49
+ "Argument `session` and the global `SESSION` cannot "
50
+ "both be `None`. Use the `login` method to set the global, or "
51
+ "explicitly provide a value."
52
+ )
53
+ session = SESSION
54
+
55
+ opt_config = {
56
+ labels.ALGORITHM: algorithm,
57
+ labels.HPARAMS: hparams,
58
+ }
59
+
60
+ def init_fn(params: PyTree) -> StateToken:
61
+ """Handles 'init' requests."""
62
+ serialized_data = json_utils.json_from_pytree(
63
+ dict(opt_config=opt_config, data={"params": params})
64
+ )
65
+ post_response = session.post(
66
+ f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
67
+ )
68
+
69
+ if not post_response.status_code == 200:
70
+ raise requests.RequestException(post_response.text)
71
+ response = json.loads(post_response.text)
72
+ new_state_token: str = response[labels.STATE_TOKEN]
73
+ return new_state_token
74
+
75
+ def update_fn(
76
+ *,
77
+ grad: PyTree,
78
+ value: float,
79
+ params: PyTree,
80
+ state: StateToken,
81
+ ) -> StateToken:
82
+ """Handles 'update' requests."""
83
+ state_token = state
84
+ del state
85
+ serialized_data = json_utils.json_from_pytree(
86
+ {
87
+ labels.OPT_CONFIG: opt_config,
88
+ labels.DATA: {
89
+ labels.PARAMS: params,
90
+ labels.VALUE: value,
91
+ labels.GRAD: grad,
92
+ labels.STATE_TOKEN: state_token,
93
+ },
94
+ }
95
+ )
96
+ post_response = session.post(
97
+ f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
98
+ data=serialized_data,
99
+ )
100
+
101
+ if not post_response.status_code == 200:
102
+ raise requests.RequestException(post_response.text)
103
+ response = json.loads(post_response.text)
104
+ new_state_token: str = response[labels.STATE_TOKEN]
105
+ return new_state_token
106
+
107
+ def params_fn(
108
+ state: StateToken,
109
+ timeout: float = 60.0,
110
+ poll_interval: float = 0.1,
111
+ ) -> PyTree:
112
+ """Handles 'params' requests."""
113
+ state_token = state
114
+ del state
115
+ assert timeout >= poll_interval
116
+ start_time = time.time()
117
+ while time.time() < start_time + timeout:
118
+ get_response = session.get(
119
+ f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
120
+ )
121
+ if get_response.status_code == 200:
122
+ break
123
+ elif get_response.status_code == 404 and get_response.text.endswith(
124
+ labels.MESSAGE_STATE_NOT_READY.format(state_token)
125
+ ):
126
+ time.sleep(poll_interval)
127
+ else:
128
+ raise requests.RequestException(get_response.text)
129
+
130
+ if not get_response.status_code == 200:
131
+ raise requests.Timeout("Timed out while waiting for params.")
132
+ response = json.loads(get_response.text)
133
+ return json_utils.pytree_from_json(response[labels.PARAMS])
134
+
135
+ return base.Optimizer(
136
+ init=init_fn,
137
+ update=update_fn, # type: ignore[arg-type]
138
+ params=params_fn,
139
+ )
140
+
141
+
142
+ # -----------------------------------------------------------------------------
143
+ # Specific optimizers implemented here.
144
+ # -----------------------------------------------------------------------------
145
+
146
+
147
+ def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
148
+ """Optimizer implementing the L-BFGS-B scheme."""
149
+ hparams = {
150
+ "maxcor": maxcor,
151
+ "line_search_max_steps": line_search_max_steps,
152
+ }
153
+ return optimizer_client(
154
+ algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
155
+ )
@@ -0,0 +1,23 @@
1
+ """Defines labels and messages used in the context of an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ VALUE = "value"
7
+ GRAD = "grad"
8
+ PARAMS = "params"
9
+ STATE_TOKEN = "state_token"
10
+ MESSAGE = "message"
11
+
12
+ OPT_CONFIG = "opt_config"
13
+ ALGORITHM = "algorithm"
14
+ HPARAMS = "hparams"
15
+ DATA = "data"
16
+
17
+ ROUTE_INIT = "init"
18
+ ROUTE_UPDATE = "update"
19
+ ROUTE_PARAMS = "params"
20
+
21
+ MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
22
+ MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
23
+ MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
File without changes
@@ -4,8 +4,13 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import dataclasses
7
+ import inspect
7
8
  from typing import Any, Protocol
8
9
 
10
+ import jax.numpy as jnp
11
+ import optax # type: ignore[import-untyped]
12
+ from totypes import json_utils
13
+
9
14
  PyTree = Any
10
15
 
11
16
 
@@ -30,7 +35,7 @@ class UpdateFn(Protocol):
30
35
  self,
31
36
  *,
32
37
  grad: PyTree,
33
- value: float,
38
+ value: jnp.ndarray,
34
39
  params: PyTree,
35
40
  state: PyTree,
36
41
  ) -> PyTree:
@@ -44,3 +49,13 @@ class Optimizer:
44
49
  init: InitFn
45
50
  params: ParamsFn
46
51
  update: UpdateFn
52
+
53
+
54
+ # Register all optax state types for serialization.
55
+ optax_types = {}
56
+ for name, obj in inspect.getmembers(optax):
57
+ if name.endswith("State") and isinstance(obj, type):
58
+ optax_types[obj] = True
59
+
60
+ for obj in optax_types.keys():
61
+ json_utils.register_custom_type(obj)