deepsensor 0.3.7__tar.gz → 0.3.8__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {deepsensor-0.3.7 → deepsensor-0.3.8}/PKG-INFO +2 -2
- {deepsensor-0.3.7 → deepsensor-0.3.8}/README.md +1 -1
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/data/processor.py +21 -26
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/model/convnp.py +78 -4
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor.egg-info/PKG-INFO +2 -2
- {deepsensor-0.3.7 → deepsensor-0.3.8}/setup.cfg +1 -1
- {deepsensor-0.3.7 → deepsensor-0.3.8}/tests/test_active_learning.py +0 -3
- {deepsensor-0.3.7 → deepsensor-0.3.8}/tests/test_model.py +74 -44
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/active_learning/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/active_learning/acquisition_fns.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/active_learning/algorithms.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/config.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/data/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/data/loader.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/data/sources.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/data/task.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/data/utils.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/errors.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/model/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/model/defaults.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/model/model.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/model/nps.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/model/pred.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/plot.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/py.typed +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/tensorflow/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/torch/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/train/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor/train/train.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor.egg-info/SOURCES.txt +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor.egg-info/dependency_links.txt +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor.egg-info/not-zip-safe +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor.egg-info/requires.txt +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/deepsensor.egg-info/top_level.txt +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/pyproject.toml +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/setup.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/tests/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/tests/test_data_processor.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/tests/test_plotting.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/tests/test_task.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/tests/test_task_loader.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/tests/test_training.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.3.8}/tests/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: deepsensor
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.8
|
|
4
4
|
Summary: A Python package for modelling xarray and pandas data with neural processes.
|
|
5
5
|
Home-page: https://github.com/alan-turing-institute/deepsensor
|
|
6
6
|
Author: Tom R. Andersson
|
|
@@ -44,7 +44,7 @@ data with neural processes</p>
|
|
|
44
44
|
|
|
45
45
|
-----------
|
|
46
46
|
|
|
47
|
-
[](https://github.com/alan-turing-institute/deepsensor/releases)
|
|
48
48
|
[](https://alan-turing-institute.github.io/deepsensor/)
|
|
49
49
|

|
|
50
50
|
[](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
|
|
@@ -11,7 +11,7 @@ data with neural processes</p>
|
|
|
11
11
|
|
|
12
12
|
-----------
|
|
13
13
|
|
|
14
|
-
[](https://github.com/alan-turing-institute/deepsensor/releases)
|
|
15
15
|
[](https://alan-turing-institute.github.io/deepsensor/)
|
|
16
16
|

|
|
17
17
|
[](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
|
|
@@ -97,7 +97,7 @@ class DataProcessor:
|
|
|
97
97
|
self.verbose = verbose
|
|
98
98
|
|
|
99
99
|
# List of valid normalisation method names
|
|
100
|
-
self.valid_methods = ["mean_std", "min_max"]
|
|
100
|
+
self.valid_methods = ["mean_std", "min_max", "positive_semidefinite"]
|
|
101
101
|
|
|
102
102
|
def save(self, folder: str):
|
|
103
103
|
"""Save DataProcessor config to JSON in `folder`"""
|
|
@@ -293,6 +293,8 @@ class DataProcessor:
|
|
|
293
293
|
params = {"mean": float(data.mean()), "std": float(data.std())}
|
|
294
294
|
elif method == "min_max":
|
|
295
295
|
params = {"min": float(data.min()), "max": float(data.max())}
|
|
296
|
+
elif method == "positive_semidefinite":
|
|
297
|
+
params = {"min": float(data.min()), "std": float(data.std())}
|
|
296
298
|
if self.verbose:
|
|
297
299
|
print(f"Done. {var_ID} {method} params={params}")
|
|
298
300
|
self.add_to_config(
|
|
@@ -498,33 +500,25 @@ class DataProcessor:
|
|
|
498
500
|
|
|
499
501
|
params = self.get_config(var_ID, data, method)
|
|
500
502
|
|
|
503
|
+
# Linear transformation:
|
|
504
|
+
# - Inverse normalisation: y_unnorm = m * y_norm + c
|
|
505
|
+
# - Inverse normalisation: y_norm = (1/m) * y_unnorm - c/m
|
|
501
506
|
if method == "mean_std":
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
if unnorm:
|
|
505
|
-
scale = std
|
|
506
|
-
offset = mean
|
|
507
|
-
else:
|
|
508
|
-
scale = 1 / std
|
|
509
|
-
offset = -mean / std
|
|
510
|
-
data = data * scale
|
|
511
|
-
if add_offset:
|
|
512
|
-
data = data + offset
|
|
513
|
-
return data
|
|
514
|
-
|
|
507
|
+
m = params["std"]
|
|
508
|
+
c = params["mean"]
|
|
515
509
|
elif method == "min_max":
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
510
|
+
m = (params["max"] - params["min"]) / 2
|
|
511
|
+
c = (params["max"] + params["min"]) / 2
|
|
512
|
+
elif method == "positive_semidefinite":
|
|
513
|
+
m = params["std"]
|
|
514
|
+
c = params["min"]
|
|
515
|
+
if not unnorm:
|
|
516
|
+
c = -c / m
|
|
517
|
+
m = 1 / m
|
|
518
|
+
data = data * m
|
|
519
|
+
if add_offset:
|
|
520
|
+
data = data + c
|
|
521
|
+
return data
|
|
528
522
|
|
|
529
523
|
def map(
|
|
530
524
|
self,
|
|
@@ -610,6 +604,7 @@ class DataProcessor:
|
|
|
610
604
|
method (str, optional): Normalisation method. Options include:
|
|
611
605
|
- "mean_std": Normalise to mean=0 and std=1 (default)
|
|
612
606
|
- "min_max": Normalise to min=-1 and max=1
|
|
607
|
+
- "positive_semidefinite": Normalise to min=0 and std=1
|
|
613
608
|
|
|
614
609
|
Returns:
|
|
615
610
|
:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]:
|
|
@@ -539,10 +539,10 @@ class ConvNP(DeepSensorModel):
|
|
|
539
539
|
def alpha(
|
|
540
540
|
self, dist: AbstractMultiOutputDistribution
|
|
541
541
|
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
542
|
-
if self.config["likelihood"] not in ["spikes-beta"
|
|
542
|
+
if self.config["likelihood"] not in ["spikes-beta"]:
|
|
543
543
|
raise NotImplementedError(
|
|
544
544
|
f"ConvNP.alpha method not supported for likelihood {self.config['likelihood']}. "
|
|
545
|
-
f"
|
|
545
|
+
f"Valid likelihoods: 'spikes-beta'."
|
|
546
546
|
)
|
|
547
547
|
alpha = dist.slab.alpha
|
|
548
548
|
alpha = self._cast_numpy_and_squeeze(alpha)
|
|
@@ -576,10 +576,10 @@ class ConvNP(DeepSensorModel):
|
|
|
576
576
|
def beta(
|
|
577
577
|
self, dist: AbstractMultiOutputDistribution
|
|
578
578
|
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
579
|
-
if self.config["likelihood"] not in ["spikes-beta"
|
|
579
|
+
if self.config["likelihood"] not in ["spikes-beta"]:
|
|
580
580
|
raise NotImplementedError(
|
|
581
581
|
f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. "
|
|
582
|
-
f"
|
|
582
|
+
f"Valid likelihoods: 'spikes-beta'."
|
|
583
583
|
)
|
|
584
584
|
beta = dist.slab.beta
|
|
585
585
|
beta = self._cast_numpy_and_squeeze(beta)
|
|
@@ -608,6 +608,80 @@ class ConvNP(DeepSensorModel):
|
|
|
608
608
|
dist = self(task)
|
|
609
609
|
return self.beta(dist)
|
|
610
610
|
|
|
611
|
+
@dispatch
|
|
612
|
+
def k(
|
|
613
|
+
self, dist: AbstractMultiOutputDistribution
|
|
614
|
+
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
615
|
+
if self.config["likelihood"] not in ["bernoulli-gamma"]:
|
|
616
|
+
raise NotImplementedError(
|
|
617
|
+
f"ConvNP.k method not supported for likelihood {self.config['likelihood']}. "
|
|
618
|
+
f"Valid likelihoods: 'bernoulli-gamma'."
|
|
619
|
+
)
|
|
620
|
+
k = dist.slab.k
|
|
621
|
+
k = self._cast_numpy_and_squeeze(k)
|
|
622
|
+
return self._maybe_concat_multi_targets(k)
|
|
623
|
+
|
|
624
|
+
@dispatch
|
|
625
|
+
def k(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
|
|
626
|
+
"""
|
|
627
|
+
k parameter values of model's distribution at target locations in task.
|
|
628
|
+
|
|
629
|
+
Returned numpy arrays have shape ``(N_features, *N_targets)``.
|
|
630
|
+
|
|
631
|
+
.. note::
|
|
632
|
+
This method only works for models that return a distribution with
|
|
633
|
+
a ``dist.slab.k`` attribute, e.g. models with a Beta or
|
|
634
|
+
Bernoulli-Gamma likelihood, where it returns the k values of
|
|
635
|
+
the slab component of the mixture model.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
task (:class:`~.data.task.Task`):
|
|
639
|
+
The task containing the context and target data.
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
|
|
643
|
+
k values.
|
|
644
|
+
"""
|
|
645
|
+
dist = self(task)
|
|
646
|
+
return self.k(dist)
|
|
647
|
+
|
|
648
|
+
@dispatch
|
|
649
|
+
def scale(
|
|
650
|
+
self, dist: AbstractMultiOutputDistribution
|
|
651
|
+
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
652
|
+
if self.config["likelihood"] not in ["bernoulli-gamma"]:
|
|
653
|
+
raise NotImplementedError(
|
|
654
|
+
f"ConvNP.scale method not supported for likelihood {self.config['likelihood']}. "
|
|
655
|
+
f"Valid likelihoods: 'bernoulli-gamma'."
|
|
656
|
+
)
|
|
657
|
+
scale = dist.slab.scale
|
|
658
|
+
scale = self._cast_numpy_and_squeeze(scale)
|
|
659
|
+
return self._maybe_concat_multi_targets(scale)
|
|
660
|
+
|
|
661
|
+
@dispatch
|
|
662
|
+
def scale(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
|
|
663
|
+
"""
|
|
664
|
+
Scale parameter values of model's distribution at target locations in task.
|
|
665
|
+
|
|
666
|
+
Returned numpy arrays have shape ``(N_features, *N_targets)``.
|
|
667
|
+
|
|
668
|
+
.. note::
|
|
669
|
+
This method only works for models that return a distribution with
|
|
670
|
+
a ``dist.slab.scale`` attribute, e.g. models with a Beta or
|
|
671
|
+
Bernoulli-Gamma likelihood, where it returns the scale values of
|
|
672
|
+
the slab component of the mixture model.
|
|
673
|
+
|
|
674
|
+
Args:
|
|
675
|
+
task (:class:`~.data.task.Task`):
|
|
676
|
+
The task containing the context and target data.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
|
|
680
|
+
Scale values.
|
|
681
|
+
"""
|
|
682
|
+
dist = self(task)
|
|
683
|
+
return self.scale(dist)
|
|
684
|
+
|
|
611
685
|
@dispatch
|
|
612
686
|
def mixture_probs(self, dist: AbstractMultiOutputDistribution):
|
|
613
687
|
if self.N_mixture_components == 1:
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: deepsensor
|
|
3
|
-
Version: 0.3.
|
|
3
|
+
Version: 0.3.8
|
|
4
4
|
Summary: A Python package for modelling xarray and pandas data with neural processes.
|
|
5
5
|
Home-page: https://github.com/alan-turing-institute/deepsensor
|
|
6
6
|
Author: Tom R. Andersson
|
|
@@ -44,7 +44,7 @@ data with neural processes</p>
|
|
|
44
44
|
|
|
45
45
|
-----------
|
|
46
46
|
|
|
47
|
-
[](https://github.com/alan-turing-institute/deepsensor/releases)
|
|
48
48
|
[](https://alan-turing-institute.github.io/deepsensor/)
|
|
49
49
|

|
|
50
50
|
[](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
|
|
@@ -26,9 +26,6 @@ from deepsensor.data.processor import DataProcessor, xarray_to_coord_array_norma
|
|
|
26
26
|
from deepsensor.model.convnp import ConvNP
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
# from deepsensor.active_learning.acquisition_fns import
|
|
30
|
-
|
|
31
|
-
|
|
32
29
|
class TestActiveLearning(unittest.TestCase):
|
|
33
30
|
|
|
34
31
|
@classmethod
|
|
@@ -193,7 +193,7 @@ class TestModel(unittest.TestCase):
|
|
|
193
193
|
n_targets * dim_y_combined * n_target_dims,
|
|
194
194
|
),
|
|
195
195
|
)
|
|
196
|
-
if likelihood in ["cnp-spikes-beta"]:
|
|
196
|
+
if likelihood in ["cnp-spikes-beta", "bernoulli-gamma"]:
|
|
197
197
|
mixture_probs = model.mixture_probs(task)
|
|
198
198
|
if isinstance(mixture_probs, (list, tuple)):
|
|
199
199
|
for p, dim_y in zip(mixture_probs, tl.target_dims):
|
|
@@ -215,6 +215,7 @@ class TestModel(unittest.TestCase):
|
|
|
215
215
|
),
|
|
216
216
|
)
|
|
217
217
|
|
|
218
|
+
if likelihood in ["cnp-spikes-beta"]:
|
|
218
219
|
x = model.alpha(task)
|
|
219
220
|
if isinstance(x, (list, tuple)):
|
|
220
221
|
for p, dim_y in zip(x, tl.target_dims):
|
|
@@ -229,6 +230,21 @@ class TestModel(unittest.TestCase):
|
|
|
229
230
|
else:
|
|
230
231
|
assert_shape(x, (dim_y_combined, *expected_obs_shape))
|
|
231
232
|
|
|
233
|
+
if likelihood in ["bernoulli-gamma"]:
|
|
234
|
+
x = model.k(task)
|
|
235
|
+
if isinstance(x, (list, tuple)):
|
|
236
|
+
for p, dim_y in zip(x, tl.target_dims):
|
|
237
|
+
assert_shape(p, (dim_y, *expected_obs_shape))
|
|
238
|
+
else:
|
|
239
|
+
assert_shape(x, (dim_y_combined, *expected_obs_shape))
|
|
240
|
+
|
|
241
|
+
x = model.scale(task)
|
|
242
|
+
if isinstance(x, (list, tuple)):
|
|
243
|
+
for p, dim_y in zip(x, tl.target_dims):
|
|
244
|
+
assert_shape(p, (dim_y, *expected_obs_shape))
|
|
245
|
+
else:
|
|
246
|
+
assert_shape(x, (dim_y_combined, *expected_obs_shape))
|
|
247
|
+
|
|
232
248
|
# Scalars
|
|
233
249
|
if likelihood in ["cnp", "gnp"]:
|
|
234
250
|
# Methods for Gaussian likelihoods only
|
|
@@ -451,61 +467,75 @@ class TestModel(unittest.TestCase):
|
|
|
451
467
|
def test_highlevel_predict_with_pred_params_pandas(self):
|
|
452
468
|
"""
|
|
453
469
|
Test that passing ``pred_params`` to ``.predict`` works with
|
|
454
|
-
|
|
470
|
+
mixture model likelihoods for off-grid prediction to pandas.
|
|
455
471
|
"""
|
|
456
472
|
tl = TaskLoader(context=self.da, target=self.da)
|
|
457
|
-
model = ConvNP(
|
|
458
|
-
self.dp,
|
|
459
|
-
tl,
|
|
460
|
-
unet_channels=(5, 5, 5),
|
|
461
|
-
verbose=False,
|
|
462
|
-
likelihood="cnp-spikes-beta",
|
|
463
|
-
)
|
|
464
|
-
task = tl("2020-01-01", context_sampling=10, target_sampling=10)
|
|
465
473
|
|
|
466
|
-
|
|
467
|
-
|
|
474
|
+
likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"]
|
|
475
|
+
expected_pred_params = [
|
|
476
|
+
["mean", "std", "variance", "alpha", "beta"],
|
|
477
|
+
["mean", "std", "variance", "k", "scale"],
|
|
478
|
+
]
|
|
468
479
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
480
|
+
for likelihood, pred_params in zip(likelihoods, expected_pred_params):
|
|
481
|
+
model = ConvNP(
|
|
482
|
+
self.dp,
|
|
483
|
+
tl,
|
|
484
|
+
unet_channels=(5, 5, 5),
|
|
485
|
+
verbose=False,
|
|
486
|
+
likelihood=likelihood,
|
|
487
|
+
)
|
|
488
|
+
task = tl("2020-01-01", context_sampling=10)
|
|
489
|
+
|
|
490
|
+
# Off-grid prediction
|
|
491
|
+
X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]])
|
|
474
492
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
493
|
+
# Check that nothing breaks and the correct parameters are returned
|
|
494
|
+
pred = model.predict(task, X_t=X_t, pred_params=pred_params)
|
|
495
|
+
for pred_param in pred_params:
|
|
496
|
+
assert pred_param in pred["var"]
|
|
497
|
+
|
|
498
|
+
# Test mixture probs special case
|
|
499
|
+
pred_params = ["mixture_probs"]
|
|
500
|
+
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
|
|
501
|
+
for component in range(model.N_mixture_components):
|
|
502
|
+
pred_param = f"mixture_probs_{component}"
|
|
503
|
+
assert pred_param in pred["var"]
|
|
481
504
|
|
|
482
505
|
def test_highlevel_predict_with_pred_params_xarray(self):
|
|
483
506
|
"""
|
|
484
507
|
Test that passing ``pred_params`` to ``.predict`` works with
|
|
485
|
-
|
|
508
|
+
mixture model likelihoods for gridded prediction to xarray.
|
|
486
509
|
"""
|
|
487
510
|
tl = TaskLoader(context=self.da, target=self.da)
|
|
488
|
-
model = ConvNP(
|
|
489
|
-
self.dp,
|
|
490
|
-
tl,
|
|
491
|
-
unet_channels=(5, 5, 5),
|
|
492
|
-
verbose=False,
|
|
493
|
-
likelihood="cnp-spikes-beta",
|
|
494
|
-
)
|
|
495
|
-
task = tl("2020-01-01", context_sampling=10, target_sampling=10)
|
|
496
511
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
512
|
+
likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"]
|
|
513
|
+
expected_pred_params = [
|
|
514
|
+
["mean", "std", "variance", "alpha", "beta"],
|
|
515
|
+
["mean", "std", "variance", "k", "scale"],
|
|
516
|
+
]
|
|
517
|
+
|
|
518
|
+
for likelihood, pred_params in zip(likelihoods, expected_pred_params):
|
|
519
|
+
model = ConvNP(
|
|
520
|
+
self.dp,
|
|
521
|
+
tl,
|
|
522
|
+
unet_channels=(5, 5, 5),
|
|
523
|
+
verbose=False,
|
|
524
|
+
likelihood=likelihood,
|
|
525
|
+
)
|
|
526
|
+
task = tl("2020-01-01", context_sampling=10)
|
|
527
|
+
|
|
528
|
+
# Check that nothing breaks and the correct parameters are returned
|
|
529
|
+
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
|
|
530
|
+
for pred_param in pred_params:
|
|
531
|
+
assert pred_param in pred["var"]
|
|
532
|
+
|
|
533
|
+
# Test mixture probs special case
|
|
534
|
+
pred_params = ["mixture_probs"]
|
|
535
|
+
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
|
|
536
|
+
for component in range(model.N_mixture_components):
|
|
537
|
+
pred_param = f"mixture_probs_{component}"
|
|
538
|
+
assert pred_param in pred["var"]
|
|
509
539
|
|
|
510
540
|
def test_highlevel_predict_with_invalid_pred_params(self):
|
|
511
541
|
"""Test that passing ``pred_params`` to ``.predict`` works."""
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|