aspire-inference 0.1.0a7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
aspire/history.py ADDED
@@ -0,0 +1,148 @@
1
+ from __future__ import annotations
2
+
3
+ import copy
4
+ from dataclasses import dataclass, field
5
+
6
+ import matplotlib.pyplot as plt
7
+ from matplotlib.figure import Figure
8
+
9
+ from .utils import recursively_save_to_h5_file
10
+
11
+
12
+ @dataclass
13
+ class History:
14
+ """Base class for storing history of a sampler."""
15
+
16
+ def save(self, h5_file, path="history"):
17
+ """Save the history to an HDF5 file."""
18
+ dictionary = copy.deepcopy(self.__dict__)
19
+ recursively_save_to_h5_file(h5_file, path, dictionary)
20
+
21
+
22
+ @dataclass
23
+ class FlowHistory(History):
24
+ training_loss: list[float] = field(default_factory=list)
25
+ validation_loss: list[float] = field(default_factory=list)
26
+
27
+ def plot_loss(self) -> Figure:
28
+ """Plot the training and validation loss."""
29
+ fig = plt.figure()
30
+ plt.plot(self.training_loss, label="Training loss")
31
+ plt.plot(self.validation_loss, label="Validation loss")
32
+ plt.legend()
33
+ plt.xlabel("Epoch")
34
+ plt.ylabel("Loss")
35
+ return fig
36
+
37
+ def save(self, h5_file, path="flow_history"):
38
+ """Save the history to an HDF5 file."""
39
+ super().save(h5_file, path=path)
40
+
41
+
42
+ @dataclass
43
+ class SMCHistory(History):
44
+ log_norm_ratio: list[float] = field(default_factory=list)
45
+ log_norm_ratio_var: list[float] = field(default_factory=list)
46
+ beta: list[float] = field(default_factory=list)
47
+ ess: list[float] = field(default_factory=list)
48
+ ess_target: list[float] = field(default_factory=list)
49
+ eff_target: list[float] = field(default_factory=list)
50
+ mcmc_autocorr: list[float] = field(default_factory=list)
51
+ mcmc_acceptance: list[float] = field(default_factory=list)
52
+
53
+ def save(self, h5_file, path="smc_history"):
54
+ """Save the history to an HDF5 file."""
55
+ super().save(h5_file, path=path)
56
+
57
+ def plot_beta(self, ax=None) -> Figure | None:
58
+ if ax is None:
59
+ fig, ax = plt.subplots()
60
+ else:
61
+ fig = None
62
+ ax.plot(self.beta)
63
+ ax.set_xlabel("Iteration")
64
+ ax.set_ylabel(r"$\beta$")
65
+ return fig
66
+
67
+ def plot_log_norm_ratio(self, ax=None) -> Figure | None:
68
+ if ax is None:
69
+ fig, ax = plt.subplots()
70
+ else:
71
+ fig = None
72
+ ax.plot(self.log_norm_ratio)
73
+ ax.set_xlabel("Iteration")
74
+ ax.set_ylabel("Log evidence ratio")
75
+ return fig
76
+
77
+ def plot_ess(self, ax=None) -> Figure | None:
78
+ if ax is None:
79
+ fig, ax = plt.subplots()
80
+ else:
81
+ fig = None
82
+ ax.plot(self.ess)
83
+ ax.set_xlabel("Iteration")
84
+ ax.set_ylabel("ESS")
85
+ return fig
86
+
87
+ def plot_ess_target(self, ax=None) -> Figure | None:
88
+ if ax is None:
89
+ fig, ax = plt.subplots()
90
+ else:
91
+ fig = None
92
+ ax.plot(self.ess_target)
93
+ ax.set_xlabel("Iteration")
94
+ ax.set_ylabel("ESS target")
95
+ return fig
96
+
97
+ def plot_eff_target(self, ax=None) -> Figure | None:
98
+ if ax is None:
99
+ fig, ax = plt.subplots()
100
+ else:
101
+ fig = None
102
+ ax.plot(self.eff_target)
103
+ ax.set_xlabel("Iteration")
104
+ ax.set_ylabel("Efficiency target")
105
+ return fig
106
+
107
+ def plot_mcmc_acceptance(self, ax=None) -> Figure | None:
108
+ if ax is None:
109
+ fig, ax = plt.subplots()
110
+ else:
111
+ fig = None
112
+ ax.plot(self.mcmc_acceptance)
113
+ ax.set_xlabel("Iteration")
114
+ ax.set_ylabel("MCMC Acceptance")
115
+ return fig
116
+
117
+ def plot_mcmc_autocorr(self, ax=None) -> Figure | None:
118
+ if ax is None:
119
+ fig, ax = plt.subplots()
120
+ else:
121
+ fig = None
122
+ ax.plot(self.mcmc_autocorr)
123
+ ax.set_xlabel("Iteration")
124
+ ax.set_ylabel("MCMC Autocorr")
125
+ return fig
126
+
127
+ def plot(self, fig: Figure | None = None) -> Figure:
128
+ methods = [
129
+ self.plot_beta,
130
+ self.plot_log_norm_ratio,
131
+ self.plot_ess,
132
+ self.plot_ess_target,
133
+ self.plot_eff_target,
134
+ self.plot_mcmc_acceptance,
135
+ ]
136
+
137
+ if fig is None:
138
+ fig, axs = plt.subplots(len(methods), 1, sharex=True)
139
+ else:
140
+ axs = fig.axes
141
+
142
+ for method, ax in zip(methods, axs):
143
+ method(ax)
144
+
145
+ for ax in axs[:-1]:
146
+ ax.set_xlabel("")
147
+
148
+ return fig
aspire/plot.py ADDED
@@ -0,0 +1,50 @@
1
+ import copy
2
+
3
+
4
+ def plot_comparison(
5
+ *samples, parameters=None, per_samples_kwargs=None, labels=None, **kwargs
6
+ ):
7
+ """
8
+ Plot a comparison of multiple samples.
9
+ """
10
+ default_kwargs = dict(
11
+ density=True,
12
+ bins=30,
13
+ color="C0",
14
+ smooth=1.0,
15
+ plot_datapoints=True,
16
+ plot_density=False,
17
+ hist_kwargs=dict(density=True, color="C0"),
18
+ )
19
+ default_kwargs.update(kwargs)
20
+
21
+ if per_samples_kwargs is None:
22
+ per_samples_kwargs = [{}] * len(samples)
23
+
24
+ fig = None
25
+ for i, sample in enumerate(samples):
26
+ kwds = copy.deepcopy(default_kwargs)
27
+ color = per_samples_kwargs[i].pop("color", f"C{i}")
28
+ kwds["color"] = color
29
+ kwds["hist_kwargs"]["color"] = color
30
+ kwds.update(per_samples_kwargs[i])
31
+ fig = sample.plot_corner(fig=fig, parameters=parameters, **kwds)
32
+
33
+ if labels:
34
+ fig.legend(
35
+ labels=labels,
36
+ loc="upper right",
37
+ bbox_to_anchor=(0.9, 0.9),
38
+ bbox_transform=fig.transFigure,
39
+ )
40
+ return fig
41
+
42
+
43
+ def plot_history_comparison(*histories):
44
+ # Assert that all histories are of the same type
45
+ if not all(isinstance(h, histories[0].__class__) for h in histories):
46
+ raise ValueError("All histories must be of the same type")
47
+ fig = histories[0].plot()
48
+ for history in histories[1:]:
49
+ fig = history.plot(fig=fig)
50
+ return fig
File without changes
@@ -0,0 +1,94 @@
1
+ import logging
2
+ from typing import Any, Callable
3
+
4
+ from ..flows.base import Flow
5
+ from ..samples import Samples
6
+ from ..transforms import IdentityTransform
7
+ from ..utils import track_calls
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class Sampler:
13
+ """Base class for all samplers.
14
+
15
+ Parameters
16
+ ----------
17
+ log_likelihood : Callable
18
+ The log likelihood function.
19
+ log_prior : Callable
20
+ The log prior function.
21
+ dims : int
22
+ The number of dimensions.
23
+ flow : Flow
24
+ The flow object.
25
+ xp : Callable
26
+ The array backend to use.
27
+ parameters : list[str] | None
28
+ The list of parameter names. If None, any samples objects will not
29
+ have the parameters names specified.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ log_likelihood: Callable,
35
+ log_prior: Callable,
36
+ dims: int,
37
+ prior_flow: Flow,
38
+ xp: Callable,
39
+ dtype: Any | str | None = None,
40
+ parameters: list[str] | None = None,
41
+ preconditioning_transform: Callable | None = None,
42
+ ):
43
+ self.prior_flow = prior_flow
44
+ self._log_likelihood = log_likelihood
45
+ self.log_prior = log_prior
46
+ self.dims = dims
47
+ self.xp = xp
48
+ self.dtype = dtype
49
+ self.parameters = parameters
50
+ self.history = None
51
+ self.n_likelihood_evaluations = 0
52
+ if preconditioning_transform is None:
53
+ self.preconditioning_transform = IdentityTransform(xp=self.xp)
54
+ else:
55
+ self.preconditioning_transform = preconditioning_transform
56
+
57
+ def fit_preconditioning_transform(self, x):
58
+ """Fit the data transform to the data."""
59
+ x = self.preconditioning_transform.xp.asarray(x)
60
+ return self.preconditioning_transform.fit(x)
61
+
62
+ @track_calls
63
+ def sample(self, n_samples: int) -> Samples:
64
+ raise NotImplementedError
65
+
66
+ def log_likelihood(self, samples: Samples) -> Samples:
67
+ """Computes the log likelihood of the samples.
68
+
69
+ Also tracks the number of likelihood evaluations.
70
+ """
71
+ self.n_likelihood_evaluations += len(samples)
72
+ return self._log_likelihood(samples)
73
+
74
+ def config_dict(self, include_sample_calls: bool = True) -> dict:
75
+ """
76
+ Returns a dictionary with the configuration of the sampler.
77
+
78
+ Parameters
79
+ ----------
80
+ include_sample_calls : bool
81
+ Whether to include the sample calls in the configuration.
82
+ Default is True.
83
+ """
84
+ config = {}
85
+ if include_sample_calls:
86
+ if hasattr(self, "sample") and hasattr(self.sample, "calls"):
87
+ config["sample_calls"] = self.sample.calls.to_dict(
88
+ list_to_dict=True
89
+ )
90
+ else:
91
+ logger.warning(
92
+ "Sampler does not have a sample method with calls attribute."
93
+ )
94
+ return config
@@ -0,0 +1,22 @@
1
+ from ..samples import Samples
2
+ from ..utils import track_calls
3
+ from .base import Sampler
4
+
5
+
6
+ class ImportanceSampler(Sampler):
7
+ @track_calls
8
+ def sample(self, n_samples: int) -> Samples:
9
+ x, log_q = self.prior_flow.sample_and_log_prob(n_samples)
10
+ samples = Samples(
11
+ x,
12
+ log_q=log_q,
13
+ xp=self.xp,
14
+ parameters=self.parameters,
15
+ dtype=self.dtype,
16
+ )
17
+ samples.log_prior = samples.array_to_namespace(self.log_prior(samples))
18
+ samples.log_likelihood = samples.array_to_namespace(
19
+ self.log_likelihood(samples)
20
+ )
21
+ samples.compute_weights()
22
+ return samples
@@ -0,0 +1,160 @@
1
+ import numpy as np
2
+
3
+ from ..samples import Samples, to_numpy
4
+ from ..utils import track_calls
5
+ from .base import Sampler
6
+
7
+
8
+ class MCMCSampler(Sampler):
9
+ def draw_initial_samples(self, n_samples: int) -> Samples:
10
+ """Draw initial samples from the prior flow."""
11
+ # Flow may propose samples outside prior bounds, so we may need
12
+ # to try multiple times to get enough valid samples.
13
+ n_samples_drawn = 0
14
+ samples = None
15
+ while n_samples_drawn < n_samples:
16
+ n_to_draw = n_samples - n_samples_drawn
17
+ x, log_q = self.prior_flow.sample_and_log_prob(n_to_draw)
18
+ new_samples = Samples(x, xp=self.xp, log_q=log_q, dtype=self.dtype)
19
+ new_samples.log_prior = new_samples.array_to_namespace(
20
+ self.log_prior(new_samples)
21
+ )
22
+ valid = self.xp.isfinite(new_samples.log_prior)
23
+ n_valid = self.xp.sum(valid)
24
+ if n_valid > 0:
25
+ if samples is None:
26
+ samples = new_samples[valid]
27
+ else:
28
+ samples = Samples.concatenate(
29
+ [samples, new_samples[valid]]
30
+ )
31
+ n_samples_drawn += n_valid
32
+
33
+ if n_samples_drawn > n_samples:
34
+ samples = samples[:n_samples]
35
+
36
+ samples.log_likelihood = samples.array_to_namespace(
37
+ self.log_likelihood(samples)
38
+ )
39
+ return samples
40
+
41
+ def log_prob(self, z):
42
+ """Compute the log probability of the samples.
43
+
44
+ Input samples are in the transformed space.
45
+ """
46
+ x, log_abs_det_jacobian = self.preconditioning_transform.inverse(z)
47
+ samples = Samples(x, xp=self.xp, dtype=self.dtype)
48
+ samples.log_prior = self.log_prior(samples)
49
+ samples.log_likelihood = self.log_likelihood(samples)
50
+ log_prob = (
51
+ samples.log_likelihood
52
+ + samples.log_prior
53
+ + samples.array_to_namespace(log_abs_det_jacobian)
54
+ )
55
+ return to_numpy(log_prob).flatten()
56
+
57
+
58
+ class Emcee(MCMCSampler):
59
+ @track_calls
60
+ def sample(
61
+ self,
62
+ n_samples: int,
63
+ nwalkers: int = None,
64
+ nsteps: int = 500,
65
+ rng=None,
66
+ discard=0,
67
+ **kwargs,
68
+ ) -> Samples:
69
+ from emcee import EnsembleSampler
70
+
71
+ nwalkers = nwalkers or n_samples
72
+ self.sampler = EnsembleSampler(
73
+ nwalkers,
74
+ self.dims,
75
+ log_prob_fn=self.log_prob,
76
+ vectorize=True,
77
+ )
78
+
79
+ rng = rng or np.random.default_rng()
80
+
81
+ samples = self.draw_initial_samples(nwalkers)
82
+ p0 = samples.x
83
+
84
+ z0 = to_numpy(self.preconditioning_transform.fit(p0))
85
+
86
+ self.sampler.run_mcmc(z0, nsteps, **kwargs)
87
+
88
+ z = self.sampler.get_chain(flat=True, discard=discard)
89
+ x = self.preconditioning_transform.inverse(z)[0]
90
+
91
+ x_evidence, log_q = self.prior_flow.sample_and_log_prob(n_samples)
92
+ samples_evidence = Samples(x_evidence, log_q=log_q, xp=self.xp)
93
+ samples_evidence.log_prior = self.log_prior(samples_evidence)
94
+ samples_evidence.log_likelihood = self.log_likelihood(samples_evidence)
95
+ samples_evidence.compute_weights()
96
+
97
+ samples_mcmc = Samples(
98
+ x, xp=self.xp, parameters=self.parameters, dtype=self.dtype
99
+ )
100
+ samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
101
+ self.log_prior(samples_mcmc)
102
+ )
103
+ samples_mcmc.log_likelihood = samples_mcmc.array_to_namespace(
104
+ self.log_likelihood(samples_mcmc)
105
+ )
106
+ samples_mcmc.log_evidence = samples_mcmc.array_to_namespace(
107
+ samples_evidence.log_evidence
108
+ )
109
+ samples_mcmc.log_evidence_error = samples_mcmc.array_to_namespace(
110
+ samples_evidence.log_evidence_error
111
+ )
112
+
113
+ return samples_mcmc
114
+
115
+
116
+ class MiniPCN(MCMCSampler):
117
+ @track_calls
118
+ def sample(
119
+ self,
120
+ n_samples,
121
+ rng=None,
122
+ target_acceptance_rate=0.234,
123
+ n_steps=100,
124
+ thin=1,
125
+ burnin=0,
126
+ last_step_only=False,
127
+ step_fn="tpcn",
128
+ ):
129
+ from minipcn import Sampler
130
+
131
+ rng = rng or np.random.default_rng()
132
+ p0 = self.draw_initial_samples(n_samples).x
133
+
134
+ z0 = to_numpy(self.preconditioning_transform.fit(p0))
135
+
136
+ self.sampler = Sampler(
137
+ log_prob_fn=self.log_prob,
138
+ step_fn=step_fn,
139
+ rng=rng,
140
+ dims=self.dims,
141
+ target_acceptance_rate=target_acceptance_rate,
142
+ )
143
+
144
+ chain, history = self.sampler.sample(z0, n_steps=n_steps)
145
+
146
+ if last_step_only:
147
+ z = chain[-1]
148
+ else:
149
+ z = chain[burnin::thin].reshape(-1, self.dims)
150
+
151
+ x = self.preconditioning_transform.inverse(z)[0]
152
+
153
+ samples_mcmc = Samples(x, xp=self.xp, parameters=self.parameters)
154
+ samples_mcmc.log_prior = samples_mcmc.array_to_namespace(
155
+ self.log_prior(samples_mcmc)
156
+ )
157
+ samples_mcmc.log_likelihood = samples_mcmc.array_to_namespace(
158
+ self.log_likelihood(samples_mcmc)
159
+ )
160
+ return samples_mcmc
File without changes