flowMC 0.3.0__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {flowMC-0.3.0/src/flowMC.egg-info → flowmc-0.3.2}/PKG-INFO +1 -1
  2. {flowMC-0.3.0 → flowmc-0.3.2}/setup.cfg +1 -1
  3. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/Sampler.py +17 -9
  4. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/base.py +18 -11
  5. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/common.py +2 -7
  6. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/realNVP.py +7 -5
  7. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/rqSpline.py +7 -5
  8. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/Gaussian_random_walk.py +3 -1
  9. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/HMC.py +3 -1
  10. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/MALA.py +4 -2
  11. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/NF_proposal.py +5 -3
  12. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/base.py +2 -1
  13. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/flowHMC.py +7 -6
  14. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/strategy/base.py +5 -1
  15. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/strategy/global_tuning.py +24 -11
  16. flowmc-0.3.2/src/flowMC/strategy/optimization.py +120 -0
  17. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/utils/PythonFunctionWrap.py +2 -1
  18. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/utils/postprocessing.py +3 -2
  19. {flowMC-0.3.0 → flowmc-0.3.2/src/flowMC.egg-info}/PKG-INFO +1 -1
  20. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC.egg-info/SOURCES.txt +1 -0
  21. {flowMC-0.3.0 → flowmc-0.3.2}/LICENSE +0 -0
  22. {flowMC-0.3.0 → flowmc-0.3.2}/README.md +0 -0
  23. {flowMC-0.3.0 → flowmc-0.3.2}/pyproject.toml +0 -0
  24. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/__init__.py +0 -0
  25. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/nfmodel/__init__.py +0 -0
  26. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/proposal/__init__.py +0 -0
  27. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/strategy/__init__.py +0 -0
  28. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/strategy/importance_sampling.py +0 -0
  29. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/utils/EvolutionaryOptimizer.py +2 -2
  30. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC/utils/__init__.py +0 -0
  31. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC.egg-info/dependency_links.txt +0 -0
  32. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC.egg-info/requires.txt +0 -0
  33. {flowMC-0.3.0 → flowmc-0.3.2}/src/flowMC.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flowMC
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Normalizing flow exhanced sampler in jax
5
5
  Home-page: https://github.com/kazewong/flowMC
6
6
  Author: Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
@@ -1,6 +1,6 @@
1
1
  [metadata]
2
2
  name = flowMC
3
- version = 0.3.0
3
+ version = 0.3.2
4
4
  author = Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
5
5
  author_email = kazewong.physics@gmail.com
6
6
  url = https://github.com/kazewong/flowMC
@@ -1,15 +1,15 @@
1
1
  import pickle
2
- import jax
3
- import jax.numpy as jnp
4
- from jaxtyping import Array, Int, Float, PRNGKeyArray
5
- from tqdm import tqdm
2
+
6
3
  import equinox as eqx
4
+ import jax.numpy as jnp
7
5
  import optax
8
- from flowMC.proposal.NF_proposal import NFProposal
9
- from flowMC.proposal.base import ProposalBase
6
+ from jaxtyping import Array, Float, Int, PRNGKeyArray
7
+
10
8
  from flowMC.nfmodel.base import NFModel
9
+ from flowMC.proposal.base import ProposalBase
10
+ from flowMC.proposal.NF_proposal import NFProposal
11
11
  from flowMC.strategy.base import Strategy
12
- from flowMC.strategy.global_tuning import GlobalTuning, GlobalSampling
12
+ from flowMC.strategy.global_tuning import GlobalSampling, GlobalTuning
13
13
 
14
14
 
15
15
  class Sampler:
@@ -131,7 +131,7 @@ class Sampler:
131
131
  )
132
132
  self.optim_state = self.optim.init(eqx.filter(self.nf_model, eqx.is_array))
133
133
 
