invrs-opt 0.8.1__py3-none-any.whl → 0.9.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
invrs_opt/__init__.py CHANGED
@@ -3,7 +3,7 @@
3
3
  Copyright (c) 2023 The INVRS-IO authors.
4
4
  """
5
5
 
6
- __version__ = "v0.8.1"
6
+ __version__ = "v0.9.0"
7
7
  __author__ = "Martin F. Schubert <mfschubert@gmail.com>"
8
8
 
9
9
  from invrs_opt import parameterization as parameterization
@@ -19,7 +19,7 @@ from totypes import types
19
19
 
20
20
  from invrs_opt.optimizers import base
21
21
  from invrs_opt.parameterization import (
22
- base as parameterization_base,
22
+ base as param_base,
23
23
  filter_project,
24
24
  gaussian_levelset,
25
25
  pixel,
@@ -252,7 +252,7 @@ def levelset_lbfgsb(
252
252
 
253
253
 
254
254
  def parameterized_lbfgsb(
255
- density_parameterization: Optional[parameterization_base.Density2DParameterization],
255
+ density_parameterization: Optional[param_base.Density2DParameterization],
256
256
  penalty: float,
257
257
  maxcor: int = DEFAULT_MAXCOR,
258
258
  line_search_max_steps: int = DEFAULT_LINE_SEARCH_MAX_STEPS,
@@ -296,59 +296,6 @@ def parameterized_lbfgsb(
296
296
  if density_parameterization is None:
297
297
  density_parameterization = pixel.pixel()
298
298
 
299
- def _init_latents(params: PyTree) -> PyTree:
300
- def _leaf_init_latents(leaf: Any) -> Any:
301
- leaf = _clip(leaf)
302
- if not _is_density(leaf) or density_parameterization is None:
303
- return leaf
304
- return density_parameterization.from_density(leaf)
305
-
306
- return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
307
-
308
- def _params_from_latents(latent_params: PyTree) -> PyTree:
309
- def _leaf_params_from_latents(leaf: Any) -> Any:
310
- if not _is_parameterized_density(leaf) or density_parameterization is None:
311
- return leaf
312
- return density_parameterization.to_density(leaf)
313
-
314
- return tree_util.tree_map(
315
- _leaf_params_from_latents,
316
- latent_params,
317
- is_leaf=_is_parameterized_density,
318
- )
319
-
320
- def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
321
- def _constraint_loss_leaf(
322
- params: parameterization_base.ParameterizedDensity2DArrayBase,
323
- ) -> jnp.ndarray:
324
- constraints = density_parameterization.constraints(params)
325
- constraints = tree_util.tree_map(
326
- lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
327
- constraints,
328
- )
329
- return jnp.sum(jnp.asarray(constraints))
330
-
331
- losses = [0.0] + [
332
- _constraint_loss_leaf(p)
333
- for p in tree_util.tree_leaves(
334
- latent_params, is_leaf=_is_parameterized_density
335
- )
336
- if _is_parameterized_density(p)
337
- ]
338
- return penalty * jnp.sum(jnp.asarray(losses))
339
-
340
- def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
341
- def _update_leaf(leaf: Any) -> Any:
342
- if not _is_parameterized_density(leaf):
343
- return leaf
344
- return density_parameterization.update(leaf, step)
345
-
346
- return tree_util.tree_map(
347
- _update_leaf,
348
- latent_params,
349
- is_leaf=_is_parameterized_density,
350
- )
351
-
352
299
  def init_fn(params: PyTree) -> LbfgsbState:
353
300
  """Initializes the optimization state."""
354
301
 
@@ -368,10 +315,17 @@ def parameterized_lbfgsb(
368
315
  return latent_params, scipy_lbfgsb_state.to_jax()
369
316
 
370
317
  latent_params = _init_latents(params)
371
- latent_params, jax_lbfgsb_state = jax.pure_callback(
372
- _init_state_pure, _example_state(latent_params, maxcor), latent_params
318
+ metadata, latents = param_base.partition_density_metadata(latent_params)
319
+ latents, jax_lbfgsb_state = jax.pure_callback(
320
+ _init_state_pure, _example_state(latents, maxcor), latents
321
+ )
322
+ latent_params = param_base.combine_density_metadata(metadata, latents)
323
+ return (
324
+ 0, # step
325
+ _params_from_latent_params(latent_params), # params
326
+ latent_params, # latent params
327
+ jax_lbfgsb_state, # opt state
373
328
  )
374
- return 0, _params_from_latents(latent_params), latent_params, jax_lbfgsb_state
375
329
 
376
330
  def params_fn(state: LbfgsbState) -> PyTree:
377
331
  """Returns the parameters for the given `state`."""
@@ -403,46 +357,118 @@ def parameterized_lbfgsb(
403
357
  return flat_latent_params, scipy_lbfgsb_state.to_jax()
404
358
 
405
359
  step, _, latent_params, jax_lbfgsb_state = state
406
- _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
407
- (latent_grad,) = vjp_fn(grad)
360
+ metadata, latents = param_base.partition_density_metadata(latent_params)
361
+
362
+ def _params_from_latents(latents: PyTree) -> PyTree:
363
+ latent_params = param_base.combine_density_metadata(metadata, latents)
364
+ return _params_from_latent_params(latent_params)
365
+
366
+ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
367
+ latent_params = param_base.combine_density_metadata(metadata, latents)
368
+ return _constraint_loss(latent_params)
369
+
370
+ _, vjp_fn = jax.vjp(_params_from_latents, latents)
371
+ (latents_grad,) = vjp_fn(grad)
408
372
 
409
373
  if not (
410
- tree_util.tree_structure(latent_grad)
411
- == tree_util.tree_structure(latent_params) # type: ignore[operator]
374
+ tree_util.tree_structure(latents_grad)
375
+ == tree_util.tree_structure(latents) # type: ignore[operator]
412
376
  ):
413
377
  raise ValueError(
414
- f"Tree structure of `latent_grad` was different than expected, got \n"
415
- f"{tree_util.tree_structure(latent_grad)} but expected \n"
416
- f"{tree_util.tree_structure(latent_params)}."
378
+ f"Tree structure of `latents_grad` was different than expected, got \n"
379
+ f"{tree_util.tree_structure(latents_grad)} but expected \n"
380
+ f"{tree_util.tree_structure(latents)}."
417
381
  )
418
382
 
419
383
  (
420
384
  constraint_loss_value,
421
385
  constraint_loss_grad,
422
386
  ) = jax.value_and_grad(
423
- _constraint_loss
424
- )(latent_params)
387
+ _constraint_loss_latents
388
+ )(latents)
425
389
  value += constraint_loss_value
426
- latent_grad = tree_util.tree_map(
427
- lambda a, b: a + b, latent_grad, constraint_loss_grad
390
+ latents_grad = tree_util.tree_map(
391
+ lambda a, b: a + b, latents_grad, constraint_loss_grad
428
392
  )
429
393
 
430
- flat_latent_grad, unflatten_fn = flatten_util.ravel_pytree(
431
- latent_grad
394
+ flat_latents_grad, unflatten_fn = flatten_util.ravel_pytree(
395
+ latents_grad
432
396
  ) # type: ignore[no-untyped-call]
433
397
 
434
- flat_latent_params, jax_lbfgsb_state = jax.pure_callback(
398
+ flat_latents, jax_lbfgsb_state = jax.pure_callback(
435
399
  _update_pure,
436
- (flat_latent_grad, jax_lbfgsb_state),
437
- flat_latent_grad,
400
+ (flat_latents_grad, jax_lbfgsb_state),
401
+ flat_latents_grad,
438
402
  value,
439
403
  jax_lbfgsb_state,
440
404
  )
441
- latent_params = unflatten_fn(flat_latent_params)
405
+ latents = unflatten_fn(flat_latents)
406
+ latent_params = param_base.combine_density_metadata(metadata, latents)
442
407
  latent_params = _update_parameterized_densities(latent_params, step)
443
- params = _params_from_latents(latent_params)
408
+ params = _params_from_latent_params(latent_params)
444
409
  return step + 1, params, latent_params, jax_lbfgsb_state
445
410
 
411
+ # -------------------------------------------------------------------------
412
+ # Functions related to the density parameterization.
413
+ # -------------------------------------------------------------------------
414
+
415
+ def _init_latents(params: PyTree) -> PyTree:
416
+ def _leaf_init_latents(leaf: Any) -> Any:
417
+ leaf = _clip(leaf)
418
+ if not _is_density(leaf) or density_parameterization is None:
419
+ return leaf
420
+ return density_parameterization.from_density(leaf)
421
+
422
+ return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
423
+
424
+ def _params_from_latent_params(latent_params: PyTree) -> PyTree:
425
+ def _leaf_params_from_latents(leaf: Any) -> Any:
426
+ if not _is_parameterized_density(leaf) or density_parameterization is None:
427
+ return leaf
428
+ return density_parameterization.to_density(leaf)
429
+
430
+ return tree_util.tree_map(
431
+ _leaf_params_from_latents,
432
+ latent_params,
433
+ is_leaf=_is_parameterized_density,
434
+ )
435
+
436
+ def _update_parameterized_densities(latent_params: PyTree, step: int) -> PyTree:
437
+ def _update_leaf(leaf: Any) -> Any:
438
+ if not _is_parameterized_density(leaf):
439
+ return leaf
440
+ return density_parameterization.update(leaf, step)
441
+
442
+ return tree_util.tree_map(
443
+ _update_leaf,
444
+ latent_params,
445
+ is_leaf=_is_parameterized_density,
446
+ )
447
+
448
+ # -------------------------------------------------------------------------
449
+ # Functions related to the constraints to be minimized.
450
+ # -------------------------------------------------------------------------
451
+
452
+ def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
453
+ def _constraint_loss_leaf(
454
+ leaf: param_base.ParameterizedDensity2DArray,
455
+ ) -> jnp.ndarray:
456
+ constraints = density_parameterization.constraints(leaf)
457
+ constraints = tree_util.tree_map(
458
+ lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
459
+ constraints,
460
+ )
461
+ return jnp.sum(jnp.asarray(constraints))
462
+
463
+ losses = [0.0] + [
464
+ _constraint_loss_leaf(p)
465
+ for p in tree_util.tree_leaves(
466
+ latent_params, is_leaf=_is_parameterized_density
467
+ )
468
+ if _is_parameterized_density(p)
469
+ ]
470
+ return penalty * jnp.sum(jnp.asarray(losses))
471
+
446
472
  return base.Optimizer(
447
473
  init=init_fn,
448
474
  params=params_fn,
@@ -467,7 +493,7 @@ def _is_density(leaf: Any) -> Any:
467
493
 
468
494
  def _is_parameterized_density(leaf: Any) -> Any:
469
495
  """Return `True` if `leaf` is a parameterized density array."""
470
- return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase)
496
+ return isinstance(leaf, param_base.ParameterizedDensity2DArray)
471
497
 
472
498
 
473
499
  def _is_custom_type(leaf: Any) -> bool:
@@ -13,7 +13,7 @@ from totypes import types
13
13
 
14
14
  from invrs_opt.optimizers import base
15
15
  from invrs_opt.parameterization import (
16
- base as parameterization_base,
16
+ base as param_base,
17
17
  filter_project,
18
18
  gaussian_levelset,
19
19
  pixel,
@@ -158,7 +158,7 @@ def levelset_wrapped_optax(
158
158
 
159
159
  def parameterized_wrapped_optax(
160
160
  opt: optax.GradientTransformation,
161
- density_parameterization: Optional[parameterization_base.Density2DParameterization],
161
+ density_parameterization: Optional[param_base.Density2DParameterization],
162
162
  penalty: float,
163
163
  ) -> base.Optimizer:
164
164
  """Wrapped optax optimizer with specified density parameterization.
