invrs-opt 0.8.1__py3-none-any.whl → 0.9.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- invrs_opt/__init__.py +1 -1
- invrs_opt/optimizers/lbfgsb.py +103 -77
- invrs_opt/optimizers/wrapped_optax.py +32 -32
- invrs_opt/parameterization/base.py +62 -16
- invrs_opt/parameterization/filter_project.py +48 -18
- invrs_opt/parameterization/gaussian_levelset.py +88 -105
- invrs_opt/parameterization/pixel.py +20 -6
- {invrs_opt-0.8.1.dist-info → invrs_opt-0.9.1.dist-info}/METADATA +2 -2
- invrs_opt-0.9.1.dist-info/RECORD +20 -0
- invrs_opt-0.8.1.dist-info/RECORD +0 -20
- {invrs_opt-0.8.1.dist-info → invrs_opt-0.9.1.dist-info}/LICENSE +0 -0
- {invrs_opt-0.8.1.dist-info → invrs_opt-0.9.1.dist-info}/WHEEL +0 -0
- {invrs_opt-0.8.1.dist-info → invrs_opt-0.9.1.dist-info}/top_level.txt +0 -0
invrs_opt/__init__.py
CHANGED
invrs_opt/optimizers/lbfgsb.py
CHANGED
@@ -19,7 +19,7 @@ from totypes import types
|
|
19
19
|
|
20
20
|
from invrs_opt.optimizers import base
|
21
21
|
from invrs_opt.parameterization import (
|
22
|
-
base as
|
22
|
+
base as param_base,
|
23
23
|
filter_project,
|
24
24
|
gaussian_levelset,
|
25
25
|
pixel,
|
@@ -252,7 +252,7 @@ def levelset_lbfgsb(
|
|
252
252
|
|
253
253
|
|
254
254
|
def parameterized_lbfgsb(
|
255
|
-
density_parameterization: Optional[
|
255
|
+
density_parameterization: Optional[param_base.Density2DParameterization],
|
256
256
|
penalty: float,
|
257
257
|
maxcor: int = DEFAULT_MAXCOR,
|
258
258
|
line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
|
@@ -296,59 +296,6 @@ def parameterized_lbfgsb(
|
|
296
296
|
if density_parameterization is None:
|
297
297
|
density_parameterization = pixel.pixel()
|
298
298
|
|
299
|
-
def _init_latents(params: PyTree) -> PyTree:
|
300
|
-
def _leaf_init_latents(leaf: Any) -> Any:
|
301
|
-
leaf = _clip(leaf)
|
302
|
-
if not _is_density(leaf) or density_parameterization is None:
|
303
|
-
return leaf
|
304
|
-
return density_parameterization.from_density(leaf)
|
305
|
-
|
306
|
-
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
307
|
-
|
308
|
-
def _params_from_latents(latent_params: PyTree) -> PyTree:
|
309
|
-
def _leaf_params_from_latents(leaf: Any) -> Any:
|
310
|
-
if not _is_parameterized_density(leaf) or density_parameterization is None:
|
311
|
-
return leaf
|
312
|
-
return density_parameterization.to_density(leaf)
|
313
|
-
|
314
|
-
return tree_util.tree_map(
|
315
|
-
_leaf_params_from_latents,
|
316
|
-
latent_params,
|
317
|
-
is_leaf=_is_parameterized_density,
|
318
|
-
)
|
319
|
-
|
320
|
-
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
321
|
-
def _constraint_loss_leaf(
|
322
|
-
params: parameterization_base.ParameterizedDensity2DArrayBase,
|
323
|
-
) -> jnp.ndarray:
|
324
|
-
constraints = density_parameterization.constraints(params)
|
325
|
-
constraints = tree_util.tree_map(
|
326
|
-
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
327
|
-
constraints,
|
328
|
-
)
|
329
|
-
return jnp.sum(jnp.asarray(constraints))
|
330
|
-
|
331
|
-
losses = [0.0] + [
|
332
|
-
_constraint_loss_leaf(p)
|
333
|
-
for p in tree_util.tree_leaves(
|
334
|
-
latent_params, is_leaf=_is_parameterized_density
|
335
|
-
)
|
336
|
-
if _is_parameterized_density(p)
|
337
|
-
]
|
338
|
-
return penalty * jnp.sum(jnp.asarray(losses))
|
339
|
-
|
340
|
-
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
341
|
-
def _update_leaf(leaf: Any) -> Any:
|
342
|
-
if not _is_parameterized_density(leaf):
|
343
|
-
return leaf
|
344
|
-
return density_parameterization.update(leaf, step)
|
345
|
-
|
346
|
-
return tree_util.tree_map(
|
347
|
-
_update_leaf,
|
348
|
-
latent_params,
|
349
|
-
is_leaf=_is_parameterized_density,
|
350
|
-
)
|
351
|
-
|
352
299
|
def init_fn(params: PyTree) -> LbfgsbState:
|
353
300
|
"""Initializes the optimization state."""
|
354
301
|
|
@@ -368,10 +315,17 @@ def parameterized_lbfgsb(
|
|
368
315
|
return latent_params, scipy_lbfgsb_state.to_jax()
|
369
316
|
|
370
317
|
latent_params = _init_latents(params)
|
371
|
-
|
372
|
-
|
318
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
319
|
+
latents, jax_lbfgsb_state = jax.pure_callback(
|
320
|
+
_init_state_pure, _example_state(latents, maxcor), latents
|
321
|
+
)
|
322
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
323
|
+
return (
|
324
|
+
0, # step
|
325
|
+
_params_from_latent_params(latent_params), # params
|
326
|
+
latent_params, # latent params
|
327
|
+
jax_lbfgsb_state, # opt state
|
373
328
|
)
|
374
|
-
return 0, _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
|
375
329
|
|
376
330
|
def params_fn(state: LbfgsbState) -> PyTree:
|
377
331
|
"""Returns the parameters for the given `state`."""
|
@@ -403,46 +357,118 @@ def parameterized_lbfgsb(
|
|
403
357
|
return flat_latent_params, scipy_lbfgsb_state.to_jax()
|
404
358
|
|
405
359
|
step, _, latent_params, jax_lbfgsb_state = state
|
406
|
-
|
407
|
-
|
360
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
361
|
+
|
362
|
+
def _params_from_latents(latents: PyTree) -> PyTree:
|
363
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
364
|
+
return _params_from_latent_params(latent_params)
|
365
|
+
|
366
|
+
def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
|
367
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
368
|
+
return _constraint_loss(latent_params)
|
369
|
+
|
370
|
+
_, vjp_fn = jax.vjp(_params_from_latents, latents)
|
371
|
+
(latents_grad,) = vjp_fn(grad)
|
408
372
|
|
409
373
|
if not (
|
410
|
-
tree_util.tree_structure(
|
411
|
-
== tree_util.tree_structure(
|
374
|
+
tree_util.tree_structure(latents_grad)
|
375
|
+
== tree_util.tree_structure(latents) # type: ignore[operator]
|
412
376
|
):
|
413
377
|
raise ValueError(
|
414
|
-
f"Tree structure of `
|
415
|
-
f"{tree_util.tree_structure(
|
416
|
-
f"{tree_util.tree_structure(
|
378
|
+
f"Tree structure of `latents_grad` was different than expected, got \n"
|
379
|
+
f"{tree_util.tree_structure(latents_grad)} but expected \n"
|
380
|
+
f"{tree_util.tree_structure(latents)}."
|
417
381
|
)
|
418
382
|
|
419
383
|
(
|
420
384
|
constraint_loss_value,
|
421
385
|
constraint_loss_grad,
|
422
386
|
) = jax.value_and_grad(
|
423
|
-
|
424
|
-
)(
|
387
|
+
_constraint_loss_latents
|
388
|
+
)(latents)
|
425
389
|
value += constraint_loss_value
|
426
|
-
|
427
|
-
lambda a, b: a + b,
|
390
|
+
latents_grad = tree_util.tree_map(
|
391
|
+
lambda a, b: a + b, latents_grad, constraint_loss_grad
|
428
392
|
)
|
429
393
|
|
430
|
-
|
431
|
-
|
394
|
+
flat_latents_grad, unflatten_fn = flatten_util.ravel_pytree(
|
395
|
+
latents_grad
|
432
396
|
) # type: ignore[no-untyped-call]
|
433
397
|
|
434
|
-
|
398
|
+
flat_latents, jax_lbfgsb_state = jax.pure_callback(
|
435
399
|
_update_pure,
|
436
|
-
(
|
437
|
-
|
400
|
+
(flat_latents_grad, jax_lbfgsb_state),
|
401
|
+
flat_latents_grad,
|
438
402
|
value,
|
439
403
|
jax_lbfgsb_state,
|
440
404
|
)
|
441
|
-
|
405
|
+
latents = unflatten_fn(flat_latents)
|
406
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
442
407
|
latent_params = _update_parameterized_densities(latent_params, step)
|
443
|
-
params =
|
408
|
+
params = _params_from_latent_params(latent_params)
|
444
409
|
return step + 1, params, latent_params, jax_lbfgsb_state
|
445
410
|
|
411
|
+
# -------------------------------------------------------------------------
|
412
|
+
# Functions related to the density parameterization.
|
413
|
+
# -------------------------------------------------------------------------
|
414
|
+
|
415
|
+
def _init_latents(params: PyTree) -> PyTree:
|
416
|
+
def _leaf_init_latents(leaf: Any) -> Any:
|
417
|
+
leaf = _clip(leaf)
|
418
|
+
if not _is_density(leaf) or density_parameterization is None:
|
419
|
+
return leaf
|
420
|
+
return density_parameterization.from_density(leaf)
|
421
|
+
|
422
|
+
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
423
|
+
|
424
|
+
def _params_from_latent_params(latent_params: PyTree) -> PyTree:
|
425
|
+
def _leaf_params_from_latents(leaf: Any) -> Any:
|
426
|
+
if not _is_parameterized_density(leaf) or density_parameterization is None:
|
427
|
+
return leaf
|
428
|
+
return density_parameterization.to_density(leaf)
|
429
|
+
|
430
|
+
return tree_util.tree_map(
|
431
|
+
_leaf_params_from_latents,
|
432
|
+
latent_params,
|
433
|
+
is_leaf=_is_parameterized_density,
|
434
|
+
)
|
435
|
+
|
436
|
+
def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
|
437
|
+
def _update_leaf(leaf: Any) -> Any:
|
438
|
+
if not _is_parameterized_density(leaf):
|
439
|
+
return leaf
|
440
|
+
return density_parameterization.update(leaf, step)
|
441
|
+
|
442
|
+
return tree_util.tree_map(
|
443
|
+
_update_leaf,
|
444
|
+
latent_params,
|
445
|
+
is_leaf=_is_parameterized_density,
|
446
|
+
)
|
447
|
+
|
448
|
+
# -------------------------------------------------------------------------
|
449
|
+
# Functions related to the constraints to be minimized.
|
450
|
+
# -------------------------------------------------------------------------
|
451
|
+
|
452
|
+
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
453
|
+
def _constraint_loss_leaf(
|
454
|
+
leaf: param_base.ParameterizedDensity2DArray,
|
455
|
+
) -> jnp.ndarray:
|
456
|
+
constraints = density_parameterization.constraints(leaf)
|
457
|
+
constraints = tree_util.tree_map(
|
458
|
+
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
459
|
+
constraints,
|
460
|
+
)
|
461
|
+
return jnp.sum(jnp.asarray(constraints))
|
462
|
+
|
463
|
+
losses = [0.0] + [
|
464
|
+
_constraint_loss_leaf(p)
|
465
|
+
for p in tree_util.tree_leaves(
|
466
|
+
latent_params, is_leaf=_is_parameterized_density
|
467
|
+
)
|
468
|
+
if _is_parameterized_density(p)
|
469
|
+
]
|
470
|
+
return penalty * jnp.sum(jnp.asarray(losses))
|
471
|
+
|
446
472
|
return base.Optimizer(
|
447
473
|
init=init_fn,
|
448
474
|
params=params_fn,
|
@@ -467,7 +493,7 @@ def _is_density(leaf: Any) -> Any:
|
|
467
493
|
|
468
494
|
def _is_parameterized_density(leaf: Any) -> Any:
|
469
495
|
"""Return `True` if `leaf` is a parameterized density array."""
|
470
|
-
return isinstance(leaf,
|
496
|
+
return isinstance(leaf, param_base.ParameterizedDensity2DArray)
|
471
497
|
|
472
498
|
|
473
499
|
def _is_custom_type(leaf: Any) -> bool:
|
@@ -13,7 +13,7 @@ from totypes import types
|
|
13
13
|
|
14
14
|
from invrs_opt.optimizers import base
|
15
15
|
from invrs_opt.parameterization import (
|
16
|
-
base as
|
16
|
+
base as param_base,
|
17
17
|
filter_project,
|
18
18
|
gaussian_levelset,
|
19
19
|
pixel,
|
@@ -158,7 +158,7 @@ def levelset_wrapped_optax(
|
|
158
158
|
|
159
159
|
def parameterized_wrapped_optax(
|
160
160
|
opt: optax.GradientTransformation,
|
161
|
-
density_parameterization: Optional[
|
161
|
+
density_parameterization: Optional[param_base.Density2DParameterization],
|
162
162
|
penalty: float,
|
163
163
|
) -> base.Optimizer:
|
164
164
|
"""Wrapped optax optimizer with specified density parameterization.
|
@@ -181,11 +181,12 @@ def parameterized_wrapped_optax(
|
|
181
181
|
def init_fn(params: PyTree) -> WrappedOptaxState:
|
182
182
|
"""Initializes the optimization state."""
|
183
183
|
latent_params = _init_latents(params)
|
184
|
+
_, latents = param_base.partition_density_metadata(latent_params)
|
184
185
|
return (
|
185
186
|
0, # step
|
186
|
-
|
187
|
+
_params_from_latent_params(latent_params), # params
|
187
188
|
latent_params, # latent params
|
188
|
-
opt.init(
|
189
|
+
opt.init(latents), # opt state
|
189
190
|
)
|
190
191
|
|
191
192
|
def params_fn(state: WrappedOptaxState) -> PyTree:
|
@@ -204,42 +205,41 @@ def parameterized_wrapped_optax(
|
|
204
205
|
del value, params
|
205
206
|
|
206
207
|
step, params, latent_params, opt_state = state
|
208
|
+
metadata, latents = param_base.partition_density_metadata(latent_params)
|
207
209
|
|
208
|
-
|
209
|
-
|
210
|
+
def _params_from_latents(latents: PyTree) -> PyTree:
|
211
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
212
|
+
return _params_from_latent_params(latent_params)
|
213
|
+
|
214
|
+
def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
|
215
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
216
|
+
return _constraint_loss(latent_params)
|
217
|
+
|
218
|
+
_, vjp_fn = jax.vjp(_params_from_latents, latents)
|
219
|
+
(latents_grad,) = vjp_fn(grad)
|
210
220
|
|
211
221
|
if not (
|
212
|
-
tree_util.tree_structure(
|
213
|
-
== tree_util.tree_structure(
|
222
|
+
tree_util.tree_structure(latents_grad)
|
223
|
+
== tree_util.tree_structure(latents) # type: ignore[operator]
|
214
224
|
):
|
215
225
|
raise ValueError(
|
216
|
-
f"Tree structure of `
|
217
|
-
f"{tree_util.tree_structure(
|
218
|
-
f"{tree_util.tree_structure(
|
226
|
+
f"Tree structure of `latents_grad` was different than expected, got \n"
|
227
|
+
f"{tree_util.tree_structure(latents_grad)} but expected \n"
|
228
|
+
f"{tree_util.tree_structure(latents)}."
|
219
229
|
)
|
220
230
|
|
221
|
-
constraint_loss_grad = jax.grad(
|
222
|
-
|
223
|
-
lambda a, b: a + b,
|
231
|
+
constraint_loss_grad = jax.grad(_constraint_loss_latents)(latents)
|
232
|
+
latents_grad = tree_util.tree_map(
|
233
|
+
lambda a, b: a + b, latents_grad, constraint_loss_grad
|
224
234
|
)
|
225
235
|
|
226
|
-
|
227
|
-
|
228
|
-
state=opt_state,
|
229
|
-
params=tree_util.tree_leaves(latent_params),
|
230
|
-
)
|
231
|
-
latent_params_leaves = optax.apply_updates(
|
232
|
-
params=tree_util.tree_leaves(latent_params),
|
233
|
-
updates=updates_leaves,
|
234
|
-
)
|
235
|
-
latent_params = tree_util.tree_unflatten(
|
236
|
-
treedef=tree_util.tree_structure(latent_params),
|
237
|
-
leaves=latent_params_leaves,
|
238
|
-
)
|
236
|
+
updates, opt_state = opt.update(latents_grad, state=opt_state, params=latents)
|
237
|
+
latents = optax.apply_updates(params=latents, updates=updates)
|
239
238
|
|
239
|
+
latent_params = param_base.combine_density_metadata(metadata, latents)
|
240
240
|
latent_params = _clip(latent_params)
|
241
241
|
latent_params = _update_parameterized_densities(latent_params, step + 1)
|
242
|
-
params =
|
242
|
+
params = _params_from_latent_params(latent_params)
|
243
243
|
return (step + 1, params, latent_params, opt_state)
|
244
244
|
|
245
245
|
# -------------------------------------------------------------------------
|
@@ -255,7 +255,7 @@ def parameterized_wrapped_optax(
|
|
255
255
|
|
256
256
|
return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
|
257
257
|
|
258
|
-
def
|
258
|
+
def _params_from_latent_params(params: PyTree) -> PyTree:
|
259
259
|
def _leaf_params_from_latents(leaf: Any) -> Any:
|
260
260
|
if not _is_parameterized_density(leaf):
|
261
261
|
return leaf
|
@@ -285,9 +285,9 @@ def parameterized_wrapped_optax(
|
|
285
285
|
|
286
286
|
def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
|
287
287
|
def _constraint_loss_leaf(
|
288
|
-
|
288
|
+
leaf: param_base.ParameterizedDensity2DArray,
|
289
289
|
) -> jnp.ndarray:
|
290
|
-
constraints = density_parameterization.constraints(
|
290
|
+
constraints = density_parameterization.constraints(leaf)
|
291
291
|
constraints = tree_util.tree_map(
|
292
292
|
lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
|
293
293
|
constraints,
|
@@ -313,7 +313,7 @@ def _is_density(leaf: Any) -> Any:
|
|
313
313
|
|
314
314
|
def _is_parameterized_density(leaf: Any) -> Any:
|
315
315
|
"""Return `True` if `leaf` is a parameterized density array."""
|
316
|
-
return isinstance(leaf,
|
316
|
+
return isinstance(leaf, param_base.ParameterizedDensity2DArray)
|
317
317
|
|
318
318
|
|
319
319
|
def _is_custom_type(leaf: Any) -> bool:
|
@@ -9,24 +9,74 @@ from typing import Any, Optional, Protocol, Sequence, Tuple
|
|
9
9
|
import jax.numpy as jnp
|
10
10
|
import numpy as onp
|
11
11
|
from jax import tree_util
|
12
|
-
from totypes import json_utils, types
|
12
|
+
from totypes import json_utils, partition_utils, types
|
13
13
|
|
14
14
|
Array = jnp.ndarray | onp.ndarray[Any, Any]
|
15
15
|
PyTree = Any
|
16
16
|
|
17
17
|
|
18
|
-
|
19
|
-
|
18
|
+
@dataclasses.dataclass
|
19
|
+
class ParameterizedDensity2DArray:
|
20
|
+
"""Stores latents and metadata for a parameterized density array."""
|
21
|
+
|
22
|
+
latents: "LatentsBase"
|
23
|
+
metadata: Optional["MetadataBase"]
|
24
|
+
|
25
|
+
|
26
|
+
class LatentsBase:
|
27
|
+
"""Base class for latents of a parameterized density array."""
|
28
|
+
|
29
|
+
pass
|
30
|
+
|
31
|
+
|
32
|
+
class MetadataBase:
|
33
|
+
"""Base class for metadata of a parameterized density array."""
|
20
34
|
|
21
35
|
pass
|
22
36
|
|
23
37
|
|
38
|
+
tree_util.register_dataclass(
|
39
|
+
ParameterizedDensity2DArray,
|
40
|
+
data_fields=["latents", "metadata"],
|
41
|
+
meta_fields=[],
|
42
|
+
)
|
43
|
+
json_utils.register_custom_type(ParameterizedDensity2DArray)
|
44
|
+
|
45
|
+
|
46
|
+
def partition_density_metadata(tree: PyTree) -> Tuple[PyTree, PyTree]:
|
47
|
+
"""Splits a pytree with parameterized densities into metadata from latents."""
|
48
|
+
metadata, latents = partition_utils.partition(
|
49
|
+
tree,
|
50
|
+
select_fn=lambda x: isinstance(x, MetadataBase),
|
51
|
+
is_leaf=_is_metadata_or_none,
|
52
|
+
)
|
53
|
+
return metadata, latents
|
54
|
+
|
55
|
+
|
56
|
+
def combine_density_metadata(metadata: PyTree, latents: PyTree) -> PyTree:
|
57
|
+
"""Combines pytrees containing metadata and latents."""
|
58
|
+
return partition_utils.combine(metadata, latents, is_leaf=_is_metadata_or_none)
|
59
|
+
|
60
|
+
|
61
|
+
def _is_metadata_or_none(leaf: Any) -> bool:
|
62
|
+
"""Return `True` if `leaf` is `None` or density metadata."""
|
63
|
+
return leaf is None or isinstance(leaf, MetadataBase)
|
64
|
+
|
65
|
+
|
66
|
+
@dataclasses.dataclass
|
67
|
+
class Density2DParameterization:
|
68
|
+
"""Stores `(from_density, to_density, constraints, update)` function triple."""
|
69
|
+
|
70
|
+
from_density: "FromDensityFn"
|
71
|
+
to_density: "ToDensityFn"
|
72
|
+
constraints: "ConstraintsFn"
|
73
|
+
update: "UpdateFn"
|
74
|
+
|
75
|
+
|
24
76
|
class FromDensityFn(Protocol):
|
25
77
|
"""Generate the latent representation of a density array."""
|
26
78
|
|
27
|
-
def __call__(
|
28
|
-
self, density: types.Density2DArray
|
29
|
-
) -> ParameterizedDensity2DArrayBase:
|
79
|
+
def __call__(self, density: types.Density2DArray) -> ParameterizedDensity2DArray:
|
30
80
|
...
|
31
81
|
|
32
82
|
|
@@ -51,16 +101,6 @@ class UpdateFn(Protocol):
|
|
51
101
|
...
|
52
102
|
|
53
103
|
|
54
|
-
@dataclasses.dataclass
|
55
|
-
class Density2DParameterization:
|
56
|
-
"""Stores `(from_density, to_density, constraints)` function triple."""
|
57
|
-
|
58
|
-
from_density: FromDensityFn
|
59
|
-
to_density: ToDensityFn
|
60
|
-
constraints: ConstraintsFn
|
61
|
-
update: UpdateFn
|
62
|
-
|
63
|
-
|
64
104
|
@dataclasses.dataclass
|
65
105
|
class Density2DMetadata:
|
66
106
|
"""Stores the metadata of a `Density2DArray`."""
|
@@ -78,6 +118,12 @@ class Density2DMetadata:
|
|
78
118
|
self.periodic = tuple(self.periodic)
|
79
119
|
self.symmetries = tuple(self.symmetries)
|
80
120
|
|
121
|
+
@classmethod
|
122
|
+
def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
|
123
|
+
density_metadata_dict = dataclasses.asdict(density)
|
124
|
+
del density_metadata_dict["array"]
|
125
|
+
return Density2DMetadata(**density_metadata_dict)
|
126
|
+
|
81
127
|
|
82
128
|
def _flatten_density_2d_metadata(
|
83
129
|
metadata: Density2DMetadata,
|
@@ -13,25 +13,53 @@ from invrs_opt.parameterization import base, transforms
|
|
13
13
|
|
14
14
|
|
15
15
|
@dataclasses.dataclass
|
16
|
-
class
|
17
|
-
"""Stores
|
16
|
+
class FilterProjectParams(base.ParameterizedDensity2DArray):
|
17
|
+
"""Stores parameters for the filter-project parameterization."""
|
18
18
|
|
19
|
-
|
19
|
+
latents: "FilterProjectLatents"
|
20
|
+
metadata: "FilterProjectMetadata"
|
21
|
+
|
22
|
+
|
23
|
+
@dataclasses.dataclass
|
24
|
+
class FilterProjectLatents(base.LatentsBase):
|
25
|
+
"""Stores latent parameters for the filter-project parameterization.
|
26
|
+
|
27
|
+
Attributes:s
|
20
28
|
latent_density: The latent variable from which the density is obtained.
|
21
|
-
beta: Determines the sharpness of the thresholding operation.
|
22
29
|
"""
|
23
30
|
|
24
31
|
latent_density: types.Density2DArray
|
32
|
+
|
33
|
+
|
34
|
+
@dataclasses.dataclass
|
35
|
+
class FilterProjectMetadata(base.MetadataBase):
|
36
|
+
"""Stores metadata for the filter-project parameterization.
|
37
|
+
|
38
|
+
Attributes:
|
39
|
+
beta: Determines the sharpness of the thresholding operation.
|
40
|
+
"""
|
41
|
+
|
25
42
|
beta: float
|
26
43
|
|
27
44
|
|
28
45
|
tree_util.register_dataclass(
|
29
|
-
|
46
|
+
FilterProjectParams,
|
47
|
+
data_fields=["latents", "metadata"],
|
48
|
+
meta_fields=[],
|
49
|
+
)
|
50
|
+
tree_util.register_dataclass(
|
51
|
+
FilterProjectLatents,
|
30
52
|
data_fields=["latent_density"],
|
53
|
+
meta_fields=[],
|
54
|
+
)
|
55
|
+
tree_util.register_dataclass(
|
56
|
+
FilterProjectMetadata,
|
57
|
+
data_fields=[],
|
31
58
|
meta_fields=["beta"],
|
32
59
|
)
|
33
|
-
|
34
|
-
json_utils.register_custom_type(
|
60
|
+
json_utils.register_custom_type(FilterProjectParams)
|
61
|
+
json_utils.register_custom_type(FilterProjectLatents)
|
62
|
+
json_utils.register_custom_type(FilterProjectMetadata)
|
35
63
|
|
36
64
|
|
37
65
|
def filter_project(beta: float) -> base.Density2DParameterization:
|
@@ -55,24 +83,26 @@ def filter_project(beta: float) -> base.Density2DParameterization:
|
|
55
83
|
The `Density2DParameterization`.
|
56
84
|
"""
|
57
85
|
|
58
|
-
def from_density_fn(density: types.Density2DArray) ->
|
86
|
+
def from_density_fn(density: types.Density2DArray) -> FilterProjectParams:
|
59
87
|
"""Return latent parameters for the given `density`."""
|
60
88
|
array = transforms.normalized_array_from_density(density)
|
61
89
|
array = jnp.clip(array, -1, 1)
|
62
90
|
array *= jnp.tanh(beta)
|
63
91
|
latent_array = jnp.arctanh(array) / beta
|
64
92
|
latent_array = transforms.rescale_array_for_density(latent_array, density)
|
65
|
-
|
66
|
-
|
67
|
-
|
93
|
+
latent_density = density = dataclasses.replace(density, array=latent_array)
|
94
|
+
return FilterProjectParams(
|
95
|
+
latents=FilterProjectLatents(latent_density=latent_density),
|
96
|
+
metadata=FilterProjectMetadata(beta=beta),
|
68
97
|
)
|
69
98
|
|
70
|
-
def to_density_fn(params:
|
99
|
+
def to_density_fn(params: FilterProjectParams) -> types.Density2DArray:
|
71
100
|
"""Return a density from the latent parameters."""
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
)
|
101
|
+
latent_density = params.latents.latent_density
|
102
|
+
beta = params.metadata.beta
|
103
|
+
|
104
|
+
transformed = types.symmetrize_density(latent_density)
|
105
|
+
transformed = transforms.density_gaussian_filter_and_tanh(transformed, beta)
|
76
106
|
# Scale to ensure that the full valid range of the density array is reachable.
|
77
107
|
mid_value = (transformed.lower_bound + transformed.upper_bound) / 2
|
78
108
|
transformed = tree_util.tree_map(
|
@@ -80,12 +110,12 @@ def filter_project(beta: float) -> base.Density2DParameterization:
|
|
80
110
|
)
|
81
111
|
return transforms.apply_fixed_pixels(transformed)
|
82
112
|
|
83
|
-
def constraints_fn(params:
|
113
|
+
def constraints_fn(params: FilterProjectParams) -> jnp.ndarray:
|
84
114
|
"""Computes constraints associated with the params."""
|
85
115
|
del params
|
86
116
|
return jnp.asarray(0.0)
|
87
117
|
|
88
|
-
def update_fn(params:
|
118
|
+
def update_fn(params: FilterProjectParams, step: int) -> FilterProjectParams:
|
89
119
|
"""Perform updates to `params` required for the given `step`."""
|
90
120
|
del step
|
91
121
|
return params
|
@@ -29,12 +29,30 @@ DEFAULT_INIT_OPTIMIZER: optax.GradientTransformation = optax.adam(1e-1)
|
|
29
29
|
|
30
30
|
|
31
31
|
@dataclasses.dataclass
|
32
|
-
class GaussianLevelsetParams(base.
|
33
|
-
"""
|
32
|
+
class GaussianLevelsetParams(base.ParameterizedDensity2DArray):
|
33
|
+
"""Stores parameters for the Gaussian levelset parameterization."""
|
34
|
+
|
35
|
+
latents: "GaussianLevelsetLatents"
|
36
|
+
metadata: "GaussianLevelsetMetadata"
|
37
|
+
|
38
|
+
|
39
|
+
@dataclasses.dataclass
|
40
|
+
class GaussianLevelsetLatents(base.LatentsBase):
|
41
|
+
"""Stores latent parameters for the Gaussian levelset parameterization.
|
34
42
|
|
35
43
|
Attributes:
|
36
44
|
amplitude: Array giving the amplitude of the Gaussian basis function at
|
37
45
|
levelset control points.
|
46
|
+
"""
|
47
|
+
|
48
|
+
amplitude: jnp.ndarray
|
49
|
+
|
50
|
+
|
51
|
+
@dataclasses.dataclass
|
52
|
+
class GaussianLevelsetMetadata(base.MetadataBase):
|
53
|
+
"""Stores metadata for the Gaussian levelset parameterization.
|
54
|
+
|
55
|
+
Attributes:
|
38
56
|
length_scale_spacing_factor: The number of levelset control points per unit of
|
39
57
|
minimum length scale (mean of density minimum width and minimum spacing).
|
40
58
|
length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
|
@@ -45,7 +63,6 @@ class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase):
|
|
45
63
|
density_metadata: Metadata for the density array obtained from the parameters.
|
46
64
|
"""
|
47
65
|
|
48
|
-
amplitude: jnp.ndarray
|
49
66
|
length_scale_spacing_factor: float
|
50
67
|
length_scale_fwhm_factor: float
|
51
68
|
smoothing_factor: int
|
@@ -55,68 +72,29 @@ class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase):
|
|
55
72
|
def __post_init__(self) -> None:
|
56
73
|
self.density_shape = tuple(self.density_shape)
|
57
74
|
|
58
|
-
def example_density(self) -> types.Density2DArray:
|
59
|
-
"""Returns an example density with appropriate shape and metadata."""
|
60
|
-
with jax.ensure_compile_time_eval():
|
61
|
-
return types.Density2DArray(
|
62
|
-
array=jnp.zeros(self.density_shape),
|
63
|
-
**dataclasses.asdict(self.density_metadata),
|
64
|
-
)
|
65
|
-
|
66
|
-
|
67
|
-
_GaussianLevelsetParamsAux = Tuple[
|
68
|
-
float, float, int, Tuple[int, ...], tree_util.PyTreeDef
|
69
|
-
]
|
70
|
-
|
71
|
-
|
72
|
-
def _flatten_gaussian_levelset_params(
|
73
|
-
params: GaussianLevelsetParams,
|
74
|
-
) -> Tuple[Tuple[jnp.ndarray], _GaussianLevelsetParamsAux]:
|
75
|
-
_, flat_metadata = tree_util.tree_flatten(params.density_metadata)
|
76
|
-
return (
|
77
|
-
(params.amplitude,),
|
78
|
-
(
|
79
|
-
params.length_scale_spacing_factor,
|
80
|
-
params.length_scale_fwhm_factor,
|
81
|
-
params.smoothing_factor,
|
82
|
-
params.density_shape,
|
83
|
-
flat_metadata,
|
84
|
-
),
|
85
|
-
)
|
86
|
-
|
87
|
-
|
88
|
-
def _unflatten_gaussian_levelset_params(
|
89
|
-
aux: _GaussianLevelsetParamsAux,
|
90
|
-
children: Tuple[jnp.ndarray],
|
91
|
-
) -> GaussianLevelsetParams:
|
92
|
-
(amplitude,) = children
|
93
|
-
(
|
94
|
-
length_scale_spacing_factor,
|
95
|
-
length_scale_fwhm_factor,
|
96
|
-
smoothing_factor,
|
97
|
-
density_shape,
|
98
|
-
flat_metadata,
|
99
|
-
) = aux
|
100
|
-
|
101
|
-
density_metadata = tree_util.tree_unflatten(flat_metadata, ())
|
102
|
-
return GaussianLevelsetParams(
|
103
|
-
amplitude=amplitude,
|
104
|
-
length_scale_spacing_factor=length_scale_spacing_factor,
|
105
|
-
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
106
|
-
smoothing_factor=smoothing_factor,
|
107
|
-
density_shape=tuple(density_shape),
|
108
|
-
density_metadata=density_metadata,
|
109
|
-
)
|
110
|
-
|
111
75
|
|
112
|
-
tree_util.
|
76
|
+
tree_util.register_dataclass(
|
113
77
|
GaussianLevelsetParams,
|
114
|
-
|
115
|
-
|
78
|
+
data_fields=["latents", "metadata"],
|
79
|
+
meta_fields=[],
|
80
|
+
)
|
81
|
+
tree_util.register_dataclass(
|
82
|
+
GaussianLevelsetLatents,
|
83
|
+
data_fields=["amplitude"],
|
84
|
+
meta_fields=[],
|
85
|
+
)
|
86
|
+
tree_util.register_dataclass(
|
87
|
+
GaussianLevelsetMetadata,
|
88
|
+
data_fields=[
|
89
|
+
"length_scale_spacing_factor",
|
90
|
+
"length_scale_fwhm_factor",
|
91
|
+
"density_metadata",
|
92
|
+
],
|
93
|
+
meta_fields=["density_shape", "smoothing_factor"],
|
116
94
|
)
|
117
|
-
|
118
|
-
|
119
95
|
json_utils.register_custom_type(GaussianLevelsetParams)
|
96
|
+
json_utils.register_custom_type(GaussianLevelsetLatents)
|
97
|
+
json_utils.register_custom_type(GaussianLevelsetMetadata)
|
120
98
|
|
121
99
|
|
122
100
|
def gaussian_levelset(
|
@@ -187,23 +165,21 @@ def gaussian_levelset(
|
|
187
165
|
pad_width += ((0, 0),) if density.periodic[1] else ((1, 1),)
|
188
166
|
amplitude = jnp.pad(amplitude, pad_width, mode="edge")
|
189
167
|
|
190
|
-
|
191
|
-
|
192
|
-
density_metadata = base.Density2DMetadata(**density_metadata_dict)
|
193
|
-
params = GaussianLevelsetParams(
|
194
|
-
amplitude=amplitude,
|
168
|
+
latents = GaussianLevelsetLatents(amplitude=amplitude)
|
169
|
+
metadata = GaussianLevelsetMetadata(
|
195
170
|
length_scale_spacing_factor=length_scale_spacing_factor,
|
196
171
|
length_scale_fwhm_factor=length_scale_fwhm_factor,
|
197
172
|
smoothing_factor=smoothing_factor,
|
198
173
|
density_shape=density.shape,
|
199
|
-
density_metadata=
|
174
|
+
density_metadata=base.Density2DMetadata.from_density(density),
|
200
175
|
)
|
201
176
|
|
202
177
|
def step_fn(
|
203
178
|
_: int,
|
204
179
|
params_and_state: Tuple[PyTree, PyTree],
|
205
180
|
) -> Tuple[PyTree, PyTree]:
|
206
|
-
def loss_fn(
|
181
|
+
def loss_fn(latents: GaussianLevelsetLatents) -> jnp.ndarray:
|
182
|
+
params = GaussianLevelsetParams(latents, metadata=metadata)
|
207
183
|
density_from_params = to_density_fn(params, mask_gradient=False)
|
208
184
|
return jnp.mean((density_from_params.array - target_array) ** 2)
|
209
185
|
|
@@ -213,13 +189,14 @@ def gaussian_levelset(
|
|
213
189
|
params = optax.apply_updates(params, updates)
|
214
190
|
return params, state
|
215
191
|
|
216
|
-
state = init_optimizer.init(
|
217
|
-
|
218
|
-
0, init_steps, body_fun=step_fn, init_val=(
|
192
|
+
state = init_optimizer.init(latents)
|
193
|
+
latents, _ = jax.lax.fori_loop(
|
194
|
+
0, init_steps, body_fun=step_fn, init_val=(latents, state)
|
219
195
|
)
|
220
196
|
|
221
|
-
maxval = jnp.amax(jnp.abs(
|
222
|
-
|
197
|
+
maxval = jnp.amax(jnp.abs(latents.amplitude), axis=(-2, -1), keepdims=True)
|
198
|
+
latents = dataclasses.replace(latents, amplitude=latents.amplitude / maxval)
|
199
|
+
return GaussianLevelsetParams(latents=latents, metadata=metadata)
|
223
200
|
|
224
201
|
def to_density_fn(
|
225
202
|
params: GaussianLevelsetParams,
|
@@ -228,7 +205,7 @@ def gaussian_levelset(
|
|
228
205
|
"""Return a density from the latent parameters."""
|
229
206
|
array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=0)
|
230
207
|
|
231
|
-
example_density = params
|
208
|
+
example_density = _example_density(params)
|
232
209
|
lb = example_density.lower_bound
|
233
210
|
ub = example_density.upper_bound
|
234
211
|
array = lb + array * (ub - lb)
|
@@ -263,7 +240,7 @@ def gaussian_levelset(
|
|
263
240
|
)
|
264
241
|
|
265
242
|
# Normalize constraints to make them (somewhat) resolution-independent.
|
266
|
-
example_density = params
|
243
|
+
example_density = _example_density(params)
|
267
244
|
length_scale = 0.5 * (
|
268
245
|
example_density.minimum_spacing + example_density.minimum_width
|
269
246
|
)
|
@@ -287,6 +264,15 @@ def gaussian_levelset(
|
|
287
264
|
# -----------------------------------------------------------------------------
|
288
265
|
|
289
266
|
|
267
|
+
def _example_density(params: GaussianLevelsetParams) -> types.Density2DArray:
|
268
|
+
"""Returns an example density with appropriate shape and metadata."""
|
269
|
+
with jax.ensure_compile_time_eval():
|
270
|
+
return types.Density2DArray(
|
271
|
+
array=jnp.zeros(params.metadata.density_shape),
|
272
|
+
**dataclasses.asdict(params.metadata.density_metadata),
|
273
|
+
)
|
274
|
+
|
275
|
+
|
290
276
|
def _to_array(
|
291
277
|
params: GaussianLevelsetParams,
|
292
278
|
mask_gradient: bool,
|
@@ -308,18 +294,11 @@ def _to_array(
|
|
308
294
|
Returns:
|
309
295
|
The array.
|
310
296
|
"""
|
311
|
-
example_density = params
|
297
|
+
example_density = _example_density(params)
|
312
298
|
periodic: Tuple[bool, bool] = example_density.periodic
|
313
|
-
phi = _phi_from_params(
|
314
|
-
|
315
|
-
|
316
|
-
)
|
317
|
-
array = _levelset_threshold(
|
318
|
-
phi=phi,
|
319
|
-
periodic=periodic,
|
320
|
-
mask_gradient=mask_gradient,
|
321
|
-
)
|
322
|
-
return _downsample_spatial_dims(array, params.smoothing_factor)
|
299
|
+
phi = _phi_from_params(params=params, pad_pixels=pad_pixels)
|
300
|
+
array = _levelset_threshold(phi=phi, periodic=periodic, mask_gradient=mask_gradient)
|
301
|
+
return _downsample_spatial_dims(array, params.metadata.smoothing_factor)
|
323
302
|
|
324
303
|
|
325
304
|
def _phi_from_params(
|
@@ -337,32 +316,35 @@ def _phi_from_params(
|
|
337
316
|
The levelset array `phi`.
|
338
317
|
"""
|
339
318
|
with jax.ensure_compile_time_eval():
|
340
|
-
example_density = params
|
319
|
+
example_density = _example_density(params)
|
341
320
|
length_scale = 0.5 * (
|
342
321
|
example_density.minimum_width + example_density.minimum_spacing
|
343
322
|
)
|
344
|
-
fwhm = length_scale * params.length_scale_fwhm_factor
|
323
|
+
fwhm = length_scale * params.metadata.length_scale_fwhm_factor
|
345
324
|
sigma = fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
|
346
325
|
|
326
|
+
s_factor = params.metadata.smoothing_factor
|
347
327
|
highres_i = (
|
348
328
|
0.5
|
349
329
|
+ jnp.arange(
|
350
|
-
|
351
|
-
|
330
|
+
s_factor * (-pad_pixels),
|
331
|
+
s_factor * (pad_pixels + example_density.shape[-2]),
|
352
332
|
)
|
353
|
-
) /
|
333
|
+
) / s_factor
|
354
334
|
highres_j = (
|
355
335
|
0.5
|
356
336
|
+ jnp.arange(
|
357
|
-
|
358
|
-
|
337
|
+
s_factor * (-pad_pixels),
|
338
|
+
s_factor * (pad_pixels + example_density.shape[-1]),
|
359
339
|
)
|
360
|
-
) /
|
340
|
+
) / s_factor
|
361
341
|
|
362
342
|
# Coordinates for the control points of the Gaussian radial basis functions.
|
363
343
|
levelset_i, levelset_j = _control_point_coords(
|
364
|
-
density_shape=params.density_shape[-2:], # type: ignore[arg-type]
|
365
|
-
levelset_shape=
|
344
|
+
density_shape=params.metadata.density_shape[-2:], # type: ignore[arg-type]
|
345
|
+
levelset_shape=(
|
346
|
+
params.latents.amplitude.shape[-2:] # type: ignore[arg-type]
|
347
|
+
),
|
366
348
|
periodic=example_density.periodic,
|
367
349
|
)
|
368
350
|
|
@@ -391,7 +373,7 @@ def _phi_from_params(
|
|
391
373
|
levelset_i = levelset_i.flatten()
|
392
374
|
levelset_j = levelset_j.flatten()
|
393
375
|
|
394
|
-
amplitude = params.amplitude
|
376
|
+
amplitude = params.latents.amplitude
|
395
377
|
if example_density.periodic[0]:
|
396
378
|
amplitude = jnp.concat([amplitude] * 3, axis=-2)
|
397
379
|
if example_density.periodic[1]:
|
@@ -410,8 +392,8 @@ def _phi_from_params(
|
|
410
392
|
_, array = jax.lax.scan(scan_fn, (), xs=highres_i)
|
411
393
|
array = jnp.moveaxis(array, 0, -2)
|
412
394
|
|
413
|
-
assert array.shape[-2] %
|
414
|
-
assert array.shape[-1] %
|
395
|
+
assert array.shape[-2] % s_factor == 0
|
396
|
+
assert array.shape[-1] % s_factor == 0
|
415
397
|
array = symmetry.symmetrize(array, tuple(example_density.symmetries))
|
416
398
|
return array
|
417
399
|
|
@@ -443,7 +425,7 @@ def _fixed_pixel_constraint(
|
|
443
425
|
"""
|
444
426
|
array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=pad_pixels)
|
445
427
|
|
446
|
-
example_density = params
|
428
|
+
example_density = _example_density(params)
|
447
429
|
fixed_solid = jnp.zeros(example_density.shape[-2:], dtype=bool)
|
448
430
|
fixed_void = jnp.zeros(example_density.shape[-2:], dtype=bool)
|
449
431
|
if example_density.fixed_solid is not None:
|
@@ -491,9 +473,9 @@ def _levelset_constraints(
|
|
491
473
|
beyond the boundaries of the parameterized density.
|
492
474
|
|
493
475
|
Returns:
|
494
|
-
The minimum length scale and minimum curvature constraint arrays.
|
476
|
+
The minimum length scale and minimum curvature constraint arrays.s
|
495
477
|
"""
|
496
|
-
example_density = params
|
478
|
+
example_density = _example_density(params)
|
497
479
|
minimum_length_scale = 0.5 * (
|
498
480
|
example_density.minimum_width + example_density.minimum_spacing
|
499
481
|
)
|
@@ -512,9 +494,10 @@ def _levelset_constraints(
|
|
512
494
|
)
|
513
495
|
|
514
496
|
# Downsample so that constraints shape matches the density shape.
|
497
|
+
factor = params.metadata.smoothing_factor
|
515
498
|
return (
|
516
|
-
_downsample_spatial_dims(length_scale_constraint,
|
517
|
-
_downsample_spatial_dims(curvature_constraint,
|
499
|
+
_downsample_spatial_dims(length_scale_constraint, factor),
|
500
|
+
_downsample_spatial_dims(curvature_constraint, factor),
|
518
501
|
)
|
519
502
|
|
520
503
|
|
@@ -529,7 +512,7 @@ def _phi_derivatives_and_inverse_radius(
|
|
529
512
|
pad_pixels=pad_pixels,
|
530
513
|
)
|
531
514
|
|
532
|
-
d = 1 / params.smoothing_factor
|
515
|
+
d = 1 / params.metadata.smoothing_factor
|
533
516
|
phi_x, phi_y = jnp.gradient(phi, d, axis=(-2, -1))
|
534
517
|
phi_xx, phi_yx = jnp.gradient(phi_x, d, axis=(-2, -1))
|
535
518
|
phi_xy, phi_yy = jnp.gradient(phi_y, d, axis=(-2, -1))
|
@@ -13,26 +13,40 @@ from invrs_opt.parameterization import base
|
|
13
13
|
|
14
14
|
|
15
15
|
@dataclasses.dataclass
|
16
|
-
class PixelParams(base.
|
17
|
-
|
16
|
+
class PixelParams(base.ParameterizedDensity2DArray):
|
17
|
+
latents: "PixelLatents"
|
18
|
+
metadata: None = None
|
18
19
|
|
19
|
-
density: types.Density2DArray
|
20
20
|
|
21
|
+
@dataclasses.dataclass
|
22
|
+
class PixelLatents(base.LatentsBase):
|
23
|
+
"""Stores latent parameters for the direct pixel parameterization."""
|
21
24
|
|
22
|
-
|
25
|
+
density: types.Density2DArray
|
23
26
|
|
24
27
|
|
28
|
+
tree_util.register_dataclass(
|
29
|
+
PixelParams,
|
30
|
+
data_fields=["latents"],
|
31
|
+
meta_fields=[],
|
32
|
+
)
|
33
|
+
tree_util.register_dataclass(
|
34
|
+
PixelLatents,
|
35
|
+
data_fields=["density"],
|
36
|
+
meta_fields=[],
|
37
|
+
)
|
25
38
|
json_utils.register_custom_type(PixelParams)
|
39
|
+
json_utils.register_custom_type(PixelLatents)
|
26
40
|
|
27
41
|
|
28
42
|
def pixel() -> base.Density2DParameterization:
|
29
43
|
"""Return the direct pixel parameterization."""
|
30
44
|
|
31
45
|
def from_density_fn(density: types.Density2DArray) -> PixelParams:
|
32
|
-
return PixelParams(density=density)
|
46
|
+
return PixelParams(latents=PixelLatents(density=density))
|
33
47
|
|
34
48
|
def to_density_fn(params: PixelParams) -> types.Density2DArray:
|
35
|
-
return params.density
|
49
|
+
return params.latents.density
|
36
50
|
|
37
51
|
def constraints_fn(params: PixelParams) -> jnp.ndarray:
|
38
52
|
del params
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: invrs_opt
|
3
|
-
Version: 0.
|
3
|
+
Version: 0.9.1
|
4
4
|
Summary: Algorithms for inverse design
|
5
5
|
Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
6
6
|
Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
|
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
|
|
533
533
|
Requires-Dist: pytest-subtests; extra == "tests"
|
534
534
|
|
535
535
|
# invrs-opt - Optimization algorithms for inverse design
|
536
|
-
`v0.
|
536
|
+
`v0.9.1`
|
537
537
|
|
538
538
|
## Overview
|
539
539
|
|
@@ -0,0 +1,20 @@
|
|
1
|
+
invrs_opt/__init__.py,sha256=ETNFCTO6rJkj597a5uqDRdrkOcB4Zq4d2vV0y3POYAM,585
|
2
|
+
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
+
invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
|
5
|
+
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
6
|
+
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
+
invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
|
8
|
+
invrs_opt/optimizers/lbfgsb.py,sha256=8BPiEAqececL-zLnqrgN0CogGDkAd1tyAGndUB-kahc,36349
|
9
|
+
invrs_opt/optimizers/wrapped_optax.py,sha256=VXdCteT2kumqhP81l3p6QiEqwBffoUuJ3UjrAyX5ToA,13468
|
10
|
+
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
+
invrs_opt/parameterization/base.py,sha256=BObzbz6efT2nBjib0_5BSdkCmFi2f0mcZ9VJYpDzO6Q,5278
|
12
|
+
invrs_opt/parameterization/filter_project.py,sha256=7Jb8JVENmBTdx3-XmI-VRm4aMjxg_Wtin8tMKKKxWvQ,4309
|
13
|
+
invrs_opt/parameterization/gaussian_levelset.py,sha256=Uagx7k69SWmass0YirD5JN8O4QDbwwKTBBjRfkIXvv8,24793
|
14
|
+
invrs_opt/parameterization/pixel.py,sha256=CCSuWF_bDebwmwTG33vmZwmvZoDJXYAR1kCrz4KnLt8,1607
|
15
|
+
invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
|
16
|
+
invrs_opt-0.9.1.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
+
invrs_opt-0.9.1.dist-info/METADATA,sha256=MBVhAw_cmI7gL7tYh4UuNdKqeujiecIuemCHr27T_R8,32633
|
18
|
+
invrs_opt-0.9.1.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
19
|
+
invrs_opt-0.9.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
+
invrs_opt-0.9.1.dist-info/RECORD,,
|
invrs_opt-0.8.1.dist-info/RECORD
DELETED
@@ -1,20 +0,0 @@
|
|
1
|
-
invrs_opt/__init__.py,sha256=tjbabqQKp5y1I5eqDa7onsy7E58P9nTk_Ifn12oMfZM,585
|
2
|
-
invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
-
invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
4
|
-
invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
|
5
|
-
invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
|
6
|
-
invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
-
invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
|
8
|
-
invrs_opt/optimizers/lbfgsb.py,sha256=2NcyFllUM5CrjOZLEBocYAlCrkWt1fiRGJrxg05waog,35149
|
9
|
-
invrs_opt/optimizers/wrapped_optax.py,sha256=c4iSsXHF2cJZaZhdxLpyRy26q5IxBTo4-6sY8IOLT38,13259
|
10
|
-
invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
11
|
-
invrs_opt/parameterization/base.py,sha256=k2VXZorMnkmaIWWY_G2CvS0NtxYojfvKlPccDGZyGNA,3779
|
12
|
-
invrs_opt/parameterization/filter_project.py,sha256=D2w8xrg34V8ysIbbr1RPvegM5WoLdz8QKUCGi80ieOI,3466
|
13
|
-
invrs_opt/parameterization/gaussian_levelset.py,sha256=DllvpkBpVxuDFQFu941f8v_Xh2D0m9Eh-JxBH5afNHU,25221
|
14
|
-
invrs_opt/parameterization/pixel.py,sha256=yWdGpf0x6Om_-7CYOcHTgNK9UF-XGAPM7GbOBNijnBw,1305
|
15
|
-
invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
|
16
|
-
invrs_opt-0.8.1.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
|
17
|
-
invrs_opt-0.8.1.dist-info/METADATA,sha256=4qLsuDLp3pNn5C1TTPnF0qwWps7MeN_Uc3rKsBLbPrk,32633
|
18
|
-
invrs_opt-0.8.1.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
|
19
|
-
invrs_opt-0.8.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
|
20
|
-
invrs_opt-0.8.1.dist-info/RECORD,,
|
File without changes
|
File without changes
|
File without changes
|