gpjax 0.8.1__tar.gz → 0.9.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (86) hide show
  1. {gpjax-0.8.1 → gpjax-0.9.0}/PKG-INFO +14 -17
  2. {gpjax-0.8.1 → gpjax-0.9.0}/README.md +2 -6
  3. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/__init__.py +9 -6
  4. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/citation.py +0 -47
  5. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/dataset.py +19 -13
  6. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/decision_maker.py +16 -19
  7. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/posterior_handler.py +20 -23
  8. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/search_space.py +1 -1
  9. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/test_functions/continuous_functions.py +22 -10
  10. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/test_functions/non_conjugate_functions.py +8 -8
  11. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/utility_functions/__init__.py +8 -0
  12. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/utility_functions/base.py +9 -13
  13. gpjax-0.9.0/gpjax/decision_making/utility_functions/expected_improvement.py +112 -0
  14. gpjax-0.9.0/gpjax/decision_making/utility_functions/probability_of_improvement.py +125 -0
  15. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/utility_functions/thompson_sampling.py +4 -4
  16. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/utility_maximizer.py +7 -10
  17. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/utils.py +16 -2
  18. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/distributions.py +53 -60
  19. gpjax-0.9.0/gpjax/fit.py +351 -0
  20. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/gps.py +197 -197
  21. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/integrators.py +39 -41
  22. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/__init__.py +3 -0
  23. gpjax-0.9.0/gpjax/kernels/approximations/rff.py +101 -0
  24. gpjax-0.9.0/gpjax/kernels/base.py +339 -0
  25. gpjax-0.9.0/gpjax/kernels/computations/base.py +110 -0
  26. gpjax-0.9.0/gpjax/kernels/computations/basis_functions.py +74 -0
  27. gpjax-0.9.0/gpjax/kernels/computations/constant_diagonal.py +55 -0
  28. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/computations/dense.py +4 -17
  29. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/computations/diagonal.py +6 -31
  30. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/computations/eigen.py +12 -22
  31. gpjax-0.9.0/gpjax/kernels/non_euclidean/graph.py +113 -0
  32. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/non_euclidean/utils.py +4 -5
  33. gpjax-0.9.0/gpjax/kernels/nonstationary/arccosine.py +160 -0
  34. gpjax-0.9.0/gpjax/kernels/nonstationary/linear.py +79 -0
  35. gpjax-0.9.0/gpjax/kernels/nonstationary/polynomial.py +91 -0
  36. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/stationary/__init__.py +2 -0
  37. gpjax-0.9.0/gpjax/kernels/stationary/base.py +194 -0
  38. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/stationary/matern12.py +13 -30
  39. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/stationary/matern32.py +20 -38
  40. gpjax-0.9.0/gpjax/kernels/stationary/matern52.py +57 -0
  41. gpjax-0.9.0/gpjax/kernels/stationary/periodic.py +91 -0
  42. gpjax-0.9.0/gpjax/kernels/stationary/powered_exponential.py +94 -0
  43. gpjax-0.9.0/gpjax/kernels/stationary/rational_quadratic.py +86 -0
  44. gpjax-0.9.0/gpjax/kernels/stationary/rbf.py +48 -0
  45. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/stationary/utils.py +0 -1
  46. gpjax-0.9.0/gpjax/kernels/stationary/white.py +64 -0
  47. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/likelihoods.py +69 -50
  48. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/lower_cholesky.py +23 -14
  49. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/mean_functions.py +44 -47
  50. gpjax-0.9.0/gpjax/objectives.py +417 -0
  51. gpjax-0.9.0/gpjax/parameters.py +167 -0
  52. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/scan.py +3 -7
  53. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/typing.py +5 -3
  54. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/variational_families.py +207 -172
  55. {gpjax-0.8.1 → gpjax-0.9.0}/pyproject.toml +42 -63
  56. gpjax-0.8.1/gpjax/base/__init__.py +0 -38
  57. gpjax-0.8.1/gpjax/base/module.py +0 -416
  58. gpjax-0.8.1/gpjax/base/param.py +0 -73
  59. gpjax-0.8.1/gpjax/fit.py +0 -321
  60. gpjax-0.8.1/gpjax/flax_base/bijectors.py +0 -8
  61. gpjax-0.8.1/gpjax/flax_base/param.py +0 -16
  62. gpjax-0.8.1/gpjax/flax_base/types.py +0 -15
  63. gpjax-0.8.1/gpjax/kernels/approximations/rff.py +0 -92
  64. gpjax-0.8.1/gpjax/kernels/base.py +0 -208
  65. gpjax-0.8.1/gpjax/kernels/computations/base.py +0 -89
  66. gpjax-0.8.1/gpjax/kernels/computations/basis_functions.py +0 -90
  67. gpjax-0.8.1/gpjax/kernels/computations/constant_diagonal.py +0 -92
  68. gpjax-0.8.1/gpjax/kernels/non_euclidean/graph.py +0 -107
  69. gpjax-0.8.1/gpjax/kernels/nonstationary/arccosine.py +0 -122
  70. gpjax-0.8.1/gpjax/kernels/nonstationary/linear.py +0 -59
  71. gpjax-0.8.1/gpjax/kernels/nonstationary/polynomial.py +0 -66
  72. gpjax-0.8.1/gpjax/kernels/stationary/matern52.py +0 -75
  73. gpjax-0.8.1/gpjax/kernels/stationary/periodic.py +0 -64
  74. gpjax-0.8.1/gpjax/kernels/stationary/powered_exponential.py +0 -69
  75. gpjax-0.8.1/gpjax/kernels/stationary/rational_quadratic.py +0 -62
  76. gpjax-0.8.1/gpjax/kernels/stationary/rbf.py +0 -66
  77. gpjax-0.8.1/gpjax/kernels/stationary/white.py +0 -62
  78. gpjax-0.8.1/gpjax/objectives.py +0 -498
  79. {gpjax-0.8.1 → gpjax-0.9.0}/LICENSE +0 -0
  80. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/__init__.py +0 -0
  81. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/decision_making/test_functions/__init__.py +0 -0
  82. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/approximations/__init__.py +0 -0
  83. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/computations/__init__.py +0 -0
  84. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/non_euclidean/__init__.py +0 -0
  85. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/kernels/nonstationary/__init__.py +0 -0
  86. {gpjax-0.8.1 → gpjax-0.9.0}/gpjax/progress_bar.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gpjax
