gpjax 0.10.2__tar.gz → 0.11.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (153) hide show
  1. {gpjax-0.10.2 → gpjax-0.11.0}/PKG-INFO +2 -61
  2. {gpjax-0.10.2 → gpjax-0.11.0}/README.md +0 -59
  3. {gpjax-0.10.2 → gpjax-0.11.0}/docs/scripts/sharp_bits_figure.py +3 -3
  4. {gpjax-0.10.2 → gpjax-0.11.0}/docs/sharp_bits.md +1 -1
  5. {gpjax-0.10.2 → gpjax-0.11.0}/examples/barycentres.py +10 -10
  6. {gpjax-0.10.2 → gpjax-0.11.0}/examples/classification.py +11 -15
  7. {gpjax-0.10.2 → gpjax-0.11.0}/examples/collapsed_vi.py +4 -4
  8. {gpjax-0.10.2 → gpjax-0.11.0}/examples/constructing_new_kernels.py +26 -5
  9. {gpjax-0.10.2 → gpjax-0.11.0}/examples/deep_kernels.py +3 -3
  10. {gpjax-0.10.2 → gpjax-0.11.0}/examples/graph_kernels.py +4 -4
  11. {gpjax-0.10.2 → gpjax-0.11.0}/examples/intro_to_gps.py +27 -20
  12. {gpjax-0.10.2 → gpjax-0.11.0}/examples/intro_to_kernels.py +12 -17
  13. {gpjax-0.10.2 → gpjax-0.11.0}/examples/likelihoods_guide.py +6 -9
  14. {gpjax-0.10.2 → gpjax-0.11.0}/examples/oceanmodelling.py +186 -98
  15. {gpjax-0.10.2 → gpjax-0.11.0}/examples/poisson.py +2 -4
  16. {gpjax-0.10.2 → gpjax-0.11.0}/examples/regression.py +6 -6
  17. {gpjax-0.10.2 → gpjax-0.11.0}/examples/uncollapsed_vi.py +4 -70
  18. {gpjax-0.10.2 → gpjax-0.11.0}/examples/yacht.py +3 -3
  19. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/__init__.py +1 -1
  20. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/distributions.py +101 -111
  21. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/fit.py +2 -2
  22. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/approximations/rff.py +1 -1
  23. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/base.py +2 -2
  24. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/matern12.py +2 -2
  25. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/matern32.py +2 -2
  26. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/matern52.py +2 -2
  27. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/rbf.py +3 -3
  28. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/utils.py +3 -5
  29. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/likelihoods.py +36 -35
  30. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/mean_functions.py +3 -2
  31. gpjax-0.11.0/gpjax/numpyro_extras.py +106 -0
  32. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/objectives.py +4 -6
  33. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/parameters.py +15 -13
  34. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/variational_families.py +5 -1
  35. {gpjax-0.10.2 → gpjax-0.11.0}/pyproject.toml +5 -2
  36. {gpjax-0.10.2 → gpjax-0.11.0}/tests/integration_tests.py +24 -0
  37. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_dataset.py +30 -30
  38. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_gaussian_distribution.py +24 -43
  39. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_gps.py +9 -9
  40. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_likelihoods.py +13 -15
  41. gpjax-0.11.0/tests/test_numpyro_extras.py +127 -0
  42. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_parameters.py +1 -5
  43. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_variational_families.py +8 -8
  44. gpjax-0.10.2/examples/oak_example.py +0 -216
  45. {gpjax-0.10.2 → gpjax-0.11.0}/.cursorrules +0 -0
  46. {gpjax-0.10.2 → gpjax-0.11.0}/.github/CODE_OF_CONDUCT.md +0 -0
  47. {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/01_BUG_REPORT.md +0 -0
  48. {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/02_FEATURE_REQUEST.md +0 -0
  49. {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/03_CODEBASE_IMPROVEMENT.md +0 -0
  50. {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/04_DOCS_IMPROVEMENT.md +0 -0
  51. {gpjax-0.10.2 → gpjax-0.11.0}/.github/ISSUE_TEMPLATE/config.yml +0 -0
  52. {gpjax-0.10.2 → gpjax-0.11.0}/.github/codecov.yml +0 -0
  53. {gpjax-0.10.2 → gpjax-0.11.0}/.github/labels.yml +0 -0
  54. {gpjax-0.10.2 → gpjax-0.11.0}/.github/pull_request_template.md +0 -0
  55. {gpjax-0.10.2 → gpjax-0.11.0}/.github/release-drafter.yml +0 -0
  56. {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/build_docs.yml +0 -0
  57. {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/integration.yml +0 -0
  58. {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/pr_greeting.yml +0 -0
  59. {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/ruff.yml +0 -0
  60. {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/stale_prs.yml +0 -0
  61. {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/test_docs.yml +0 -0
  62. {gpjax-0.10.2 → gpjax-0.11.0}/.github/workflows/tests.yml +0 -0
  63. {gpjax-0.10.2 → gpjax-0.11.0}/.gitignore +0 -0
  64. {gpjax-0.10.2 → gpjax-0.11.0}/CITATION.bib +0 -0
  65. {gpjax-0.10.2 → gpjax-0.11.0}/LICENSE.txt +0 -0
  66. {gpjax-0.10.2 → gpjax-0.11.0}/Makefile +0 -0
  67. {gpjax-0.10.2 → gpjax-0.11.0}/docs/CODE_OF_CONDUCT.md +0 -0
  68. {gpjax-0.10.2 → gpjax-0.11.0}/docs/GOVERNANCE.md +0 -0
  69. {gpjax-0.10.2 → gpjax-0.11.0}/docs/contributing.md +0 -0
  70. {gpjax-0.10.2 → gpjax-0.11.0}/docs/design.md +0 -0
  71. {gpjax-0.10.2 → gpjax-0.11.0}/docs/index.md +0 -0
  72. {gpjax-0.10.2 → gpjax-0.11.0}/docs/index.rst +0 -0
  73. {gpjax-0.10.2 → gpjax-0.11.0}/docs/installation.md +0 -0
  74. {gpjax-0.10.2 → gpjax-0.11.0}/docs/javascripts/katex.js +0 -0
  75. {gpjax-0.10.2 → gpjax-0.11.0}/docs/refs.bib +0 -0
  76. {gpjax-0.10.2 → gpjax-0.11.0}/docs/scripts/gen_examples.py +0 -0
  77. {gpjax-0.10.2 → gpjax-0.11.0}/docs/scripts/gen_pages.py +0 -0
  78. {gpjax-0.10.2 → gpjax-0.11.0}/docs/scripts/notebook_converter.py +0 -0
  79. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/GP.pdf +0 -0
  80. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/GP.svg +0 -0
  81. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/bijector_figure.svg +0 -0
  82. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/css/gpjax_theme.css +0 -0
  83. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/favicon.ico +0 -0
  84. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/gpjax.mplstyle +0 -0
  85. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/gpjax_logo.pdf +0 -0
  86. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/gpjax_logo.svg +0 -0
  87. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/jaxkern/lato.ttf +0 -0
  88. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/jaxkern/logo.png +0 -0
  89. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/jaxkern/logo.svg +0 -0
  90. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/jaxkern/main.py +0 -0
  91. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/step_size_figure.png +0 -0
  92. {gpjax-0.10.2 → gpjax-0.11.0}/docs/static/step_size_figure.svg +0 -0
  93. {gpjax-0.10.2 → gpjax-0.11.0}/docs/stylesheets/extra.css +0 -0
  94. {gpjax-0.10.2 → gpjax-0.11.0}/docs/stylesheets/permalinks.css +0 -0
  95. {gpjax-0.10.2 → gpjax-0.11.0}/examples/backend.py +0 -0
  96. {gpjax-0.10.2 → gpjax-0.11.0}/examples/barycentres/barycentre_gp.gif +0 -0
  97. {gpjax-0.10.2 → gpjax-0.11.0}/examples/data/max_tempeature_switzerland.csv +0 -0
  98. {gpjax-0.10.2 → gpjax-0.11.0}/examples/data/yacht_hydrodynamics.data +0 -0
  99. {gpjax-0.10.2 → gpjax-0.11.0}/examples/gpjax.mplstyle +0 -0
  100. {gpjax-0.10.2 → gpjax-0.11.0}/examples/intro_to_gps/decomposed_mll.png +0 -0
  101. {gpjax-0.10.2 → gpjax-0.11.0}/examples/intro_to_gps/generating_process.png +0 -0
  102. {gpjax-0.10.2 → gpjax-0.11.0}/examples/utils.py +0 -0
  103. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/citation.py +0 -0
  104. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/dataset.py +0 -0
  105. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/gps.py +0 -0
  106. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/integrators.py +0 -0
  107. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/__init__.py +0 -0
  108. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/approximations/__init__.py +0 -0
  109. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/base.py +0 -0
  110. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/__init__.py +0 -0
  111. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/base.py +0 -0
  112. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/basis_functions.py +0 -0
  113. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/constant_diagonal.py +0 -0
  114. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/dense.py +0 -0
  115. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/diagonal.py +0 -0
  116. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/computations/eigen.py +0 -0
  117. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  118. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/non_euclidean/graph.py +0 -0
  119. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/non_euclidean/utils.py +0 -0
  120. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/nonstationary/__init__.py +0 -0
  121. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/nonstationary/arccosine.py +0 -0
  122. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/nonstationary/linear.py +0 -0
  123. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/nonstationary/polynomial.py +0 -0
  124. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/__init__.py +0 -0
  125. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/periodic.py +0 -0
  126. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/powered_exponential.py +0 -0
  127. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/rational_quadratic.py +0 -0
  128. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/kernels/stationary/white.py +0 -0
  129. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/lower_cholesky.py +0 -0
  130. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/scan.py +0 -0
  131. {gpjax-0.10.2 → gpjax-0.11.0}/gpjax/typing.py +0 -0
  132. {gpjax-0.10.2 → gpjax-0.11.0}/mkdocs.yml +0 -0
  133. {gpjax-0.10.2 → gpjax-0.11.0}/static/CONTRIBUTING.md +0 -0
  134. {gpjax-0.10.2 → gpjax-0.11.0}/static/paper.bib +0 -0
  135. {gpjax-0.10.2 → gpjax-0.11.0}/static/paper.md +0 -0
  136. {gpjax-0.10.2 → gpjax-0.11.0}/static/paper.pdf +0 -0
  137. {gpjax-0.10.2 → gpjax-0.11.0}/tests/__init__.py +0 -0
  138. {gpjax-0.10.2 → gpjax-0.11.0}/tests/conftest.py +0 -0
  139. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_citations.py +0 -0
  140. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_fit.py +0 -0
  141. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_integrators.py +0 -0
  142. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/__init__.py +0 -0
  143. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_approximations.py +0 -0
  144. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_base.py +0 -0
  145. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_computation.py +0 -0
  146. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_non_euclidean.py +0 -0
  147. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_nonstationary.py +0 -0
  148. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_stationary.py +0 -0
  149. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_kernels/test_utils.py +0 -0
  150. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_lower_cholesky.py +0 -0
  151. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_markdown.py +0 -0
  152. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_mean_functions.py +0 -0
  153. {gpjax-0.10.2 → gpjax-0.11.0}/tests/test_objectives.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: gpjax
3
- Version: 0.10.2
3
+ Version: 0.11.0
4
4
  Summary: Gaussian processes in JAX.
5
5
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
6
6
  Project-URL: Issues, https://github.com/JaxGaussianProcesses/GPJax/issues
@@ -24,8 +24,8 @@ Requires-Dist: jax>=0.5.0
24
24
  Requires-Dist: jaxlib>=0.5.0
25
25
  Requires-Dist: jaxtyping>0.2.10
26
26
  Requires-Dist: numpy>=2.0.0
27
+ Requires-Dist: numpyro
27
28
  Requires-Dist: optax>0.2.1
28
- Requires-Dist: tensorflow-probability>=0.24.0
29
29
  Requires-Dist: tqdm>4.66.2
30
30
  Description-Content-Type: text/markdown
31
31
 
@@ -138,65 +138,6 @@ jupytext --to notebook example.py
138
138
  jupytext --to py:percent example.ipynb
139
139
  ```
140
140
 
141
- # Simple example
142
-
143
- Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.
144
-
145
- ```python
146
- from jax import config
147
-
148
- config.update("jax_enable_x64", True)
149
-
150
- import gpjax as gpx
151
- from jax import grad, jit
152
- import jax.numpy as jnp
153
- import jax.random as jr
154
- import optax as ox
155
-
156
- key = jr.key(123)
157
-
158
- f = lambda x: 10 * jnp.sin(x)
159
-
160
- n = 50
161
- x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
162
- y = f(x) + jr.normal(key, shape=(n,1))
163
- D = gpx.Dataset(X=x, y=y)
164
-
165
- # Construct the prior
166
- meanf = gpx.mean_functions.Zero()
167
- kernel = gpx.kernels.RBF()
168
- prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
169
-
170
- # Define a likelihood
171
- likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
172
-
173
- # Construct the posterior
174
- posterior = prior * likelihood
175
-
176
- # Define an optimiser
177
- optimiser = ox.adam(learning_rate=1e-2)
178
-
179
- # Obtain Type 2 MLEs of the hyperparameters
180
- opt_posterior, history = gpx.fit(
181
- model=posterior,
182
- objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
183
- train_data=D,
184
- optim=optimiser,
185
- num_iters=500,
186
- safe=True,
187
- key=key,
188
- )
189
-
190
- # Infer the predictive posterior distribution
191
- xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
192
- latent_dist = opt_posterior(xtest, D)
193
- predictive_dist = opt_posterior.likelihood(latent_dist)
194
-
195
- # Obtain the predictive mean and standard deviation
196
- pred_mean = predictive_dist.mean()
197
- pred_std = predictive_dist.stddev()
198
- ```
199
-
200
141
  # Installation
201
142
 
202
143
  ## Stable version
@@ -107,65 +107,6 @@ jupytext --to notebook example.py
107
107
  jupytext --to py:percent example.ipynb
108
108
  ```
109
109
 
110
- # Simple example
111
-
112
- Let us import some dependencies and simulate a toy dataset $\mathcal{D}$.
113
-
114
- ```python
115
- from jax import config
116
-
117
- config.update("jax_enable_x64", True)
118
-
119
- import gpjax as gpx
120
- from jax import grad, jit
121
- import jax.numpy as jnp
122
- import jax.random as jr
123
- import optax as ox
124
-
125
- key = jr.key(123)
126
-
127
- f = lambda x: 10 * jnp.sin(x)
128
-
129
- n = 50
130
- x = jr.uniform(key=key, minval=-3.0, maxval=3.0, shape=(n,1)).sort()
131
- y = f(x) + jr.normal(key, shape=(n,1))
132
- D = gpx.Dataset(X=x, y=y)
133
-
134
- # Construct the prior
135
- meanf = gpx.mean_functions.Zero()
136
- kernel = gpx.kernels.RBF()
137
- prior = gpx.gps.Prior(mean_function=meanf, kernel = kernel)
138
-
139
- # Define a likelihood
140
- likelihood = gpx.likelihoods.Gaussian(num_datapoints = n)
141
-
142
- # Construct the posterior
143
- posterior = prior * likelihood
144
-
145
- # Define an optimiser
146
- optimiser = ox.adam(learning_rate=1e-2)
147
-
148
- # Obtain Type 2 MLEs of the hyperparameters
149
- opt_posterior, history = gpx.fit(
150
- model=posterior,
151
- objective=lambda p, d: -gpx.objectives.conjugate_mll(p, d),
152
- train_data=D,
153
- optim=optimiser,
154
- num_iters=500,
155
- safe=True,
156
- key=key,
157
- )
158
-
159
- # Infer the predictive posterior distribution
160
- xtest = jnp.linspace(-3., 3., 100).reshape(-1, 1)
161
- latent_dist = opt_posterior(xtest, D)
162
- predictive_dist = opt_posterior.likelihood(latent_dist)
163
-
164
- # Obtain the predictive mean and standard deviation
165
- pred_mean = predictive_dist.mean()
166
- pred_std = predictive_dist.stddev()
167
- ```
168
-
169
110
  # Installation
170
111
 
171
112
  ## Stable version
@@ -69,12 +69,12 @@ ax.set_xlim(-0.07, 0.25)
69
69
  plt.savefig("../_static/step_size_figure.png", bbox_inches="tight")
70
70
 
71
71
  # %%
72
- import tensorflow_probability.substrates.jax.bijectors as tfb
72
+ import numpyro.distributions.transforms as npt
73
73
 
74
- bij = tfb.Exp()
74
+ bij = npt.ExpTransform()
75
75
 
76
76
  x = np.linspace(0.05, 3.0, 6)
77
- y = np.asarray(bij.inverse(x))
77
+ y = np.asarray(bij.inv(x))
78
78
  lval = 0.5
79
79
  rval = 0.52
80
80
 
@@ -80,7 +80,7 @@ this value that we apply gradient updates to. When we wish to recover the constr
80
80
  value, we apply the inverse of the bijector, which is the exponential function in this
81
81
  case. This gives us back the blue cross.
82
82
 
83
- In GPJax, we supply bijective functions using [Tensorflow Probability](https://www.tensorflow.org/probability/api_docs/python/tfp/substrates/jax/bijectors).
83
+ In GPJax, we supply bijective functions using [Numpyro](https://num.pyro.ai/en/stable/distributions.html#transforms).
84
84
 
85
85
 
86
86
  ## Positive-definiteness
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.6
11
+ # jupytext_version: 1.16.7
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -41,7 +41,7 @@ import jax.random as jr
41
41
  import jax.scipy.linalg as jsl
42
42
  from jaxtyping import install_import_hook
43
43
  import matplotlib.pyplot as plt
44
- import tensorflow_probability.substrates.jax.distributions as tfd
44
+ import numpyro.distributions as npd
45
45
 
46
46
  from examples.utils import use_mpl_style
47
47
 
@@ -161,7 +161,7 @@ plt.show()
161
161
 
162
162
 
163
163
  # %%
164
- def fit_gp(x: jax.Array, y: jax.Array) -> tfd.MultivariateNormalFullCovariance:
164
+ def fit_gp(x: jax.Array, y: jax.Array) -> npd.MultivariateNormal:
165
165
  if y.ndim == 1:
166
166
  y = y.reshape(-1, 1)
167
167
  D = gpx.Dataset(X=x, y=y)
@@ -204,9 +204,9 @@ def sqrtm(A: jax.Array):
204
204
 
205
205
 
206
206
  def wasserstein_barycentres(
207
- distributions: tp.List[tfd.MultivariateNormalFullCovariance], weights: jax.Array
207
+ distributions: tp.List[npd.MultivariateNormal], weights: jax.Array
208
208
  ):
209
- covariances = [d.covariance() for d in distributions]
209
+ covariances = [d.covariance_matrix for d in distributions]
210
210
  cov_stack = jnp.stack(covariances)
211
211
  stack_sqrt = jax.vmap(sqrtm)(cov_stack)
212
212
 
@@ -231,7 +231,7 @@ def wasserstein_barycentres(
231
231
  # %%
232
232
  weights = jnp.ones((n_datasets,)) / n_datasets
233
233
 
234
- means = jnp.stack([d.mean() for d in posterior_preds])
234
+ means = jnp.stack([d.mean for d in posterior_preds])
235
235
  barycentre_mean = jnp.tensordot(weights, means, axes=1)
236
236
 
237
237
  step_fn = jax.jit(wasserstein_barycentres(posterior_preds, weights))
@@ -242,7 +242,7 @@ barycentre_covariance, sequence = jax.lax.scan(
242
242
  )
243
243
  L = jnp.linalg.cholesky(barycentre_covariance)
244
244
 
245
- barycentre_process = tfd.MultivariateNormalTriL(barycentre_mean, L)
245
+ barycentre_process = npd.MultivariateNormal(barycentre_mean, scale_tril=L)
246
246
 
247
247
  # %% [markdown]
248
248
  # ## Plotting the result
@@ -254,7 +254,7 @@ barycentre_process = tfd.MultivariateNormalTriL(barycentre_mean, L)
254
254
 
255
255
  # %%
256
256
  def plot(
257
- dist: tfd.MultivariateNormalTriL,
257
+ dist: npd.MultivariateNormal,
258
258
  ax,
259
259
  color: str,
260
260
  label: str = None,
@@ -262,8 +262,8 @@ def plot(
262
262
  linewidth: float = 1.0,
263
263
  zorder: int = 0,
264
264
  ):
265
- mu = dist.mean()
266
- sigma = dist.stddev()
265
+ mu = dist.mean
266
+ sigma = jnp.sqrt(dist.variance)
267
267
  ax.plot(xtest, mu, linewidth=linewidth, color=color, label=label, zorder=zorder)
268
268
  ax.fill_between(
269
269
  xtest.squeeze(),
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.6
11
+ # jupytext_version: 1.16.7
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -37,8 +37,8 @@ from jaxtyping import (
37
37
  install_import_hook,
38
38
  )
39
39
  import matplotlib.pyplot as plt
40
+ import numpyro.distributions as npd
40
41
  import optax as ox
41
- import tensorflow_probability.substrates.jax as tfp
42
42
 
43
43
  from examples.utils import use_mpl_style
44
44
  from gpjax.lower_cholesky import lower_cholesky
@@ -50,7 +50,6 @@ with install_import_hook("gpjax", "beartype.beartype"):
50
50
  import gpjax as gpx
51
51
 
52
52
 
53
- tfd = tfp.distributions
54
53
  identity_matrix = jnp.eye
55
54
 
56
55
  # set the default style for plotting
@@ -120,7 +119,6 @@ print(type(posterior))
120
119
  # Optax's optimisers.
121
120
 
122
121
  # %%
123
-
124
122
  optimiser = ox.adam(learning_rate=0.01)
125
123
 
126
124
  opt_posterior, history = gpx.fit(
@@ -140,8 +138,8 @@ opt_posterior, history = gpx.fit(
140
138
  map_latent_dist = opt_posterior.predict(xtest, train_data=D)
141
139
  predictive_dist = opt_posterior.likelihood(map_latent_dist)
142
140
 
143
- predictive_mean = predictive_dist.mean()
144
- predictive_std = predictive_dist.stddev()
141
+ predictive_mean = predictive_dist.mean
142
+ predictive_std = jnp.sqrt(predictive_dist.variance)
145
143
 
146
144
  fig, ax = plt.subplots()
147
145
  ax.scatter(x, y, label="Observations", color=cols[0])
@@ -215,8 +213,6 @@ ax.legend()
215
213
  # datapoints below.
216
214
 
217
215
  # %%
218
-
219
-
220
216
  gram, cross_covariance = (kernel.gram, kernel.cross_covariance)
221
217
  jitter = 1e-6
222
218
 
@@ -246,7 +242,7 @@ L = jnp.linalg.cholesky(H + identity_matrix(D.n) * jitter)
246
242
  L_inv = jsp.linalg.solve_triangular(L, identity_matrix(D.n), lower=True)
247
243
  H_inv = jsp.linalg.solve_triangular(L.T, L_inv, lower=False)
248
244
  LH = jnp.linalg.cholesky(H_inv)
249
- laplace_approximation = tfd.MultivariateNormalTriL(f_hat.squeeze(), LH)
245
+ laplace_approximation = npd.MultivariateNormal(f_hat.squeeze(), scale_tril=LH)
250
246
 
251
247
 
252
248
  # %% [markdown]
@@ -265,7 +261,7 @@ laplace_approximation = tfd.MultivariateNormalTriL(f_hat.squeeze(), LH)
265
261
 
266
262
 
267
263
  # %%
268
- def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNormalTriL:
264
+ def construct_laplace(test_inputs: Float[Array, "N D"]) -> npd.MultivariateNormal:
269
265
  map_latent_dist = opt_posterior.predict(xtest, train_data=D)
270
266
 
271
267
  Kxt = opt_posterior.prior.kernel.cross_covariance(x, test_inputs)
@@ -279,10 +275,10 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
279
275
  # Ktx Kxx⁻¹[ H⁻¹ ] Kxx⁻¹ Kxt
280
276
  laplace_cov_term = jnp.matmul(jnp.matmul(Kxx_inv_Kxt.T, H_inv), Kxx_inv_Kxt)
281
277
 
282
- mean = map_latent_dist.mean()
283
- covariance = map_latent_dist.covariance() + laplace_cov_term
278
+ mean = map_latent_dist.mean
279
+ covariance = map_latent_dist.covariance_matrix + laplace_cov_term
284
280
  L = jnp.linalg.cholesky(covariance)
285
- return tfd.MultivariateNormalTriL(jnp.atleast_1d(mean.squeeze()), L)
281
+ return npd.MultivariateNormal(jnp.atleast_1d(mean.squeeze()), scale_tril=L)
286
282
 
287
283
 
288
284
  # %% [markdown]
@@ -291,8 +287,8 @@ def construct_laplace(test_inputs: Float[Array, "N D"]) -> tfd.MultivariateNorma
291
287
  laplace_latent_dist = construct_laplace(xtest)
292
288
  predictive_dist = opt_posterior.likelihood(laplace_latent_dist)
293
289
 
294
- predictive_mean = predictive_dist.mean()
295
- predictive_std = predictive_dist.stddev()
290
+ predictive_mean = predictive_dist.mean
291
+ predictive_std = jnp.sqrt(predictive_dist.variance)
296
292
 
297
293
  fig, ax = plt.subplots()
298
294
  ax.scatter(x, y, label="Observations", color=cols[0])
@@ -7,7 +7,7 @@
7
7
  # extension: .py
8
8
  # format_name: percent
9
9
  # format_version: '1.3'
10
- # jupytext_version: 1.16.6
10
+ # jupytext_version: 1.16.7
11
11
  # kernelspec:
12
12
  # display_name: gpjax_beartype
13
13
  # language: python
@@ -161,10 +161,10 @@ predictive_dist = opt_posterior.posterior.likelihood(latent_dist)
161
161
 
162
162
  inducing_points = opt_posterior.inducing_inputs.value
163
163
 
164
- samples = latent_dist.sample(seed=key, sample_shape=(20,))
164
+ samples = latent_dist.sample(key=key, sample_shape=(20,))
165
165
 
166
- predictive_mean = predictive_dist.mean()
167
- predictive_std = predictive_dist.stddev()
166
+ predictive_mean = predictive_dist.mean
167
+ predictive_std = jnp.sqrt(predictive_dist.variance)
168
168
 
169
169
  fig, ax = plt.subplots()
170
170
 
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.6
11
+ # jupytext_version: 1.16.7
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -24,6 +24,7 @@
24
24
  # %%
25
25
  # Enable Float64 for more stable matrix inversions.
26
26
  from jax import config
27
+ from jax.nn import softplus
27
28
  import jax.numpy as jnp
28
29
  import jax.random as jr
29
30
  from jaxtyping import (
@@ -32,7 +33,9 @@ from jaxtyping import (
32
33
  install_import_hook,
33
34
  )
34
35
  import matplotlib.pyplot as plt
35
- import tensorflow_probability.substrates.jax as tfp
36
+ import numpyro.distributions as npd
37
+ from numpyro.distributions import constraints
38
+ import numpyro.distributions.transforms as npt
36
39
 
37
40
  from examples.utils import use_mpl_style
38
41
  from gpjax.kernels.computations import DenseKernelComputation
@@ -225,9 +228,27 @@ def angular_distance(x, y, c):
225
228
  return jnp.abs((x - y + c) % (c * 2) - c)
226
229
 
227
230
 
228
- bij = tfb.SoftClip(low=jnp.array(4.0, dtype=jnp.float64))
231
+ class ShiftedSoftplusTransform(npt.ParameterFreeTransform):
232
+ r"""
233
+ Transform from unconstrained space to the domain [4, infinity) via
234
+ :math:`y = 4 + \log(1 + \exp(x))`. The inverse is computed as
235
+ :math:`x = \log(\exp(y - 4) - 1)`.
236
+ """
229
237
 
230
- DEFAULT_BIJECTION["polar"] = bij
238
+ domain = constraints.real
239
+ codomain = constraints.interval(4.0, jnp.inf) # updated codomain
240
+
241
+ def __call__(self, x):
242
+ return 4.0 + softplus(x) # shift the softplus output by 4
243
+
244
+ def _inverse(self, y):
245
+ return npt._softplus_inv(y - 4.0) # subtract the shift in the inverse
246
+
247
+ def log_abs_det_jacobian(self, x, y, intermediates=None):
248
+ return -softplus(-x)
249
+
250
+
251
+ DEFAULT_BIJECTION["polar"] = ShiftedSoftplusTransform()
231
252
 
232
253
 
233
254
  class Polar(gpx.kernels.AbstractKernel):
@@ -307,7 +328,7 @@ opt_posterior, history = gpx.fit_scipy(
307
328
 
308
329
  # %%
309
330
  posterior_rv = opt_posterior.likelihood(opt_posterior.predict(angles, train_data=D))
310
- mu = posterior_rv.mean()
331
+ mu = posterior_rv.mean
311
332
  one_sigma = posterior_rv.stddev()
312
333
 
313
334
  # %%
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.6
11
+ # jupytext_version: 1.16.7
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -238,8 +238,8 @@ opt_posterior, history = gpx.fit(
238
238
  latent_dist = opt_posterior(xtest, train_data=D)
239
239
  predictive_dist = opt_posterior.likelihood(latent_dist)
240
240
 
241
- predictive_mean = predictive_dist.mean()
242
- predictive_std = predictive_dist.stddev()
241
+ predictive_mean = predictive_dist.mean
242
+ predictive_std = jnp.sqrt(predictive_dist.variance)
243
243
 
244
244
  fig, ax = plt.subplots()
245
245
  ax.plot(x, y, "o", label="Observations", color=cols[0])
@@ -8,7 +8,7 @@
8
8
  # extension: .py
9
9
  # format_name: percent
10
10
  # format_version: '1.3'
11
- # jupytext_version: 1.16.6
11
+ # jupytext_version: 1.16.7
12
12
  # kernelspec:
13
13
  # display_name: gpjax
14
14
  # language: python
@@ -124,7 +124,7 @@ true_kernel = gpx.kernels.GraphKernel(
124
124
  prior = gpx.gps.Prior(mean_function=gpx.mean_functions.Zero(), kernel=true_kernel)
125
125
 
126
126
  fx = prior(x)
127
- y = fx.sample(seed=key, sample_shape=(1,)).reshape(-1, 1)
127
+ y = fx.sample(key=key, sample_shape=(1,)).reshape(-1, 1)
128
128
 
129
129
  D = gpx.Dataset(X=x, y=y)
130
130
 
@@ -194,8 +194,8 @@ opt_posterior, training_history = gpx.fit_scipy(
194
194
  initial_dist = likelihood(posterior(x, D))
195
195
  predictive_dist = opt_posterior.likelihood(opt_posterior(x, D))
196
196
 
197
- initial_mean = initial_dist.mean()
198
- learned_mean = predictive_dist.mean()
197
+ initial_mean = initial_dist.mean
198
+ learned_mean = predictive_dist.mean
199
199
 
200
200
  rmse = lambda ytrue, ypred: jnp.sum(jnp.sqrt(jnp.square(ytrue - ypred)))
201
201
 
@@ -7,7 +7,7 @@
7
7
  # extension: .py
8
8
  # format_name: percent
9
9
  # format_version: '1.3'
10
- # jupytext_version: 1.16.6
10
+ # jupytext_version: 1.16.7
11
11
  # kernelspec:
12
12
  # display_name: gpjax
13
13
  # language: python
@@ -127,9 +127,9 @@ import jax.numpy as jnp
127
127
  import jax.random as jr
128
128
  import matplotlib as mpl
129
129
  import matplotlib.pyplot as plt
130
+ import numpyro.distributions as npd
130
131
  import pandas as pd
131
132
  import seaborn as sns
132
- import tensorflow_probability.substrates.jax as tfp
133
133
 
134
134
  from examples.utils import (
135
135
  confidence_ellipse,
@@ -143,11 +143,10 @@ key = jr.key(42)
143
143
 
144
144
 
145
145
  cols = mpl.rcParams["axes.prop_cycle"].by_key()["color"]
146
- tfd = tfp.distributions
147
146
 
148
- ud1 = tfd.Normal(0.0, 1.0)
149
- ud2 = tfd.Normal(-1.0, 0.5)
150
- ud3 = tfd.Normal(0.25, 1.5)
147
+ ud1 = npd.Normal(0.0, 1.0)
148
+ ud2 = npd.Normal(-1.0, 0.5)
149
+ ud3 = npd.Normal(0.25, 1.5)
151
150
 
152
151
  xs = jnp.linspace(-5.0, 5.0, 500)
153
152
 
@@ -156,7 +155,7 @@ for d in [ud1, ud2, ud3]:
156
155
  ax.plot(
157
156
  xs,
158
157
  jnp.exp(d.log_prob(xs)),
159
- label=f"$\\mathcal{{N}}({{{float(d.mean())}}},\\ {{{float(d.stddev())}}}^2)$",
158
+ label=f"$\\mathcal{{N}}({{{float(d.mean)}}},\\ {{{float(jnp.sqrt(d.variance))}}}^2)$",
160
159
  )
161
160
  ax.fill_between(xs, jnp.zeros_like(xs), jnp.exp(d.log_prob(xs)), alpha=0.2)
162
161
  ax.legend(loc="best")
@@ -190,12 +189,12 @@ ax.legend(loc="best")
190
189
  # %%
191
190
  key = jr.key(123)
192
191
 
193
- d1 = tfd.MultivariateNormalDiag(loc=jnp.zeros(2), scale_diag=jnp.ones(2))
194
- d2 = tfd.MultivariateNormalTriL(
195
- jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, 0.9], [0.9, 1.0]]))
192
+ d1 = npd.MultivariateNormal(loc=jnp.zeros(2), covariance_matrix=jnp.diag(jnp.ones(2)))
193
+ d2 = npd.MultivariateNormal(
194
+ jnp.zeros(2), scale_tril=jnp.linalg.cholesky(jnp.array([[1.0, 0.9], [0.9, 1.0]]))
196
195
  )
197
- d3 = tfd.MultivariateNormalTriL(
198
- jnp.zeros(2), jnp.linalg.cholesky(jnp.array([[1.0, -0.5], [-0.5, 1.0]]))
196
+ d3 = npd.MultivariateNormal(
197
+ jnp.zeros(2), scale_tril=jnp.linalg.cholesky(jnp.array([[1.0, -0.5], [-0.5, 1.0]]))
199
198
  )
200
199
 
201
200
  dists = [d1, d2, d3]
@@ -215,13 +214,21 @@ titles = [r"$\rho = 0$", r"$\rho = 0.9$", r"$\rho = -0.5$"]
215
214
  cmap = mpl.colors.LinearSegmentedColormap.from_list("custom", ["white", cols[1]], N=256)
216
215
 
217
216
  for a, t, d in zip([ax0, ax1, ax2], titles, dists, strict=False):
218
- d_prob = d.prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)])).reshape(
219
- xx.shape
217
+ d_prob = jnp.exp(
218
+ d.log_prob(jnp.hstack([xx.reshape(-1, 1), yy.reshape(-1, 1)]))
219
+ ).reshape(xx.shape)
220
+ cntf = a.contourf(
221
+ xx,
222
+ yy,
223
+ jnp.exp(d_prob),
224
+ levels=20,
225
+ antialiased=True,
226
+ cmap=cmap,
227
+ edgecolor="face",
220
228
  )
221
- cntf = a.contourf(xx, yy, jnp.exp(d_prob), levels=20, antialiased=True, cmap=cmap, edgecolor="face")
222
229
  a.set_xlim(-2.75, 2.75)
223
230
  a.set_ylim(-2.75, 2.75)
224
- samples = d.sample(seed=key, sample_shape=(5000,))
231
+ samples = d.sample(key=key, sample_shape=(5000,))
225
232
  xsample, ysample = samples[:, 0], samples[:, 1]
226
233
  confidence_ellipse(
227
234
  xsample, ysample, a, edgecolor="#3f3f3f", n_std=1.0, linestyle="--", alpha=0.8
@@ -274,13 +281,13 @@ for a, t, d in zip([ax0, ax1, ax2], titles, dists, strict=False):
274
281
 
275
282
  # %%
276
283
  n = 1000
277
- x = tfd.Normal(loc=0.0, scale=1.0).sample(seed=key, sample_shape=(n,))
284
+ x = npd.Normal(loc=0.0, scale=1.0).sample(key, sample_shape=(n,))
278
285
  key, subkey = jr.split(key)
279
- y = tfd.Normal(loc=0.25, scale=0.5).sample(seed=subkey, sample_shape=(n,))
286
+ y = npd.Normal(loc=0.25, scale=0.5).sample(subkey, sample_shape=(n,))
280
287
  key, subkey = jr.split(subkey)
281
- xfull = tfd.Normal(loc=0.0, scale=1.0).sample(seed=subkey, sample_shape=(n * 10,))
288
+ xfull = npd.Normal(loc=0.0, scale=1.0).sample(subkey, sample_shape=(n * 10,))
282
289
  key, subkey = jr.split(subkey)
283
- yfull = tfd.Normal(loc=0.25, scale=0.5).sample(seed=subkey, sample_shape=(n * 10,))
290
+ yfull = npd.Normal(loc=0.25, scale=0.5).sample(subkey, sample_shape=(n * 10,))
284
291
  key, subkey = jr.split(subkey)
285
292
  df = pd.DataFrame({"x": x, "y": y, "idx": jnp.ones(n)})
286
293