@@ -181,11 +181,12 @@ def parameterized_wrapped_optax(
181
181
  def init_fn(params: PyTree) -> WrappedOptaxState:
182
182
  """Initializes the optimization state."""
183
183
  latent_params = _init_latents(params)
184
+ _, latents = param_base.partition_density_metadata(latent_params)
184
185
  return (
185
186
  0, # step
186
- _params_from_latents(latent_params), # params
187
+ _params_from_latent_params(latent_params), # params
187
188
  latent_params, # latent params
188
- opt.init(tree_util.tree_leaves(latent_params)), # opt state
189
+ opt.init(latents), # opt state
189
190
  )
190
191
 
191
192
  def params_fn(state: WrappedOptaxState) -> PyTree:
@@ -204,42 +205,41 @@ def parameterized_wrapped_optax(
204
205
  del value, params
205
206
 
206
207
  step, params, latent_params, opt_state = state
208
+ metadata, latents = param_base.partition_density_metadata(latent_params)
207
209
 
208
- _, vjp_fn = jax.vjp(_params_from_latents, latent_params)
209
- (latent_grad,) = vjp_fn(grad)
210
+ def _params_from_latents(latents: PyTree) -> PyTree:
211
+ latent_params = param_base.combine_density_metadata(metadata, latents)
212
+ return _params_from_latent_params(latent_params)
213
+
214
+ def _constraint_loss_latents(latents: PyTree) -> jnp.ndarray:
215
+ latent_params = param_base.combine_density_metadata(metadata, latents)
216
+ return _constraint_loss(latent_params)
217
+
218
+ _, vjp_fn = jax.vjp(_params_from_latents, latents)
219
+ (latents_grad,) = vjp_fn(grad)
210
220
 
211
221
  if not (
212
- tree_util.tree_structure(latent_grad)
213
- == tree_util.tree_structure(latent_params) # type: ignore[operator]
222
+ tree_util.tree_structure(latents_grad)
223
+ == tree_util.tree_structure(latents) # type: ignore[operator]
214
224
  ):
215
225
  raise ValueError(
216
- f"Tree structure of `latent_grad` was different than expected, got \n"
217
- f"{tree_util.tree_structure(latent_grad)} but expected \n"
218
- f"{tree_util.tree_structure(latent_params)}."
226
+ f"Tree structure of `latents_grad` was different than expected, got \n"
227
+ f"{tree_util.tree_structure(latents_grad)} but expected \n"
228
+ f"{tree_util.tree_structure(latents)}."
219
229
  )
220
230
 
221
- constraint_loss_grad = jax.grad(_constraint_loss)(latent_params)
222
- latent_grad = tree_util.tree_map(
223
- lambda a, b: a + b, latent_grad, constraint_loss_grad
231
+ constraint_loss_grad = jax.grad(_constraint_loss_latents)(latents)
232
+ latents_grad = tree_util.tree_map(
233
+ lambda a, b: a + b, latents_grad, constraint_loss_grad
224
234
  )
225
235
 
226
- updates_leaves, opt_state = opt.update(
227
- updates=tree_util.tree_leaves(latent_grad),
228
- state=opt_state,
229
- params=tree_util.tree_leaves(latent_params),
230
- )
231
- latent_params_leaves = optax.apply_updates(
232
- params=tree_util.tree_leaves(latent_params),
233
- updates=updates_leaves,
234
- )
235
- latent_params = tree_util.tree_unflatten(
236
- treedef=tree_util.tree_structure(latent_params),
237
- leaves=latent_params_leaves,
238
- )
236
+ updates, opt_state = opt.update(latents_grad, state=opt_state, params=latents)
237
+ latents = optax.apply_updates(params=latents, updates=updates)
239
238
 
239
+ latent_params = param_base.combine_density_metadata(metadata, latents)
240
240
  latent_params = _clip(latent_params)
241
241
  latent_params = _update_parameterized_densities(latent_params, step + 1)
242
- params = _params_from_latents(latent_params)
242
+ params = _params_from_latent_params(latent_params)
243
243
  return (step + 1, params, latent_params, opt_state)
244
244
 
245
245
  # -------------------------------------------------------------------------
@@ -255,7 +255,7 @@ def parameterized_wrapped_optax(
255
255
 
256
256
  return tree_util.tree_map(_leaf_init_latents, params, is_leaf=_is_custom_type)
257
257
 
258
- def _params_from_latents(params: PyTree) -> PyTree:
258
+ def _params_from_latent_params(params: PyTree) -> PyTree:
259
259
  def _leaf_params_from_latents(leaf: Any) -> Any:
260
260
  if not _is_parameterized_density(leaf):
261
261
  return leaf
@@ -285,9 +285,9 @@ def parameterized_wrapped_optax(
285
285
 
286
286
  def _constraint_loss(latent_params: PyTree) -> jnp.ndarray:
287
287
  def _constraint_loss_leaf(
288
- params: parameterization_base.ParameterizedDensity2DArrayBase,
288
+ leaf: param_base.ParameterizedDensity2DArray,
289
289
  ) -> jnp.ndarray:
290
- constraints = density_parameterization.constraints(params)
290
+ constraints = density_parameterization.constraints(leaf)
291
291
  constraints = tree_util.tree_map(
292
292
  lambda x: jnp.sum(jnp.maximum(x, 0.0) ** 2),
293
293
  constraints,
@@ -313,7 +313,7 @@ def _is_density(leaf: Any) -> Any:
313
313
 
314
314
  def _is_parameterized_density(leaf: Any) -> Any:
315
315
  """Return `True` if `leaf` is a parameterized density array."""
316
- return isinstance(leaf, parameterization_base.ParameterizedDensity2DArrayBase)
316
+ return isinstance(leaf, param_base.ParameterizedDensity2DArray)
317
317
 
318
318
 
319
319
  def _is_custom_type(leaf: Any) -> bool:
@@ -9,24 +9,74 @@ from typing import Any, Optional, Protocol, Sequence, Tuple
9
9
  import jax.numpy as jnp
10
10
  import numpy as onp
11
11
  from jax import tree_util
12
- from totypes import json_utils, types
12
+ from totypes import json_utils, partition_utils, types
13
13
 
14
14
  Array = jnp.ndarray | onp.ndarray[Any, Any]
15
15
  PyTree = Any
16
16
 
17
17
 
18
- class ParameterizedDensity2DArrayBase:
19
- """Base class for parameterized density arrays."""
18
+ @dataclasses.dataclass
19
+ class ParameterizedDensity2DArray:
20
+ """Stores latents and metadata for a parameterized density array."""
21
+
22
+ latents: "LatentsBase"
23
+ metadata: Optional["MetadataBase"]
24
+
25
+
26
+ class LatentsBase:
27
+ """Base class for latents of a parameterized density array."""
28
+
29
+ pass
30
+
31
+
32
+ class MetadataBase:
33
+ """Base class for metadata of a parameterized density array."""
20
34
 
21
35
  pass
22
36
 
23
37
 
38
+ tree_util.register_dataclass(
39
+ ParameterizedDensity2DArray,
40
+ data_fields=["latents", "metadata"],
41
+ meta_fields=[],
42
+ )
43
+ json_utils.register_custom_type(ParameterizedDensity2DArray)
44
+
45
+
46
+ def partition_density_metadata(tree: PyTree) -> Tuple[PyTree, PyTree]:
47
+ """Splits a pytree with parameterized densities into metadata from latents."""
48
+ metadata, latents = partition_utils.partition(
49
+ tree,
50
+ select_fn=lambda x: isinstance(x, MetadataBase),
51
+ is_leaf=_is_metadata_or_none,
52
+ )
53
+ return metadata, latents
54
+
55
+
56
+ def combine_density_metadata(metadata: PyTree, latents: PyTree) -> PyTree:
57
+ """Combines pytrees containing metadata and latents."""
58
+ return partition_utils.combine(metadata, latents, is_leaf=_is_metadata_or_none)
59
+
60
+
61
+ def _is_metadata_or_none(leaf: Any) -> bool:
62
+ """Return `True` if `leaf` is `None` or density metadata."""
63
+ return leaf is None or isinstance(leaf, MetadataBase)
64
+
65
+
66
+ @dataclasses.dataclass
67
+ class Density2DParameterization:
68
+ """Stores `(from_density, to_density, constraints, update)` function triple."""
69
+
70
+ from_density: "FromDensityFn"
71
+ to_density: "ToDensityFn"
72
+ constraints: "ConstraintsFn"
73
+ update: "UpdateFn"
74
+
75
+
24
76
  class FromDensityFn(Protocol):
25
77
  """Generate the latent representation of a density array."""
26
78
 
27
- def __call__(
28
- self, density: types.Density2DArray
29
- ) -> ParameterizedDensity2DArrayBase:
79
+ def __call__(self, density: types.Density2DArray) -> ParameterizedDensity2DArray:
30
80
  ...
31
81
 
32
82
 
@@ -51,16 +101,6 @@ class UpdateFn(Protocol):
51
101
  ...
52
102
 
53
103
 
54
- @dataclasses.dataclass
55
- class Density2DParameterization:
56
- """Stores `(from_density, to_density, constraints)` function triple."""
57
-
58
- from_density: FromDensityFn
59
- to_density: ToDensityFn
60
- constraints: ConstraintsFn
61
- update: UpdateFn
62
-
63
-
64
104
  @dataclasses.dataclass
65
105
  class Density2DMetadata:
66
106
  """Stores the metadata of a `Density2DArray`."""
@@ -78,6 +118,12 @@ class Density2DMetadata:
78
118
  self.periodic = tuple(self.periodic)
79
119
  self.symmetries = tuple(self.symmetries)
80
120
 
121
+ @classmethod
122
+ def from_density(self, density: types.Density2DArray) -> "Density2DMetadata":
123
+ density_metadata_dict = dataclasses.asdict(density)
124
+ del density_metadata_dict["array"]
125
+ return Density2DMetadata(**density_metadata_dict)
126
+
81
127
 
82
128
  def _flatten_density_2d_metadata(
83
129
  metadata: Density2DMetadata,
@@ -13,25 +13,53 @@ from invrs_opt.parameterization import base, transforms
13
13
 
14
14
 
15
15
  @dataclasses.dataclass
16
- class FilterAndProjectParams(base.ParameterizedDensity2DArrayBase):
17
- """Stores the latent parameters of the pixel parameterization.
16
+ class FilterProjectParams(base.ParameterizedDensity2DArray):
17
+ """Stores parameters for the filter-project parameterization."""
18
18
 
19
- Attributes:
19
+ latents: "FilterProjectLatents"
20
+ metadata: "FilterProjectMetadata"
21
+
22
+
23
+ @dataclasses.dataclass
24
+ class FilterProjectLatents(base.LatentsBase):
25
+ """Stores latent parameters for the filter-project parameterization.
26
+
27
+ Attributes:s
20
28
  latent_density: The latent variable from which the density is obtained.
21
- beta: Determines the sharpness of the thresholding operation.
22
29
  """
23
30
 
24
31
  latent_density: types.Density2DArray
32
+
33
+
34
+ @dataclasses.dataclass
35
+ class FilterProjectMetadata(base.MetadataBase):
36
+ """Stores metadata for the filter-project parameterization.
37
+
38
+ Attributes:
39
+ beta: Determines the sharpness of the thresholding operation.
40
+ """
41
+
25
42
  beta: float
26
43
 
27
44
 
28
45
  tree_util.register_dataclass(
29
- FilterAndProjectParams,
46
+ FilterProjectParams,
47
+ data_fields=["latents", "metadata"],
48
+ meta_fields=[],
49
+ )
50
+ tree_util.register_dataclass(
51
+ FilterProjectLatents,
30
52
  data_fields=["latent_density"],
53
+ meta_fields=[],
54
+ )
55
+ tree_util.register_dataclass(
56
+ FilterProjectMetadata,
57
+ data_fields=[],
31
58
  meta_fields=["beta"],
32
59
  )
33
-
34
- json_utils.register_custom_type(FilterAndProjectParams)
60
+ json_utils.register_custom_type(FilterProjectParams)
61
+ json_utils.register_custom_type(FilterProjectLatents)
62
+ json_utils.register_custom_type(FilterProjectMetadata)
35
63
 
36
64
 
37
65
  def filter_project(beta: float) -> base.Density2DParameterization:
@@ -55,24 +83,26 @@ def filter_project(beta: float) -> base.Density2DParameterization:
55
83
  The `Density2DParameterization`.
56
84
  """
57
85
 
58
- def from_density_fn(density: types.Density2DArray) -> FilterAndProjectParams:
86
+ def from_density_fn(density: types.Density2DArray) -> FilterProjectParams:
59
87
  """Return latent parameters for the given `density`."""
60
88
  array = transforms.normalized_array_from_density(density)
61
89
  array = jnp.clip(array, -1, 1)
62
90
  array *= jnp.tanh(beta)
63
91
  latent_array = jnp.arctanh(array) / beta
64
92
  latent_array = transforms.rescale_array_for_density(latent_array, density)
65
- return FilterAndProjectParams(
66
- latent_density=dataclasses.replace(density, array=latent_array),
67
- beta=beta,
93
+ latent_density = density = dataclasses.replace(density, array=latent_array)
94
+ return FilterProjectParams(
95
+ latents=FilterProjectLatents(latent_density=latent_density),
96
+ metadata=FilterProjectMetadata(beta=beta),
68
97
  )
69
98
 
70
- def to_density_fn(params: FilterAndProjectParams) -> types.Density2DArray:
99
+ def to_density_fn(params: FilterProjectParams) -> types.Density2DArray:
71
100
  """Return a density from the latent parameters."""
72
- transformed = types.symmetrize_density(params.latent_density)
73
- transformed = transforms.density_gaussian_filter_and_tanh(
74
- transformed, beta=params.beta
75
- )
101
+ latent_density = params.latents.latent_density
102
+ beta = params.metadata.beta
103
+
104
+ transformed = types.symmetrize_density(latent_density)
105
+ transformed = transforms.density_gaussian_filter_and_tanh(transformed, beta)
76
106
  # Scale to ensure that the full valid range of the density array is reachable.
77
107
  mid_value = (transformed.lower_bound + transformed.upper_bound) / 2
78
108
  transformed = tree_util.tree_map(
@@ -80,12 +110,12 @@ def filter_project(beta: float) -> base.Density2DParameterization:
80
110
  )
81
111
  return transforms.apply_fixed_pixels(transformed)
82
112
 
83
- def constraints_fn(params: FilterAndProjectParams) -> jnp.ndarray:
113
+ def constraints_fn(params: FilterProjectParams) -> jnp.ndarray:
84
114
  """Computes constraints associated with the params."""
85
115
  del params
86
116
  return jnp.asarray(0.0)
87
117
 
88
- def update_fn(params: FilterAndProjectParams, step: int) -> FilterAndProjectParams:
118
+ def update_fn(params: FilterProjectParams, step: int) -> FilterProjectParams:
89
119
  """Perform updates to `params` required for the given `step`."""
90
120
  del step
91
121
  return params
@@ -29,12 +29,30 @@ DEFAULT_INIT_OPTIMIZER: optax.GradientTransformation = optax.adam(1e-1)
29
29
 
30
30
 
31
31
  @dataclasses.dataclass
32
- class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase):
33
- """Parameters of a density represented by a Gaussian levelset.
32
+ class GaussianLevelsetParams(base.ParameterizedDensity2DArray):
33
+ """Stores parameters for the Gaussian levelset parameterization."""
34
+
35
+ latents: "GaussianLevelsetLatents"
36
+ metadata: "GaussianLevelsetMetadata"
37
+
38
+
39
+ @dataclasses.dataclass
40
+ class GaussianLevelsetLatents(base.LatentsBase):
41
+ """Stores latent parameters for the Gaussian levelset parameterization.
34
42
 
35
43
  Attributes:
36
44
  amplitude: Array giving the amplitude of the Gaussian basis function at
37
45
  levelset control points.
46
+ """
47
+
48
+ amplitude: jnp.ndarray
49
+
50
+
51
+ @dataclasses.dataclass
52
+ class GaussianLevelsetMetadata(base.MetadataBase):
53
+ """Stores metadata for the Gaussian levelset parameterization.
54
+
55
+ Attributes:
38
56
  length_scale_spacing_factor: The number of levelset control points per unit of
39
57
  minimum length scale (mean of density minimum width and minimum spacing).
40
58
  length_scale_fwhm_factor: The ratio of Gaussian full-width at half-maximum to
@@ -45,7 +63,6 @@ class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase):
45
63
  density_metadata: Metadata for the density array obtained from the parameters.
46
64
  """
47
65
 
48
- amplitude: jnp.ndarray
49
66
  length_scale_spacing_factor: float
50
67
  length_scale_fwhm_factor: float
51
68
  smoothing_factor: int
@@ -55,68 +72,31 @@ class GaussianLevelsetParams(base.ParameterizedDensity2DArrayBase):
55
72
  def __post_init__(self) -> None:
56
73
  self.density_shape = tuple(self.density_shape)
57
74
 
58
- def example_density(self) -> types.Density2DArray:
59
- """Returns an example density with appropriate shape and metadata."""
60
- with jax.ensure_compile_time_eval():
61
- return types.Density2DArray(
62
- array=jnp.zeros(self.density_shape),
63
- **dataclasses.asdict(self.density_metadata),
64
- )
65
-
66
-
67
- _GaussianLevelsetParamsAux = Tuple[
68
- float, float, int, Tuple[int, ...], tree_util.PyTreeDef
69
- ]
70
-
71
75
 
72
- def _flatten_gaussian_levelset_params(
73
- params: GaussianLevelsetParams,
74
- ) -> Tuple[Tuple[jnp.ndarray], _GaussianLevelsetParamsAux]:
75
- _, flat_metadata = tree_util.tree_flatten(params.density_metadata)
76
- return (
77
- (params.amplitude,),
78
- (
79
- params.length_scale_spacing_factor,
80
- params.length_scale_fwhm_factor,
81
- params.smoothing_factor,
82
- params.density_shape,
83
- flat_metadata,
84
- ),
85
- )
86
-
87
-
88
- def _unflatten_gaussian_levelset_params(
89
- aux: _GaussianLevelsetParamsAux,
90
- children: Tuple[jnp.ndarray],
91
- ) -> GaussianLevelsetParams:
92
- (amplitude,) = children
93
- (
94
- length_scale_spacing_factor,
95
- length_scale_fwhm_factor,
96
- smoothing_factor,
97
- density_shape,
98
- flat_metadata,
99
- ) = aux
100
-
101
- density_metadata = tree_util.tree_unflatten(flat_metadata, ())
102
- return GaussianLevelsetParams(
103
- amplitude=amplitude,
104
- length_scale_spacing_factor=length_scale_spacing_factor,
105
- length_scale_fwhm_factor=length_scale_fwhm_factor,
106
- smoothing_factor=smoothing_factor,
107
- density_shape=tuple(density_shape),
108
- density_metadata=density_metadata,
109
- )
110
-
111
-
112
- tree_util.register_pytree_node(
76
+ tree_util.register_dataclass(
113
77
  GaussianLevelsetParams,
114
- flatten_func=_flatten_gaussian_levelset_params,
115
- unflatten_func=_unflatten_gaussian_levelset_params,
78
+ data_fields=["latents", "metadata"],
79
+ meta_fields=[],
80
+ )
81
+ tree_util.register_dataclass(
82
+ GaussianLevelsetLatents,
83
+ data_fields=["amplitude"],
84
+ meta_fields=[],
85
+ )
86
+ tree_util.register_dataclass(
87
+ GaussianLevelsetMetadata,
88
+ data_fields=[
89
+ "length_scale_spacing_factor",
90
+ "length_scale_fwhm_factor",
91
+ "smoothing_factor",
92
+ "density_shape",
93
+ "density_metadata",
94
+ ],
95
+ meta_fields=[],
116
96
  )
117
-
118
-
119
97
  json_utils.register_custom_type(GaussianLevelsetParams)
98
+ json_utils.register_custom_type(GaussianLevelsetLatents)
99
+ json_utils.register_custom_type(GaussianLevelsetMetadata)
120
100
 
121
101
 
122
102
  def gaussian_levelset(
@@ -187,23 +167,21 @@ def gaussian_levelset(
187
167
  pad_width += ((0, 0),) if density.periodic[1] else ((1, 1),)
188
168
  amplitude = jnp.pad(amplitude, pad_width, mode="edge")
189
169
 
190
- density_metadata_dict = dataclasses.asdict(density)
191
- del density_metadata_dict["array"]
192
- density_metadata = base.Density2DMetadata(**density_metadata_dict)
193
- params = GaussianLevelsetParams(
194
- amplitude=amplitude,
170
+ latents = GaussianLevelsetLatents(amplitude=amplitude)
171
+ metadata = GaussianLevelsetMetadata(
195
172
  length_scale_spacing_factor=length_scale_spacing_factor,
196
173
  length_scale_fwhm_factor=length_scale_fwhm_factor,
197
174
  smoothing_factor=smoothing_factor,
198
175
  density_shape=density.shape,
199
- density_metadata=density_metadata,
176
+ density_metadata=base.Density2DMetadata.from_density(density),
200
177
  )
201
178
 
202
179
  def step_fn(
203
180
  _: int,
204
181
  params_and_state: Tuple[PyTree, PyTree],
205
182
  ) -> Tuple[PyTree, PyTree]:
206
- def loss_fn(params: GaussianLevelsetParams) -> jnp.ndarray:
183
+ def loss_fn(latents: GaussianLevelsetLatents) -> jnp.ndarray:
184
+ params = GaussianLevelsetParams(latents, metadata=metadata)
207
185
  density_from_params = to_density_fn(params, mask_gradient=False)
208
186
  return jnp.mean((density_from_params.array - target_array) ** 2)
209
187
 
@@ -213,13 +191,14 @@ def gaussian_levelset(
213
191
  params = optax.apply_updates(params, updates)
214
192
  return params, state
215
193
 
216
- state = init_optimizer.init(params)
217
- params, _ = jax.lax.fori_loop(
218
- 0, init_steps, body_fun=step_fn, init_val=(params, state)
194
+ state = init_optimizer.init(latents)
195
+ latents, _ = jax.lax.fori_loop(
196
+ 0, init_steps, body_fun=step_fn, init_val=(latents, state)
219
197
  )
220
198
 
221
- maxval = jnp.amax(jnp.abs(params.amplitude), axis=(-2, -1), keepdims=True)
222
- return dataclasses.replace(params, amplitude=params.amplitude / maxval)
199
+ maxval = jnp.amax(jnp.abs(latents.amplitude), axis=(-2, -1), keepdims=True)
200
+ latents = dataclasses.replace(latents, amplitude=latents.amplitude / maxval)
201
+ return GaussianLevelsetParams(latents=latents, metadata=metadata)
223
202
 
224
203
  def to_density_fn(
225
204
  params: GaussianLevelsetParams,
@@ -228,7 +207,7 @@ def gaussian_levelset(
228
207
  """Return a density from the latent parameters."""
229
208
  array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=0)
230
209
 
231
- example_density = params.example_density()
210
+ example_density = _example_density(params)
232
211
  lb = example_density.lower_bound
233
212
  ub = example_density.upper_bound
234
213
  array = lb + array * (ub - lb)
@@ -263,7 +242,7 @@ def gaussian_levelset(
263
242
  )
264
243
 
265
244
  # Normalize constraints to make them (somewhat) resolution-independent.
266
- example_density = params.example_density()
245
+ example_density = _example_density(params)
267
246
  length_scale = 0.5 * (
268
247
  example_density.minimum_spacing + example_density.minimum_width
269
248
  )
@@ -287,6 +266,15 @@ def gaussian_levelset(
287
266
  # -----------------------------------------------------------------------------
288
267
 
289
268
 
269
+ def _example_density(params: GaussianLevelsetParams) -> types.Density2DArray:
270
+ """Returns an example density with appropriate shape and metadata."""
271
+ with jax.ensure_compile_time_eval():
272
+ return types.Density2DArray(
273
+ array=jnp.zeros(params.metadata.density_shape),
274
+ **dataclasses.asdict(params.metadata.density_metadata),
275
+ )
276
+
277
+
290
278
  def _to_array(
291
279
  params: GaussianLevelsetParams,
292
280
  mask_gradient: bool,
@@ -308,7 +296,7 @@ def _to_array(
308
296
  Returns:
309
297
  The array.
310
298
  """
311
- example_density = params.example_density()
299
+ example_density = _example_density(params)
312
300
  periodic: Tuple[bool, bool] = example_density.periodic
313
301
  phi = _phi_from_params(
314
302
  params=params,
@@ -319,7 +307,7 @@ def _to_array(
319
307
  periodic=periodic,
320
308
  mask_gradient=mask_gradient,
321
309
  )
322
- return _downsample_spatial_dims(array, params.smoothing_factor)
310
+ return _downsample_spatial_dims(array, params.metadata.smoothing_factor)
323
311
 
324
312
 
325
313
  def _phi_from_params(
@@ -337,32 +325,35 @@ def _phi_from_params(
337
325
  The levelset array `phi`.
338
326
  """
339
327
  with jax.ensure_compile_time_eval():
340
- example_density = params.example_density()
328
+ example_density = _example_density(params)
341
329
  length_scale = 0.5 * (
342
330
  example_density.minimum_width + example_density.minimum_spacing
343
331
  )
344
- fwhm = length_scale * params.length_scale_fwhm_factor
332
+ fwhm = length_scale * params.metadata.length_scale_fwhm_factor
345
333
  sigma = fwhm / (2 * jnp.sqrt(2 * jnp.log(2)))
346
334
 
335
+ s_factor = params.metadata.smoothing_factor
347
336
  highres_i = (
348
337
  0.5
349
338
  + jnp.arange(
350
- params.smoothing_factor * (-pad_pixels),
351
- params.smoothing_factor * (pad_pixels + example_density.shape[-2]),
339
+ s_factor * (-pad_pixels),
340
+ s_factor * (pad_pixels + example_density.shape[-2]),
352
341
  )
353
- ) / params.smoothing_factor
342
+ ) / s_factor
354
343
  highres_j = (
355
344
  0.5
356
345
  + jnp.arange(
357
- params.smoothing_factor * (-pad_pixels),
358
- params.smoothing_factor * (pad_pixels + example_density.shape[-1]),
346
+ s_factor * (-pad_pixels),
347
+ s_factor * (pad_pixels + example_density.shape[-1]),
359
348
  )
360
- ) / params.smoothing_factor
349
+ ) / s_factor
361
350
 
362
351
  # Coordinates for the control points of the Gaussian radial basis functions.
363
352
  levelset_i, levelset_j = _control_point_coords(
364
- density_shape=params.density_shape[-2:], # type: ignore[arg-type]
365
- levelset_shape=params.amplitude.shape[-2:], # type: ignore[arg-type]
353
+ density_shape=params.metadata.density_shape[-2:], # type: ignore[arg-type]
354
+ levelset_shape=(
355
+ params.latents.amplitude.shape[-2:] # type: ignore[arg-type]
356
+ ),
366
357
  periodic=example_density.periodic,
367
358
  )
368
359
 
@@ -391,7 +382,7 @@ def _phi_from_params(
391
382
  levelset_i = levelset_i.flatten()
392
383
  levelset_j = levelset_j.flatten()
393
384
 
394
- amplitude = params.amplitude
385
+ amplitude = params.latents.amplitude
395
386
  if example_density.periodic[0]:
396
387
  amplitude = jnp.concat([amplitude] * 3, axis=-2)
397
388
  if example_density.periodic[1]:
@@ -410,8 +401,8 @@ def _phi_from_params(
410
401
  _, array = jax.lax.scan(scan_fn, (), xs=highres_i)
411
402
  array = jnp.moveaxis(array, 0, -2)
412
403
 
413
- assert array.shape[-2] % params.smoothing_factor == 0
414
- assert array.shape[-1] % params.smoothing_factor == 0
404
+ assert array.shape[-2] % s_factor == 0
405
+ assert array.shape[-1] % s_factor == 0
415
406
  array = symmetry.symmetrize(array, tuple(example_density.symmetries))
416
407
  return array
417
408
 
@@ -443,7 +434,7 @@ def _fixed_pixel_constraint(
443
434
  """
444
435
  array = _to_array(params, mask_gradient=mask_gradient, pad_pixels=pad_pixels)
445
436
 
446
- example_density = params.example_density()
437
+ example_density = _example_density(params)
447
438
  fixed_solid = jnp.zeros(example_density.shape[-2:], dtype=bool)
448
439
  fixed_void = jnp.zeros(example_density.shape[-2:], dtype=bool)
449
440
  if example_density.fixed_solid is not None:
@@ -491,9 +482,9 @@ def _levelset_constraints(
491
482
  beyond the boundaries of the parameterized density.
492
483
 
493
484
  Returns:
494
- The minimum length scale and minimum curvature constraint arrays.
485
+ The minimum length scale and minimum curvature constraint arrays.s
495
486
  """
496
- example_density = params.example_density()
487
+ example_density = _example_density(params)
497
488
  minimum_length_scale = 0.5 * (
498
489
  example_density.minimum_width + example_density.minimum_spacing
499
490
  )
@@ -512,9 +503,10 @@ def _levelset_constraints(
512
503
  )
513
504
 
514
505
  # Downsample so that constraints shape matches the density shape.
506
+ factor = params.metadata.smoothing_factor
515
507
  return (
516
- _downsample_spatial_dims(length_scale_constraint, params.smoothing_factor),
517
- _downsample_spatial_dims(curvature_constraint, params.smoothing_factor),
508
+ _downsample_spatial_dims(length_scale_constraint, factor),
509
+ _downsample_spatial_dims(curvature_constraint, factor),
518
510
  )
519
511
 
520
512
 
@@ -529,7 +521,7 @@ def _phi_derivatives_and_inverse_radius(
529
521
  pad_pixels=pad_pixels,
530
522
  )
531
523
 
532
- d = 1 / params.smoothing_factor
524
+ d = 1 / params.metadata.smoothing_factor
533
525
  phi_x, phi_y = jnp.gradient(phi, d, axis=(-2, -1))
534
526
  phi_xx, phi_yx = jnp.gradient(phi_x, d, axis=(-2, -1))
535
527
  phi_xy, phi_yy = jnp.gradient(phi_y, d, axis=(-2, -1))
@@ -13,26 +13,40 @@ from invrs_opt.parameterization import base
13
13
 
14
14
 
15
15
  @dataclasses.dataclass
16
- class PixelParams(base.ParameterizedDensity2DArrayBase):
17
- """Stores latent parameters of the direct pixel parameterization."""
16
+ class PixelParams(base.ParameterizedDensity2DArray):
17
+ latents: "PixelLatents"
18
+ metadata: None = None
18
19
 
19
- density: types.Density2DArray
20
20
 
21
+ @dataclasses.dataclass
22
+ class PixelLatents(base.LatentsBase):
23
+ """Stores latent parameters for the direct pixel parameterization."""
21
24
 
22
- tree_util.register_dataclass(PixelParams, data_fields=["density"], meta_fields=[])
25
+ density: types.Density2DArray
23
26
 
24
27
 
28
+ tree_util.register_dataclass(
29
+ PixelParams,
30
+ data_fields=["latents"],
31
+ meta_fields=[],
32
+ )
33
+ tree_util.register_dataclass(
34
+ PixelLatents,
35
+ data_fields=["density"],
36
+ meta_fields=[],
37
+ )
25
38
  json_utils.register_custom_type(PixelParams)
39
+ json_utils.register_custom_type(PixelLatents)
26
40
 
27
41
 
28
42
  def pixel() -> base.Density2DParameterization:
29
43
  """Return the direct pixel parameterization."""
30
44
 
31
45
  def from_density_fn(density: types.Density2DArray) -> PixelParams:
32
- return PixelParams(density=density)
46
+ return PixelParams(latents=PixelLatents(density=density))
33
47
 
34
48
  def to_density_fn(params: PixelParams) -> types.Density2DArray:
35
- return params.density
49
+ return params.latents.density
36
50
 
37
51
  def constraints_fn(params: PixelParams) -> jnp.ndarray:
38
52
  del params
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: invrs_opt
3
- Version: 0.8.1
3
+ Version: 0.9.0
4
4
  Summary: Algorithms for inverse design
5
5
  Author-email: "Martin F. Schubert" <mfschubert@gmail.com>
6
6
  Maintainer-email: "Martin F. Schubert" <mfschubert@gmail.com>
@@ -533,7 +533,7 @@ Requires-Dist: pytest-cov; extra == "tests"
533
533
  Requires-Dist: pytest-subtests; extra == "tests"
534
534
 
535
535
  # invrs-opt - Optimization algorithms for inverse design
536
- `v0.8.1`
536
+ `v0.9.0`
537
537
 
538
538
  ## Overview
539
539
 
@@ -0,0 +1,20 @@
1
+ invrs_opt/__init__.py,sha256=LkjQEBq5HYuP2WgS265RPBoIe48tcG4dxGfMxPOXa68,585
2
+ invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
+ invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
5
+ invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
+ invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
+ invrs_opt/optimizers/lbfgsb.py,sha256=8BPiEAqececL-zLnqrgN0CogGDkAd1tyAGndUB-kahc,36349
9
+ invrs_opt/optimizers/wrapped_optax.py,sha256=VXdCteT2kumqhP81l3p6QiEqwBffoUuJ3UjrAyX5ToA,13468
10
+ invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
+ invrs_opt/parameterization/base.py,sha256=BObzbz6efT2nBjib0_5BSdkCmFi2f0mcZ9VJYpDzO6Q,5278
12
+ invrs_opt/parameterization/filter_project.py,sha256=7Jb8JVENmBTdx3-XmI-VRm4aMjxg_Wtin8tMKKKxWvQ,4309
13
+ invrs_opt/parameterization/gaussian_levelset.py,sha256=N10UeezN4LdrlACxTEfJgJfhqh79zhUAY8XiSYvFYRY,24865
14
+ invrs_opt/parameterization/pixel.py,sha256=CCSuWF_bDebwmwTG33vmZwmvZoDJXYAR1kCrz4KnLt8,1607
15
+ invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
+ invrs_opt-0.9.0.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
+ invrs_opt-0.9.0.dist-info/METADATA,sha256=qS0kFxxwAsfBW1igDyJHel8xFDsQQWCiUh7f1QRRGEs,32633
18
+ invrs_opt-0.9.0.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
19
+ invrs_opt-0.9.0.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
+ invrs_opt-0.9.0.dist-info/RECORD,,
@@ -1,20 +0,0 @@
1
- invrs_opt/__init__.py,sha256=tjbabqQKp5y1I5eqDa7onsy7E58P9nTk_Ifn12oMfZM,585
2
- invrs_opt/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
- invrs_opt/experimental/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
4
- invrs_opt/experimental/client.py,sha256=t4XxnditYbM9DWZeyBPj0Sa2acvkikT0ybhUdmH2r-Y,4852
5
- invrs_opt/experimental/labels.py,sha256=dQDAMPyFMV6mXnMy295z8Ap205DRdVzysXny_Be8FmY,562
6
- invrs_opt/optimizers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- invrs_opt/optimizers/base.py,sha256=-wNwH0J475y8FzB5aLAkc_1602LvYeF4Hddr9OiBkDY,1276
8
- invrs_opt/optimizers/lbfgsb.py,sha256=2NcyFllUM5CrjOZLEBocYAlCrkWt1fiRGJrxg05waog,35149
9
- invrs_opt/optimizers/wrapped_optax.py,sha256=c4iSsXHF2cJZaZhdxLpyRy26q5IxBTo4-6sY8IOLT38,13259
10
- invrs_opt/parameterization/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
11
- invrs_opt/parameterization/base.py,sha256=k2VXZorMnkmaIWWY_G2CvS0NtxYojfvKlPccDGZyGNA,3779
12
- invrs_opt/parameterization/filter_project.py,sha256=D2w8xrg34V8ysIbbr1RPvegM5WoLdz8QKUCGi80ieOI,3466
13
- invrs_opt/parameterization/gaussian_levelset.py,sha256=DllvpkBpVxuDFQFu941f8v_Xh2D0m9Eh-JxBH5afNHU,25221
14
- invrs_opt/parameterization/pixel.py,sha256=yWdGpf0x6Om_-7CYOcHTgNK9UF-XGAPM7GbOBNijnBw,1305
15
- invrs_opt/parameterization/transforms.py,sha256=8GzaIsUuuXvMCLiqAEEfxmi9qE9KqHzbuTj_m0GjH3w,8216
16
- invrs_opt-0.8.1.dist-info/LICENSE,sha256=IMF9i4xIpgCADf0U-V1cuf9HBmqWQd3qtI3FSuyW4zE,26526
17
- invrs_opt-0.8.1.dist-info/METADATA,sha256=4qLsuDLp3pNn5C1TTPnF0qwWps7MeN_Uc3rKsBLbPrk,32633
18
- invrs_opt-0.8.1.dist-info/WHEEL,sha256=HiCZjzuy6Dw0hdX5R3LCFPDmFS4BWl8H-8W39XfmgX4,91
19
- invrs_opt-0.8.1.dist-info/top_level.txt,sha256=hOziS2uJ_NgwaW9yhtOfeuYnm1X7vobPBcp_6eVWKfM,10
20
- invrs_opt-0.8.1.dist-info/RECORD,,