invrs-opt 0.3.2__py3-none-any.whl → 0.10.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,8 +3,19 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.3.2"
6
+ __version__ = "v0.10.3"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
- from invrs_opt.lbfgsb.lbfgsb import density_lbfgsb as density_lbfgsb
10
- from invrs_opt.lbfgsb.lbfgsb import lbfgsb as lbfgsb
9
+ from invrs_opt import parameterization as parameterization
10
+
11
+ from invrs_opt.optimizers.lbfgsb import (
12
+ density_lbfgsb as density_lbfgsb,
13
+ lbfgsb as lbfgsb,
14
+ levelset_lbfgsb as levelset_lbfgsb,
15
+ )
16
+
17
+ from invrs_opt.optimizers.wrapped_optax import (
18
+ density_wrapped_optax as density_wrapped_optax,
19
+ levelset_wrapped_optax as levelset_wrapped_optax,
20
+ wrapped_optax as wrapped_optax,
21
+ )
@@ -0,0 +1,155 @@
1
+ """Defines basic client optimizers for use with an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ import json
7
+ import time
8
+ from typing import Any, Dict, Optional
9
+
10
+ import requests
11
+ from totypes import json_utils
12
+
13
+ from invrs_opt.experimental import labels
14
+ from invrs_opt.optimizers import base
15
+
16
+ PyTree = Any
17
+ StateToken = str
18
+
19
+ SESSION = None
20
+ SERVER_ADDRESS = None
21
+
22
+
23
+ def login(server_address: str) -> None:
24
+ """Set the global server address and create a requests session."""
25
+ global SESSION
26
+ global SERVER_ADDRESS
27
+ SESSION = requests.Session()
28
+ SERVER_ADDRESS = server_address
29
+
30
+
31
+ def optimizer_client(
32
+ algorithm: str,
33
+ hparams: Dict[str, Any],
34
+ server_address: Optional[str],
35
+ session: Optional[requests.Session],
36
+ ) -> base.Optimizer:
37
+ """Generic optimizer class."""
38
+
39
+ if server_address is None:
40
+ if SERVER_ADDRESS is None:
41
+ raise ValueError(
42
+ "Argument `server_address` and the global `SERVER_ADDRESS` cannot "
43
+ "both be `None`. Use the `login` method to set the global, or "
44
+ "explicitly provide a value."
45
+ )
46
+ if session is None:
47
+ if SESSION is None:
48
+ raise ValueError(
49
+ "Argument `session` and the global `SESSION` cannot "
50
+ "both be `None`. Use the `login` method to set the global, or "
51
+ "explicitly provide a value."
52
+ )
53
+ session = SESSION
54
+
55
+ opt_config = {
56
+ labels.ALGORITHM: algorithm,
57
+ labels.HPARAMS: hparams,
58
+ }
59
+
60
+ def init_fn(params: PyTree) -> StateToken:
61
+ """Handles 'init' requests."""
62
+ serialized_data = json_utils.json_from_pytree(
63
+ dict(opt_config=opt_config, data={"params": params})
64
+ )
65
+ post_response = session.post(
66
+ f"{SERVER_ADDRESS}/{labels.ROUTE_INIT}/", data=serialized_data
67
+ )
68
+
69
+ if not post_response.status_code == 200:
70
+ raise requests.RequestException(post_response.text)
71
+ response = json.loads(post_response.text)
72
+ new_state_token: str = response[labels.STATE_TOKEN]
73
+ return new_state_token
74
+
75
+ def update_fn(
76
+ *,
77
+ grad: PyTree,
78
+ value: float,
79
+ params: PyTree,
80
+ state: StateToken,
81
+ ) -> StateToken:
82
+ """Handles 'update' requests."""
83
+ state_token = state
84
+ del state
85
+ serialized_data = json_utils.json_from_pytree(
86
+ {
87
+ labels.OPT_CONFIG: opt_config,
88
+ labels.DATA: {
89
+ labels.PARAMS: params,
90
+ labels.VALUE: value,
91
+ labels.GRAD: grad,
92
+ labels.STATE_TOKEN: state_token,
93
+ },
94
+ }
95
+ )
96
+ post_response = session.post(
97
+ f"{SERVER_ADDRESS}/{labels.ROUTE_UPDATE}/{state_token}/",
98
+ data=serialized_data,
99
+ )
100
+
101
+ if not post_response.status_code == 200:
102
+ raise requests.RequestException(post_response.text)
103
+ response = json.loads(post_response.text)
104
+ new_state_token: str = response[labels.STATE_TOKEN]
105
+ return new_state_token
106
+
107
+ def params_fn(
108
+ state: StateToken,
109
+ timeout: float = 60.0,
110
+ poll_interval: float = 0.1,
111
+ ) -> PyTree:
112
+ """Handles 'params' requests."""
113
+ state_token = state
114
+ del state
115
+ assert timeout >= poll_interval
116
+ start_time = time.time()
117
+ while time.time() < start_time + timeout:
118
+ get_response = session.get(
119
+ f"{SERVER_ADDRESS}/{labels.ROUTE_PARAMS}/{state_token}"
120
+ )
121
+ if get_response.status_code == 200:
122
+ break
123
+ elif get_response.status_code == 404 and get_response.text.endswith(
124
+ labels.MESSAGE_STATE_NOT_READY.format(state_token)
125
+ ):
126
+ time.sleep(poll_interval)
127
+ else:
128
+ raise requests.RequestException(get_response.text)
129
+
130
+ if not get_response.status_code == 200:
131
+ raise requests.Timeout("Timed out while waiting for params.")
132
+ response = json.loads(get_response.text)
133
+ return json_utils.pytree_from_json(response[labels.PARAMS])
134
+
135
+ return base.Optimizer(
136
+ init=init_fn,
137
+ update=update_fn, # type: ignore[arg-type]
138
+ params=params_fn,
139
+ )
140
+
141
+
142
+ # -----------------------------------------------------------------------------
143
+ # Specific optimizers implemented here.
144
+ # -----------------------------------------------------------------------------
145
+
146
+
147
+ def lbfgsb(maxcor: int = 20, line_search_max_steps: int = 100) -> base.Optimizer:
148
+ """Optimizer implementing the L-BFGS-B scheme."""
149
+ hparams = {
150
+ "maxcor": maxcor,
151
+ "line_search_max_steps": line_search_max_steps,
152
+ }
153
+ return optimizer_client(
154
+ algorithm="lbfgsb", hparams=hparams, server_address=None, session=None
155
+ )
@@ -0,0 +1,23 @@
1
+ """Defines labels and messages used in the context of an optimization service.
2
+
3
+ Copyright (c) 2023 The INVRS-IO authors.
4
+ """
5
+
6
+ VALUE = "value"
7
+ GRAD = "grad"
8
+ PARAMS = "params"
9
+ STATE_TOKEN = "state_token"
10
+ MESSAGE = "message"
11
+
12
+ OPT_CONFIG = "opt_config"
13
+ ALGORITHM = "algorithm"
14
+ HPARAMS = "hparams"
15
+ DATA = "data"
16
+
17
+ ROUTE_INIT = "init"
18
+ ROUTE_UPDATE = "update"
19
+ ROUTE_PARAMS = "params"
20
+
21
+ MESSAGE_STATE_NOT_KNOWN = "State token {} was not recognized."
22
+ MESSAGE_STATE_NOT_READY = "State for token {} is not ready."
23
+ MESSAGE_STATE_NOT_VALID = "State for token {} is not valid."
File without changes
@@ -4,8 +4,13 @@ Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
6
  import dataclasses
7
+ import inspect
7
8
  from typing import Any, Protocol
8
9
 
10
+ import jax.numpy as jnp
11
+ import optax # type: ignore[import-untyped]
12
+ from totypes import json_utils
13
+
9
14
  PyTree = Any
10
15
 
11
16
 
@@ -30,7 +35,7 @@ class UpdateFn(Protocol):
30
35
  self,
31
36
  *,
32
37
  grad: PyTree,
33
- value: float,
38
+ value: jnp.ndarray,
34
39
  params: PyTree,
35
40
  state: PyTree,
36
41
  ) -> PyTree:
@@ -44,3 +49,13 @@ class Optimizer:
44
49
  init: InitFn
45
50
  params: ParamsFn
46
51
  update: UpdateFn
52
+
53
+
54
+ # Register all optax state types for serialization.
55
+ optax_types = {}
56
+ for name, obj in inspect.getmembers(optax):
57
+ if name.endswith("State") and isinstance(obj, type):
58
+ optax_types[obj] = True
59
+
60
+ for obj in optax_types.keys():
61
+ json_utils.register_custom_type(obj)