gpjax 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- gpjax/__init__.py +1 -1
- gpjax/decision_making/test_functions/non_conjugate_functions.py +1 -1
- gpjax/fit.py +31 -29
- gpjax/flax_base/bijectors.py +8 -0
- gpjax/flax_base/param.py +16 -0
- gpjax/flax_base/types.py +15 -0
- gpjax/gps.py +69 -40
- gpjax/likelihoods.py +2 -2
- {gpjax-0.8.0.dist-info → gpjax-0.8.2.dist-info}/METADATA +6 -6
- {gpjax-0.8.0.dist-info → gpjax-0.8.2.dist-info}/RECORD +12 -9
- {gpjax-0.8.0.dist-info → gpjax-0.8.2.dist-info}/WHEEL +1 -1
- {gpjax-0.8.0.dist-info → gpjax-0.8.2.dist-info}/LICENSE +0 -0
gpjax/__init__.py
CHANGED
|
@@ -38,7 +38,7 @@ __license__ = "MIT"
|
|
|
38
38
|
__description__ = "Didactic Gaussian processes in JAX"
|
|
39
39
|
__url__ = "https://github.com/JaxGaussianProcesses/GPJax"
|
|
40
40
|
__contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
|
|
41
|
-
__version__ = "0.8.
|
|
41
|
+
__version__ = "0.8.2"
|
|
42
42
|
|
|
43
43
|
__all__ = [
|
|
44
44
|
"base",
|
gpjax/fit.py
CHANGED
|
@@ -23,11 +23,16 @@ from beartype.typing import (
|
|
|
23
23
|
Union,
|
|
24
24
|
)
|
|
25
25
|
import jax
|
|
26
|
+
from jax import (
|
|
27
|
+
jit,
|
|
28
|
+
value_and_grad,
|
|
29
|
+
)
|
|
26
30
|
from jax._src.random import _check_prng_key
|
|
31
|
+
from jax.flatten_util import ravel_pytree
|
|
27
32
|
import jax.numpy as jnp
|
|
28
33
|
import jax.random as jr
|
|
29
|
-
import jaxopt
|
|
30
34
|
import optax as ox
|
|
35
|
+
import scipy
|
|
31
36
|
|
|
32
37
|
from gpjax.base import Module
|
|
33
38
|
from gpjax.dataset import Dataset
|
|
@@ -42,10 +47,6 @@ from gpjax.typing import (
|
|
|
42
47
|
ModuleModel = TypeVar("ModuleModel", bound=Module)
|
|
43
48
|
|
|
44
49
|
|
|
45
|
-
class FailedScipyFitError(Exception):
|
|
46
|
-
"""Raised a model fit using Scipy fails"""
|
|
47
|
-
|
|
48
|
-
|
|
49
50
|
def fit( # noqa: PLR0913
|
|
50
51
|
*,
|
|
51
52
|
model: ModuleModel,
|
|
@@ -72,7 +73,7 @@ def fit( # noqa: PLR0913
|
|
|
72
73
|
>>>
|
|
73
74
|
>>> # (1) Create a dataset:
|
|
74
75
|
>>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
|
|
75
|
-
>>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.
|
|
76
|
+
>>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
|
|
76
77
|
>>> D = gpx.Dataset(X, y)
|
|
77
78
|
>>>
|
|
78
79
|
>>> # (2) Define your model:
|
|
@@ -110,7 +111,7 @@ def fit( # noqa: PLR0913
|
|
|
110
111
|
batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1
|
|
111
112
|
(i.e. full batch).
|
|
112
113
|
key (Optional[KeyArray]): The random key to use for the optimisation batch
|
|
113
|
-
selection. Defaults to jr.
|
|
114
|
+
selection. Defaults to jr.key(42).
|
|
114
115
|
log_rate (Optional[int]): How frequently the objective function's value should
|
|
115
116
|
be printed. Defaults to 10.
|
|
116
117
|
verbose (Optional[bool]): Whether to print the training loading bar. Defaults
|
|
@@ -130,7 +131,7 @@ def fit( # noqa: PLR0913
|
|
|
130
131
|
_check_optim(optim)
|
|
131
132
|
_check_num_iters(num_iters)
|
|
132
133
|
_check_batch_size(batch_size)
|
|
133
|
-
_check_prng_key(key)
|
|
134
|
+
_check_prng_key("fit", key)
|
|
134
135
|
_check_log_rate(log_rate)
|
|
135
136
|
_check_verbose(verbose)
|
|
136
137
|
|
|
@@ -214,30 +215,31 @@ def fit_scipy( # noqa: PLR0913
|
|
|
214
215
|
model = model.unconstrain()
|
|
215
216
|
|
|
216
217
|
# Unconstrained space loss function with stop-gradient rule for non-trainable params.
|
|
217
|
-
def loss(model: Module
|
|
218
|
+
def loss(model: Module) -> ScalarFloat:
|
|
218
219
|
model = model.stop_gradient()
|
|
219
|
-
return objective(model.constrain(),
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
220
|
+
return objective(model.constrain(), train_data)
|
|
221
|
+
|
|
222
|
+
# convert to numpy for interface with scipy
|
|
223
|
+
x0, scipy_to_jnp = ravel_pytree(model)
|
|
224
|
+
|
|
225
|
+
@jit
|
|
226
|
+
def scipy_wrapper(x0):
|
|
227
|
+
value, grads = value_and_grad(loss)(scipy_to_jnp(jnp.array(x0)))
|
|
228
|
+
scipy_grads = ravel_pytree(grads)[0]
|
|
229
|
+
return value, scipy_grads
|
|
230
|
+
|
|
231
|
+
history = [scipy_wrapper(x0)[0]]
|
|
232
|
+
result = scipy.optimize.minimize(
|
|
233
|
+
fun=scipy_wrapper,
|
|
234
|
+
x0=x0,
|
|
235
|
+
jac=True,
|
|
236
|
+
callback=lambda X: history.append(scipy_wrapper(X)[0]),
|
|
237
|
+
options={"maxiter": max_iters, "disp": verbose},
|
|
224
238
|
)
|
|
239
|
+
history = jnp.array(history)
|
|
225
240
|
|
|
226
|
-
|
|
227
|
-
model
|
|
228
|
-
history = jnp.array([initial_loss, result.fun_val])
|
|
229
|
-
|
|
230
|
-
if verbose:
|
|
231
|
-
print(f"Initial loss is {initial_loss}")
|
|
232
|
-
if result.success:
|
|
233
|
-
print("Optimization was successful")
|
|
234
|
-
else:
|
|
235
|
-
raise FailedScipyFitError(
|
|
236
|
-
"Optimization failed, try increasing max_iters or using a different optimiser."
|
|
237
|
-
)
|
|
238
|
-
print(f"Final loss is {result.fun_val} after {result.num_fun_eval} iterations")
|
|
239
|
-
|
|
240
|
-
# Constrained space.
|
|
241
|
+
# convert back to pytree and reconstrain
|
|
242
|
+
model = scipy_to_jnp(result.x)
|
|
241
243
|
model = model.constrain()
|
|
242
244
|
return model, history
|
|
243
245
|
|
gpjax/flax_base/param.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
import jax.numpy as jnp
|
|
2
|
+
import typing as tp
|
|
3
|
+
from flax.experimental import nnx
|
|
4
|
+
from gpjax.flax_base.types import DomainType, A
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class AbstractParameter(nnx.Variable[A]):
|
|
8
|
+
domain: DomainType = "real"
|
|
9
|
+
static: bool = False
|
|
10
|
+
|
|
11
|
+
def __init__(self, value: A, *args, **kwargs):
|
|
12
|
+
super().__init__(jnp.asarray(value), *args, **kwargs)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PositiveParameter(AbstractParameter[A]):
|
|
16
|
+
domain: DomainType = "positive"
|
gpjax/flax_base/types.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
import typing as tp
|
|
2
|
+
import tensorflow_probability.substrates.jax.bijectors as tfb
|
|
3
|
+
|
|
4
|
+
DomainType = tp.Literal["real", "positive"]
|
|
5
|
+
A = tp.TypeVar("A")
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
# class BijectorLookup(tp.TypedDict):
|
|
9
|
+
# domain: DomainType
|
|
10
|
+
# bijector: tfb.Bijector
|
|
11
|
+
|
|
12
|
+
class BijectorLookupType(tp.Dict[DomainType, tfb.Bijector]):
|
|
13
|
+
pass
|
|
14
|
+
|
|
15
|
+
__all__ = ["DomainType", "A", "BijectorLookupType"]
|
gpjax/gps.py
CHANGED
|
@@ -12,10 +12,15 @@
|
|
|
12
12
|
# See the License for the specific language governing permissions and
|
|
13
13
|
# limitations under the License.
|
|
14
14
|
# ==============================================================================
|
|
15
|
-
|
|
15
|
+
# from __future__ import annotations
|
|
16
16
|
from abc import abstractmethod
|
|
17
17
|
from dataclasses import dataclass
|
|
18
|
-
from typing import
|
|
18
|
+
from typing import (
|
|
19
|
+
TYPE_CHECKING,
|
|
20
|
+
Generic,
|
|
21
|
+
TypeVar,
|
|
22
|
+
overload,
|
|
23
|
+
)
|
|
19
24
|
|
|
20
25
|
from beartype.typing import (
|
|
21
26
|
Any,
|
|
@@ -47,7 +52,7 @@ from gpjax.kernels.base import AbstractKernel
|
|
|
47
52
|
from gpjax.likelihoods import (
|
|
48
53
|
AbstractLikelihood,
|
|
49
54
|
Gaussian,
|
|
50
|
-
|
|
55
|
+
NonGaussian,
|
|
51
56
|
)
|
|
52
57
|
from gpjax.lower_cholesky import lower_cholesky
|
|
53
58
|
from gpjax.mean_functions import AbstractMeanFunction
|
|
@@ -57,13 +62,19 @@ from gpjax.typing import (
|
|
|
57
62
|
KeyArray,
|
|
58
63
|
)
|
|
59
64
|
|
|
65
|
+
Kernel = TypeVar("Kernel", bound=AbstractKernel)
|
|
66
|
+
MeanFunction = TypeVar("MeanFunction", bound=AbstractMeanFunction)
|
|
67
|
+
Likelihood = TypeVar("Likelihood", bound=AbstractLikelihood)
|
|
68
|
+
NonGaussianLikelihood = TypeVar("NonGaussianLikelihood", bound=NonGaussian)
|
|
69
|
+
GaussianLikelihood = TypeVar("GaussianLikelihood", bound=Gaussian)
|
|
70
|
+
|
|
60
71
|
|
|
61
72
|
@dataclass
|
|
62
|
-
class AbstractPrior(Module):
|
|
73
|
+
class AbstractPrior(Module, Generic[MeanFunction, Kernel]):
|
|
63
74
|
r"""Abstract Gaussian process prior."""
|
|
64
75
|
|
|
65
|
-
kernel:
|
|
66
|
-
mean_function:
|
|
76
|
+
kernel: Kernel
|
|
77
|
+
mean_function: MeanFunction
|
|
67
78
|
jitter: float = static_field(1e-6)
|
|
68
79
|
|
|
69
80
|
def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution:
|
|
@@ -113,7 +124,7 @@ class AbstractPrior(Module):
|
|
|
113
124
|
# GP Priors
|
|
114
125
|
#######################
|
|
115
126
|
@dataclass
|
|
116
|
-
class Prior(AbstractPrior):
|
|
127
|
+
class Prior(AbstractPrior[MeanFunction, Kernel]):
|
|
117
128
|
r"""A Gaussian process prior object.
|
|
118
129
|
|
|
119
130
|
The GP is parameterised by a
|
|
@@ -137,17 +148,27 @@ class Prior(AbstractPrior):
|
|
|
137
148
|
```
|
|
138
149
|
"""
|
|
139
150
|
|
|
140
|
-
|
|
141
|
-
# def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
|
|
142
|
-
# ...
|
|
151
|
+
if TYPE_CHECKING:
|
|
143
152
|
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
153
|
+
@overload
|
|
154
|
+
def __mul__(
|
|
155
|
+
self, other: GaussianLikelihood
|
|
156
|
+
) -> "ConjugatePosterior[Prior[MeanFunction, Kernel], GaussianLikelihood]":
|
|
157
|
+
...
|
|
147
158
|
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
159
|
+
@overload
|
|
160
|
+
def __mul__(
|
|
161
|
+
self, other: NonGaussianLikelihood
|
|
162
|
+
) -> (
|
|
163
|
+
"NonConjugatePosterior[Prior[MeanFunction, Kernel], NonGaussianLikelihood]"
|
|
164
|
+
):
|
|
165
|
+
...
|
|
166
|
+
|
|
167
|
+
@overload
|
|
168
|
+
def __mul__(
|
|
169
|
+
self, other: Likelihood
|
|
170
|
+
) -> "AbstractPosterior[Prior[MeanFunction, Kernel], Likelihood]":
|
|
171
|
+
...
|
|
151
172
|
|
|
152
173
|
def __mul__(self, other):
|
|
153
174
|
r"""Combine the prior with a likelihood to form a posterior distribution.
|
|
@@ -183,17 +204,27 @@ class Prior(AbstractPrior):
|
|
|
183
204
|
"""
|
|
184
205
|
return construct_posterior(prior=self, likelihood=other)
|
|
185
206
|
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
207
|
+
if TYPE_CHECKING:
|
|
208
|
+
|
|
209
|
+
@overload
|
|
210
|
+
def __rmul__(
|
|
211
|
+
self, other: GaussianLikelihood
|
|
212
|
+
) -> "ConjugatePosterior[Prior[MeanFunction, Kernel], GaussianLikelihood]":
|
|
213
|
+
...
|
|
189
214
|
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
215
|
+
@overload
|
|
216
|
+
def __rmul__(
|
|
217
|
+
self, other: NonGaussianLikelihood
|
|
218
|
+
) -> (
|
|
219
|
+
"NonConjugatePosterior[Prior[MeanFunction, Kernel], NonGaussianLikelihood]"
|
|
220
|
+
):
|
|
221
|
+
...
|
|
193
222
|
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
223
|
+
@overload
|
|
224
|
+
def __rmul__(
|
|
225
|
+
self, other: Likelihood
|
|
226
|
+
) -> "AbstractPosterior[Prior[MeanFunction, Kernel], Likelihood]":
|
|
227
|
+
...
|
|
197
228
|
|
|
198
229
|
def __rmul__(self, other):
|
|
199
230
|
r"""Combine the prior with a likelihood to form a posterior distribution.
|
|
@@ -285,7 +316,7 @@ class Prior(AbstractPrior):
|
|
|
285
316
|
>>> import gpjax as gpx
|
|
286
317
|
>>> import jax.numpy as jnp
|
|
287
318
|
>>> import jax.random as jr
|
|
288
|
-
>>> key = jr.
|
|
319
|
+
>>> key = jr.key(123)
|
|
289
320
|
>>>
|
|
290
321
|
>>> meanf = gpx.mean_functions.Zero()
|
|
291
322
|
>>> kernel = gpx.kernels.RBF()
|
|
@@ -324,19 +355,22 @@ class Prior(AbstractPrior):
|
|
|
324
355
|
return sample_fn
|
|
325
356
|
|
|
326
357
|
|
|
358
|
+
PriorType = TypeVar("PriorType", bound=AbstractPrior)
|
|
359
|
+
|
|
360
|
+
|
|
327
361
|
#######################
|
|
328
362
|
# GP Posteriors
|
|
329
363
|
#######################
|
|
330
364
|
@dataclass
|
|
331
|
-
class AbstractPosterior(Module):
|
|
365
|
+
class AbstractPosterior(Module, Generic[PriorType, Likelihood]):
|
|
332
366
|
r"""Abstract Gaussian process posterior.
|
|
333
367
|
|
|
334
368
|
The base GP posterior object conditioned on an observed dataset. All
|
|
335
369
|
posterior objects should inherit from this class.
|
|
336
370
|
"""
|
|
337
371
|
|
|
338
|
-
prior: AbstractPrior
|
|
339
|
-
likelihood:
|
|
372
|
+
prior: AbstractPrior[MeanFunction, Kernel]
|
|
373
|
+
likelihood: Likelihood
|
|
340
374
|
jitter: float = static_field(1e-6)
|
|
341
375
|
|
|
342
376
|
def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution:
|
|
@@ -381,7 +415,7 @@ class AbstractPosterior(Module):
|
|
|
381
415
|
|
|
382
416
|
|
|
383
417
|
@dataclass
|
|
384
|
-
class ConjugatePosterior(AbstractPosterior):
|
|
418
|
+
class ConjugatePosterior(AbstractPosterior[PriorType, GaussianLikelihood]):
|
|
385
419
|
r"""A Conjuate Gaussian process posterior object.
|
|
386
420
|
|
|
387
421
|
A Gaussian process posterior distribution when the constituent likelihood
|
|
@@ -600,7 +634,7 @@ class ConjugatePosterior(AbstractPosterior):
|
|
|
600
634
|
|
|
601
635
|
|
|
602
636
|
@dataclass
|
|
603
|
-
class NonConjugatePosterior(AbstractPosterior):
|
|
637
|
+
class NonConjugatePosterior(AbstractPosterior[PriorType, NonGaussianLikelihood]):
|
|
604
638
|
r"""A non-conjugate Gaussian process posterior object.
|
|
605
639
|
|
|
606
640
|
A Gaussian process posterior object for models where the likelihood is
|
|
@@ -685,22 +719,17 @@ class NonConjugatePosterior(AbstractPosterior):
|
|
|
685
719
|
#######################
|
|
686
720
|
|
|
687
721
|
|
|
688
|
-
@overload
|
|
689
|
-
def construct_posterior(prior: Prior, likelihood: Gaussian) -> ConjugatePosterior:
|
|
690
|
-
...
|
|
691
|
-
|
|
692
|
-
|
|
693
722
|
@overload
|
|
694
723
|
def construct_posterior(
|
|
695
|
-
prior:
|
|
696
|
-
) ->
|
|
724
|
+
prior: PriorType, likelihood: GaussianLikelihood
|
|
725
|
+
) -> ConjugatePosterior[PriorType, GaussianLikelihood]:
|
|
697
726
|
...
|
|
698
727
|
|
|
699
728
|
|
|
700
729
|
@overload
|
|
701
730
|
def construct_posterior(
|
|
702
|
-
prior:
|
|
703
|
-
) ->
|
|
731
|
+
prior: PriorType, likelihood: NonGaussianLikelihood
|
|
732
|
+
) -> NonConjugatePosterior[PriorType, NonGaussianLikelihood]:
|
|
704
733
|
...
|
|
705
734
|
|
|
706
735
|
|
gpjax/likelihoods.py
CHANGED
|
@@ -247,11 +247,11 @@ def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]:
|
|
|
247
247
|
return 0.5 * (1.0 + jsp.special.erf(x / jnp.sqrt(2.0))) * (1 - 2 * jitter) + jitter
|
|
248
248
|
|
|
249
249
|
|
|
250
|
-
|
|
250
|
+
NonGaussian = Union[Poisson, Bernoulli]
|
|
251
251
|
|
|
252
252
|
__all__ = [
|
|
253
253
|
"AbstractLikelihood",
|
|
254
|
-
"
|
|
254
|
+
"NonGaussian",
|
|
255
255
|
"Gaussian",
|
|
256
256
|
"Bernoulli",
|
|
257
257
|
"Poisson",
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: gpjax
|
|
3
|
-
Version: 0.8.
|
|
3
|
+
Version: 0.8.2
|
|
4
4
|
Summary: Gaussian processes in JAX.
|
|
5
5
|
Home-page: https://github.com/JaxGaussianProcesses/GPJax
|
|
6
6
|
License: Apache-2.0
|
|
@@ -14,14 +14,14 @@ Classifier: Programming Language :: Python :: 3.10
|
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.11
|
|
15
15
|
Requires-Dist: beartype (>=0.16.2,<0.17.0)
|
|
16
16
|
Requires-Dist: cola-ml (>=0.0.5,<0.0.6)
|
|
17
|
-
Requires-Dist: jax (>=0.4.
|
|
18
|
-
Requires-Dist: jaxlib (>=0.4.
|
|
19
|
-
Requires-Dist: jaxopt (>=0.8.
|
|
17
|
+
Requires-Dist: jax (>=0.4.16)
|
|
18
|
+
Requires-Dist: jaxlib (>=0.4.16)
|
|
19
|
+
Requires-Dist: jaxopt (>=0.8.3,<0.9.0)
|
|
20
20
|
Requires-Dist: jaxtyping (>=0.2.15,<0.3.0)
|
|
21
21
|
Requires-Dist: optax (>=0.1.4,<0.2.0)
|
|
22
22
|
Requires-Dist: orbax-checkpoint (>=0.2.3)
|
|
23
23
|
Requires-Dist: simple-pytree (>=0.1.7,<0.2.0)
|
|
24
|
-
Requires-Dist: tensorflow-probability (>=0.
|
|
24
|
+
Requires-Dist: tensorflow-probability (>=0.22.0,<0.23.0)
|
|
25
25
|
Requires-Dist: tqdm (>=4.65.0,<5.0.0)
|
|
26
26
|
Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
|
|
27
27
|
Project-URL: Repository, https://github.com/JaxGaussianProcesses/GPJax
|
|
@@ -152,7 +152,7 @@ import jax.numpy as jnp
|
|
|
152
152
|
import jax.random as jr
|
|
153
153
|
import optax as ox
|
|
154
154
|
|
|
155
|
-
key = jr.
|
|
155
|
+
key = jr.key(123)
|
|
156
156
|
|
|
157
157
|
f = lambda x: 10 * jnp.sin(x)
|
|
158
158
|
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
|
|
2
|
-
gpjax/__init__.py,sha256=
|
|
2
|
+
gpjax/__init__.py,sha256=HAAaTQX1KSXatt4Pm-aFLAYucVc3gg60OoMemZ81LrY,1545
|
|
3
3
|
gpjax/base/__init__.py,sha256=jLAm6UYta8Uo02IQSsJGpzZI110XnQdyuAnHEX3A-2s,1068
|
|
4
4
|
gpjax/base/module.py,sha256=50TgoIOKv58T13SwCoBpgU-3EH7e8piC9cuXMEIs5_c,12842
|
|
5
5
|
gpjax/base/param.py,sha256=D_EQYfOf4wE1Xr3RBb5lyleeHYw8HKxNwU7Sr2LRmYw,2271
|
|
@@ -11,15 +11,18 @@ gpjax/decision_making/posterior_handler.py,sha256=fTI-mHl_WTrjXk0Pu_a91cFhDDZlWF
|
|
|
11
11
|
gpjax/decision_making/search_space.py,sha256=3GRfM3KquOmOlLCzUvtELtB3HqJyaUOhxW9Jj2BnTtA,3547
|
|
12
12
|
gpjax/decision_making/test_functions/__init__.py,sha256=GDCY9_kaAnxDWwzo1FkdxnDx-80MErAHchbGybT9xYs,1109
|
|
13
13
|
gpjax/decision_making/test_functions/continuous_functions.py,sha256=ATrMPX_hE-uzgyjvsm1t_xzXmIkOrcGmTLgCQUOHUyw,5088
|
|
14
|
-
gpjax/decision_making/test_functions/non_conjugate_functions.py,sha256=
|
|
14
|
+
gpjax/decision_making/test_functions/non_conjugate_functions.py,sha256=MXFbPtRAZ7nFYs8eF3F9OBhm1Y4NrMs8w7b6uigAMGs,2982
|
|
15
15
|
gpjax/decision_making/utility_functions/__init__.py,sha256=pAjgstn2l_HL0yMXP_vi6YfFX53eTdwAVIJeTA2gIlI,1171
|
|
16
16
|
gpjax/decision_making/utility_functions/base.py,sha256=Vq50tPzmxzE_d9edRx_H85EN_XjaTFY3AkudUtXh9wY,4053
|
|
17
17
|
gpjax/decision_making/utility_functions/thompson_sampling.py,sha256=vERF8aLKan_hxhLU6w5rGMLFcp1HGb3jrQzYuksg7bw,4344
|
|
18
18
|
gpjax/decision_making/utility_maximizer.py,sha256=7J_qFwMwgfYxHU7ye9H7g78wHK9qdZvgW0orE0upJnM,6041
|
|
19
19
|
gpjax/decision_making/utils.py,sha256=pIl4t5lhnWsfnAqrgUaQi6Ua82ogyKAezwhQIWo-4AE,1838
|
|
20
20
|
gpjax/distributions.py,sha256=BwvpCSdcsi2kIFVLXKlMO7TWOUojVJ1K_7ai8EcCbq8,9520
|
|
21
|
-
gpjax/fit.py,sha256=
|
|
22
|
-
gpjax/
|
|
21
|
+
gpjax/fit.py,sha256=kiGL9uSvYDcyVjmAPj8_QGorkRqzxMN_N03pO3U4RKY,10638
|
|
22
|
+
gpjax/flax_base/bijectors.py,sha256=LpgZma8_fo9Kqi3JdRVse8CROoCG9PuNP-GTbaV63gY,244
|
|
23
|
+
gpjax/flax_base/param.py,sha256=_ZkGO5o7vgBwm_l2suV2orfP7XIgk7Cj2D7_XXBxdDQ,426
|
|
24
|
+
gpjax/flax_base/types.py,sha256=4_bGquCOuWNfOR-O6GLjZJJgueNwxu27Wckmm6Ckz_o,365
|
|
25
|
+
gpjax/gps.py,sha256=lIgoGqxNI4AbZ80bjm_c_rNSNSy6KG3YkOlPJlE1gKQ,30328
|
|
23
26
|
gpjax/integrators.py,sha256=wThbb-9d1FW1_33tg6Z7GLwruwuMm1-spd3-TEUCSAI,5763
|
|
24
27
|
gpjax/kernels/__init__.py,sha256=pCTWcgf1hkJGYnX9BS-AcWryyZhovZ2Cfb4Ttf12J6E,1845
|
|
25
28
|
gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
|
|
@@ -49,7 +52,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=JzrFdm5enVYdD61TtTnxyVrGBB
|
|
|
49
52
|
gpjax/kernels/stationary/rbf.py,sha256=rIDwkF8YwOM2NfT7_y5Mfhc4BJSBOO6MdMeWPolspZ0,2442
|
|
50
53
|
gpjax/kernels/stationary/utils.py,sha256=KU4j4IAXS1MZbHPUOiW6rieWX7lMCMRtSRlxIH3Kybo,2226
|
|
51
54
|
gpjax/kernels/stationary/white.py,sha256=ymTKpOQZ-AWJqiZv4CbWt9pdpsw0HGfbf3jNIV83f_8,2086
|
|
52
|
-
gpjax/likelihoods.py,sha256=
|
|
55
|
+
gpjax/likelihoods.py,sha256=MPZntqagPsQ4Y00-KW9yyUvm9C2UHGX-pqN6GubxcCk,8091
|
|
53
56
|
gpjax/lower_cholesky.py,sha256=4uM_2JF_fTDr4qq88N7mljR0MG5G7UVOmiqvvcf7hec,1850
|
|
54
57
|
gpjax/mean_functions.py,sha256=6qxvaIoHfNMCYIN_9R_9HUiCTzBGiZzTk2vNlb4RaU8,6437
|
|
55
58
|
gpjax/objectives.py,sha256=HKgbeGNUOm9kzndX7vXHYvHzTsE21GgAmD2VU9ktHzg,19178
|
|
@@ -57,7 +60,7 @@ gpjax/progress_bar.py,sha256=ErXoJ-jjEPOediq_clgG2OC0UTptpP0MtVM1a2ZLC_o,4418
|
|
|
57
60
|
gpjax/scan.py,sha256=RTfIISMqo4jznDFJ6m9fiaPEYgnINRPwGNRiz71gLFg,5506
|
|
58
61
|
gpjax/typing.py,sha256=H_7je0XuJ1UP77zP5bJskdK2gJhm812xRfnRbGdqgkg,1688
|
|
59
62
|
gpjax/variational_families.py,sha256=5Q7btVizm3Y3xcVVpoosS-0JM2JreaQEiFbgF2uNgXc,26466
|
|
60
|
-
gpjax-0.8.
|
|
61
|
-
gpjax-0.8.
|
|
62
|
-
gpjax-0.8.
|
|
63
|
-
gpjax-0.8.
|
|
63
|
+
gpjax-0.8.2.dist-info/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
|
|
64
|
+
gpjax-0.8.2.dist-info/METADATA,sha256=yI9KlzNn3IATrcpijURB4MvKwXeXPfQVMpZcHmKpz5I,10079
|
|
65
|
+
gpjax-0.8.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
|
|
66
|
+
gpjax-0.8.2.dist-info/RECORD,,
|
|
File without changes
|