gpjax 0.8.0__py3-none-any.whl → 0.8.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
gpjax/__init__.py CHANGED
@@ -38,7 +38,7 @@ __license__ = "MIT"
38
38
  __description__ = "Didactic Gaussian processes in JAX"
39
39
  __url__ = "https://github.com/JaxGaussianProcesses/GPJax"
40
40
  __contributors__ = "https://github.com/JaxGaussianProcesses/GPJax/graphs/contributors"
41
- __version__ = "0.8.0"
41
+ __version__ = "0.8.2"
42
42
 
43
43
  __all__ = [
44
44
  "base",
@@ -85,6 +85,6 @@ class PoissonTestFunction:
85
85
  Returns:
86
86
  Integer[Array, 'N 1']: Values of the test function at the points.
87
87
  """
88
- key = jr.PRNGKey(42)
88
+ key = jr.key(42)
89
89
  f = lambda x: 2.0 * jnp.sin(3 * x) + 0.5 * x
90
90
  return jr.poisson(key, jnp.exp(f(x)))
gpjax/fit.py CHANGED
@@ -23,11 +23,16 @@ from beartype.typing import (
23
23
  Union,
24
24
  )
25
25
  import jax
26
+ from jax import (
27
+ jit,
28
+ value_and_grad,
29
+ )
26
30
  from jax._src.random import _check_prng_key
31
+ from jax.flatten_util import ravel_pytree
27
32
  import jax.numpy as jnp
28
33
  import jax.random as jr
29
- import jaxopt
30
34
  import optax as ox
35
+ import scipy
31
36
 
32
37
  from gpjax.base import Module
33
38
  from gpjax.dataset import Dataset
@@ -42,10 +47,6 @@ from gpjax.typing import (
42
47
  ModuleModel = TypeVar("ModuleModel", bound=Module)
43
48
 
44
49
 
45
- class FailedScipyFitError(Exception):
46
- """Raised a model fit using Scipy fails"""
47
-
48
-
49
50
  def fit( # noqa: PLR0913
50
51
  *,
51
52
  model: ModuleModel,
@@ -72,7 +73,7 @@ def fit( # noqa: PLR0913
72
73
  >>>
73
74
  >>> # (1) Create a dataset:
74
75
  >>> X = jnp.linspace(0.0, 10.0, 100)[:, None]
75
- >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.PRNGKey(0), X.shape)
76
+ >>> y = 2.0 * X + 1.0 + 10 * jr.normal(jr.key(0), X.shape)
76
77
  >>> D = gpx.Dataset(X, y)
77
78
  >>>
78
79
  >>> # (2) Define your model:
@@ -110,7 +111,7 @@ def fit( # noqa: PLR0913
110
111
  batch_size (Optional[int]): The size of the mini-batch to use. Defaults to -1
111
112
  (i.e. full batch).
112
113
  key (Optional[KeyArray]): The random key to use for the optimisation batch
113
- selection. Defaults to jr.PRNGKey(42).
114
+ selection. Defaults to jr.key(42).
114
115
  log_rate (Optional[int]): How frequently the objective function's value should
115
116
  be printed. Defaults to 10.
116
117
  verbose (Optional[bool]): Whether to print the training loading bar. Defaults
@@ -130,7 +131,7 @@ def fit( # noqa: PLR0913
130
131
  _check_optim(optim)
131
132
  _check_num_iters(num_iters)
132
133
  _check_batch_size(batch_size)
133
- _check_prng_key(key)
134
+ _check_prng_key("fit", key)
134
135
  _check_log_rate(log_rate)
135
136
  _check_verbose(verbose)
136
137
 
@@ -214,30 +215,31 @@ def fit_scipy( # noqa: PLR0913
214
215
  model = model.unconstrain()
215
216
 
216
217
  # Unconstrained space loss function with stop-gradient rule for non-trainable params.
217
- def loss(model: Module, data: Dataset) -> ScalarFloat:
218
+ def loss(model: Module) -> ScalarFloat:
218
219
  model = model.stop_gradient()
219
- return objective(model.constrain(), data)
220
-
221
- solver = jaxopt.ScipyMinimize(
222
- fun=loss,
223
- maxiter=max_iters,
220
+ return objective(model.constrain(), train_data)
221
+
222
+ # convert to numpy for interface with scipy
223
+ x0, scipy_to_jnp = ravel_pytree(model)
224
+
225
+ @jit
226
+ def scipy_wrapper(x0):
227
+ value, grads = value_and_grad(loss)(scipy_to_jnp(jnp.array(x0)))
228
+ scipy_grads = ravel_pytree(grads)[0]
229
+ return value, scipy_grads
230
+
231
+ history = [scipy_wrapper(x0)[0]]
232
+ result = scipy.optimize.minimize(
233
+ fun=scipy_wrapper,
234
+ x0=x0,
235
+ jac=True,
236
+ callback=lambda X: history.append(scipy_wrapper(X)[0]),
237
+ options={"maxiter": max_iters, "disp": verbose},
224
238
  )
239
+ history = jnp.array(history)
225
240
 
226
- initial_loss = solver.fun(model, train_data)
227
- model, result = solver.run(model, data=train_data)
228
- history = jnp.array([initial_loss, result.fun_val])
229
-
230
- if verbose:
231
- print(f"Initial loss is {initial_loss}")
232
- if result.success:
233
- print("Optimization was successful")
234
- else:
235
- raise FailedScipyFitError(
236
- "Optimization failed, try increasing max_iters or using a different optimiser."
237
- )
238
- print(f"Final loss is {result.fun_val} after {result.num_fun_eval} iterations")
239
-
240
- # Constrained space.
241
+ # convert back to pytree and reconstrain
242
+ model = scipy_to_jnp(result.x)
241
243
  model = model.constrain()
242
244
  return model, history
243
245
 
@@ -0,0 +1,8 @@
1
+ from gpjax.flax_base.types import BijectorLookupType, DomainType
2
+ import tensorflow_probability.substrates.jax.bijectors as tfb
3
+ import typing as tp
4
+
5
+ Bijectors: BijectorLookupType = {
6
+ "real": tfb.Identity(),
7
+ "positive": tfb.Softplus(),
8
+ }
@@ -0,0 +1,16 @@
1
+ import jax.numpy as jnp
2
+ import typing as tp
3
+ from flax.experimental import nnx
4
+ from gpjax.flax_base.types import DomainType, A
5
+
6
+
7
+ class AbstractParameter(nnx.Variable[A]):
8
+ domain: DomainType = "real"
9
+ static: bool = False
10
+
11
+ def __init__(self, value: A, *args, **kwargs):
12
+ super().__init__(jnp.asarray(value), *args, **kwargs)
13
+
14
+
15
+ class PositiveParameter(AbstractParameter[A]):
16
+ domain: DomainType = "positive"
@@ -0,0 +1,15 @@
1
+ import typing as tp
2
+ import tensorflow_probability.substrates.jax.bijectors as tfb
3
+
4
+ DomainType = tp.Literal["real", "positive"]
5
+ A = tp.TypeVar("A")
6
+
7
+
8
+ # class BijectorLookup(tp.TypedDict):
9
+ # domain: DomainType
10
+ # bijector: tfb.Bijector
11
+
12
+ class BijectorLookupType(tp.Dict[DomainType, tfb.Bijector]):
13
+ pass
14
+
15
+ __all__ = ["DomainType", "A", "BijectorLookupType"]
gpjax/gps.py CHANGED
@@ -12,10 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
  # ==============================================================================
15
-
15
+ # from __future__ import annotations
16
16
  from abc import abstractmethod
17
17
  from dataclasses import dataclass
18
- from typing import overload
18
+ from typing import (
19
+ TYPE_CHECKING,
20
+ Generic,
21
+ TypeVar,
22
+ overload,
23
+ )
19
24
 
20
25
  from beartype.typing import (
21
26
  Any,
@@ -47,7 +52,7 @@ from gpjax.kernels.base import AbstractKernel
47
52
  from gpjax.likelihoods import (
48
53
  AbstractLikelihood,
49
54
  Gaussian,
50
- NonGaussianLikelihood,
55
+ NonGaussian,
51
56
  )
52
57
  from gpjax.lower_cholesky import lower_cholesky
53
58
  from gpjax.mean_functions import AbstractMeanFunction
@@ -57,13 +62,19 @@ from gpjax.typing import (
57
62
  KeyArray,
58
63
  )
59
64
 
65
+ Kernel = TypeVar("Kernel", bound=AbstractKernel)
66
+ MeanFunction = TypeVar("MeanFunction", bound=AbstractMeanFunction)
67
+ Likelihood = TypeVar("Likelihood", bound=AbstractLikelihood)
68
+ NonGaussianLikelihood = TypeVar("NonGaussianLikelihood", bound=NonGaussian)
69
+ GaussianLikelihood = TypeVar("GaussianLikelihood", bound=Gaussian)
70
+
60
71
 
61
72
  @dataclass
62
- class AbstractPrior(Module):
73
+ class AbstractPrior(Module, Generic[MeanFunction, Kernel]):
63
74
  r"""Abstract Gaussian process prior."""
64
75
 
65
- kernel: AbstractKernel
66
- mean_function: AbstractMeanFunction
76
+ kernel: Kernel
77
+ mean_function: MeanFunction
67
78
  jitter: float = static_field(1e-6)
68
79
 
69
80
  def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution:
@@ -113,7 +124,7 @@ class AbstractPrior(Module):
113
124
  # GP Priors
114
125
  #######################
115
126
  @dataclass
116
- class Prior(AbstractPrior):
127
+ class Prior(AbstractPrior[MeanFunction, Kernel]):
117
128
  r"""A Gaussian process prior object.
118
129
 
119
130
  The GP is parameterised by a
@@ -137,17 +148,27 @@ class Prior(AbstractPrior):
137
148
  ```
138
149
  """
139
150
 
140
- # @overload
141
- # def __mul__(self, other: Gaussian) -> "ConjugatePosterior":
142
- # ...
151
+ if TYPE_CHECKING:
143
152
 
144
- # @overload
145
- # def __mul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
146
- # ...
153
+ @overload
154
+ def __mul__(
155
+ self, other: GaussianLikelihood
156
+ ) -> "ConjugatePosterior[Prior[MeanFunction, Kernel], GaussianLikelihood]":
157
+ ...
147
158
 
148
- # @overload
149
- # def __mul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
150
- # ...
159
+ @overload
160
+ def __mul__(
161
+ self, other: NonGaussianLikelihood
162
+ ) -> (
163
+ "NonConjugatePosterior[Prior[MeanFunction, Kernel], NonGaussianLikelihood]"
164
+ ):
165
+ ...
166
+
167
+ @overload
168
+ def __mul__(
169
+ self, other: Likelihood
170
+ ) -> "AbstractPosterior[Prior[MeanFunction, Kernel], Likelihood]":
171
+ ...
151
172
 
152
173
  def __mul__(self, other):
153
174
  r"""Combine the prior with a likelihood to form a posterior distribution.
@@ -183,17 +204,27 @@ class Prior(AbstractPrior):
183
204
  """
184
205
  return construct_posterior(prior=self, likelihood=other)
185
206
 
186
- # @overload
187
- # def __rmul__(self, other: Gaussian) -> "ConjugatePosterior":
188
- # ...
207
+ if TYPE_CHECKING:
208
+
209
+ @overload
210
+ def __rmul__(
211
+ self, other: GaussianLikelihood
212
+ ) -> "ConjugatePosterior[Prior[MeanFunction, Kernel], GaussianLikelihood]":
213
+ ...
189
214
 
190
- # @overload
191
- # def __rmul__(self, other: NonGaussianLikelihood) -> "NonConjugatePosterior":
192
- # ...
215
+ @overload
216
+ def __rmul__(
217
+ self, other: NonGaussianLikelihood
218
+ ) -> (
219
+ "NonConjugatePosterior[Prior[MeanFunction, Kernel], NonGaussianLikelihood]"
220
+ ):
221
+ ...
193
222
 
194
- # @overload
195
- # def __rmul__(self, other: AbstractLikelihood) -> "AbstractPosterior":
196
- # ...
223
+ @overload
224
+ def __rmul__(
225
+ self, other: Likelihood
226
+ ) -> "AbstractPosterior[Prior[MeanFunction, Kernel], Likelihood]":
227
+ ...
197
228
 
198
229
  def __rmul__(self, other):
199
230
  r"""Combine the prior with a likelihood to form a posterior distribution.
@@ -285,7 +316,7 @@ class Prior(AbstractPrior):
285
316
  >>> import gpjax as gpx
286
317
  >>> import jax.numpy as jnp
287
318
  >>> import jax.random as jr
288
- >>> key = jr.PRNGKey(123)
319
+ >>> key = jr.key(123)
289
320
  >>>
290
321
  >>> meanf = gpx.mean_functions.Zero()
291
322
  >>> kernel = gpx.kernels.RBF()
@@ -324,19 +355,22 @@ class Prior(AbstractPrior):
324
355
  return sample_fn
325
356
 
326
357
 
358
+ PriorType = TypeVar("PriorType", bound=AbstractPrior)
359
+
360
+
327
361
  #######################
328
362
  # GP Posteriors
329
363
  #######################
330
364
  @dataclass
331
- class AbstractPosterior(Module):
365
+ class AbstractPosterior(Module, Generic[PriorType, Likelihood]):
332
366
  r"""Abstract Gaussian process posterior.
333
367
 
334
368
  The base GP posterior object conditioned on an observed dataset. All
335
369
  posterior objects should inherit from this class.
336
370
  """
337
371
 
338
- prior: AbstractPrior
339
- likelihood: AbstractLikelihood
372
+ prior: AbstractPrior[MeanFunction, Kernel]
373
+ likelihood: Likelihood
340
374
  jitter: float = static_field(1e-6)
341
375
 
342
376
  def __call__(self, *args: Any, **kwargs: Any) -> GaussianDistribution:
@@ -381,7 +415,7 @@ class AbstractPosterior(Module):
381
415
 
382
416
 
383
417
  @dataclass
384
- class ConjugatePosterior(AbstractPosterior):
418
+ class ConjugatePosterior(AbstractPosterior[PriorType, GaussianLikelihood]):
385
419
  r"""A Conjuate Gaussian process posterior object.
386
420
 
387
421
  A Gaussian process posterior distribution when the constituent likelihood
@@ -600,7 +634,7 @@ class ConjugatePosterior(AbstractPosterior):
600
634
 
601
635
 
602
636
  @dataclass
603
- class NonConjugatePosterior(AbstractPosterior):
637
+ class NonConjugatePosterior(AbstractPosterior[PriorType, NonGaussianLikelihood]):
604
638
  r"""A non-conjugate Gaussian process posterior object.
605
639
 
606
640
  A Gaussian process posterior object for models where the likelihood is
@@ -685,22 +719,17 @@ class NonConjugatePosterior(AbstractPosterior):
685
719
  #######################
686
720
 
687
721
 
688
- @overload
689
- def construct_posterior(prior: Prior, likelihood: Gaussian) -> ConjugatePosterior:
690
- ...
691
-
692
-
693
722
  @overload
694
723
  def construct_posterior(
695
- prior: Prior, likelihood: NonGaussianLikelihood
696
- ) -> NonConjugatePosterior:
724
+ prior: PriorType, likelihood: GaussianLikelihood
725
+ ) -> ConjugatePosterior[PriorType, GaussianLikelihood]:
697
726
  ...
698
727
 
699
728
 
700
729
  @overload
701
730
  def construct_posterior(
702
- prior: Prior, likelihood: AbstractLikelihood
703
- ) -> AbstractPosterior:
731
+ prior: PriorType, likelihood: NonGaussianLikelihood
732
+ ) -> NonConjugatePosterior[PriorType, NonGaussianLikelihood]:
704
733
  ...
705
734
 
706
735
 
gpjax/likelihoods.py CHANGED
@@ -247,11 +247,11 @@ def inv_probit(x: Float[Array, " *N"]) -> Float[Array, " *N"]:
247
247
  return 0.5 * (1.0 + jsp.special.erf(x / jnp.sqrt(2.0))) * (1 - 2 * jitter) + jitter
248
248
 
249
249
 
250
- NonGaussianLikelihood = Union[Poisson, Bernoulli]
250
+ NonGaussian = Union[Poisson, Bernoulli]
251
251
 
252
252
  __all__ = [
253
253
  "AbstractLikelihood",
254
- "NonGaussianLikelihood",
254
+ "NonGaussian",
255
255
  "Gaussian",
256
256
  "Bernoulli",
257
257
  "Poisson",
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: gpjax
3
- Version: 0.8.0
3
+ Version: 0.8.2
4
4
  Summary: Gaussian processes in JAX.
5
5
  Home-page: https://github.com/JaxGaussianProcesses/GPJax
6
6
  License: Apache-2.0
@@ -14,14 +14,14 @@ Classifier: Programming Language :: Python :: 3.10
14
14
  Classifier: Programming Language :: Python :: 3.11
15
15
  Requires-Dist: beartype (>=0.16.2,<0.17.0)
16
16
  Requires-Dist: cola-ml (>=0.0.5,<0.0.6)
17
- Requires-Dist: jax (>=0.4.10)
18
- Requires-Dist: jaxlib (>=0.4.10)
19
- Requires-Dist: jaxopt (>=0.8.2,<0.9.0)
17
+ Requires-Dist: jax (>=0.4.16)
18
+ Requires-Dist: jaxlib (>=0.4.16)
19
+ Requires-Dist: jaxopt (>=0.8.3,<0.9.0)
20
20
  Requires-Dist: jaxtyping (>=0.2.15,<0.3.0)
21
21
  Requires-Dist: optax (>=0.1.4,<0.2.0)
22
22
  Requires-Dist: orbax-checkpoint (>=0.2.3)
23
23
  Requires-Dist: simple-pytree (>=0.1.7,<0.2.0)
24
- Requires-Dist: tensorflow-probability (>=0.19.0,<0.20.0)
24
+ Requires-Dist: tensorflow-probability (>=0.22.0,<0.23.0)
25
25
  Requires-Dist: tqdm (>=4.65.0,<5.0.0)
26
26
  Project-URL: Documentation, https://docs.jaxgaussianprocesses.com/
27
27
  Project-URL: Repository, https://github.com/JaxGaussianProcesses/GPJax
@@ -152,7 +152,7 @@ import jax.numpy as jnp
152
152
  import jax.random as jr
153
153
  import optax as ox
154
154
 
155
- key = jr.PRNGKey(123)
155
+ key = jr.key(123)
156
156
 
157
157
  f = lambda x: 10 * jnp.sin(x)
158
158
 
@@ -1,5 +1,5 @@
1
1
  LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
2
- gpjax/__init__.py,sha256=UPKdFOtBDHq93YqdBSar8kXQkvk2g-mE_IAgd-MbqMM,1545
2
+ gpjax/__init__.py,sha256=HAAaTQX1KSXatt4Pm-aFLAYucVc3gg60OoMemZ81LrY,1545
3
3
  gpjax/base/__init__.py,sha256=jLAm6UYta8Uo02IQSsJGpzZI110XnQdyuAnHEX3A-2s,1068
4
4
  gpjax/base/module.py,sha256=50TgoIOKv58T13SwCoBpgU-3EH7e8piC9cuXMEIs5_c,12842
5
5
  gpjax/base/param.py,sha256=D_EQYfOf4wE1Xr3RBb5lyleeHYw8HKxNwU7Sr2LRmYw,2271
@@ -11,15 +11,18 @@ gpjax/decision_making/posterior_handler.py,sha256=fTI-mHl_WTrjXk0Pu_a91cFhDDZlWF
11
11
  gpjax/decision_making/search_space.py,sha256=3GRfM3KquOmOlLCzUvtELtB3HqJyaUOhxW9Jj2BnTtA,3547
12
12
  gpjax/decision_making/test_functions/__init__.py,sha256=GDCY9_kaAnxDWwzo1FkdxnDx-80MErAHchbGybT9xYs,1109
13
13
  gpjax/decision_making/test_functions/continuous_functions.py,sha256=ATrMPX_hE-uzgyjvsm1t_xzXmIkOrcGmTLgCQUOHUyw,5088
14
- gpjax/decision_making/test_functions/non_conjugate_functions.py,sha256=aDvnzZYmChY23nzOIETHrDoKvaJmZk81FMB0x9uYVfw,2986
14
+ gpjax/decision_making/test_functions/non_conjugate_functions.py,sha256=MXFbPtRAZ7nFYs8eF3F9OBhm1Y4NrMs8w7b6uigAMGs,2982
15
15
  gpjax/decision_making/utility_functions/__init__.py,sha256=pAjgstn2l_HL0yMXP_vi6YfFX53eTdwAVIJeTA2gIlI,1171
16
16
  gpjax/decision_making/utility_functions/base.py,sha256=Vq50tPzmxzE_d9edRx_H85EN_XjaTFY3AkudUtXh9wY,4053
17
17
  gpjax/decision_making/utility_functions/thompson_sampling.py,sha256=vERF8aLKan_hxhLU6w5rGMLFcp1HGb3jrQzYuksg7bw,4344
18
18
  gpjax/decision_making/utility_maximizer.py,sha256=7J_qFwMwgfYxHU7ye9H7g78wHK9qdZvgW0orE0upJnM,6041
19
19
  gpjax/decision_making/utils.py,sha256=pIl4t5lhnWsfnAqrgUaQi6Ua82ogyKAezwhQIWo-4AE,1838
20
20
  gpjax/distributions.py,sha256=BwvpCSdcsi2kIFVLXKlMO7TWOUojVJ1K_7ai8EcCbq8,9520
21
- gpjax/fit.py,sha256=FGvtkYSYehOtscsZinAiqXCdNctFMtICD5ZqjnhDZeI,10654
22
- gpjax/gps.py,sha256=lPO_zuwf3rUpA2FAoVtAoGQ-NzItsUUbiW3MxPETRsA,29206
21
+ gpjax/fit.py,sha256=kiGL9uSvYDcyVjmAPj8_QGorkRqzxMN_N03pO3U4RKY,10638
22
+ gpjax/flax_base/bijectors.py,sha256=LpgZma8_fo9Kqi3JdRVse8CROoCG9PuNP-GTbaV63gY,244
23
+ gpjax/flax_base/param.py,sha256=_ZkGO5o7vgBwm_l2suV2orfP7XIgk7Cj2D7_XXBxdDQ,426
24
+ gpjax/flax_base/types.py,sha256=4_bGquCOuWNfOR-O6GLjZJJgueNwxu27Wckmm6Ckz_o,365
25
+ gpjax/gps.py,sha256=lIgoGqxNI4AbZ80bjm_c_rNSNSy6KG3YkOlPJlE1gKQ,30328
23
26
  gpjax/integrators.py,sha256=wThbb-9d1FW1_33tg6Z7GLwruwuMm1-spd3-TEUCSAI,5763
24
27
  gpjax/kernels/__init__.py,sha256=pCTWcgf1hkJGYnX9BS-AcWryyZhovZ2Cfb4Ttf12J6E,1845
25
28
  gpjax/kernels/approximations/__init__.py,sha256=bK9HlGd-PZeGrqtG5RpXxUTXNUrZTgfjH1dP626yNMA,68
@@ -49,7 +52,7 @@ gpjax/kernels/stationary/rational_quadratic.py,sha256=JzrFdm5enVYdD61TtTnxyVrGBB
49
52
  gpjax/kernels/stationary/rbf.py,sha256=rIDwkF8YwOM2NfT7_y5Mfhc4BJSBOO6MdMeWPolspZ0,2442
50
53
  gpjax/kernels/stationary/utils.py,sha256=KU4j4IAXS1MZbHPUOiW6rieWX7lMCMRtSRlxIH3Kybo,2226
51
54
  gpjax/kernels/stationary/white.py,sha256=ymTKpOQZ-AWJqiZv4CbWt9pdpsw0HGfbf3jNIV83f_8,2086
52
- gpjax/likelihoods.py,sha256=j_l-AkgaYN_agnIehYYWc6ryUOxkoPp4CfftJ7nQmnc,8111
55
+ gpjax/likelihoods.py,sha256=MPZntqagPsQ4Y00-KW9yyUvm9C2UHGX-pqN6GubxcCk,8091
53
56
  gpjax/lower_cholesky.py,sha256=4uM_2JF_fTDr4qq88N7mljR0MG5G7UVOmiqvvcf7hec,1850
54
57
  gpjax/mean_functions.py,sha256=6qxvaIoHfNMCYIN_9R_9HUiCTzBGiZzTk2vNlb4RaU8,6437
55
58
  gpjax/objectives.py,sha256=HKgbeGNUOm9kzndX7vXHYvHzTsE21GgAmD2VU9ktHzg,19178
@@ -57,7 +60,7 @@ gpjax/progress_bar.py,sha256=ErXoJ-jjEPOediq_clgG2OC0UTptpP0MtVM1a2ZLC_o,4418
57
60
  gpjax/scan.py,sha256=RTfIISMqo4jznDFJ6m9fiaPEYgnINRPwGNRiz71gLFg,5506
58
61
  gpjax/typing.py,sha256=H_7je0XuJ1UP77zP5bJskdK2gJhm812xRfnRbGdqgkg,1688
59
62
  gpjax/variational_families.py,sha256=5Q7btVizm3Y3xcVVpoosS-0JM2JreaQEiFbgF2uNgXc,26466
60
- gpjax-0.8.0.dist-info/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
61
- gpjax-0.8.0.dist-info/WHEEL,sha256=vxFmldFsRN_Hx10GDvsdv1wroKq8r5Lzvjp6GZ4OO8c,88
62
- gpjax-0.8.0.dist-info/METADATA,sha256=RJk6PlMtLj1QM5qP1lmpP38UEfbBOkRLJoo2NWgEaVg,10083
63
- gpjax-0.8.0.dist-info/RECORD,,
63
+ gpjax-0.8.2.dist-info/LICENSE,sha256=tAkwu8-AdEyGxGoSvJ2gVmQdcicWw3j1ZZueVV74M-E,11357
64
+ gpjax-0.8.2.dist-info/METADATA,sha256=yI9KlzNn3IATrcpijURB4MvKwXeXPfQVMpZcHmKpz5I,10079
65
+ gpjax-0.8.2.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
66
+ gpjax-0.8.2.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: poetry-core 1.3.2
2
+ Generator: poetry-core 1.9.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
File without changes