gpjax 0.10.2__tar.gz → 0.11.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {gpjax-0.10.2 → gpjax-0.11.0}/PKG-INFO +2 -61
- {gpjax-0.10.2 → gpjax-0.11.0}/README.md +0 -59
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/scripts/sharp_bits_figure.py +3 -3
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/sharp_bits.md +1 -1
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/barycentres.py +10 -10
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/classification.py +11 -15
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/collapsed_vi.py +4 -4
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/constructing_new_kernels.py +26 -5
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/deep_kernels.py +3 -3
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/graph_kernels.py +4 -4
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/intro_to_gps.py +27 -20
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/intro_to_kernels.py +12 -17
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/likelihoods_guide.py +6 -9
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/oceanmodelling.py +186 -98
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/poisson.py +2 -4
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/regression.py +6 -6
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/uncollapsed_vi.py +4 -70
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/yacht.py +3 -3
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/__init__.py +1 -1
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/distributions.py +101 -111
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/fit.py +2 -2
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/approximations/rff.py +1 -1
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/base.py +2 -2
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/matern12.py +2 -2
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/matern32.py +2 -2
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/matern52.py +2 -2
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/rbf.py +3 -3
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/utils.py +3 -5
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/likelihoods.py +36 -35
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/mean_functions.py +3 -2
- gpjax-0.11.0/gpjax/numpyro_extras.py +106 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/objectives.py +4 -6
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/parameters.py +15 -13
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/variational_families.py +5 -1
- {gpjax-0.10.2 → gpjax-0.11.0}/pyproject.toml +5 -2
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/integration_tests.py +24 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_dataset.py +30 -30
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_gaussian_distribution.py +24 -43
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_gps.py +9 -9
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_likelihoods.py +13 -15
- gpjax-0.11.0/tests/test_numpyro_extras.py +127 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_parameters.py +1 -5
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_variational_families.py +8 -8
- gpjax-0.10.2/examples/oak_example.py +0 -216
- {gpjax-0.10.2 → gpjax-0.11.0}/.cursorrules +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/config.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/codecov.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/labels.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/pull_request_template.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/release-drafter.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/build_docs.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/integration.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/pr_greeting.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/ruff.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/stale_prs.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/test_docs.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/tests.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/.gitignore +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/CITATION.bib +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/LICENSE.txt +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/Makefile +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/CODE_OF_CONDUCT.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/GOVERNANCE.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/contributing.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/design.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/index.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/index.rst +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/installation.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/javascripts/katex.js +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/refs.bib +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/scripts/gen_examples.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/scripts/gen_pages.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/scripts/notebook_converter.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/GP.pdf +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/GP.svg +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/bijector_figure.svg +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/css/gpjax_theme.css +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/favicon.ico +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/gpjax.mplstyle +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/gpjax_logo.pdf +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/gpjax_logo.svg +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/jaxkern/lato.ttf +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/jaxkern/logo.png +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/jaxkern/logo.svg +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/jaxkern/main.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/step_size_figure.png +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/step_size_figure.svg +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/stylesheets/extra.css +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/docs/stylesheets/permalinks.css +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/backend.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/barycentres/barycentre_gp.gif +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/data/max_tempeature_switzerland.csv +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/data/yacht_hydrodynamics.data +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/gpjax.mplstyle +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/intro_to_gps/decomposed_mll.png +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/intro_to_gps/generating_process.png +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/examples/utils.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/citation.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/dataset.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/gps.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/integrators.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/__init__.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/approximations/__init__.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/base.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/__init__.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/base.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/basis_functions.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/constant_diagonal.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/dense.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/diagonal.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/eigen.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/non_euclidean/__init__.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/non_euclidean/graph.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/non_euclidean/utils.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/nonstationary/__init__.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/nonstationary/arccosine.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/nonstationary/linear.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/nonstationary/polynomial.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/__init__.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/periodic.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/powered_exponential.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/white.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/lower_cholesky.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/scan.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/typing.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/mkdocs.yml +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/static/CONTRIBUTING.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/static/paper.bib +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/static/paper.md +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/static/paper.pdf +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/__init__.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/conftest.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_citations.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_fit.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_integrators.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/__init__.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_approximations.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_base.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_computation.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_non_euclidean.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_nonstationary.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_stationary.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_utils.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_lower_cholesky.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_markdown.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_mean_functions.py +0 -0
- {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_objectives.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.11.0
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
6
6
|
Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
|
|
@@ -24,8 +24,8 @@ Requires-Dist: jax>=0.5.0
|
|
|
24
24
|
Requires-Dist: jaxlib>=0.5.0
|
|
25
25
|
Requires-Dist: jaxtyping>0.2.10
|
|
26
26
|
Requires-Dist: numpy>=2.0.0
|
|
27
|
+
Requires-Dist: numpyro
|
|
27
28
|
Requires-Dist: optax>0.2.1
|
|
28
|
-
Requires-Dist: tensorflow-probability>=0.24.0
|
|
29
29
|
Requires-Dist: tqdm>4.66.2
|
|
30
30
|
Description-Content-Type: text/markdown
|
|
31
31
|
|
|
@@ -138,65 +138,6 @@ jupytext --to notebook example.py
|
|
|
138
138
|
jupytext --to py:percent example.ipynb
|
|
139
139
|
```
|
|
140
140
|
|
|
141
|
-
# Simple example
|
|
142
|
-
|
|
143
|
-
Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.
|
|
144
|
-
|
|
145
|
-
```python
|
|
146
|
-
from jax import config
|
|
147
|
-
|
|
148
|
-
config.update("jax_enable_x64", True)
|
|
149
|
-
|
|
150
|
-
import gpjax as gpx
|
|
151
|
-
from jax import grad, jit
|
|
152
|
-
import jax.numpy as jnp
|
|
153
|
-
import jax.random as jr
|
|
154
|
-
import optax as ox
|
|
155
|
-
|
|
156
|
-
key = jr.key(123)
|
|
157
|
-
|
|
158
|
-
f = lambda x: 10 * jnp.sin(x)
|
|
159
|
-
|
|
160
|
-
n = 50
|
|
161
|
-
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
|
|
162
|
-
y = f(x) + jr.normal(key, shape=(n,1))
|
|
163
|
-
D = gpx.Dataset(X=x, y=y)
|
|
164
|
-
|
|
165
|
-
# Construct the prior
|
|
166
|
-
meanf = gpx.mean_functions.Zero()
|
|
167
|
-
kernel = gpx.kernels.RBF()
|
|
168
|
-
prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
|
|
169
|
-
|
|
170
|
-
# Define a likelihood
|
|
171
|
-
likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
|
|
172
|
-
|
|
173
|
-
# Construct the posterior
|
|
174
|
-
posterior = prior * likelihood
|
|
175
|
-
|
|
176
|
-
# Define an optimiser
|
|
177
|
-
optimiser = ox.adam(learning_rate=1e-2)
|
|
178
|
-
|
|
179
|
-
# Obtain Type 2 MLEs of the hyperparameters
|
|
180
|
-
opt_posterior, history = gpx.fit(
|
|
181
|
-
model=posterior,
|
|
182
|
-
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
|
|
183
|
-
train_data=D,
|
|
184
|
-
optim=optimiser,
|
|
185
|
-
num_iters=500,
|
|
186
|
-
safe=True,
|
|
187
|
-
key=key,
|
|
188
|
-
)
|
|
189
|
-
|
|
190
|
-
# Infer the predictive posterior distribution
|
|
191
|
-
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
|
|
192
|
-
latent_dist = opt_posterior(xtest, D)
|
|
193
|
-
predictive_dist = opt_posterior.likelihood(latent_dist)
|
|
194
|
-
|
|
195
|
-
# Obtain the predictive mean and standard deviation
|
|
196
|
-
pred_mean = predictive_dist.mean()
|
|
197
|
-
pred_std = predictive_dist.stddev()
|
|
198
|
-
```
|
|
199
|
-
|
|
200
141
|
# Installation
|
|
201
142
|
|
|
202
143
|
## Stable version
|
|
@@ -107,65 +107,6 @@ jupytext --to notebook example.py
|
|
|
107
107
|
jupytext --to py:percent example.ipynb
|
|
108
108
|
```
|
|
109
109
|
|
|
110
|
-
# Simple example
|
|
111
|
-
|
|
112
|
-
Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.
|
|
113
|
-
|
|
114
|
-
```python
|
|
115
|
-
from jax import config
|
|
116
|
-
|
|
117
|
-
config.update("jax_enable_x64", True)
|
|
118
|
-
|
|
119
|
-
import gpjax as gpx
|
|
120
|
-
from jax import grad, jit
|
|
121
|
-
import jax.numpy as jnp
|
|
122
|
-
import jax.random as jr
|
|
123
|
-
import optax as ox
|
|
124
|
-
|
|
125
|
-
key = jr.key(123)
|
|
126
|
-
|
|
127
|
-
f = lambda x: 10 * jnp.sin(x)
|
|
128
|
-
|
|
129
|
-
n = 50
|
|
130
|
-
x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
|
|
131
|
-
y = f(x) + jr.normal(key, shape=(n,1))
|
|
132
|
-
D = gpx.Dataset(X=x, y=y)
|
|
133
|
-
|
|
134
|
-
# Construct the prior
|
|
135
|
-
meanf = gpx.mean_functions.Zero()
|
|
136
|
-
kernel = gpx.kernels.RBF()
|
|
137
|
-
prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
|
|
138
|
-
|
|
139
|
-
# Define a likelihood
|
|
140
|
-
likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
|
|
141
|
-
|
|
142
|
-
# Construct the posterior
|
|
143
|
-
posterior = prior * likelihood
|
|
144
|
-
|
|
145
|
-
# Define an optimiser
|
|
146
|
-
optimiser = ox.adam(learning_rate=1e-2)
|
|
147
|
-
|
|
148
|
-
# Obtain Type 2 MLEs of the hyperparameters
|
|
149
|
-
opt_posterior, history = gpx.fit(
|
|
150
|
-
model=posterior,
|
|
151
|
-
objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
|
|
152
|
-
train_data=D,
|
|
153
|
-
optim=optimiser,
|
|
154
|
-
num_iters=500,
|
|
155
|
-
safe=True,
|
|
156
|
-
key=key,
|
|
157
|
-
)
|
|
158
|
-
|
|
159
|
-
# Infer the predictive posterior distribution
|
|
160
|
-
xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
|
|
161
|
-
latent_dist = opt_posterior(xtest, D)
|
|
162
|
-
predictive_dist = opt_posterior.likelihood(latent_dist)
|
|
163
|
-
|
|
164
|
-
# Obtain the predictive mean and standard deviation
|
|
165
|
-
pred_mean = predictive_dist.mean()
|
|
166
|
-
pred_std = predictive_dist.stddev()
|
|
167
|
-
```
|
|
168
|
-
|
|
169
110
|
# Installation
|
|
170
111
|
|
|
171
112
|
## Stable version
|
|
@@ -69,12 +69,12 @@ ax.set_xlim(-0.07, 0.25)
|
|
|
69
69
|
plt.savefig("../_static/step_size_figure.png", bbox_inches="tight")
|
|
70
70
|
|
|
71
71
|
# %%
|
|
72
|
-
import
|
|
72
|
+
import numpyro.distributions.transforms as npt
|
|
73
73
|
|
|
74
|
-
bij =
|
|
74
|
+
bij = npt.ExpTransform()
|
|
75
75
|
|
|
76
76
|
x = np.linspace(0.05, 3.0, 6)
|
|
77
|
-
y = np.asarray(bij.
|
|
77
|
+
y = np.asarray(bij.inv(x))
|
|
78
78
|
lval = 0.5
|
|
79
79
|
rval = 0.52
|
|
80
80
|
|
|
@@ -80,7 +80,7 @@ this value that we apply gradient updates to. When we wish to recover the constr
|
|
|
80
80
|
value, we apply the inverse of the bijector, which is the exponential function in this
|
|
81
81
|
case. This gives us back the blue cross.
|
|
82
82
|
|
|
83
|
-
In GPJax, we supply bijective functions using [
|
|
83
|
+
In GPJax, we supply bijective functions using [Numpyro](https://num.pyro.ai/en/stable/distributions.html#transforms).
|
|
84
84
|
|
|
85
85
|
|
|
86
86
|
## Positive-definiteness
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.7
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -41,7 +41,7 @@ import jax.random as jr
|
|
|
41
41
|
import jax.scipy.linalg as jsl
|
|
42
42
|
from jaxtyping import install_import_hook
|
|
43
43
|
import matplotlib.pyplot as plt
|
|
44
|
-
import
|
|
44
|
+
import numpyro.distributions as npd
|
|
45
45
|
|
|
46
46
|
from examples.utils import use_mpl_style
|
|
47
47
|
|
|
@@ -161,7 +161,7 @@ plt.show()
|
|
|
161
161
|
|
|
162
162
|
|
|
163
163
|
# %%
|
|
164
|
-
def fit_gp(x: jax.Array, y: jax.Array) ->
|
|
164
|
+
def fit_gp(x: jax.Array, y: jax.Array) -> npd.MultivariateNormal:
|
|
165
165
|
if y.ndim == 1:
|
|
166
166
|
y = y.reshape(-1, 1)
|
|
167
167
|
D = gpx.Dataset(X=x, y=y)
|
|
@@ -204,9 +204,9 @@ def sqrtm(A: jax.Array):
|
|
|
204
204
|
|
|
205
205
|
|
|
206
206
|
def wasserstein_barycentres(
|
|
207
|
-
distributions: tp.List[
|
|
207
|
+
distributions: tp.List[npd.MultivariateNormal], weights: jax.Array
|
|
208
208
|
):
|
|
209
|
-
covariances = [d.
|
|
209
|
+
covariances = [d.covariance_matrix for d in distributions]
|
|
210
210
|
cov_stack = jnp.stack(covariances)
|
|
211
211
|
stack_sqrt = jax.vmap(sqrtm)(cov_stack)
|
|
212
212
|
|
|
@@ -231,7 +231,7 @@ def wasserstein_barycentres(
|
|
|
231
231
|
# %%
|
|
232
232
|
weights = jnp.ones((n_datasets,)) / n_datasets
|
|
233
233
|
|
|
234
|
-
means = jnp.stack([d.mean
|
|
234
|
+
means = jnp.stack([d.mean for d in posterior_preds])
|
|
235
235
|
barycentre_mean = jnp.tensordot(weights, means, axes=1)
|
|
236
236
|
|
|
237
237
|
step_fn = jax.jit(wasserstein_barycentres(posterior_preds, weights))
|
|
@@ -242,7 +242,7 @@ barycentre_covariance, sequence = jax.lax.scan(
|
|
|
242
242
|
)
|
|
243
243
|
L = jnp.linalg.cholesky(barycentre_covariance)
|
|
244
244
|
|
|
245
|
-
barycentre_process =
|
|
245
|
+
barycentre_process = npd.MultivariateNormal(barycentre_mean, scale_tril=L)
|
|
246
246
|
|
|
247
247
|
# %% [markdown]
|
|
248
248
|
# ## Plotting the result
|
|
@@ -254,7 +254,7 @@ barycentre_process = tfd.MultivariateNormalTriL(barycentre_mean, L)
|
|
|
254
254
|
|
|
255
255
|
# %%
|
|
256
256
|
def plot(
|
|
257
|
-
dist:
|
|
257
|
+
dist: npd.MultivariateNormal,
|
|
258
258
|
ax,
|
|
259
259
|
color: str,
|
|
260
260
|
label: str = None,
|
|
@@ -262,8 +262,8 @@ def plot(
|
|
|
262
262
|
linewidth: float = 1.0,
|
|
263
263
|
zorder: int = 0,
|
|
264
264
|
):
|
|
265
|
-
mu = dist.mean
|
|
266
|
-
sigma = dist.
|
|
265
|
+
mu = dist.mean
|
|
266
|
+
sigma = jnp.sqrt(dist.variance)
|
|
267
267
|
ax.plot(xtest, mu, linewidth=linewidth, color=color, label=label, zorder=zorder)
|
|
268
268
|
ax.fill_between(
|
|
269
269
|
xtest.squeeze(),
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.7
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -37,8 +37,8 @@ from jaxtyping import (
|
|
|
37
37
|
install_import_hook,
|
|
38
38
|
)
|
|
39
39
|
import matplotlib.pyplot as plt
|
|
40
|
+
import numpyro.distributions as npd
|
|
40
41
|
import optax as ox
|
|
41
|
-
import tensorflow_probability.substrates.jax as tfp
|
|
42
42
|
|
|
43
43
|
from examples.utils import use_mpl_style
|
|
44
44
|
from gpjax.lower_cholesky import lower_cholesky
|
|
@@ -50,7 +50,6 @@ with install_import_hook("gpjax", "beartype.beartype"):
|
|
|
50
50
|
import gpjax as gpx
|
|
51
51
|
|
|
52
52
|
|
|
53
|
-
tfd = tfp.distributions
|
|
54
53
|
identity_matrix = jnp.eye
|
|
55
54
|
|
|
56
55
|
# set the default style for plotting
|
|
@@ -120,7 +119,6 @@ print(type(posterior))
|
|
|
120
119
|
# Optax's optimisers.
|
|
121
120
|
|
|
122
121
|
# %%
|
|
123
|
-
|
|
124
122
|
optimiser = ox.adam(learning_rate=0.01)
|
|
125
123
|
|
|
126
124
|
opt_posterior, history = gpx.fit(
|
|
@@ -140,8 +138,8 @@ opt_posterior, history = gpx.fit(
|
|
|
140
138
|
map_latent_dist = opt_posterior.predict(xtest, train_data=D)
|
|
141
139
|
predictive_dist = opt_posterior.likelihood(map_latent_dist)
|
|
142
140
|
|
|
143
|
-
predictive_mean = predictive_dist.mean
|
|
144
|
-
predictive_std = predictive_dist.
|
|
141
|
+
predictive_mean = predictive_dist.mean
|
|
142
|
+
predictive_std = jnp.sqrt(predictive_dist.variance)
|
|
145
143
|
|
|
146
144
|
fig, ax = plt.subplots()
|
|
147
145
|
ax.scatter(x, y, label="Observations", color=cols[0])
|
|
@@ -215,8 +213,6 @@ ax.legend()
|
|
|
215
213
|
# datapoints below.
|
|
216
214
|
|
|
217
215
|
# %%
|
|
218
|
-
|
|
219
|
-
|
|
220
216
|
gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
|
|
221
217
|
jitter = 1e-6
|
|
222
218
|
|
|
@@ -246,7 +242,7 @@ L = jnp.linalg.cholesky(H + identity_matrix(D.n) * jitter)
|
|
|
246
242
|
L_inv = jsp.linalg.solve_triangular(L, identity_matrix(D.n), lower=True)
|
|
247
243
|
H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False)
|
|
248
244
|
LH = jnp.linalg.cholesky(H_inv)
|
|
249
|
-
laplace_approximation =
|
|
245
|
+
laplace_approximation = npd.MultivariateNormal(f_hat.squeeze(), scale_tril=LH)
|
|
250
246
|
|
|
251
247
|
|
|
252
248
|
# %% [markdown]
|
|
@@ -265,7 +261,7 @@ laplace_approximation = tfd.MultivariateNormalTriL(f_hat.squeeze(), LH)
|
|
|
265
261
|
|
|
266
262
|
|
|
267
263
|
# %%
|
|
268
|
-
def construct_laplace(test_inputs: Float[Array, "N D"]) ->
|
|
264
|
+
def construct_laplace(test_inputs: Float[Array, "N D"]) -> npd.MultivariateNormal:
|
|
269
265
|
map_latent_dist = opt_posterior.predict(xtest, train_data=D)
|
|
270
266
|
|
|
271
267
|
Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
|
|
@@ -279,10 +275,10 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
|
|
|
279
275
|
# Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
|
|
280
276
|
laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)
|
|
281
277
|
|
|
282
|
-
mean = map_latent_dist.mean
|
|
283
|
-
covariance = map_latent_dist.
|
|
278
|
+
mean = map_latent_dist.mean
|
|
279
|
+
covariance = map_latent_dist.covariance_matrix + laplace_cov_term
|
|
284
280
|
L = jnp.linalg.cholesky(covariance)
|
|
285
|
-
return
|
|
281
|
+
return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), scale_tril=L)
|
|
286
282
|
|
|
287
283
|
|
|
288
284
|
# %% [markdown]
|
|
@@ -291,8 +287,8 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
|
|
|
291
287
|
laplace_latent_dist = construct_laplace(xtest)
|
|
292
288
|
predictive_dist = opt_posterior.likelihood(laplace_latent_dist)
|
|
293
289
|
|
|
294
|
-
predictive_mean = predictive_dist.mean
|
|
295
|
-
predictive_std = predictive_dist.
|
|
290
|
+
predictive_mean = predictive_dist.mean
|
|
291
|
+
predictive_std = jnp.sqrt(predictive_dist.variance)
|
|
296
292
|
|
|
297
293
|
fig, ax = plt.subplots()
|
|
298
294
|
ax.scatter(x, y, label="Observations", color=cols[0])
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
# extension: .py
|
|
8
8
|
# format_name: percent
|
|
9
9
|
# format_version: '1.3'
|
|
10
|
-
# jupytext_version: 1.16.
|
|
10
|
+
# jupytext_version: 1.16.7
|
|
11
11
|
# kernelspec:
|
|
12
12
|
# display_name: gpjax_beartype
|
|
13
13
|
# language: python
|
|
@@ -161,10 +161,10 @@ predictive_dist = opt_posterior.posterior.likelihood(latent_dist)
|
|
|
161
161
|
|
|
162
162
|
inducing_points = opt_posterior.inducing_inputs.value
|
|
163
163
|
|
|
164
|
-
samples = latent_dist.sample(
|
|
164
|
+
samples = latent_dist.sample(key=key, sample_shape=(20,))
|
|
165
165
|
|
|
166
|
-
predictive_mean = predictive_dist.mean
|
|
167
|
-
predictive_std = predictive_dist.
|
|
166
|
+
predictive_mean = predictive_dist.mean
|
|
167
|
+
predictive_std = jnp.sqrt(predictive_dist.variance)
|
|
168
168
|
|
|
169
169
|
fig, ax = plt.subplots()
|
|
170
170
|
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.7
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -24,6 +24,7 @@
|
|
|
24
24
|
# %%
|
|
25
25
|
# Enable Float64 for more stable matrix inversions.
|
|
26
26
|
from jax import config
|
|
27
|
+
from jax.nn import softplus
|
|
27
28
|
import jax.numpy as jnp
|
|
28
29
|
import jax.random as jr
|
|
29
30
|
from jaxtyping import (
|
|
@@ -32,7 +33,9 @@ from jaxtyping import (
|
|
|
32
33
|
install_import_hook,
|
|
33
34
|
)
|
|
34
35
|
import matplotlib.pyplot as plt
|
|
35
|
-
import
|
|
36
|
+
import numpyro.distributions as npd
|
|
37
|
+
from numpyro.distributions import constraints
|
|
38
|
+
import numpyro.distributions.transforms as npt
|
|
36
39
|
|
|
37
40
|
from examples.utils import use_mpl_style
|
|
38
41
|
from gpjax.kernels.computations import DenseKernelComputation
|
|
@@ -225,9 +228,27 @@ def angular_distance(x, y, c):
|
|
|
225
228
|
return jnp.abs((x - y + c) % (c * 2) - c)
|
|
226
229
|
|
|
227
230
|
|
|
228
|
-
|
|
231
|
+
class ShiftedSoftplusTransform(npt.ParameterFreeTransform):
|
|
232
|
+
r"""
|
|
233
|
+
Transform from unconstrained space to the domain [4, infinity) via
|
|
234
|
+
:math:`y = 4 + \log(1 + \exp(x))`. The inverse is computed as
|
|
235
|
+
:math:`x = \log(\exp(y - 4) - 1)`.
|
|
236
|
+
"""
|
|
229
237
|
|
|
230
|
-
|
|
238
|
+
domain = constraints.real
|
|
239
|
+
codomain = constraints.interval(4.0, jnp.inf) # updated codomain
|
|
240
|
+
|
|
241
|
+
def __call__(self, x):
|
|
242
|
+
return 4.0 + softplus(x) # shift the softplus output by 4
|
|
243
|
+
|
|
244
|
+
def _inverse(self, y):
|
|
245
|
+
return npt._softplus_inv(y - 4.0) # subtract the shift in the inverse
|
|
246
|
+
|
|
247
|
+
def log_abs_det_jacobian(self, x, y, intermediates=None):
|
|
248
|
+
return -softplus(-x)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
DEFAULT_BIJECTION["polar"] = ShiftedSoftplusTransform()
|
|
231
252
|
|
|
232
253
|
|
|
233
254
|
class Polar(gpx.kernels.AbstractKernel):
|
|
@@ -307,7 +328,7 @@ opt_posterior, history = gpx.fit_scipy(
|
|
|
307
328
|
|
|
308
329
|
# %%
|
|
309
330
|
posterior_rv = opt_posterior.likelihood(opt_posterior.predict(angles, train_data=D))
|
|
310
|
-
mu = posterior_rv.mean
|
|
331
|
+
mu = posterior_rv.mean
|
|
311
332
|
one_sigma = posterior_rv.stddev()
|
|
312
333
|
|
|
313
334
|
# %%
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.7
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -238,8 +238,8 @@ opt_posterior, history = gpx.fit(
|
|
|
238
238
|
latent_dist = opt_posterior(xtest, train_data=D)
|
|
239
239
|
predictive_dist = opt_posterior.likelihood(latent_dist)
|
|
240
240
|
|
|
241
|
-
predictive_mean = predictive_dist.mean
|
|
242
|
-
predictive_std = predictive_dist.
|
|
241
|
+
predictive_mean = predictive_dist.mean
|
|
242
|
+
predictive_std = jnp.sqrt(predictive_dist.variance)
|
|
243
243
|
|
|
244
244
|
fig, ax = plt.subplots()
|
|
245
245
|
ax.plot(x, y, "o", label="Observations", color=cols[0])
|
|
@@ -8,7 +8,7 @@
|
|
|
8
8
|
# extension: .py
|
|
9
9
|
# format_name: percent
|
|
10
10
|
# format_version: '1.3'
|
|
11
|
-
# jupytext_version: 1.16.
|
|
11
|
+
# jupytext_version: 1.16.7
|
|
12
12
|
# kernelspec:
|
|
13
13
|
# display_name: gpjax
|
|
14
14
|
# language: python
|
|
@@ -124,7 +124,7 @@ true_kernel = gpx.kernels.GraphKernel(
|
|
|
124
124
|
prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=true_kernel)
|
|
125
125
|
|
|
126
126
|
fx = prior(x)
|
|
127
|
-
y = fx.sample(
|
|
127
|
+
y = fx.sample(key=key, sample_shape=(1,)).reshape(-1, 1)
|
|
128
128
|
|
|
129
129
|
D = gpx.Dataset(X=x, y=y)
|
|
130
130
|
|
|
@@ -194,8 +194,8 @@ opt_posterior, training_history = gpx.fit_scipy(
|
|
|
194
194
|
initial_dist = likelihood(posterior(x, D))
|
|
195
195
|
predictive_dist = opt_posterior.likelihood(opt_posterior(x, D))
|
|
196
196
|
|
|
197
|
-
initial_mean = initial_dist.mean
|
|
198
|
-
learned_mean = predictive_dist.mean
|
|
197
|
+
initial_mean = initial_dist.mean
|
|
198
|
+
learned_mean = predictive_dist.mean
|
|
199
199
|
|
|
200
200
|
rmse = lambda ytrue, ypred: jnp.sum(jnp.sqrt(jnp.square(ytrue - ypred)))
|
|
201
201
|
|
|
@@ -7,7 +7,7 @@
|
|
|
7
7
|
# extension: .py
|
|
8
8
|
# format_name: percent
|
|
9
9
|
# format_version: '1.3'
|
|
10
|
-
# jupytext_version: 1.16.
|
|
10
|
+
# jupytext_version: 1.16.7
|
|
11
11
|
# kernelspec:
|
|
12
12
|
# display_name: gpjax
|
|
13
13
|
# language: python
|
|
@@ -127,9 +127,9 @@ import jax.numpy as jnp
|
|
|
127
127
|
import jax.random as jr
|
|
128
128
|
import matplotlib as mpl
|
|
129
129
|
import matplotlib.pyplot as plt
|
|
130
|
+
import numpyro.distributions as npd
|
|
130
131
|
import pandas as pd
|
|
131
132
|
import seaborn as sns
|
|
132
|
-
import tensorflow_probability.substrates.jax as tfp
|
|
133
133
|
|
|
134
134
|
from examples.utils import (
|
|
135
135
|
confidence_ellipse,
|
|
@@ -143,11 +143,10 @@ key = jr.key(42)
|
|
|
143
143
|
|
|
144
144
|
|
|
145
145
|
cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
|
|
146
|
-
tfd = tfp.distributions
|
|
147
146
|
|
|
148
|
-
ud1 =
|
|
149
|
-
ud2 =
|
|
150
|
-
ud3 =
|
|
147
|
+
ud1 = npd.Normal(0.0, 1.0)
|
|
148
|
+
ud2 = npd.Normal(-1.0, 0.5)
|
|
149
|
+
ud3 = npd.Normal(0.25, 1.5)
|
|
151
150
|
|
|
152
151
|
xs = jnp.linspace(-5.0, 5.0, 500)
|
|
153
152
|
|
|
@@ -156,7 +155,7 @@ for d in [ud1, ud2, ud3]:
|
|
|
156
155
|
ax.plot(
|
|
157
156
|
xs,
|
|
158
157
|
jnp.exp(d.log_prob(xs)),
|
|
159
|
-
label=f"$\\mathcal{{N}}({{{float(d.mean
|
|
158
|
+
label=f"$\\mathcal{{N}}({{{float(d.mean)}}},\\ {{{float(jnp.sqrt(d.variance))}}}^2)$",
|
|
160
159
|
)
|
|
161
160
|
ax.fill_between(xs, jnp.zeros_like(xs), jnp.exp(d.log_prob(xs)), alpha=0.2)
|
|
162
161
|
ax.legend(loc="best")
|
|
@@ -190,12 +189,12 @@ ax.legend(loc="best")
|
|
|
190
189
|
# %%
|
|
191
190
|
key = jr.key(123)
|
|
192
191
|
|
|
193
|
-
d1 =
|
|
194
|
-
d2 =
|
|
195
|
-
jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, 0.9], [0.9, 1.0]]))
|
|
192
|
+
d1 = npd.MultivariateNormal(loc=jnp.zeros(2), covariance_matrix=jnp.diag(jnp.ones(2)))
|
|
193
|
+
d2 = npd.MultivariateNormal(
|
|
194
|
+
jnp.zeros(2), scale_tril=jnp.linalg.cholesky(jnp.array([[1.0, 0.9], [0.9, 1.0]]))
|
|
196
195
|
)
|
|
197
|
-
d3 =
|
|
198
|
-
jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, -0.5], [-0.5, 1.0]]))
|
|
196
|
+
d3 = npd.MultivariateNormal(
|
|
197
|
+
jnp.zeros(2), scale_tril=jnp.linalg.cholesky(jnp.array([[1.0, -0.5], [-0.5, 1.0]]))
|
|
199
198
|
)
|
|
200
199
|
|
|
201
200
|
dists = [d1, d2, d3]
|
|
@@ -215,13 +214,21 @@ titles = [r"$\rho = 0$", r"$\rho = 0.9$", r"$\rho = -0.5$"]
|
|
|
215
214
|
cmap = mpl.colors.LinearSegmentedColormap.from_list("custom", ["white", cols[1]], N=256)
|
|
216
215
|
|
|
217
216
|
for a, t, d in zip([ax0, ax1, ax2], titles, dists, strict=False):
|
|
218
|
-
d_prob =
|
|
219
|
-
xx.
|
|
217
|
+
d_prob = jnp.exp(
|
|
218
|
+
d.log_prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]))
|
|
219
|
+
).reshape(xx.shape)
|
|
220
|
+
cntf = a.contourf(
|
|
221
|
+
xx,
|
|
222
|
+
yy,
|
|
223
|
+
jnp.exp(d_prob),
|
|
224
|
+
levels=20,
|
|
225
|
+
antialiased=True,
|
|
226
|
+
cmap=cmap,
|
|
227
|
+
edgecolor="face",
|
|
220
228
|
)
|
|
221
|
-
cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap, edgecolor="face")
|
|
222
229
|
a.set_xlim(-2.75, 2.75)
|
|
223
230
|
a.set_ylim(-2.75, 2.75)
|
|
224
|
-
samples = d.sample(
|
|
231
|
+
samples = d.sample(key=key, sample_shape=(5000,))
|
|
225
232
|
xsample, ysample = samples[:, 0], samples[:, 1]
|
|
226
233
|
confidence_ellipse(
|
|
227
234
|
xsample, ysample, a, edgecolor="#3f3f3f", n_std=1.0, linestyle="--", alpha=0.8
|
|
@@ -274,13 +281,13 @@ for a, t, d in zip([ax0, ax1, ax2], titles, dists, strict=False):
|
|
|
274
281
|
|
|
275
282
|
# %%
|
|
276
283
|
n = 1000
|
|
277
|
-
x =
|
|
284
|
+
x = npd.Normal(loc=0.0, scale=1.0).sample(key, sample_shape=(n,))
|
|
278
285
|
key, subkey = jr.split(key)
|
|
279
|
-
y =
|
|
286
|
+
y = npd.Normal(loc=0.25, scale=0.5).sample(subkey, sample_shape=(n,))
|
|
280
287
|
key, subkey = jr.split(subkey)
|
|
281
|
-
xfull =
|
|
288
|
+
xfull = npd.Normal(loc=0.0, scale=1.0).sample(subkey, sample_shape=(n * 10,))
|
|
282
289
|
key, subkey = jr.split(subkey)
|
|
283
|
-
yfull =
|
|
290
|
+
yfull = npd.Normal(loc=0.25, scale=0.5).sample(subkey, sample_shape=(n * 10,))
|
|
284
291
|
key, subkey = jr.split(subkey)
|
|
285
292
|
df = pd.DataFrame({"x": x, "y": y, "idx": jnp.ones(n)})
|
|
286
293
|
|