134
- self.strategies = [
134
+ default_strategies = [
135
135
  GlobalTuning(
136
136
  n_dim=self.n_dim,
137
137
  n_chains=self.n_chains,
@@ -161,7 +161,15 @@ class Sampler:
161
161
  if kwargs.get("strategies") is not None:
162
162
  kwargs_strategies = kwargs.get("strategies")
163
163
  assert isinstance(kwargs_strategies, list)
164
- self.strategies = kwargs_strategies
164
+ self.strategies = []
165
+ for strategy in kwargs_strategies:
166
+ if isinstance(strategy, str):
167
+ if strategy.lower() == "default":
168
+ self.strategies += default_strategies
169
+ else:
170
+ self.strategies.append(strategy)
171
+ else:
172
+ self.strategies = default_strategies
165
173
 
166
174
  self.summary = {}
167
175
 
@@ -1,13 +1,14 @@
1
- from abc import abstractmethod
2
1
  import copy
2
+ from abc import abstractmethod
3
+ from typing import Optional, overload
4
+
3
5
  import equinox as eqx
4
- from typing import overload, Optional
5
- from typing_extensions import Self
6
- from jaxtyping import Array, PRNGKeyArray, Float
7
- import optax
8
- from tqdm import trange, tqdm
9
- import jax.numpy as jnp
10
6
  import jax
7
+ import jax.numpy as jnp
8
+ import optax
9
+ from jaxtyping import Array, Float, PRNGKeyArray
10
+ from tqdm import tqdm, trange
11
+ from typing_extensions import Self
11
12
 
12
13
 
13
14
  class NFModel(eqx.Module):
@@ -102,9 +103,10 @@ class NFModel(eqx.Module):
102
103
  model (eqx.Model): Updated model.
103
104
  opt_state (optax.OptState): Updated optimizer state.
104
105
  """
105
- loss, grads = self.loss_fn(x)
106
+ model = self
107
+ loss, grads = model.loss_fn(x)
106
108
  updates, state = optim.update(grads, state)
107
- model = eqx.apply_updates(self, updates)
109
+ model = eqx.apply_updates(model, updates)
108
110
  return loss, model, state
109
111
 
110
112
  def train_epoch(
@@ -162,16 +164,21 @@ class NFModel(eqx.Module):
162
164
  pbar = trange(num_epochs, desc="Training NF", miniters=int(num_epochs / 10))
163
165
  else:
164
166
  pbar = range(num_epochs)
167
+
165
168
  best_model = model = self
169
+ best_state = state
166
170
  best_loss = 1e9
167
171
  for epoch in pbar:
168
172
  # Use a separate PRNG key to permute image data during shuffling
169
173
  rng, input_rng = jax.random.split(rng)
170
174
  # Run an optimization step over a training batch
171
- value, model, state = model.train_epoch(input_rng, optim, state, data, batch_size)
175
+ value, model, state = model.train_epoch(
176
+ input_rng, optim, state, data, batch_size
177
+ )
172
178
  loss_values = loss_values.at[epoch].set(value)
173
179
  if loss_values[epoch] < best_loss:
174
180
  best_model = model
181
+ best_state = state
175
182
  best_loss = loss_values[epoch]
176
183
  if verbose:
177
184
  assert isinstance(pbar, tqdm)
@@ -182,7 +189,7 @@ class NFModel(eqx.Module):
182
189
  if epoch == num_epochs:
183
190
  pbar.set_description(f"Training NF, current loss: {value:.3f}")
184
191
 
185
- return rng, best_model, state, loss_values
192
+ return rng, best_model, best_state, loss_values
186
193
 
187
194
 
188
195
  class Bijection(eqx.Module):
@@ -1,17 +1,12 @@
1
1
  from typing import Callable, List, Tuple
2
2
 
3
- import jax
4
- import jax.numpy as jnp
5
- from jaxtyping import Array
6
3
  import equinox as eqx
7
-
8
- from flowMC.nfmodel.base import Bijection, Distribution
9
-
10
4
  import jax
11
5
  import jax.numpy as jnp
12
- import equinox as eqx
13
6
  from jaxtyping import Array, Float, PRNGKeyArray
14
7
 
8
+ from flowMC.nfmodel.base import Bijection, Distribution
9
+
15
10
 
16
11
  class MLP(eqx.Module):
17
12
  r"""Multilayer perceptron.
@@ -1,13 +1,15 @@
1
+ from functools import partial
1
2
  from typing import List, Tuple
3
+
4
+ import equinox as eqx
2
5
  import jax
3
6
  import jax.numpy as jnp
4
7
  import numpy as np
5
- import equinox as eqx
6
- from flowMC.nfmodel.base import NFModel, Distribution
7
- from flowMC.nfmodel.common import MLP, MaskedCouplingLayer, MLPAffine, Gaussian
8
- from jaxtyping import Array, Float, PRNGKeyArray
9
- from functools import partial
10
8
  import optax
9
+ from jaxtyping import Array, Float, PRNGKeyArray
10
+
11
+ from flowMC.nfmodel.base import Distribution, NFModel
12
+ from flowMC.nfmodel.common import MLP, Gaussian, MaskedCouplingLayer, MLPAffine
11
13
 
12
14
 
13
15
  class AffineCoupling(eqx.Module):
@@ -1,11 +1,13 @@
1
+ from functools import partial
2
+
3
+ import equinox as eqx
1
4
  import jax
2
5
  import jax.numpy as jnp
3
- from jaxtyping import Array, PRNGKeyArray, Float
4
- import equinox as eqx
6
+ from jaxtyping import Array, Float, PRNGKeyArray
5
7
 
6
- from flowMC.nfmodel.base import NFModel, Bijection, Distribution
7
- from flowMC.nfmodel.common import MaskedCouplingLayer, ScalarAffine, MLP, Gaussian
8
- from functools import partial
8
+ from flowMC.nfmodel.base import Bijection, Distribution, NFModel
9
+ from flowMC.nfmodel.common import (MLP, Gaussian, MaskedCouplingLayer,
10
+ ScalarAffine)
9
11
 
10
12
 
11
13
  @partial(jax.vmap, in_axes=(0, None, None))
@@ -1,9 +1,11 @@
1
1
  from typing import Callable
2
+
2
3
  import jax
3
4
  import jax.numpy as jnp
5
+ from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
4
6
  from tqdm import tqdm
7
+
5
8
  from flowMC.proposal.base import ProposalBase
6
- from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray
7
9
 
8
10
 
9
11
  class GaussianRandomWalk(ProposalBase):
@@ -1,9 +1,11 @@
1
1
  from typing import Callable
2
+
2
3
  import jax
3
4
  import jax.numpy as jnp
5
+ from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
4
6
  from tqdm import tqdm
7
+
5
8
  from flowMC.proposal.base import ProposalBase
6
- from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray
7
9
 
8
10
 
9
11
  class HMC(ProposalBase):
@@ -1,11 +1,13 @@
1
+ from functools import partialmethod
1
2
  from typing import Callable
3
+
2
4
  import jax
3
5
  import jax.numpy as jnp
4
6
  from jax.scipy.stats import multivariate_normal
7
+ from jaxtyping import Array, Bool, Float, Int, PRNGKeyArray, PyTree
5
8
  from tqdm import tqdm
9
+
6
10
  from flowMC.proposal.base import ProposalBase
7
- from functools import partialmethod
8
- from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray, Bool
9
11
 
10
12
 
11
13
  class MALA(ProposalBase):
@@ -1,12 +1,14 @@
1
+ from math import ceil
2
+ from typing import Callable
3
+
1
4
  import jax
2
5
  import jax.numpy as jnp
3
6
  from jax import random
7
+ from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
4
8
  from tqdm import tqdm
9
+
5
10
  from flowMC.nfmodel.base import NFModel
6
- from typing import Callable
7
11
  from flowMC.proposal.base import ProposalBase
8
- from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
9
- from math import ceil
10
12
 
11
13
 
12
14
  @jax.tree_util.register_pytree_node_class
@@ -1,8 +1,9 @@
1
1
  from abc import abstractmethod
2
2
  from typing import Callable
3
+
3
4
  import jax
4
5
  import jax.numpy as jnp
5
- from jaxtyping import PyTree, Array, Float, Int, PRNGKeyArray
6
+ from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
6
7
 
7
8
 
8
9
  @jax.tree_util.register_pytree_node_class
@@ -1,14 +1,15 @@
1
+ from math import ceil
2
+ from typing import Callable
3
+
1
4
  import jax
2
5
  import jax.numpy as jnp
6
+ from jax import random
7
+ from jaxtyping import Array, Float, Int, PRNGKeyArray, PyTree
8
+ from tqdm import tqdm
9
+
3
10
  from flowMC.nfmodel.base import NFModel
4
- from jaxtyping import Array, PRNGKeyArray, PyTree
5
- from typing import Callable
6
11
  from flowMC.proposal.HMC import HMC
7
12
  from flowMC.proposal.NF_proposal import NFProposal
8
- from jaxtyping import Array, Float, Int, PRNGKeyArray
9
- from math import ceil
10
- from jax import random
11
- from tqdm import tqdm
12
13
 
13
14
  ###################################
14
15
  # This is not in production yet
@@ -1,7 +1,11 @@
1
1
  from abc import abstractmethod
2
+
3
+ from jaxtyping import Array, Float, PRNGKeyArray, PyTree
4
+
2
5
  from flowMC.proposal.base import ProposalBase
3
6
  from flowMC.proposal.NF_proposal import NFProposal
4
- from jaxtyping import Array, Float, PRNGKeyArray, PyTree
7
+
8
+
5
9
  class Strategy:
6
10
  """
7
11
  Base class for strategies, which are basically wrapper blocks that modify the state of the sampler
@@ -1,13 +1,14 @@
1
- from flowMC.proposal.base import ProposalBase
2
- from flowMC.proposal.NF_proposal import NFProposal
3
- from flowMC.strategy.base import Strategy
1
+ import equinox as eqx
4
2
  import jax
5
3
  import jax.numpy as jnp
6
- from jaxtyping import Array, Float, PRNGKeyArray, PyTree
7
4
  import optax
8
- import equinox as eqx
5
+ from jaxtyping import Array, Float, PRNGKeyArray, PyTree
9
6
  from tqdm import tqdm
10
7
 
8
+ from flowMC.proposal.base import ProposalBase
9
+ from flowMC.proposal.NF_proposal import NFProposal
10
+ from flowMC.strategy.base import Strategy
11
+
11
12
 
12
13
  class GlobalTuning(Strategy):
13
14
 
@@ -76,6 +77,7 @@ class GlobalTuning(Strategy):
76
77
  summary["local_accs"] = jnp.empty((self.n_chains, 0))
77
78
  summary["global_accs"] = jnp.empty((self.n_chains, 0))
78
79
  summary["loss_vals"] = jnp.empty((0, self.n_epochs))
80
+ current_position = initial_position
79
81
  for _ in tqdm(
80
82
  range(self.n_loop),
81
83
  desc="Global Tuning",
@@ -90,7 +92,7 @@ class GlobalTuning(Strategy):
90
92
  ) = local_sampler.sample(
91
93
  rng_keys_mcmc,
92
94
  self.n_local_steps,
93
- initial_position,
95
+ current_position,
94
96
  data,
95
97
  verbose=self.verbose,
96
98
  )
@@ -112,6 +114,8 @@ class GlobalTuning(Strategy):
112
114
  axis=1,
113
115
  )
114
116
 
117
+ current_position = summary["chains"][:, -1]
118
+
115
119
  rng_key, rng_keys_nf = jax.random.split(rng_key)
116
120
  positions = summary["chains"][:, :: self.train_thinning]
117
121
  chain_size = positions.shape[0] * positions.shape[1]
@@ -142,8 +146,8 @@ class GlobalTuning(Strategy):
142
146
 
143
147
  (
144
148
  rng_keys_nf,
145
- global_sampler.model,
146
- self.optim_state,
149
+ model,
150
+ optim_state,
147
151
  loss_values,
148
152
  ) = global_sampler.model.train(
149
153
  rng_keys_nf,
@@ -154,6 +158,8 @@ class GlobalTuning(Strategy):
154
158
  self.batch_size,
155
159
  self.verbose,
156
160
  )
161
+ global_sampler.model = model
162
+ self.optim_state = optim_state
157
163
  summary["loss_vals"] = jnp.append(
158
164
  summary["loss_vals"],
159
165
  loss_values.reshape(1, -1),
@@ -168,7 +174,7 @@ class GlobalTuning(Strategy):
168
174
  ) = global_sampler.sample(
169
175
  rng_keys_nf,
170
176
  self.n_global_steps,
171
- positions[:, -1],
177
+ current_position,
172
178
  data,
173
179
  verbose=self.verbose,
174
180
  )
@@ -190,7 +196,9 @@ class GlobalTuning(Strategy):
190
196
  axis=1,
191
197
  )
192
198
 
193
- return rng_key, summary['chains'][:, -1], local_sampler, global_sampler, summary
199
+ current_position = summary["chains"][:, -1]
200
+
201
+ return rng_key, current_position, local_sampler, global_sampler, summary
194
202
 
195
203
 
196
204
  class GlobalSampling(Strategy):
@@ -239,6 +247,7 @@ class GlobalSampling(Strategy):
239
247
  summary["local_accs"] = jnp.empty((self.n_chains, 0))
240
248
  summary["global_accs"] = jnp.empty((self.n_chains, 0))
241
249
 
250
+ current_position = initial_position
242
251
  for _ in tqdm(
243
252
  range(self.n_loop),
244
253
  desc="Global Sampling",
@@ -253,7 +262,7 @@ class GlobalSampling(Strategy):
253
262
  ) = local_sampler.sample(
254
263
  rng_keys_mcmc,
255
264
  self.n_local_steps,
256
- initial_position,
265
+ current_position,
257
266
  data,
258
267
  verbose=self.verbose,
259
268
  )
@@ -275,6 +284,8 @@ class GlobalSampling(Strategy):
275
284
  axis=1,
276
285
  )
277
286
 
287
+ current_position = summary["chains"][:, -1]
288
+
278
289
  rng_key, rng_keys_nf = jax.random.split(rng_key)
279
290
  (
280
291
  rng_keys_nf,
@@ -306,4 +317,6 @@ class GlobalSampling(Strategy):
306
317
  axis=1,
307
318
  )
308
319
 
320
+ current_position = summary["chains"][:, -1]
321
+
309
322
  return rng_key, summary['chains'][:, -1], local_sampler, global_sampler, summary
@@ -0,0 +1,120 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import optax
4
+ from jaxtyping import Array, Float, PRNGKeyArray, PyTree
5
+ from typing import Callable
6
+
7
+ from flowMC.proposal.base import ProposalBase
8
+ from flowMC.proposal.NF_proposal import NFProposal
9
+ from flowMC.strategy.base import Strategy
10
+
11
+
12
+ class optimization_Adam(Strategy):
13
+
14
+ """
15
+ Optimize a set of chains using Adam optimization.
16
+ Note that if the posterior can go to infinity, this optimization scheme is likely to return NaNs.
17
+
18
+ """
19
+
20
+ n_steps: int = 100
21
+ learning_rate: float = 1e-2
22
+ noise_level: float = 10
23
+
24
+ @property
25
+ def __name__(self):
26
+ return "AdamOptimization"
27
+
28
+ def __init__(
29
+ self,
30
+ **kwargs,
31
+ ):
32
+ class_keys = list(self.__class__.__annotations__.keys())
33
+ for key, value in kwargs.items():
34
+ if key in class_keys:
35
+ if not key.startswith("__"):
36
+ setattr(self, key, value)
37
+
38
+ self.solver = optax.chain(
39
+ optax.adam(learning_rate=self.learning_rate),
40
+ )
41
+
42
+ def __call__(
43
+ self,
44
+ rng_key: PRNGKeyArray,
45
+ local_sampler: ProposalBase,
46
+ global_sampler: NFProposal,
47
+ initial_position: Float[Array, " n_chain n_dim"],
48
+ data: dict,
49
+ ) -> tuple[
50
+ PRNGKeyArray, Float[Array, " n_chain n_dim"], ProposalBase, NFProposal, PyTree
51
+ ]:
52
+ def loss_fn(params: Float[Array, " n_dim"]) -> Float:
53
+ return -local_sampler.logpdf(params, data)
54
+
55
+ grad_fn = jax.jit(jax.grad(loss_fn))
56
+
57
+ def _kernel(carry, data):
58
+ key, params, opt_state = carry
59
+
60
+ key, subkey = jax.random.split(key)
61
+ grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
62
+ updates, opt_state = self.solver.update(grad, opt_state, params)
63
+ params = optax.apply_updates(params, updates)
64
+ return (key, params, opt_state), None
65
+
66
+ def _single_optimize(
67
+ key: PRNGKeyArray,
68
+ initial_position: Float[Array, " n_dim"],
69
+ ) -> Float[Array, " n_dim"]:
70
+
71
+ opt_state = self.solver.init(initial_position)
72
+
73
+ (key, params, opt_state), _ = jax.lax.scan(
74
+ _kernel,
75
+ (key, initial_position, opt_state),
76
+ jnp.arange(self.n_steps),
77
+ )
78
+
79
+ return params # type: ignore
80
+
81
+ print("Using Adam optimization")
82
+ rng_key, subkey = jax.random.split(rng_key)
83
+ keys = jax.random.split(subkey, initial_position.shape[0])
84
+ optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
85
+ keys, initial_position
86
+ )
87
+
88
+ summary = {}
89
+ summary["initial_positions"] = initial_position
90
+ summary["initial_log_prob"] = local_sampler.logpdf_vmap(initial_position, data)
91
+ summary["final_positions"] = optimized_positions
92
+ summary["final_log_prob"] = local_sampler.logpdf_vmap(optimized_positions, data)
93
+
94
+ if jnp.isinf(summary['final_log_prob']).any() or jnp.isnan(summary['final_log_prob']).any():
95
+ print("Warning: Optimization accessed infinite or NaN log-probabilities.")
96
+
97
+ return rng_key, optimized_positions, local_sampler, global_sampler, summary
98
+
99
+
100
+ class Evosax_CMA_ES(Strategy):
101
+
102
+ def __init__(
103
+ self,
104
+ **kwargs,
105
+ ):
106
+ class_keys = list(self.__class__.__annotations__.keys())
107
+ for key, value in kwargs.items():
108
+ if key in class_keys:
109
+ if not key.startswith("__"):
110
+ setattr(self, key, value)
111
+
112
+ def __call__(
113
+ self,
114
+ rng_key: PRNGKeyArray,
115
+ local_sampler: ProposalBase,
116
+ global_sampler: NFProposal,
117
+ initial_position: Array,
118
+ data: dict,
119
+ ) -> tuple[PRNGKeyArray, Array, ProposalBase, NFProposal, PyTree]:
120
+ raise NotImplementedError
@@ -1,6 +1,7 @@
1
1
  import warnings
2
2
  from functools import wraps
3
- from typing import Any, List, Dict, Callable, Iterable, Tuple, NamedTuple, Union
3
+ from typing import (Any, Callable, Dict, Iterable, List, NamedTuple, Tuple,
4
+ Union)
4
5
 
5
6
  Array = Any
6
7
  PyTree = Union[Array, Iterable[Array], Dict[Any, Array], NamedTuple]
@@ -1,7 +1,8 @@
1
- import matplotlib.pyplot as plt
2
1
  import jax.numpy as jnp
2
+ import matplotlib.pyplot as plt
3
3
  # from flowMC.sampler.Sampler import Sampler
4
- from jaxtyping import Float, Array
4
+ from jaxtyping import Array, Float
5
+
5
6
 
6
7
  def plot_summary(sampler: object, training: bool = False, **plotkwargs) -> None:
7
8
  """
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flowMC
3
- Version: 0.3.0
3
+ Version: 0.3.2
4
4
  Summary: Normalizing flow exhanced sampler in jax
5
5
  Home-page: https://github.com/kazewong/flowMC
6
6
  Author: Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
@@ -25,6 +25,7 @@ src/flowMC/strategy/__init__.py
25
25
  src/flowMC/strategy/base.py
26
26
  src/flowMC/strategy/global_tuning.py
27
27
  src/flowMC/strategy/importance_sampling.py
28
+ src/flowMC/strategy/optimization.py
28
29
  src/flowMC/utils/EvolutionaryOptimizer.py
29
30
  src/flowMC/utils/PythonFunctionWrap.py
30
31
  src/flowMC/utils/__init__.py
File without changes
File without changes
File without changes
File without changes
@@ -1,8 +1,8 @@
1
- from evosax import CMA_ES
2
1
  import jax
3
2
  import jax.numpy as jnp
4
- from jaxtyping import PRNGKeyArray
5
3
  import tqdm
4
+ from evosax import CMA_ES
5
+ from jaxtyping import PRNGKeyArray
6
6
 
7
7
 
8
8
  class EvolutionaryOptimizer: