flowMC 0.3.1__tar.gz → 0.3.2__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (33) hide show
  1. {flowMC-0.3.1/src/flowMC.egg-info → flowmc-0.3.2}/PKG-INFO +1 -1
  2. {flowMC-0.3.1 → flowmc-0.3.2}/setup.cfg +1 -1
  3. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/Sampler.py +10 -2
  4. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/base.py +1 -0
  5. flowmc-0.3.2/src/flowMC/strategy/optimization.py +120 -0
  6. {flowMC-0.3.1 → flowmc-0.3.2/src/flowMC.egg-info}/PKG-INFO +1 -1
  7. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC.egg-info/SOURCES.txt +1 -0
  8. {flowMC-0.3.1 → flowmc-0.3.2}/LICENSE +0 -0
  9. {flowMC-0.3.1 → flowmc-0.3.2}/README.md +0 -0
  10. {flowMC-0.3.1 → flowmc-0.3.2}/pyproject.toml +0 -0
  11. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/__init__.py +0 -0
  12. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/__init__.py +0 -0
  13. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/common.py +0 -0
  14. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/realNVP.py +0 -0
  15. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/nfmodel/rqSpline.py +0 -0
  16. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/Gaussian_random_walk.py +0 -0
  17. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/HMC.py +0 -0
  18. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/MALA.py +0 -0
  19. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/NF_proposal.py +0 -0
  20. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/__init__.py +0 -0
  21. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/base.py +0 -0
  22. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/proposal/flowHMC.py +0 -0
  23. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/strategy/__init__.py +0 -0
  24. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/strategy/base.py +0 -0
  25. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/strategy/global_tuning.py +0 -0
  26. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/strategy/importance_sampling.py +0 -0
  27. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/utils/EvolutionaryOptimizer.py +0 -0
  28. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/utils/PythonFunctionWrap.py +0 -0
  29. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/utils/__init__.py +0 -0
  30. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC/utils/postprocessing.py +0 -0
  31. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC.egg-info/dependency_links.txt +0 -0
  32. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC.egg-info/requires.txt +0 -0
  33. {flowMC-0.3.1 → flowmc-0.3.2}/src/flowMC.egg-info/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flowMC
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: Normalizing flow exhanced sampler in jax
5
5
  Home-page: https://github.com/kazewong/flowMC
6
6
  Author: Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
@@ -1,6 +1,6 @@
1
1
  [metadata]
2
2
  name = flowMC
3
- version = 0.3.1
3
+ version = 0.3.2
4
4
  author = Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
5
5
  author_email = kazewong.physics@gmail.com
6
6
  url = https://github.com/kazewong/flowMC
@@ -131,7 +131,7 @@ class Sampler:
131
131
  )
132
132
  self.optim_state = self.optim.init(eqx.filter(self.nf_model, eqx.is_array))
133
133
 