3
- Version: 0.8.1
3
+ Version: 0.9.0
4
4
  Summary: Gaussian processes in JAX.
5
5
  Home-page: https://github.com/JaxGaussianProcesses/GPJax
6
6
  License: Apache-2.0
@@ -12,16 +12,17 @@ Classifier: License :: OSI Approved :: Apache Software License
12
12
  Classifier: Programming Language :: Python :: 3
13
13
  Classifier: Programming Language :: Python :: 3.10
14
14
  Classifier: Programming Language :: Python :: 3.11
15
- Requires-Dist: beartype (>=0.16.2,<0.17.0)
16
- Requires-Dist: cola-ml (>=0.0.5,<0.0.6)
17
- Requires-Dist: jax (>=0.4.16)
18
- Requires-Dist: jaxlib (>=0.4.16)
19
- Requires-Dist: jaxtyping (>=0.2.15,<0.3.0)
20
- Requires-Dist: optax (>=0.1.4,<0.2.0)
21
- Requires-Dist: orbax-checkpoint (>=0.2.3)
22
- Requires-Dist: simple-pytree (>=0.1.7,<0.2.0)
23
- Requires-Dist: tensorflow-probability (>=0.22.0,<0.23.0)
24
- Requires-Dist: tqdm (>=4.65.0,<5.0.0)
15
+ Requires-Dist: beartype (>=0.16.1,<0.17.0)
16
+ Requires-Dist: cola-ml (==0.0.5)
17
+ Requires-Dist: flax (>=0.8.4,<0.9.0)
18
+ Requires-Dist: jax (<0.4.28)
19
+ Requires-Dist: jaxlib (<0.4.28)
20
+ Requires-Dist: jaxopt (>=0.8.3,<0.9.0)
21
+ Requires-Dist: jaxtyping (>=0.2.10,<0.3.0)
22
+ Requires-Dist: numpy (<2.0.0)
23
+ Requires-Dist: optax (>=0.2.1,<0.3.0)
24
+ Requires-Dist: tensorflow-probability (>=0.24.0,<0.25.0)
25
+ Requires-Dist: tqdm (>=4.66.2,<5.0.0)
25
26
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
26
27
  Project-URL: Repository, https://github.com/JaxGaussianProcesses/GPJax
27
28
  Description-Content-Type: text/markdown
@@ -100,10 +101,9 @@ helped to shape GPJax into the package it is today.
100
101
  ## Notebook examples
101
102
 
102
103
  > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
103
- > - [**Classification with MCMC**](https://docs.jaxgaussianprocesses.com/examples/classification/)
104
+ > - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
104
105
  > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
105
106
  > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
106
- > - [**BlackJax Integration**](https://docs.jaxgaussianprocesses.com/examples/classification/#mcmc-inference)
107
107
  > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
108
108
  > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
109
109
  > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
@@ -174,13 +174,10 @@ posterior = prior * likelihood
174
174
  # Define an optimiser
175
175
  optimiser = ox.adam(learning_rate=1e-2)
176
176
 
177
- # Define the marginal log-likelihood
178
- negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))
179
-
180
177
  # Obtain Type 2 MLEs of the hyperparameters
181
178
  opt_posterior, history = gpx.fit(
182
179
  model=posterior,
183
- objective=negative_mll,
180
+ objective=gpx.objectives.conjugate_mll,
184
181
  train_data=D,
185
182
  optim=optimiser,
186
183
  num_iters=500,
@@ -72,10 +72,9 @@ helped to shape GPJax into the package it is today.
72
72
  ## Notebook examples
73
73
 
74
74
  > - [**Conjugate Inference**](https://docs.jaxgaussianprocesses.com/examples/regression/)
75
- > - [**Classification with MCMC**](https://docs.jaxgaussianprocesses.com/examples/classification/)
75
+ > - [**Classification**](https://docs.jaxgaussianprocesses.com/examples/classification/)
76
76
  > - [**Sparse Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/collapsed_vi/)
77
77
  > - [**Stochastic Variational Inference**](https://docs.jaxgaussianprocesses.com/examples/uncollapsed_vi/)
78
- > - [**BlackJax Integration**](https://docs.jaxgaussianprocesses.com/examples/classification/#mcmc-inference)
79
78
  > - [**Laplace Approximation**](https://docs.jaxgaussianprocesses.com/examples/classification/#laplace-approximation)
80
79
  > - [**Inference on Non-Euclidean Spaces**](https://docs.jaxgaussianprocesses.com/examples/constructing_new_kernels/#custom-kernel)
81
80
  > - [**Inference on Graphs**](https://docs.jaxgaussianprocesses.com/examples/graph_kernels/)
@@ -146,13 +145,10 @@ posterior = prior * likelihood
146
145
  # Define an optimiser
147
146
  optimiser = ox.adam(learning_rate=1e-2)
148
147
 
149
- # Define the marginal log-likelihood
150
- negative_mll = jit(gpx.objectives.ConjugateMLL(negative=True))
151
-
152
148
  # Obtain Type 2 MLEs of the hyperparameters
153
149
  opt_posterior, history = gpx.fit(
154
150
  model=posterior,
155
- objective=negative_mll,
151
+ objective=gpx.objectives.conjugate_mll,
156
152
  train_data=D,
157
153
  optim=optimiser,
158
154
  num_iters=500,
@@ -12,8 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
+ from warnings import filterwarnings
16
+
17
+ from beartype.roar import BeartypeDecorHintPep585DeprecationWarning
18
+
19
+ filterwarnings("ignore", category=BeartypeDecorHintPep585DeprecationWarning)
20
+
15
21
  from gpjax import (
16
- base,
17
22
  decision_making,
18
23
  gps,
19
24
  integrators,
@@ -21,12 +26,9 @@ from gpjax import (
21
26
  likelihoods,
22
27
  mean_functions,
23
28
  objectives,
29
+ parameters,
24
30
  variational_families,
25
31
  )
26
- from gpjax.base import (
27
- Module,
28
- param_field,
29
- )
30
32
  from gpjax.citation import cite
31
33
  from gpjax.dataset import Dataset
32
34
  from gpjax.fit import (
@@ -38,7 +40,7 @@ __license__ = "MIT"
38
40
  __description__ = "Didactic Gaussian processes in JAX"
39
41
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
40
42
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
41
- __version__ = "0.8.1"
43
+ __version__ = "0.9.0"
42
44
 
43
45
  __all__ = [
44
46
  "base",
@@ -49,6 +51,7 @@ __all__ = [
49
51
  "likelihoods",
50
52
  "mean_functions",
51
53
  "objectives",
54
+ "parameters",
52
55
  "variational_families",
53
56
  "Dataset",
54
57
  "cite",
@@ -23,13 +23,6 @@ from gpjax.kernels import (
23
23
  Matern32,
24
24
  Matern52,
25
25
  )
26
- from gpjax.objectives import (
27
- ELBO,
28
- CollapsedELBO,
29
- ConjugateMLL,
30
- LogPosteriorDensity,
31
- NonConjugateMLL,
32
- )
33
26
 
34
27
  CitationType = Union[None, str, Dict[str, str]]
35
28
 
@@ -158,46 +151,6 @@ def _(tree) -> PaperCitation:
158
151
  )
159
152
 
160
153
 
161
- ####################
162
- # Objective citations
163
- ####################
164
- @cite.register(ConjugateMLL)
165
- @cite.register(NonConjugateMLL)
166
- @cite.register(LogPosteriorDensity)
167
- def _(tree) -> BookCitation:
168
- return BookCitation(
169
- citation_key="rasmussen2006gaussian",
170
- title="Gaussian Processes for Machine Learning",
171
- authors="Rasmussen, Carl Edward and Williams, Christopher K",
172
- year="2006",
173
- publisher="MIT press Cambridge, MA",
174
- volume="2",
175
- )
176
-
177
-
178
- @cite.register(CollapsedELBO)
179
- def _(tree) -> PaperCitation:
180
- return PaperCitation(
181
- citation_key="titsias2009variational",
182
- title="Variational learning of inducing variables in sparse Gaussian processes",
183
- authors="Titsias, Michalis",
184
- year="2009",
185
- booktitle="International Conference on Artificial Intelligence and Statistics",
186
- )
187
-
188
-
189
- @cite.register(ELBO)
190
- def _(tree) -> PaperCitation:
191
- return PaperCitation(
192
- citation_key="hensman2013gaussian",
193
- title="Gaussian Processes for Big Data",
194
- authors="Hensman, James and Fusi, Nicolo and Lawrence, Neil D",
195
- year="2013",
196
- booktitle="Uncertainty in Artificial Intelligence",
197
- citation_type="article",
198
- )
199
-
200
-
201
154
  ####################
202
155
  # Decision making citations
203
156
  ####################
@@ -12,40 +12,39 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
16
15
  from dataclasses import dataclass
17
16
  import warnings
18
17
 
19
18
  from beartype.typing import Optional
19
+ import jax
20
20
  import jax.numpy as jnp
21
21
  from jaxtyping import Num
22
- from simple_pytree import Pytree
23
22
 
24
23
  from gpjax.typing import Array
25
24
 
26
25
 
27
26
  @dataclass
28
- class Dataset(Pytree):
27
+ @jax.tree_util.register_pytree_node_class
28
+ class Dataset:
29
29
  r"""Base class for datasets.
30
30
 
31
- Attributes
32
- ----------
33
- X (Optional[Num[Array, "N D"]]): input data.
34
- y (Optional[Num[Array, "N Q"]]): output data.
31
+ Args:
32
+ X: input data.
33
+ y: output data.
35
34
  """
36
35
 
37
36
  X: Optional[Num[Array, "N D"]] = None
38
37
  y: Optional[Num[Array, "N Q"]] = None
39
38
 
40
39
  def __post_init__(self) -> None:
41
- r"""Checks that the shapes of $`X`$ and $`y`$ are compatible,
42
- and provides warnings regarding the precision of $`X`$ and $`y`$."""
40
+ r"""Checks that the shapes of $X$ and $y$ are compatible,
41
+ and provides warnings regarding the precision of $X$ and $y$."""
43
42
  _check_shape(self.X, self.y)
44
43
  _check_precision(self.X, self.y)
45
44
 
46
45
  def __repr__(self) -> str:
47
46
  r"""Returns a string representation of the dataset."""
48
- repr = f"- Number of observations: {self.n}\n- Input dimension: {self.in_dim}"
47
+ repr = f"Dataset(Number of observations: {self.n:=} - Input dimension: {self.in_dim})"
49
48
  return repr
50
49
 
51
50
  def is_supervised(self) -> bool:
@@ -76,14 +75,21 @@ class Dataset(Pytree):
76
75
 
77
76
  @property
78
77
  def in_dim(self) -> int:
79
- r"""Dimension of the inputs, $`X`$."""
78
+ r"""Dimension of the inputs, $X$."""
80
79
  return self.X.shape[1]
81
80
 
81
+ def tree_flatten(self):
82
+ return (self.X, self.y), None
83
+
84
+ @classmethod
85
+ def tree_unflatten(cls, aux_data, children):
86
+ return cls(*children)
87
+
82
88
 
83
89
  def _check_shape(
84
90
  X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
85
91
  ) -> None:
86
- r"""Checks that the shapes of $`X`$ and $`y`$ are compatible."""
92
+ r"""Checks that the shapes of $X$ and $y$ are compatible."""
87
93
  if X is not None and y is not None and X.shape[0] != y.shape[0]:
88
94
  raise ValueError(
89
95
  "Inputs, X, and outputs, y, must have the same number of rows."
@@ -104,7 +110,7 @@ def _check_shape(
104
110
  def _check_precision(
105
111
  X: Optional[Num[Array, "..."]], y: Optional[Num[Array, "..."]]
106
112
  ) -> None:
107
- r"""Checks the precision of $`X`$ and $`y`."""
113
+ r"""Checks the precision of $X$ and $y`."""
108
114
  if X is not None and X.dtype != jnp.float64:
109
115
  warnings.warn(
110
116
  "X is not of type float64. "
@@ -58,26 +58,23 @@ class AbstractDecisionMaker(ABC):
58
58
  the black-box function of interest at this point.
59
59
 
60
60
  Attributes:
61
- search_space (AbstractSearchSpace): Search space over which we can evaluate the
62
- function(s) of interest.
63
- posterior_handlers (Dict[str, PosteriorHandler]): Dictionary of posterior
64
- handlers, which are used to update posteriors throughout the decision making
65
- loop. Note that the word `posteriors` is used for consistency with GPJax, but these
66
- objects are typically referred to as `models` in the model-based decision
67
- making literature. Tags are used to distinguish between posteriors. In a typical
68
- Bayesian optimisation setup one of the tags will be `OBJECTIVE`, defined in
69
- decision_making.utils.
70
- datasets (Dict[str, Dataset]): Dictionary of datasets, which are augmented with
61
+ search_space: Search space over which we can evaluate the function(s) of interest.
62
+ posterior_handlers: dictionary of posterior handlers, which are used to update
63
+ posteriors throughout the decision making loop. Note that the word `posteriors`
64
+ is used for consistency with GPJax, but these objects are typically referred to
65
+ as `models` in the model-based decision making literature. Tags are used to
66
+ distinguish between posteriors. In a typical Bayesian optimisation setup one of
67
+ the tags will be `OBJECTIVE`, defined in `decision_making.utils`.
68
+ datasets: dictionary of datasets, which are augmented with
71
69
  observations throughout the decision making loop. In a typical setup they are
72
70
  also used to update the posteriors, using the `posterior_handlers`. Tags are used
73
71
  to distinguish datasets, and correspond to tags in `posterior_handlers`.
74
- key (KeyArray): JAX random key, used to generate random numbers.
75
- batch_size (int): Number of points to query at each step of the decision making
72
+ key: JAX random key, used to generate random numbers.
73
+ batch_size: Number of points to query at each step of the decision making
76
74
  loop. Note that `SinglePointUtilityFunction`s are only capable of generating
77
75
  one point to be queried at each iteration of the decision making loop.
78
- post_ask (List[Callable]): List of functions to be executed after each ask step.
79
- post_tell (List[Callable]): List of functions to be executed after each tell
80
- step.
76
+ post_ask: List of functions to be executed after each ask step.
77
+ post_tell: List of functions to be executed after each tell step.
81
78
  """
82
79
 
83
80
  search_space: AbstractSearchSpace
@@ -140,10 +137,10 @@ class AbstractDecisionMaker(ABC):
140
137
  Add newly observed data to datasets and update the corresponding posteriors.
141
138
 
142
139
  Args:
143
- observation_datasets (Mapping[str, Dataset]): Dictionary of datasets
144
- containing new observations. Tags are used to distinguish datasets, and
145
- correspond to tags in `posterior_handlers` and `self.datasets`.
146
- key (KeyArray): JAX PRNG key for controlling random state.
140
+ observation_datasets: dictionary of datasets containing new observations.
141
+ Tags are used to distinguish datasets, and correspond to tags in
142
+ `posterior_handlers` and `self.datasets`.
143
+ key: JAX PRNG key for controlling random state.
147
144
  """
148
145
  if observation_datasets.keys() != self.datasets.keys():
149
146
  raise ValueError(
@@ -27,7 +27,7 @@ from gpjax.gps import (
27
27
  AbstractPosterior,
28
28
  AbstractPrior,
29
29
  )
30
- from gpjax.objectives import AbstractObjective
30
+ from gpjax.objectives import Objective
31
31
  from gpjax.typing import KeyArray
32
32
 
33
33
  LikelihoodBuilder = Callable[[int], AbstractLikelihood]
@@ -42,21 +42,18 @@ class PosteriorHandler:
42
42
  observed.
43
43
 
44
44
  Attributes:
45
- prior (AbstractPrior): Prior to use when forming the posterior.
46
- likelihood_builder (LikelihoodBuilder): Function which takes the number of
47
- datapoints as input and returns a likelihood object initialised with the given
48
- number of datapoints.
49
- optimization_objective (AbstractObjective): Objective to use for optimizing the
50
- posterior hyperparameters.
51
- optimizer (ox.GradientTransformation): Optax optimizer to use for optimizing the
52
- posterior hyperparameters.
53
- num_optimization_iterations (int): Number of iterations to optimize
54
- the posterior hyperparameters for.
45
+ prior: prior to use when forming the posterior.
46
+ likelihood_builder: function which takes the number of datapoints as input and
47
+ returns a likelihood object initialised with the given number of datapoints.
48
+ optimization_objective: objective to use for optimizing the posterior hyperparameters.
49
+ optimizer: an optax optimizer to use for optimizing the posterior hyperparameters.
50
+ num_optimization_iterations: the number of iterations to optimize
51
+ the posterior hyperparameters for.
55
52
  """
56
53
 
57
54
  prior: AbstractPrior
58
55
  likelihood_builder: LikelihoodBuilder
59
- optimization_objective: AbstractObjective
56
+ optimization_objective: Objective
60
57
  optimizer: ox.GradientTransformation
61
58
  num_optimization_iters: int
62
59
 
@@ -71,10 +68,10 @@ class PosteriorHandler:
71
68
  Initialise (and optionally optimize) a posterior using the given dataset.
72
69
 
73
70
  Args:
74
- dataset (Dataset): Dataset to get posterior for.
75
- optimize (bool): Whether to optimize the posterior hyperparameters.
76
- key (Optional[KeyArray]): A JAX PRNG key which is used for optimizing the posterior
77
- hyperparameters.
71
+ dataset: dataset to get posterior for.
72
+ optimize: whether to optimize the posterior hyperparameters.
73
+ key: a JAX PRNG key which is used for optimizing the posterior
74
+ hyperparameters.
78
75
 
79
76
  Returns:
80
77
  Posterior for the given dataset.
@@ -108,14 +105,14 @@ class PosteriorHandler:
108
105
  set as in the `likelihood_builder` function.
109
106
 
110
107
  Args:
111
- dataset: Dataset to get posterior for.
112
- previous_posterior: Posterior being updated. This is supplied as one may
113
- wish to simply increase the number of datapoints in the likelihood, without
114
- optimizing the posterior hyperparameters, in which case the previous
115
- posterior can be used to obtain the previously set prior hyperparameters.
116
- optimize: Whether to optimize the posterior hyperparameters.
108
+ dataset: dataset to get posterior for.
109
+ previous_posterior: posterior being updated. This is supplied as one may
110
+ wish to simply increase the number of datapoints in the likelihood, without
111
+ optimizing the posterior hyperparameters, in which case the previous
112
+ posterior can be used to obtain the previously set prior hyperparameters.
113
+ optimize: whether to optimize the posterior hyperparameters.
117
114
  key: A JAX PRNG key which is used for optimizing the posterior
118
- hyperparameters.
115
+ hyperparameters.
119
116
  """
120
117
  posterior = previous_posterior.prior * self.likelihood_builder(dataset.n)
121
118
 
@@ -56,7 +56,7 @@ class AbstractSearchSpace(ABC):
56
56
 
57
57
  @dataclass
58
58
  class ContinuousSearchSpace(AbstractSearchSpace):
59
- """The `ContinuousSearchSpace` class is used to bound the domain of continuous real functions of dimension $`D`$."""
59
+ """The `ContinuousSearchSpace` class is used to bound the domain of continuous real functions of dimension $D$."""
60
60
 
61
61
  lower_bounds: Float[Array, " D"]
62
62
  upper_bounds: Float[Array, " D"]
@@ -12,24 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
- from abc import (
16
- ABC,
17
- abstractmethod,
18
- )
15
+ from abc import abstractmethod
19
16
  from dataclasses import dataclass
20
17
 
21
18
  import jax.numpy as jnp
22
19
  from jaxtyping import (
23
20
  Array,
24
21
  Float,
22
+ Num,
25
23
  )
24
+ import tensorflow_probability.substrates.jax as tfp
26
25
 
27
26
  from gpjax.dataset import Dataset
28
27
  from gpjax.decision_making.search_space import ContinuousSearchSpace
28
+ from gpjax.gps import AbstractMeanFunction
29
29
  from gpjax.typing import KeyArray
30
30
 
31
31
 
32
- class AbstractContinuousTestFunction(ABC):
32
+ class AbstractContinuousTestFunction(AbstractMeanFunction):
33
33
  """
34
34
  Abstract base class for continuous test functions.
35
35
 
@@ -43,19 +43,28 @@ class AbstractContinuousTestFunction(ABC):
43
43
  minimizer: Float[Array, "1 D"]
44
44
  minimum: Float[Array, "1 1"]
45
45
 
46
- def generate_dataset(self, num_points: int, key: KeyArray) -> Dataset:
46
+ def generate_dataset(
47
+ self, num_points: int, key: KeyArray, obs_stddev: float = 0.0
48
+ ) -> Dataset:
47
49
  """
48
50
  Generate a toy dataset from the test function.
49
51
 
50
52
  Args:
51
53
  num_points (int): Number of points to sample.
52
54
  key (KeyArray): JAX PRNG key.
55
+ obs_stddev (float): (Optional) standard deviation of Gaussian distributed
56
+ noise added to observations.
53
57
 
54
58
  Returns:
55
59
  Dataset: Dataset of points sampled from the test function.
56
60
  """
57
61
  X = self.search_space.sample(num_points=num_points, key=key)
58
- y = self.evaluate(X)
62
+ gaussian_noise = tfp.distributions.Normal(
63
+ jnp.zeros(num_points), obs_stddev * jnp.ones(num_points)
64
+ )
65
+ y = self.evaluate(X) + jnp.transpose(
66
+ gaussian_noise.sample(sample_shape=[1], seed=key)
67
+ )
59
68
  return Dataset(X=X, y=y)
60
69
 
61
70
  def generate_test_points(
@@ -73,6 +82,9 @@ class AbstractContinuousTestFunction(ABC):
73
82
  """
74
83
  return self.search_space.sample(num_points=num_points, key=key)
75
84
 
85
+ def __call__(self, x: Num[Array, "N D"]) -> Float[Array, "N 1"]:
86
+ return self.evaluate(x)
87
+
76
88
  @abstractmethod
77
89
  def evaluate(self, x: Float[Array, "N D"]) -> Float[Array, "N 1"]:
78
90
  """
@@ -92,7 +104,7 @@ class Forrester(AbstractContinuousTestFunction):
92
104
  """
93
105
  Forrester function introduced in 'Engineering design via surrogate modelling: a
94
106
  practical guide' (Forrester et al. 2008), rescaled to have zero mean and unit
95
- variance over $`[0, 1]`$.
107
+ variance over $[0, 1]$.
96
108
  """
97
109
 
98
110
  search_space = ContinuousSearchSpace(
@@ -113,7 +125,7 @@ class LogarithmicGoldsteinPrice(AbstractContinuousTestFunction):
113
125
  """
114
126
  Logarithmic Goldstein-Price function introduced in 'A benchmark of kriging-based
115
127
  infill criteria for noisy optimization' (Picheny et al. 2013), which has zero mean
116
- and unit variance over $`[0, 1]^2`$.
128
+ and unit variance over $[0, 1]^2$.
117
129
  """
118
130
 
119
131
  search_space = ContinuousSearchSpace(
@@ -148,7 +160,7 @@ class LogarithmicGoldsteinPrice(AbstractContinuousTestFunction):
148
160
  @dataclass
149
161
  class Quadratic(AbstractContinuousTestFunction):
150
162
  """
151
- Toy quadratic function defined over $`[0, 1]`$.
163
+ Toy quadratic function defined over $[0, 1]$.
152
164
  """
153
165
 
154
166
  search_space = ContinuousSearchSpace(
@@ -17,15 +17,15 @@ from dataclasses import dataclass
17
17
 
18
18
  import jax.numpy as jnp
19
19
  import jax.random as jr
20
- from jaxtyping import (
21
- Array,
22
- Float,
23
- Integer,
24
- )
25
20
 
26
21
  from gpjax.dataset import Dataset
27
22
  from gpjax.decision_making.search_space import ContinuousSearchSpace
28
- from gpjax.typing import KeyArray
23
+ from gpjax.typing import (
24
+ Array,
25
+ Float,
26
+ Int,
27
+ KeyArray,
28
+ )
29
29
 
30
30
 
31
31
  @dataclass
@@ -74,7 +74,7 @@ class PoissonTestFunction:
74
74
  return self.search_space.sample(num_points=num_points, key=key)
75
75
 
76
76
  @abstractmethod
77
- def evaluate(self, x: Float[Array, "N 1"]) -> Integer[Array, "N 1"]:
77
+ def evaluate(self, x: Float[Array, "N 1"]) -> Int[Array, "N 1"]:
78
78
  """
79
79
  Evaluate the test function at a set of points. Function taken from
80
80
  https://docs.jaxgaussianprocesses.com/examples/poisson/#dataset.
@@ -83,7 +83,7 @@ class PoissonTestFunction:
83
83
  x (Float[Array, 'N D']): Points to evaluate the test function at.
84
84
 
85
85
  Returns:
86
- Integer[Array, 'N 1']: Values of the test function at the points.
86
+ Float[Array, 'N 1']: Values of the test function at the points.
87
87
  """
88
88
  key = jr.key(42)
89
89
  f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x
@@ -18,12 +18,20 @@ from gpjax.decision_making.utility_functions.base import (
18
18
  SinglePointUtilityFunction,
19
19
  UtilityFunction,
20
20
  )
21
+ from gpjax.decision_making.utility_functions.expected_improvement import (
22
+ ExpectedImprovement,
23
+ )
24
+ from gpjax.decision_making.utility_functions.probability_of_improvement import (
25
+ ProbabilityOfImprovement,
26
+ )
21
27
  from gpjax.decision_making.utility_functions.thompson_sampling import ThompsonSampling
22
28
 
23
29
  __all__ = [
24
30
  "UtilityFunction",
25
31
  "AbstractUtilityFunctionBuilder",
26
32
  "AbstractSinglePointUtilityFunctionBuilder",
33
+ "ExpectedImprovement",
27
34
  "SinglePointUtilityFunction",
28
35
  "ThompsonSampling",
36
+ "ProbabilityOfImprovement",
29
37
  ]
@@ -35,8 +35,8 @@ from gpjax.typing import (
35
35
  SinglePointUtilityFunction = Callable[[Float[Array, "N D"]], Float[Array, "N 1"]]
36
36
  """
37
37
  Type alias for utility functions which don't support batching, and instead characterise
38
- the utility of querying a single point, rather than a batch of points. They take an array of points of shape $`[N, D]`$
39
- and return the value of the utility function at each point in an array of shape $`[N, 1]`$.
38
+ the utility of querying a single point, rather than a batch of points. They take an array of points of shape $[N, D]$
39
+ and return the value of the utility function at each point in an array of shape $[N, 1]$.
40
40
  """
41
41
 
42
42
 
@@ -65,14 +65,12 @@ class AbstractSinglePointUtilityFunctionBuilder(ABC):
65
65
  datasets.
66
66
 
67
67
  Args:
68
- posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
69
- used to form the utility function.
70
- datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
71
- to form the utility function.
68
+ posteriors: dictionary of posteriors to be used to form the utility function.
69
+ datasets: dictionary of datasets which may be used to form the utility function.
72
70
 
73
71
  Raises:
74
72
  ValueError: If the objective posterior or dataset are not present in the
75
- posteriors or datasets.
73
+ posteriors or datasets.
76
74
  """
77
75
  if OBJECTIVE not in posteriors.keys():
78
76
  raise ValueError("Objective posterior not found in posteriors")
@@ -90,15 +88,13 @@ class AbstractSinglePointUtilityFunctionBuilder(ABC):
90
88
  Build a `UtilityFunction` from a set of posteriors and datasets.
91
89
 
92
90
  Args:
93
- posteriors (Mapping[str, AbstractPosterior]): Dictionary of posteriors to be
94
- used to form the utility function.
95
- datasets (Mapping[str, Dataset]): Dictionary of datasets which may be used
96
- to form the utility function.
97
- key (KeyArray): JAX PRNG key used for random number generation.
91
+ posteriors: dictionary of posteriors to be used to form the utility function.
92
+ datasets: dictionary of datasets which may be used to form the utility function.
93
+ key: JAX PRNG key used for random number generation.
98
94
 
99
95
  Returns:
100
96
  SinglePointUtilityFunction: Utility function to be *maximised* in order to
101
- decide which point to query next.
97
+ decide which point to query next.
102
98
  """
103
99
  raise NotImplementedError
104
100