invrs-opt 0.10.7__tar.gz → 0.11.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/PKG-INFO +1 -1
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/pyproject.toml +1 -1
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/__init__.py +1 -1
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/optimizers/lbfgsb.py +40 -18
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/optimizers/wrapped_optax.py +41 -16
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt.egg-info/PKG-INFO +1 -1
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/LICENSE +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/README.md +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/setup.cfg +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/experimental/__init__.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/experimental/client.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/experimental/labels.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/optimizers/__init__.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/optimizers/base.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/parameterization/__init__.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/parameterization/base.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/parameterization/filter_project.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/parameterization/gaussian_levelset.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/parameterization/pixel.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/parameterization/transforms.py +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt/py.typed +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt.egg-info/SOURCES.txt +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt.egg-info/dependency_links.txt +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt.egg-info/requires.txt +0 -0
- {invrs_opt-0.10.7 → invrs_opt-0.11.0}/src/invrs_opt.egg-info/top_level.txt +0 -0
|
@@ -16,7 +16,7 @@ from jax import flatten_util, tree_util
|
|
|
16
16
|
from scipy.optimize._lbfgsb_py import ( # type: ignore[import-untyped]
|
|
17
17
|
_lbfgsb as scipy_lbfgsb,
|
|
18
18
|
)
|
|
19
|
-
from totypes import types
|
|
19
|
+
from totypes import json_utils, types
|
|
20
20
|
|
|
21
21
|
from invrs_opt.optimizers import base
|
|
22
22
|
from invrs_opt.parameterization import (
|
|
@@ -31,7 +31,26 @@ PyTree = Any
|
|
|
31
31
|
ElementwiseBound = Union[NDArray, Sequence[Optional[float]]]
|
|
32
32
|
NumpyLbfgsbDict = Dict[str, NDArray]
|
|
33
33
|
JaxLbfgsbDict = Dict[str, jnp.ndarray]
|
|
34
|
-
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
@dataclasses.dataclass
|
|
37
|
+
class LbfgsbState:
|
|
38
|
+
"""Stores the state of the L-BFGS-B optimizer."""
|
|
39
|
+
|
|
40
|
+
step: int
|
|
41
|
+
params: PyTree
|
|
42
|
+
latent_params: PyTree
|
|
43
|
+
opt_state: JaxLbfgsbDict
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
tree_util.register_dataclass(
|
|
47
|
+
LbfgsbState,
|
|
48
|
+
data_fields=["step", "params", "latent_params", "opt_state"],
|
|
49
|
+
meta_fields=[],
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
json_utils.register_custom_type(LbfgsbState)
|
|
35
54
|
|
|
36
55
|
|
|
37
56
|
# Task message prefixes for the underlying L-BFGS-B implementation.
|
|
@@ -327,17 +346,16 @@ def parameterized_lbfgsb(
|
|
|
327
346
|
latents,
|
|
328
347
|
)
|
|
329
348
|
latent_params = param_base.combine_density_metadata(metadata, latents)
|
|
330
|
-
return (
|
|
331
|
-
0,
|
|
332
|
-
_params_from_latent_params(latent_params),
|
|
333
|
-
latent_params,
|
|
334
|
-
jax_lbfgsb_state,
|
|
349
|
+
return LbfgsbState(
|
|
350
|
+
step=0,
|
|
351
|
+
params=_params_from_latent_params(latent_params),
|
|
352
|
+
latent_params=latent_params,
|
|
353
|
+
opt_state=jax_lbfgsb_state,
|
|
335
354
|
)
|
|
336
355
|
|
|
337
356
|
def params_fn(state: LbfgsbState) -> PyTree:
|
|
338
357
|
"""Returns the parameters for the given `state`."""
|
|
339
|
-
|
|
340
|
-
return params
|
|
358
|
+
return state.params
|
|
341
359
|
|
|
342
360
|
def update_fn(
|
|
343
361
|
*,
|
|
@@ -366,8 +384,7 @@ def parameterized_lbfgsb(
|
|
|
366
384
|
flat_latent_updates = updated_flat_latent_params - flat_latent_params
|
|
367
385
|
return flat_latent_updates, scipy_lbfgsb_state.to_dict()
|
|
368
386
|
|
|
369
|
-
|
|
370
|
-
metadata, latents = param_base.partition_density_metadata(latent_params)
|
|
387
|
+
metadata, latents = param_base.partition_density_metadata(state.latent_params)
|
|
371
388
|
|
|
372
389
|
def _params_from_latents(latents: PyTree) -> PyTree:
|
|
373
390
|
latent_params = param_base.combine_density_metadata(metadata, latents)
|
|
@@ -404,23 +421,28 @@ def parameterized_lbfgsb(
|
|
|
404
421
|
latents_grad
|
|
405
422
|
) # type: ignore[no-untyped-call]
|
|
406
423
|
|
|
407
|
-
flat_latent_updates,
|
|
424
|
+
flat_latent_updates, opt_state = callback_sequential(
|
|
408
425
|
_update_pure,
|
|
409
|
-
(flat_latents_grad,
|
|
426
|
+
(flat_latents_grad, state.opt_state),
|
|
410
427
|
flat_latents_grad,
|
|
411
428
|
value,
|
|
412
|
-
|
|
429
|
+
state.opt_state,
|
|
413
430
|
)
|
|
414
431
|
latent_updates = unflatten_fn(flat_latent_updates)
|
|
415
432
|
latent_params = _apply_updates(
|
|
416
|
-
params=latent_params,
|
|
433
|
+
params=state.latent_params,
|
|
417
434
|
updates=param_base.combine_density_metadata(metadata, latent_updates),
|
|
418
435
|
value=value,
|
|
419
|
-
step=step,
|
|
436
|
+
step=state.step,
|
|
420
437
|
)
|
|
421
438
|
latent_params = _clip(latent_params)
|
|
422
439
|
params = _params_from_latent_params(latent_params)
|
|
423
|
-
return
|
|
440
|
+
return LbfgsbState(
|
|
441
|
+
step=state.step + 1,
|
|
442
|
+
params=params,
|
|
443
|
+
latent_params=latent_params,
|
|
444
|
+
opt_state=opt_state,
|
|
445
|
+
)
|
|
424
446
|
|
|
425
447
|
# -------------------------------------------------------------------------
|
|
426
448
|
# Functions related to the density parameterization.
|
|
@@ -501,7 +523,7 @@ def parameterized_lbfgsb(
|
|
|
501
523
|
|
|
502
524
|
def is_converged(state: LbfgsbState) -> jnp.ndarray:
|
|
503
525
|
"""Returns `True` if the optimization has converged."""
|
|
504
|
-
return state[
|
|
526
|
+
return state.opt_state["converged"]
|
|
505
527
|
|
|
506
528
|
|
|
507
529
|
# ------------------------------------------------------------------------------
|
|
@@ -3,13 +3,14 @@
|
|
|
3
3
|
Copyright (c) 2023 The INVRS-IO authors.
|
|
4
4
|
"""
|
|
5
5
|
|
|
6
|
-
|
|
6
|
+
import dataclasses
|
|
7
|
+
from typing import Any, Optional
|
|
7
8
|
|
|
8
9
|
import jax
|
|
9
10
|
import jax.numpy as jnp
|
|
10
11
|
import optax # type: ignore[import-untyped]
|
|
11
12
|
from jax import tree_util
|
|
12
|
-
from totypes import types
|
|
13
|
+
from totypes import json_utils, types
|
|
13
14
|
|
|
14
15
|
from invrs_opt.optimizers import base
|
|
15
16
|
from invrs_opt.parameterization import (
|
|
@@ -20,7 +21,26 @@ from invrs_opt.parameterization import (
|
|
|
20
21
|
)
|
|
21
22
|
|
|
22
23
|
PyTree = Any
|
|
23
|
-
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
@dataclasses.dataclass
|
|
27
|
+
class WrappedOptaxState:
|
|
28
|
+
"""Stores the state of a wrapped optax optimizer."""
|
|
29
|
+
|
|
30
|
+
step: int
|
|
31
|
+
params: PyTree
|
|
32
|
+
latent_params: PyTree
|
|
33
|
+
opt_state: PyTree
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
tree_util.register_dataclass(
|
|
37
|
+
WrappedOptaxState,
|
|
38
|
+
data_fields=["step", "params", "latent_params", "opt_state"],
|
|
39
|
+
meta_fields=[],
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
json_utils.register_custom_type(WrappedOptaxState)
|
|
24
44
|
|
|
25
45
|
|
|
26
46
|
def wrapped_optax(opt: optax.GradientTransformation) -> base.Optimizer:
|
|
@@ -182,17 +202,16 @@ def parameterized_wrapped_optax(
|
|
|
182
202
|
"""Initializes the optimization state."""
|
|
183
203
|
latent_params = _init_latents(params)
|
|
184
204
|
_, latents = param_base.partition_density_metadata(latent_params)
|
|
185
|
-
return (
|
|
186
|
-
0,
|
|
187
|
-
_params_from_latent_params(latent_params),
|
|
188
|
-
latent_params,
|
|
189
|
-
opt.init(latents),
|
|
205
|
+
return WrappedOptaxState(
|
|
206
|
+
step=0,
|
|
207
|
+
params=_params_from_latent_params(latent_params),
|
|
208
|
+
latent_params=latent_params,
|
|
209
|
+
opt_state=opt.init(latents),
|
|
190
210
|
)
|
|
191
211
|
|
|
192
212
|
def params_fn(state: WrappedOptaxState) -> PyTree:
|
|
193
213
|
"""Returns the parameters for the given `state`."""
|
|
194
|
-
|
|
195
|
-
return params
|
|
214
|
+
return state.params
|
|
196
215
|
|
|
197
216
|
def update_fn(
|
|
198
217
|
*,
|
|
@@ -204,8 +223,7 @@ def parameterized_wrapped_optax(
|
|
|
204
223
|
"""Updates the state."""
|
|
205
224
|
del params
|
|
206
225
|
|
|
207
|
-
|
|
208
|
-
metadata, latents = param_base.partition_density_metadata(latent_params)
|
|
226
|
+
metadata, latents = param_base.partition_density_metadata(state.latent_params)
|
|
209
227
|
|
|
210
228
|
def _params_from_latents(latents: PyTree) -> PyTree:
|
|
211
229
|
latent_params = param_base.combine_density_metadata(metadata, latents)
|
|
@@ -232,16 +250,23 @@ def parameterized_wrapped_optax(
|
|
|
232
250
|
lambda a, b: a + b, latents_grad, constraint_loss_grad
|
|
233
251
|
)
|
|
234
252
|
|
|
235
|
-
latent_updates, opt_state = opt.update(
|
|
253
|
+
latent_updates, opt_state = opt.update(
|
|
254
|
+
latents_grad, state.opt_state, params=latents
|
|
255
|
+
)
|
|
236
256
|
latent_params = _apply_updates(
|
|
237
|
-
params=latent_params,
|
|
257
|
+
params=state.latent_params,
|
|
238
258
|
updates=param_base.combine_density_metadata(metadata, latent_updates),
|
|
239
259
|
value=value,
|
|
240
|
-
step=step,
|
|
260
|
+
step=state.step,
|
|
241
261
|
)
|
|
242
262
|
latent_params = _clip(latent_params)
|
|
243
263
|
params = _params_from_latent_params(latent_params)
|
|
244
|
-
return (
|
|
264
|
+
return WrappedOptaxState(
|
|
265
|
+
step=state.step + 1,
|
|
266
|
+
params=params,
|
|
267
|
+
latent_params=latent_params,
|
|
268
|
+
opt_state=opt_state,
|
|
269
|
+
)
|
|
245
270
|
|
|
246
271
|
# -------------------------------------------------------------------------
|
|
247
272
|
# Functions related to the density parameterization.
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|