134
- self.strategies = [
134
+ default_strategies = [
135
135
  GlobalTuning(
136
136
  n_dim=self.n_dim,
137
137
  n_chains=self.n_chains,
@@ -161,7 +161,15 @@ class Sampler:
161
161
  if kwargs.get("strategies") is not None:
162
162
  kwargs_strategies = kwargs.get("strategies")
163
163
  assert isinstance(kwargs_strategies, list)
164
- self.strategies = kwargs_strategies
164
+ self.strategies = []
165
+ for strategy in kwargs_strategies:
166
+ if isinstance(strategy, str):
167
+ if strategy.lower() == "default":
168
+ self.strategies += default_strategies
169
+ else:
170
+ self.strategies.append(strategy)
171
+ else:
172
+ self.strategies = default_strategies
165
173
 
166
174
  self.summary = {}
167
175
 
@@ -166,6 +166,7 @@ class NFModel(eqx.Module):
166
166
  pbar = range(num_epochs)
167
167
 
168
168
  best_model = model = self
169
+ best_state = state
169
170
  best_loss = 1e9
170
171
  for epoch in pbar:
171
172
  # Use a separate PRNG key to permute image data during shuffling
@@ -0,0 +1,120 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ import optax
4
+ from jaxtyping import Array, Float, PRNGKeyArray, PyTree
5
+ from typing import Callable
6
+
7
+ from flowMC.proposal.base import ProposalBase
8
+ from flowMC.proposal.NF_proposal import NFProposal
9
+ from flowMC.strategy.base import Strategy
10
+
11
+
12
+ class optimization_Adam(Strategy):
13
+
14
+ """
15
+ Optimize a set of chains using Adam optimization.
16
+ Note that if the posterior can go to infinity, this optimization scheme is likely to return NaNs.
17
+
18
+ """
19
+
20
+ n_steps: int = 100
21
+ learning_rate: float = 1e-2
22
+ noise_level: float = 10
23
+
24
+ @property
25
+ def __name__(self):
26
+ return "AdamOptimization"
27
+
28
+ def __init__(
29
+ self,
30
+ **kwargs,
31
+ ):
32
+ class_keys = list(self.__class__.__annotations__.keys())
33
+ for key, value in kwargs.items():
34
+ if key in class_keys:
35
+ if not key.startswith("__"):
36
+ setattr(self, key, value)
37
+
38
+ self.solver = optax.chain(
39
+ optax.adam(learning_rate=self.learning_rate),
40
+ )
41
+
42
+ def __call__(
43
+ self,
44
+ rng_key: PRNGKeyArray,
45
+ local_sampler: ProposalBase,
46
+ global_sampler: NFProposal,
47
+ initial_position: Float[Array, " n_chain n_dim"],
48
+ data: dict,
49
+ ) -> tuple[
50
+ PRNGKeyArray, Float[Array, " n_chain n_dim"], ProposalBase, NFProposal, PyTree
51
+ ]:
52
+ def loss_fn(params: Float[Array, " n_dim"]) -> Float:
53
+ return -local_sampler.logpdf(params, data)
54
+
55
+ grad_fn = jax.jit(jax.grad(loss_fn))
56
+
57
+ def _kernel(carry, data):
58
+ key, params, opt_state = carry
59
+
60
+ key, subkey = jax.random.split(key)
61
+ grad = grad_fn(params) * (1 + jax.random.normal(subkey) * self.noise_level)
62
+ updates, opt_state = self.solver.update(grad, opt_state, params)
63
+ params = optax.apply_updates(params, updates)
64
+ return (key, params, opt_state), None
65
+
66
+ def _single_optimize(
67
+ key: PRNGKeyArray,
68
+ initial_position: Float[Array, " n_dim"],
69
+ ) -> Float[Array, " n_dim"]:
70
+
71
+ opt_state = self.solver.init(initial_position)
72
+
73
+ (key, params, opt_state), _ = jax.lax.scan(
74
+ _kernel,
75
+ (key, initial_position, opt_state),
76
+ jnp.arange(self.n_steps),
77
+ )
78
+
79
+ return params # type: ignore
80
+
81
+ print("Using Adam optimization")
82
+ rng_key, subkey = jax.random.split(rng_key)
83
+ keys = jax.random.split(subkey, initial_position.shape[0])
84
+ optimized_positions = jax.vmap(_single_optimize, in_axes=(0, 0))(
85
+ keys, initial_position
86
+ )
87
+
88
+ summary = {}
89
+ summary["initial_positions"] = initial_position
90
+ summary["initial_log_prob"] = local_sampler.logpdf_vmap(initial_position, data)
91
+ summary["final_positions"] = optimized_positions
92
+ summary["final_log_prob"] = local_sampler.logpdf_vmap(optimized_positions, data)
93
+
94
+ if jnp.isinf(summary['final_log_prob']).any() or jnp.isnan(summary['final_log_prob']).any():
95
+ print("Warning: Optimization accessed infinite or NaN log-probabilities.")
96
+
97
+ return rng_key, optimized_positions, local_sampler, global_sampler, summary
98
+
99
+
100
+ class Evosax_CMA_ES(Strategy):
101
+
102
+ def __init__(
103
+ self,
104
+ **kwargs,
105
+ ):
106
+ class_keys = list(self.__class__.__annotations__.keys())
107
+ for key, value in kwargs.items():
108
+ if key in class_keys:
109
+ if not key.startswith("__"):
110
+ setattr(self, key, value)
111
+
112
+ def __call__(
113
+ self,
114
+ rng_key: PRNGKeyArray,
115
+ local_sampler: ProposalBase,
116
+ global_sampler: NFProposal,
117
+ initial_position: Array,
118
+ data: dict,
119
+ ) -> tuple[PRNGKeyArray, Array, ProposalBase, NFProposal, PyTree]:
120
+ raise NotImplementedError
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: flowMC
3
- Version: 0.3.1
3
+ Version: 0.3.2
4
4
  Summary: Normalizing flow exhanced sampler in jax
5
5
  Home-page: https://github.com/kazewong/flowMC
6
6
  Author: Kaze Wong, Marylou Gabrié, Dan Foreman-Mackey
@@ -25,6 +25,7 @@ src/flowMC/strategy/__init__.py
25
25
  src/flowMC/strategy/base.py
26
26
  src/flowMC/strategy/global_tuning.py
27
27
  src/flowMC/strategy/importance_sampling.py
28
+ src/flowMC/strategy/optimization.py
28
29
  src/flowMC/utils/EvolutionaryOptimizer.py
29
30
  src/flowMC/utils/PythonFunctionWrap.py
30
31
  src/flowMC/utils/__init__.py
File without changes
File without changes
File without changes
File without changes