deepsensor 0.3.7__tar.gz → 0.4.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {deepsensor-0.3.7 → deepsensor-0.4.0}/PKG-INFO +3 -3
- {deepsensor-0.3.7 → deepsensor-0.4.0}/README.md +2 -2
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/loader.py +17 -13
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/processor.py +21 -26
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/sources.py +3 -3
- deepsensor-0.4.0/deepsensor/eval/__init__.py +1 -0
- deepsensor-0.4.0/deepsensor/eval/metrics.py +24 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/convnp.py +78 -4
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/model.py +74 -5
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/pred.py +86 -17
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/PKG-INFO +3 -3
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/SOURCES.txt +2 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/setup.cfg +1 -1
- {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_active_learning.py +0 -3
- {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_model.py +143 -45
- {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_task_loader.py +17 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/active_learning/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/active_learning/acquisition_fns.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/active_learning/algorithms.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/config.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/task.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/utils.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/errors.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/defaults.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/nps.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/plot.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/py.typed +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/tensorflow/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/torch/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/train/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/train/train.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/dependency_links.txt +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/not-zip-safe +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/requires.txt +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/top_level.txt +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/pyproject.toml +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/setup.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/__init__.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_data_processor.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_plotting.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_task.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_training.py +0 -0
- {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: deepsensor
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: A Python package for modelling xarray and pandas data with neural processes.
|
|
5
5
|
Home-page: https://github.com/alan-turing-institute/deepsensor
|
|
6
6
|
Author: Tom R. Andersson
|
|
@@ -44,7 +44,7 @@ data with neural processes</p>
|
|
|
44
44
|
|
|
45
45
|
-----------
|
|
46
46
|
|
|
47
|
-
[](https://github.com/alan-turing-institute/deepsensor/releases)
|
|
48
48
|
[](https://alan-turing-institute.github.io/deepsensor/)
|
|
49
49
|

|
|
50
50
|
[](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
|
|
@@ -240,7 +240,7 @@ if you would like to join this list!
|
|
|
240
240
|
<table>
|
|
241
241
|
<tbody>
|
|
242
242
|
<tr>
|
|
243
|
-
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
|
|
243
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
|
|
244
244
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
|
|
245
245
|
<td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
|
|
246
246
|
<td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
|
|
@@ -11,7 +11,7 @@ data with neural processes</p>
|
|
|
11
11
|
|
|
12
12
|
-----------
|
|
13
13
|
|
|
14
|
-
[](https://github.com/alan-turing-institute/deepsensor/releases)
|
|
15
15
|
[](https://alan-turing-institute.github.io/deepsensor/)
|
|
16
16
|

|
|
17
17
|
[](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
|
|
@@ -207,7 +207,7 @@ if you would like to join this list!
|
|
|
207
207
|
<table>
|
|
208
208
|
<tbody>
|
|
209
209
|
<tr>
|
|
210
|
-
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
|
|
210
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
|
|
211
211
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
|
|
212
212
|
<td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
|
|
213
213
|
<td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
|
|
@@ -678,11 +678,11 @@ class TaskLoader:
|
|
|
678
678
|
seed: Optional[int] = None,
|
|
679
679
|
) -> (np.ndarray, np.ndarray):
|
|
680
680
|
"""
|
|
681
|
-
Sample a
|
|
681
|
+
Sample a DataFrame according to a given strategy.
|
|
682
682
|
|
|
683
683
|
Args:
|
|
684
684
|
df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
|
|
685
|
-
|
|
685
|
+
Dataframe to sample, assumed to be time-sliced for the task
|
|
686
686
|
already.
|
|
687
687
|
sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
|
|
688
688
|
Sampling strategy, either "all" or an integer for random grid
|
|
@@ -720,20 +720,24 @@ class TaskLoader:
|
|
|
720
720
|
X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
|
|
721
721
|
Y_c = df.values.T
|
|
722
722
|
elif isinstance(sampling_strat, np.ndarray):
|
|
723
|
+
if df.index.get_level_values("x1").dtype != sampling_strat.dtype:
|
|
724
|
+
raise InvalidSamplingStrategyError(
|
|
725
|
+
"Passed a numpy coordinate array to sample pandas DataFrame, "
|
|
726
|
+
"but the coordinate array has a different dtype than the DataFrame. "
|
|
727
|
+
f"Got {sampling_strat.dtype} but expected {df.index.get_level_values('x1').dtype}."
|
|
728
|
+
)
|
|
723
729
|
X_c = sampling_strat.astype(self.dtype)
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
# Check that we got all the samples we asked for
|
|
729
|
-
if num_matches != X_c.shape[1]:
|
|
730
|
+
try:
|
|
731
|
+
Y_c = df.loc[pd.IndexSlice[:, X_c[0], X_c[1]]].values.T
|
|
732
|
+
except KeyError:
|
|
730
733
|
raise InvalidSamplingStrategyError(
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
f"
|
|
734
|
+
"Passed a numpy coordinate array to sample pandas DataFrame, "
|
|
735
|
+
"but the DataFrame did not contain all the requested samples.\n"
|
|
736
|
+
f"Indexes: {df.index}\n"
|
|
737
|
+
f"Sampling coords: {X_c}\n"
|
|
738
|
+
"If this is unexpected, check that your numpy sampling array matches "
|
|
739
|
+
"the DataFrame index values *exactly*."
|
|
734
740
|
)
|
|
735
|
-
|
|
736
|
-
Y_c = df[x1match & x2match].values.T
|
|
737
741
|
else:
|
|
738
742
|
raise InvalidSamplingStrategyError(
|
|
739
743
|
f"Unknown sampling strategy {sampling_strat}"
|
|
@@ -97,7 +97,7 @@ class DataProcessor:
|
|
|
97
97
|
self.verbose = verbose
|
|
98
98
|
|
|
99
99
|
# List of valid normalisation method names
|
|
100
|
-
self.valid_methods = ["mean_std", "min_max"]
|
|
100
|
+
self.valid_methods = ["mean_std", "min_max", "positive_semidefinite"]
|
|
101
101
|
|
|
102
102
|
def save(self, folder: str):
|
|
103
103
|
"""Save DataProcessor config to JSON in `folder`"""
|
|
@@ -293,6 +293,8 @@ class DataProcessor:
|
|
|
293
293
|
params = {"mean": float(data.mean()), "std": float(data.std())}
|
|
294
294
|
elif method == "min_max":
|
|
295
295
|
params = {"min": float(data.min()), "max": float(data.max())}
|
|
296
|
+
elif method == "positive_semidefinite":
|
|
297
|
+
params = {"min": float(data.min()), "std": float(data.std())}
|
|
296
298
|
if self.verbose:
|
|
297
299
|
print(f"Done. {var_ID} {method} params={params}")
|
|
298
300
|
self.add_to_config(
|
|
@@ -498,33 +500,25 @@ class DataProcessor:
|
|
|
498
500
|
|
|
499
501
|
params = self.get_config(var_ID, data, method)
|
|
500
502
|
|
|
503
|
+
# Linear transformation:
|
|
504
|
+
# - Inverse normalisation: y_unnorm = m * y_norm + c
|
|
505
|
+
# - Inverse normalisation: y_norm = (1/m) * y_unnorm - c/m
|
|
501
506
|
if method == "mean_std":
|
|
502
|
-
|
|
503
|
-
|
|
504
|
-
if unnorm:
|
|
505
|
-
scale = std
|
|
506
|
-
offset = mean
|
|
507
|
-
else:
|
|
508
|
-
scale = 1 / std
|
|
509
|
-
offset = -mean / std
|
|
510
|
-
data = data * scale
|
|
511
|
-
if add_offset:
|
|
512
|
-
data = data + offset
|
|
513
|
-
return data
|
|
514
|
-
|
|
507
|
+
m = params["std"]
|
|
508
|
+
c = params["mean"]
|
|
515
509
|
elif method == "min_max":
|
|
516
|
-
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
|
|
520
|
-
|
|
521
|
-
|
|
522
|
-
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
|
|
526
|
-
|
|
527
|
-
|
|
510
|
+
m = (params["max"] - params["min"]) / 2
|
|
511
|
+
c = (params["max"] + params["min"]) / 2
|
|
512
|
+
elif method == "positive_semidefinite":
|
|
513
|
+
m = params["std"]
|
|
514
|
+
c = params["min"]
|
|
515
|
+
if not unnorm:
|
|
516
|
+
c = -c / m
|
|
517
|
+
m = 1 / m
|
|
518
|
+
data = data * m
|
|
519
|
+
if add_offset:
|
|
520
|
+
data = data + c
|
|
521
|
+
return data
|
|
528
522
|
|
|
529
523
|
def map(
|
|
530
524
|
self,
|
|
@@ -610,6 +604,7 @@ class DataProcessor:
|
|
|
610
604
|
method (str, optional): Normalisation method. Options include:
|
|
611
605
|
- "mean_std": Normalise to mean=0 and std=1 (default)
|
|
612
606
|
- "min_max": Normalise to min=-1 and max=1
|
|
607
|
+
- "positive_semidefinite": Normalise to min=0 and std=1
|
|
613
608
|
|
|
614
609
|
Returns:
|
|
615
610
|
:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]:
|
|
@@ -277,7 +277,7 @@ def get_era5_reanalysis_data(
|
|
|
277
277
|
if num_processes == 1:
|
|
278
278
|
# Just download in one go
|
|
279
279
|
if verbose:
|
|
280
|
-
print("Downloading ERA5 data
|
|
280
|
+
print("Downloading ERA5 data without parallelisation... ")
|
|
281
281
|
era5_da = _get_era5_reanalysis_data_parallel(
|
|
282
282
|
date_range=date_range,
|
|
283
283
|
var_IDs=var_IDs,
|
|
@@ -432,7 +432,7 @@ def get_gldas_land_mask(
|
|
|
432
432
|
with urllib.request.urlopen(req) as response:
|
|
433
433
|
with open(fname, "wb") as f:
|
|
434
434
|
f.write(response.read())
|
|
435
|
-
da = xr.open_dataset(fname)["GLDAS_mask"].isel(time=0).
|
|
435
|
+
da = xr.open_dataset(fname)["GLDAS_mask"].isel(time=0).drop_vars("time").load()
|
|
436
436
|
|
|
437
437
|
if isinstance(extent, str):
|
|
438
438
|
extent = extent_str_to_tuple(extent)
|
|
@@ -577,7 +577,7 @@ def get_earthenv_auxiliary_data(
|
|
|
577
577
|
# Read data
|
|
578
578
|
da = xr.open_dataset(fname).to_array().squeeze().load()
|
|
579
579
|
da = da.rename({"y": "lat", "x": "lon"})
|
|
580
|
-
da = da.
|
|
580
|
+
da = da.drop_vars(["band", "spatial_ref", "variable"])
|
|
581
581
|
da.name = var_ID
|
|
582
582
|
da = da.sel(lat=slice(lat_max, lat_min), lon=slice(lon_min, lon_max))
|
|
583
583
|
da_dict[var_ID] = da
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .metrics import *
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
from deepsensor.model.pred import Prediction
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset:
|
|
6
|
+
"""
|
|
7
|
+
Compute errors between predictions and targets.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
pred: Prediction object.
|
|
11
|
+
target: Target data.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
xr.Dataset: Dataset of pointwise differences between predictions and targets
|
|
15
|
+
at the same valid time in the predictions. Note, the difference is positive
|
|
16
|
+
when the prediction is greater than the target.
|
|
17
|
+
"""
|
|
18
|
+
errors = {}
|
|
19
|
+
for var_ID, pred_var in pred.items():
|
|
20
|
+
target_var = target[var_ID]
|
|
21
|
+
error = pred_var["mean"] - target_var.sel(time=pred_var.time)
|
|
22
|
+
error.name = f"{var_ID}"
|
|
23
|
+
errors[var_ID] = error
|
|
24
|
+
return xr.Dataset(errors)
|
|
@@ -539,10 +539,10 @@ class ConvNP(DeepSensorModel):
|
|
|
539
539
|
def alpha(
|
|
540
540
|
self, dist: AbstractMultiOutputDistribution
|
|
541
541
|
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
542
|
-
if self.config["likelihood"] not in ["spikes-beta"
|
|
542
|
+
if self.config["likelihood"] not in ["spikes-beta"]:
|
|
543
543
|
raise NotImplementedError(
|
|
544
544
|
f"ConvNP.alpha method not supported for likelihood {self.config['likelihood']}. "
|
|
545
|
-
f"
|
|
545
|
+
f"Valid likelihoods: 'spikes-beta'."
|
|
546
546
|
)
|
|
547
547
|
alpha = dist.slab.alpha
|
|
548
548
|
alpha = self._cast_numpy_and_squeeze(alpha)
|
|
@@ -576,10 +576,10 @@ class ConvNP(DeepSensorModel):
|
|
|
576
576
|
def beta(
|
|
577
577
|
self, dist: AbstractMultiOutputDistribution
|
|
578
578
|
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
579
|
-
if self.config["likelihood"] not in ["spikes-beta"
|
|
579
|
+
if self.config["likelihood"] not in ["spikes-beta"]:
|
|
580
580
|
raise NotImplementedError(
|
|
581
581
|
f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. "
|
|
582
|
-
f"
|
|
582
|
+
f"Valid likelihoods: 'spikes-beta'."
|
|
583
583
|
)
|
|
584
584
|
beta = dist.slab.beta
|
|
585
585
|
beta = self._cast_numpy_and_squeeze(beta)
|
|
@@ -608,6 +608,80 @@ class ConvNP(DeepSensorModel):
|
|
|
608
608
|
dist = self(task)
|
|
609
609
|
return self.beta(dist)
|
|
610
610
|
|
|
611
|
+
@dispatch
|
|
612
|
+
def k(
|
|
613
|
+
self, dist: AbstractMultiOutputDistribution
|
|
614
|
+
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
615
|
+
if self.config["likelihood"] not in ["bernoulli-gamma"]:
|
|
616
|
+
raise NotImplementedError(
|
|
617
|
+
f"ConvNP.k method not supported for likelihood {self.config['likelihood']}. "
|
|
618
|
+
f"Valid likelihoods: 'bernoulli-gamma'."
|
|
619
|
+
)
|
|
620
|
+
k = dist.slab.k
|
|
621
|
+
k = self._cast_numpy_and_squeeze(k)
|
|
622
|
+
return self._maybe_concat_multi_targets(k)
|
|
623
|
+
|
|
624
|
+
@dispatch
|
|
625
|
+
def k(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
|
|
626
|
+
"""
|
|
627
|
+
k parameter values of model's distribution at target locations in task.
|
|
628
|
+
|
|
629
|
+
Returned numpy arrays have shape ``(N_features, *N_targets)``.
|
|
630
|
+
|
|
631
|
+
.. note::
|
|
632
|
+
This method only works for models that return a distribution with
|
|
633
|
+
a ``dist.slab.k`` attribute, e.g. models with a Beta or
|
|
634
|
+
Bernoulli-Gamma likelihood, where it returns the k values of
|
|
635
|
+
the slab component of the mixture model.
|
|
636
|
+
|
|
637
|
+
Args:
|
|
638
|
+
task (:class:`~.data.task.Task`):
|
|
639
|
+
The task containing the context and target data.
|
|
640
|
+
|
|
641
|
+
Returns:
|
|
642
|
+
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
|
|
643
|
+
k values.
|
|
644
|
+
"""
|
|
645
|
+
dist = self(task)
|
|
646
|
+
return self.k(dist)
|
|
647
|
+
|
|
648
|
+
@dispatch
|
|
649
|
+
def scale(
|
|
650
|
+
self, dist: AbstractMultiOutputDistribution
|
|
651
|
+
) -> Union[np.ndarray, List[np.ndarray]]:
|
|
652
|
+
if self.config["likelihood"] not in ["bernoulli-gamma"]:
|
|
653
|
+
raise NotImplementedError(
|
|
654
|
+
f"ConvNP.scale method not supported for likelihood {self.config['likelihood']}. "
|
|
655
|
+
f"Valid likelihoods: 'bernoulli-gamma'."
|
|
656
|
+
)
|
|
657
|
+
scale = dist.slab.scale
|
|
658
|
+
scale = self._cast_numpy_and_squeeze(scale)
|
|
659
|
+
return self._maybe_concat_multi_targets(scale)
|
|
660
|
+
|
|
661
|
+
@dispatch
|
|
662
|
+
def scale(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
|
|
663
|
+
"""
|
|
664
|
+
Scale parameter values of model's distribution at target locations in task.
|
|
665
|
+
|
|
666
|
+
Returned numpy arrays have shape ``(N_features, *N_targets)``.
|
|
667
|
+
|
|
668
|
+
.. note::
|
|
669
|
+
This method only works for models that return a distribution with
|
|
670
|
+
a ``dist.slab.scale`` attribute, e.g. models with a Beta or
|
|
671
|
+
Bernoulli-Gamma likelihood, where it returns the scale values of
|
|
672
|
+
the slab component of the mixture model.
|
|
673
|
+
|
|
674
|
+
Args:
|
|
675
|
+
task (:class:`~.data.task.Task`):
|
|
676
|
+
The task containing the context and target data.
|
|
677
|
+
|
|
678
|
+
Returns:
|
|
679
|
+
:class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
|
|
680
|
+
Scale values.
|
|
681
|
+
"""
|
|
682
|
+
dist = self(task)
|
|
683
|
+
return self.scale(dist)
|
|
684
|
+
|
|
611
685
|
@dispatch
|
|
612
686
|
def mixture_probs(self, dist: AbstractMultiOutputDistribution):
|
|
613
687
|
if self.N_mixture_components == 1:
|
|
@@ -348,6 +348,34 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
348
348
|
if ar_sample and n_samples < 1:
|
|
349
349
|
raise ValueError("Must pass `n_samples` > 0 to use `ar_sample`.")
|
|
350
350
|
|
|
351
|
+
target_delta_t = self.task_loader.target_delta_t
|
|
352
|
+
dts = [pd.Timedelta(dt) for dt in target_delta_t]
|
|
353
|
+
dts_all_zero = all([dt == pd.Timedelta(seconds=0) for dt in dts])
|
|
354
|
+
if target_delta_t is not None and dts_all_zero:
|
|
355
|
+
forecasting_mode = False
|
|
356
|
+
lead_times = None
|
|
357
|
+
elif target_delta_t is not None and not dts_all_zero:
|
|
358
|
+
target_var_IDs_set = set(self.task_loader.target_var_IDs)
|
|
359
|
+
msg = f"""
|
|
360
|
+
Got more than one set of target variables in target sets,
|
|
361
|
+
but predictions can only be made with one set of target variables
|
|
362
|
+
to simplify implementation.
|
|
363
|
+
Got {target_var_IDs_set}.
|
|
364
|
+
"""
|
|
365
|
+
assert len(target_var_IDs_set) == 1, msg
|
|
366
|
+
# Repeat lead_tim for each variable in each target set
|
|
367
|
+
lead_times = []
|
|
368
|
+
for target_set_idx, dt in enumerate(target_delta_t):
|
|
369
|
+
target_set_dim = self.task_loader.target_dims[target_set_idx]
|
|
370
|
+
lead_times += [
|
|
371
|
+
pd.Timedelta(dt, unit=self.task_loader.time_freq)
|
|
372
|
+
for _ in range(target_set_dim)
|
|
373
|
+
]
|
|
374
|
+
forecasting_mode = True
|
|
375
|
+
else:
|
|
376
|
+
forecasting_mode = False
|
|
377
|
+
lead_times = None
|
|
378
|
+
|
|
351
379
|
if type(tasks) is Task:
|
|
352
380
|
tasks = [tasks]
|
|
353
381
|
|
|
@@ -355,12 +383,14 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
355
383
|
B.set_random_seed(seed)
|
|
356
384
|
np.random.seed(seed)
|
|
357
385
|
|
|
358
|
-
|
|
386
|
+
init_dates = [task["time"] for task in tasks]
|
|
359
387
|
|
|
360
388
|
# Flatten tuple of tuples to single list
|
|
361
389
|
target_var_IDs = [
|
|
362
390
|
var_ID for set in self.task_loader.target_var_IDs for var_ID in set
|
|
363
391
|
]
|
|
392
|
+
if lead_times is not None:
|
|
393
|
+
assert len(lead_times) == len(target_var_IDs)
|
|
364
394
|
|
|
365
395
|
# TODO consider removing this logic, can we just depend on the dim names in X_t?
|
|
366
396
|
if not unnormalise:
|
|
@@ -385,7 +415,7 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
385
415
|
elif isinstance(X_t, (xr.DataArray, xr.Dataset)):
|
|
386
416
|
# Remove time dimension if present
|
|
387
417
|
if "time" in X_t.coords:
|
|
388
|
-
X_t = X_t.isel(time=0).
|
|
418
|
+
X_t = X_t.isel(time=0).drop_vars("time")
|
|
389
419
|
|
|
390
420
|
if mode == "off-grid" and append_indexes is not None:
|
|
391
421
|
# Check append_indexes are all same length as X_t
|
|
@@ -450,11 +480,13 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
450
480
|
pred = Prediction(
|
|
451
481
|
target_var_IDs,
|
|
452
482
|
pred_params_to_store,
|
|
453
|
-
|
|
483
|
+
init_dates,
|
|
454
484
|
X_t,
|
|
455
485
|
X_t_mask,
|
|
456
486
|
coord_names,
|
|
457
487
|
n_samples=n_samples,
|
|
488
|
+
forecasting_mode=forecasting_mode,
|
|
489
|
+
lead_times=lead_times,
|
|
458
490
|
)
|
|
459
491
|
|
|
460
492
|
def unnormalise_pred_array(arr, **kwargs):
|
|
@@ -605,14 +637,22 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
605
637
|
# Assign predictions to Prediction object
|
|
606
638
|
for param, arr in prediction_arrs.items():
|
|
607
639
|
if param != "mixture_probs":
|
|
608
|
-
pred.assign(param, task["time"], arr)
|
|
640
|
+
pred.assign(param, task["time"], arr, lead_times=lead_times)
|
|
609
641
|
elif param == "mixture_probs":
|
|
610
642
|
assert arr.shape[0] == self.N_mixture_components, (
|
|
611
643
|
f"Number of mixture components ({arr.shape[0]}) does not match "
|
|
612
644
|
f"model attribute N_mixture_components ({self.N_mixture_components})."
|
|
613
645
|
)
|
|
614
646
|
for component_i, probs in enumerate(arr):
|
|
615
|
-
pred.assign(
|
|
647
|
+
pred.assign(
|
|
648
|
+
f"{param}_{component_i}",
|
|
649
|
+
task["time"],
|
|
650
|
+
probs,
|
|
651
|
+
lead_times=lead_times,
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
if forecasting_mode:
|
|
655
|
+
pred = add_valid_time_coord_to_pred(pred)
|
|
616
656
|
|
|
617
657
|
if verbose:
|
|
618
658
|
dur = time.time() - tic
|
|
@@ -621,6 +661,35 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
621
661
|
return pred
|
|
622
662
|
|
|
623
663
|
|
|
664
|
+
def add_valid_time_coord_to_pred(pred: Prediction) -> Prediction:
|
|
665
|
+
"""
|
|
666
|
+
Add a valid time coordinate "time" to a Prediction object based on the
|
|
667
|
+
initialisation times "init_time" and lead times "lead_time".
|
|
668
|
+
|
|
669
|
+
Args:
|
|
670
|
+
pred (:class:`~.model.pred.Prediction`):
|
|
671
|
+
Prediction object to add valid time coordinate to.
|
|
672
|
+
|
|
673
|
+
Returns:
|
|
674
|
+
:class:`~.model.pred.Prediction`:
|
|
675
|
+
Prediction object with valid time coordinate added.
|
|
676
|
+
"""
|
|
677
|
+
for var_ID in pred.keys():
|
|
678
|
+
if isinstance(pred[var_ID], pd.DataFrame):
|
|
679
|
+
x = pred[var_ID].reset_index()
|
|
680
|
+
pred[var_ID]["time"] = (x["lead_time"] + x["init_time"]).values
|
|
681
|
+
print(f"{x}")
|
|
682
|
+
print(f"{x.dtypes}")
|
|
683
|
+
elif isinstance(pred[var_ID], xr.Dataset):
|
|
684
|
+
x = pred[var_ID]
|
|
685
|
+
pred[var_ID] = pred[var_ID].assign_coords(
|
|
686
|
+
time=x["lead_time"] + x["init_time"]
|
|
687
|
+
)
|
|
688
|
+
else:
|
|
689
|
+
raise ValueError(f"Unsupported prediction type {type(pred[var_ID])}.")
|
|
690
|
+
return pred
|
|
691
|
+
|
|
692
|
+
|
|
624
693
|
def main(): # pragma: no cover
|
|
625
694
|
import deepsensor.tensorflow
|
|
626
695
|
from deepsensor.data.loader import TaskLoader
|
|
@@ -4,6 +4,8 @@ import numpy as np
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import xarray as xr
|
|
6
6
|
|
|
7
|
+
Timestamp = Union[str, pd.Timestamp, np.datetime64]
|
|
8
|
+
|
|
7
9
|
|
|
8
10
|
class Prediction(dict):
|
|
9
11
|
"""
|
|
@@ -32,13 +34,20 @@ class Prediction(dict):
|
|
|
32
34
|
n_samples (int)
|
|
33
35
|
Number of joint samples to draw from the model. If 0, will not
|
|
34
36
|
draw samples. Default 0.
|
|
37
|
+
forecasting_mode (bool)
|
|
38
|
+
If True, stored forecast predictions with an init_time and lead_time dimension,
|
|
39
|
+
and a valid_time coordinate. If False, stores prediction at t=0 only
|
|
40
|
+
(i.e. spatial interpolation), with only a single time dimension. Default False.
|
|
41
|
+
lead_times (List[pd.Timedelta], optional)
|
|
42
|
+
List of lead times to store in predictions. Must be provided if
|
|
43
|
+
forecasting_mode is True. Default None.
|
|
35
44
|
"""
|
|
36
45
|
|
|
37
46
|
def __init__(
|
|
38
47
|
self,
|
|
39
48
|
target_var_IDs: List[str],
|
|
40
49
|
pred_params: List[str],
|
|
41
|
-
dates: List[
|
|
50
|
+
dates: List[Timestamp],
|
|
42
51
|
X_t: Union[
|
|
43
52
|
xr.Dataset,
|
|
44
53
|
xr.DataArray,
|
|
@@ -50,6 +59,8 @@ class Prediction(dict):
|
|
|
50
59
|
X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
|
|
51
60
|
coord_names: dict = None,
|
|
52
61
|
n_samples: int = 0,
|
|
62
|
+
forecasting_mode: bool = False,
|
|
63
|
+
lead_times: Optional[List[pd.Timedelta]] = None,
|
|
53
64
|
):
|
|
54
65
|
self.target_var_IDs = target_var_IDs
|
|
55
66
|
self.X_t_mask = X_t_mask
|
|
@@ -58,6 +69,13 @@ class Prediction(dict):
|
|
|
58
69
|
self.x1_name = coord_names["x1"]
|
|
59
70
|
self.x2_name = coord_names["x2"]
|
|
60
71
|
|
|
72
|
+
self.forecasting_mode = forecasting_mode
|
|
73
|
+
if forecasting_mode:
|
|
74
|
+
assert (
|
|
75
|
+
lead_times is not None
|
|
76
|
+
), "If forecasting_mode is True, lead_times must be provided."
|
|
77
|
+
self.lead_times = lead_times
|
|
78
|
+
|
|
61
79
|
self.mode = infer_prediction_modality_from_X_t(X_t)
|
|
62
80
|
|
|
63
81
|
self.pred_params = pred_params
|
|
@@ -67,15 +85,25 @@ class Prediction(dict):
|
|
|
67
85
|
*[f"sample_{i}" for i in range(n_samples)],
|
|
68
86
|
]
|
|
69
87
|
|
|
88
|
+
# Create empty xarray/pandas objects to store predictions
|
|
70
89
|
if self.mode == "on-grid":
|
|
71
90
|
for var_ID in self.target_var_IDs:
|
|
72
|
-
|
|
91
|
+
if self.forecasting_mode:
|
|
92
|
+
prepend_dims = ["lead_time"]
|
|
93
|
+
prepend_coords = {"lead_time": lead_times}
|
|
94
|
+
else:
|
|
95
|
+
prepend_dims = None
|
|
96
|
+
prepend_coords = None
|
|
73
97
|
self[var_ID] = create_empty_spatiotemporal_xarray(
|
|
74
98
|
X_t,
|
|
75
99
|
dates,
|
|
76
100
|
data_vars=self.pred_params,
|
|
77
101
|
coord_names=coord_names,
|
|
102
|
+
prepend_dims=prepend_dims,
|
|
103
|
+
prepend_coords=prepend_coords,
|
|
78
104
|
)
|
|
105
|
+
if self.forecasting_mode:
|
|
106
|
+
self[var_ID] = self[var_ID].rename(time="init_time")
|
|
79
107
|
if self.X_t_mask is None:
|
|
80
108
|
# Create 2D boolean array of True values to simplify indexing
|
|
81
109
|
self.X_t_mask = (
|
|
@@ -86,8 +114,18 @@ class Prediction(dict):
|
|
|
86
114
|
)
|
|
87
115
|
elif self.mode == "off-grid":
|
|
88
116
|
# Repeat target locs for each date to create multiindex
|
|
89
|
-
|
|
90
|
-
|
|
117
|
+
if self.forecasting_mode:
|
|
118
|
+
index_names = ["lead_time", "init_time", *X_t.index.names]
|
|
119
|
+
idxs = [
|
|
120
|
+
(lt, date, *idxs)
|
|
121
|
+
for lt in lead_times
|
|
122
|
+
for date in dates
|
|
123
|
+
for idxs in X_t.index
|
|
124
|
+
]
|
|
125
|
+
else:
|
|
126
|
+
index_names = ["time", *X_t.index.names]
|
|
127
|
+
idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
|
|
128
|
+
index = pd.MultiIndex.from_tuples(idxs, names=index_names)
|
|
91
129
|
for var_ID in self.target_var_IDs:
|
|
92
130
|
self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params)
|
|
93
131
|
|
|
@@ -106,6 +144,7 @@ class Prediction(dict):
|
|
|
106
144
|
prediction_parameter: str,
|
|
107
145
|
date: Union[str, pd.Timestamp],
|
|
108
146
|
data: np.ndarray,
|
|
147
|
+
lead_times: Optional[List[pd.Timedelta]] = None,
|
|
109
148
|
):
|
|
110
149
|
"""
|
|
111
150
|
|
|
@@ -117,11 +156,29 @@ class Prediction(dict):
|
|
|
117
156
|
data (np.ndarray)
|
|
118
157
|
If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets).
|
|
119
158
|
If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2).
|
|
159
|
+
lead_time (pd.Timedelta, optional)
|
|
160
|
+
Lead time of the forecast. Required if forecasting_mode is True. Default None.
|
|
120
161
|
"""
|
|
162
|
+
if self.forecasting_mode:
|
|
163
|
+
assert (
|
|
164
|
+
lead_times is not None
|
|
165
|
+
), "If forecasting_mode is True, lead_times must be provided."
|
|
166
|
+
|
|
167
|
+
msg = f"""
|
|
168
|
+
If forecasting_mode is True, lead_times must be of equal length to the number of
|
|
169
|
+
variables in the data (the first dimension). Got {lead_times=} of length
|
|
170
|
+
{len(lead_times)} lead times and data shape {data.shape}.
|
|
171
|
+
"""
|
|
172
|
+
assert len(lead_times) == data.shape[0], msg
|
|
173
|
+
|
|
121
174
|
if self.mode == "on-grid":
|
|
122
175
|
if prediction_parameter != "samples":
|
|
123
|
-
for var_ID, pred in zip(self.target_var_IDs, data):
|
|
124
|
-
self
|
|
176
|
+
for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
|
|
177
|
+
if self.forecasting_mode:
|
|
178
|
+
index = (lead_times[i], date)
|
|
179
|
+
else:
|
|
180
|
+
index = date
|
|
181
|
+
self[var_ID][prediction_parameter].loc[index].data[
|
|
125
182
|
self.X_t_mask.data
|
|
126
183
|
] = pred.ravel()
|
|
127
184
|
elif prediction_parameter == "samples":
|
|
@@ -130,28 +187,44 @@ class Prediction(dict):
|
|
|
130
187
|
f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}."
|
|
131
188
|
)
|
|
132
189
|
for sample_i, sample in enumerate(data):
|
|
133
|
-
for var_ID, pred in
|
|
134
|
-
self
|
|
190
|
+
for i, (var_ID, pred) in enumerate(
|
|
191
|
+
zip(self.target_var_IDs, sample)
|
|
192
|
+
):
|
|
193
|
+
if self.forecasting_mode:
|
|
194
|
+
index = (lead_times[i], date)
|
|
195
|
+
else:
|
|
196
|
+
index = date
|
|
197
|
+
self[var_ID][f"sample_{sample_i}"].loc[index].data[
|
|
135
198
|
self.X_t_mask.data
|
|
136
199
|
] = pred.ravel()
|
|
137
200
|
|
|
138
201
|
elif self.mode == "off-grid":
|
|
139
202
|
if prediction_parameter != "samples":
|
|
140
|
-
for var_ID, pred in zip(self.target_var_IDs, data):
|
|
141
|
-
self
|
|
203
|
+
for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
|
|
204
|
+
if self.forecasting_mode:
|
|
205
|
+
index = (lead_times[i], date)
|
|
206
|
+
else:
|
|
207
|
+
index = date
|
|
208
|
+
self[var_ID][prediction_parameter].loc[index] = pred
|
|
142
209
|
elif prediction_parameter == "samples":
|
|
143
210
|
assert len(data.shape) == 3, (
|
|
144
211
|
f"If prediction_parameter is 'samples', and mode is 'off-grid', data must"
|
|
145
212
|
f"have shape (N_samples, N_var, N_targets). Got {data.shape}."
|
|
146
213
|
)
|
|
147
214
|
for sample_i, sample in enumerate(data):
|
|
148
|
-
for var_ID, pred in
|
|
149
|
-
self
|
|
215
|
+
for i, (var_ID, pred) in enumerate(
|
|
216
|
+
zip(self.target_var_IDs, sample)
|
|
217
|
+
):
|
|
218
|
+
if self.forecasting_mode:
|
|
219
|
+
index = (lead_times[i], date)
|
|
220
|
+
else:
|
|
221
|
+
index = date
|
|
222
|
+
self[var_ID][f"sample_{sample_i}"].loc[index] = pred
|
|
150
223
|
|
|
151
224
|
|
|
152
225
|
def create_empty_spatiotemporal_xarray(
|
|
153
226
|
X: Union[xr.Dataset, xr.DataArray],
|
|
154
|
-
dates: List,
|
|
227
|
+
dates: List[Timestamp],
|
|
155
228
|
coord_names: dict = None,
|
|
156
229
|
data_vars: List[str] = None,
|
|
157
230
|
prepend_dims: Optional[List[str]] = None,
|
|
@@ -231,10 +304,6 @@ def create_empty_spatiotemporal_xarray(
|
|
|
231
304
|
# Convert time coord to pandas timestamps
|
|
232
305
|
pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values))
|
|
233
306
|
|
|
234
|
-
# TODO: Convert init time to forecast time?
|
|
235
|
-
# pred_ds = pred_ds.assign_coords(
|
|
236
|
-
# time=pred_ds['time'] + pd.Timedelta(days=task_loader.target_delta_t[0]))
|
|
237
|
-
|
|
238
307
|
return pred_ds
|
|
239
308
|
|
|
240
309
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: deepsensor
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.0
|
|
4
4
|
Summary: A Python package for modelling xarray and pandas data with neural processes.
|
|
5
5
|
Home-page: https://github.com/alan-turing-institute/deepsensor
|
|
6
6
|
Author: Tom R. Andersson
|
|
@@ -44,7 +44,7 @@ data with neural processes</p>
|
|
|
44
44
|
|
|
45
45
|
-----------
|
|
46
46
|
|
|
47
|
-
[](https://github.com/alan-turing-institute/deepsensor/releases)
|
|
48
48
|
[](https://alan-turing-institute.github.io/deepsensor/)
|
|
49
49
|

|
|
50
50
|
[](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
|
|
@@ -240,7 +240,7 @@ if you would like to join this list!
|
|
|
240
240
|
<table>
|
|
241
241
|
<tbody>
|
|
242
242
|
<tr>
|
|
243
|
-
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
|
|
243
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
|
|
244
244
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
|
|
245
245
|
<td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
|
|
246
246
|
<td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
|
|
@@ -22,6 +22,8 @@ deepsensor/data/processor.py
|
|
|
22
22
|
deepsensor/data/sources.py
|
|
23
23
|
deepsensor/data/task.py
|
|
24
24
|
deepsensor/data/utils.py
|
|
25
|
+
deepsensor/eval/__init__.py
|
|
26
|
+
deepsensor/eval/metrics.py
|
|
25
27
|
deepsensor/model/__init__.py
|
|
26
28
|
deepsensor/model/convnp.py
|
|
27
29
|
deepsensor/model/defaults.py
|
|
@@ -26,9 +26,6 @@ from deepsensor.data.processor import DataProcessor, xarray_to_coord_array_norma
|
|
|
26
26
|
from deepsensor.model.convnp import ConvNP
|
|
27
27
|
|
|
28
28
|
|
|
29
|
-
# from deepsensor.active_learning.acquisition_fns import
|
|
30
|
-
|
|
31
|
-
|
|
32
29
|
class TestActiveLearning(unittest.TestCase):
|
|
33
30
|
|
|
34
31
|
@classmethod
|
|
@@ -18,6 +18,7 @@ from deepsensor.data.processor import DataProcessor
|
|
|
18
18
|
from deepsensor.data.loader import TaskLoader
|
|
19
19
|
from deepsensor.model.convnp import ConvNP
|
|
20
20
|
from deepsensor.train.train import Trainer
|
|
21
|
+
from deepsensor.eval.metrics import compute_errors
|
|
21
22
|
|
|
22
23
|
from tests.utils import gen_random_data_xr, gen_random_data_pandas
|
|
23
24
|
|
|
@@ -55,8 +56,15 @@ class TestModel(unittest.TestCase):
|
|
|
55
56
|
def setUpClass(cls):
|
|
56
57
|
# super().__init__(*args, **kwargs)
|
|
57
58
|
# It's safe to share data between tests because the TaskLoader does not modify data
|
|
59
|
+
cls.var_ID = "2m_temp"
|
|
58
60
|
cls.da = _gen_data_xr()
|
|
61
|
+
cls.da.name = cls.var_ID
|
|
59
62
|
cls.df = _gen_data_pandas()
|
|
63
|
+
cls.df.name = cls.var_ID
|
|
64
|
+
# Various tests assume we have a single target set with a single variable.
|
|
65
|
+
# If a test requires multiple target sets or variables, this is set up in the test.
|
|
66
|
+
assert isinstance(cls.da, xr.DataArray)
|
|
67
|
+
assert isinstance(cls.df, pd.Series)
|
|
60
68
|
|
|
61
69
|
cls.dp = DataProcessor()
|
|
62
70
|
_ = cls.dp([cls.da, cls.df]) # Compute normalisation parameters
|
|
@@ -193,7 +201,7 @@ class TestModel(unittest.TestCase):
|
|
|
193
201
|
n_targets * dim_y_combined * n_target_dims,
|
|
194
202
|
),
|
|
195
203
|
)
|
|
196
|
-
if likelihood in ["cnp-spikes-beta"]:
|
|
204
|
+
if likelihood in ["cnp-spikes-beta", "bernoulli-gamma"]:
|
|
197
205
|
mixture_probs = model.mixture_probs(task)
|
|
198
206
|
if isinstance(mixture_probs, (list, tuple)):
|
|
199
207
|
for p, dim_y in zip(mixture_probs, tl.target_dims):
|
|
@@ -215,6 +223,7 @@ class TestModel(unittest.TestCase):
|
|
|
215
223
|
),
|
|
216
224
|
)
|
|
217
225
|
|
|
226
|
+
if likelihood in ["cnp-spikes-beta"]:
|
|
218
227
|
x = model.alpha(task)
|
|
219
228
|
if isinstance(x, (list, tuple)):
|
|
220
229
|
for p, dim_y in zip(x, tl.target_dims):
|
|
@@ -229,6 +238,21 @@ class TestModel(unittest.TestCase):
|
|
|
229
238
|
else:
|
|
230
239
|
assert_shape(x, (dim_y_combined, *expected_obs_shape))
|
|
231
240
|
|
|
241
|
+
if likelihood in ["bernoulli-gamma"]:
|
|
242
|
+
x = model.k(task)
|
|
243
|
+
if isinstance(x, (list, tuple)):
|
|
244
|
+
for p, dim_y in zip(x, tl.target_dims):
|
|
245
|
+
assert_shape(p, (dim_y, *expected_obs_shape))
|
|
246
|
+
else:
|
|
247
|
+
assert_shape(x, (dim_y_combined, *expected_obs_shape))
|
|
248
|
+
|
|
249
|
+
x = model.scale(task)
|
|
250
|
+
if isinstance(x, (list, tuple)):
|
|
251
|
+
for p, dim_y in zip(x, tl.target_dims):
|
|
252
|
+
assert_shape(p, (dim_y, *expected_obs_shape))
|
|
253
|
+
else:
|
|
254
|
+
assert_shape(x, (dim_y_combined, *expected_obs_shape))
|
|
255
|
+
|
|
232
256
|
# Scalars
|
|
233
257
|
if likelihood in ["cnp", "gnp"]:
|
|
234
258
|
# Methods for Gaussian likelihoods only
|
|
@@ -401,10 +425,10 @@ class TestModel(unittest.TestCase):
|
|
|
401
425
|
task = tl("2020-01-01")
|
|
402
426
|
pred = model.predict(task, X_t=da_raw)
|
|
403
427
|
|
|
404
|
-
|
|
428
|
+
np.testing.assert_array_equal(
|
|
405
429
|
pred["dummy_data"]["mean"]["latitude"], da_raw["latitude"]
|
|
406
430
|
)
|
|
407
|
-
|
|
431
|
+
np.testing.assert_array_equal(
|
|
408
432
|
pred["dummy_data"]["mean"]["longitude"], da_raw["longitude"]
|
|
409
433
|
)
|
|
410
434
|
|
|
@@ -451,61 +475,75 @@ class TestModel(unittest.TestCase):
|
|
|
451
475
|
def test_highlevel_predict_with_pred_params_pandas(self):
|
|
452
476
|
"""
|
|
453
477
|
Test that passing ``pred_params`` to ``.predict`` works with
|
|
454
|
-
|
|
478
|
+
mixture model likelihoods for off-grid prediction to pandas.
|
|
455
479
|
"""
|
|
456
480
|
tl = TaskLoader(context=self.da, target=self.da)
|
|
457
|
-
model = ConvNP(
|
|
458
|
-
self.dp,
|
|
459
|
-
tl,
|
|
460
|
-
unet_channels=(5, 5, 5),
|
|
461
|
-
verbose=False,
|
|
462
|
-
likelihood="cnp-spikes-beta",
|
|
463
|
-
)
|
|
464
|
-
task = tl("2020-01-01", context_sampling=10, target_sampling=10)
|
|
465
481
|
|
|
466
|
-
|
|
467
|
-
|
|
482
|
+
likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"]
|
|
483
|
+
expected_pred_params = [
|
|
484
|
+
["mean", "std", "variance", "alpha", "beta"],
|
|
485
|
+
["mean", "std", "variance", "k", "scale"],
|
|
486
|
+
]
|
|
487
|
+
|
|
488
|
+
for likelihood, pred_params in zip(likelihoods, expected_pred_params):
|
|
489
|
+
model = ConvNP(
|
|
490
|
+
self.dp,
|
|
491
|
+
tl,
|
|
492
|
+
unet_channels=(5, 5, 5),
|
|
493
|
+
verbose=False,
|
|
494
|
+
likelihood=likelihood,
|
|
495
|
+
)
|
|
496
|
+
task = tl("2020-01-01", context_sampling=10)
|
|
497
|
+
|
|
498
|
+
# Off-grid prediction
|
|
499
|
+
X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]])
|
|
468
500
|
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
assert pred_param in pred["var"]
|
|
501
|
+
# Check that nothing breaks and the correct parameters are returned
|
|
502
|
+
pred = model.predict(task, X_t=X_t, pred_params=pred_params)
|
|
503
|
+
for pred_param in pred_params:
|
|
504
|
+
assert pred_param in pred[self.var_ID]
|
|
474
505
|
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
506
|
+
# Test mixture probs special case
|
|
507
|
+
pred_params = ["mixture_probs"]
|
|
508
|
+
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
|
|
509
|
+
for component in range(model.N_mixture_components):
|
|
510
|
+
pred_param = f"mixture_probs_{component}"
|
|
511
|
+
assert pred_param in pred[self.var_ID]
|
|
481
512
|
|
|
482
513
|
def test_highlevel_predict_with_pred_params_xarray(self):
|
|
483
514
|
"""
|
|
484
515
|
Test that passing ``pred_params`` to ``.predict`` works with
|
|
485
|
-
|
|
516
|
+
mixture model likelihoods for gridded prediction to xarray.
|
|
486
517
|
"""
|
|
487
518
|
tl = TaskLoader(context=self.da, target=self.da)
|
|
488
|
-
model = ConvNP(
|
|
489
|
-
self.dp,
|
|
490
|
-
tl,
|
|
491
|
-
unet_channels=(5, 5, 5),
|
|
492
|
-
verbose=False,
|
|
493
|
-
likelihood="cnp-spikes-beta",
|
|
494
|
-
)
|
|
495
|
-
task = tl("2020-01-01", context_sampling=10, target_sampling=10)
|
|
496
519
|
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
|
|
500
|
-
|
|
501
|
-
|
|
520
|
+
likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"]
|
|
521
|
+
expected_pred_params = [
|
|
522
|
+
["mean", "std", "variance", "alpha", "beta"],
|
|
523
|
+
["mean", "std", "variance", "k", "scale"],
|
|
524
|
+
]
|
|
525
|
+
|
|
526
|
+
for likelihood, pred_params in zip(likelihoods, expected_pred_params):
|
|
527
|
+
model = ConvNP(
|
|
528
|
+
self.dp,
|
|
529
|
+
tl,
|
|
530
|
+
unet_channels=(5, 5, 5),
|
|
531
|
+
verbose=False,
|
|
532
|
+
likelihood=likelihood,
|
|
533
|
+
)
|
|
534
|
+
task = tl("2020-01-01", context_sampling=10)
|
|
535
|
+
|
|
536
|
+
# Check that nothing breaks and the correct parameters are returned
|
|
537
|
+
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
|
|
538
|
+
for pred_param in pred_params:
|
|
539
|
+
assert pred_param in pred[self.var_ID]
|
|
502
540
|
|
|
503
|
-
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
541
|
+
# Test mixture probs special case
|
|
542
|
+
pred_params = ["mixture_probs"]
|
|
543
|
+
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
|
|
544
|
+
for component in range(model.N_mixture_components):
|
|
545
|
+
pred_param = f"mixture_probs_{component}"
|
|
546
|
+
assert pred_param in pred[self.var_ID]
|
|
509
547
|
|
|
510
548
|
def test_highlevel_predict_with_invalid_pred_params(self):
|
|
511
549
|
"""Test that passing ``pred_params`` to ``.predict`` works."""
|
|
@@ -610,6 +648,66 @@ class TestModel(unittest.TestCase):
|
|
|
610
648
|
ar_sample=True,
|
|
611
649
|
)
|
|
612
650
|
|
|
651
|
+
def test_forecasting_model_predict_return_valid_times(self):
|
|
652
|
+
"""Test that the times returned by a forecasting model are valid."""
|
|
653
|
+
init_dates = ["2020-01-01", "2020-01-02"]
|
|
654
|
+
expected_init_times = np.array(init_dates).astype(np.datetime64)
|
|
655
|
+
|
|
656
|
+
lead_times_days = [1, 2, 3]
|
|
657
|
+
expected_lead_times = np.array(
|
|
658
|
+
[np.timedelta64(lt, "D") for lt in lead_times_days]
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
expected_valid_times = np.array(
|
|
662
|
+
expected_lead_times[:, None] + expected_init_times[None, :]
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
tl = TaskLoader(
|
|
666
|
+
context=self.da,
|
|
667
|
+
target=[
|
|
668
|
+
self.da,
|
|
669
|
+
]
|
|
670
|
+
* len(lead_times_days),
|
|
671
|
+
target_delta_t=lead_times_days,
|
|
672
|
+
time_freq="D",
|
|
673
|
+
)
|
|
674
|
+
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
|
|
675
|
+
tasks = tl(init_dates, context_sampling=10)
|
|
676
|
+
|
|
677
|
+
X_ts = [
|
|
678
|
+
# Gridded predictions (xarray)
|
|
679
|
+
self.da,
|
|
680
|
+
# Off-grid prediction (pandas)
|
|
681
|
+
np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]),
|
|
682
|
+
]
|
|
683
|
+
for X_t in X_ts:
|
|
684
|
+
pred = model.predict(tasks, X_t=X_t)
|
|
685
|
+
|
|
686
|
+
pred_var = pred[self.var_ID]
|
|
687
|
+
|
|
688
|
+
if isinstance(pred_var, xr.Dataset):
|
|
689
|
+
# Check we can compute errors using the valid time coord ('time')
|
|
690
|
+
errors = compute_errors(pred, self.da.to_dataset())
|
|
691
|
+
for var_ID in errors.keys():
|
|
692
|
+
assert tuple(errors[var_ID].dims) == (
|
|
693
|
+
"lead_time",
|
|
694
|
+
"init_time",
|
|
695
|
+
"x1",
|
|
696
|
+
"x2",
|
|
697
|
+
)
|
|
698
|
+
assert errors[var_ID].shape == pred[var_ID]["mean"].shape
|
|
699
|
+
elif isinstance(pred_var, pd.DataFrame):
|
|
700
|
+
# Makes coordinate checking easier by avoiding repeat values
|
|
701
|
+
pred_var = pred_var.to_xarray().isel(x1=0, x2=0)
|
|
702
|
+
|
|
703
|
+
np.testing.assert_array_equal(
|
|
704
|
+
pred_var.lead_time.values, expected_lead_times
|
|
705
|
+
)
|
|
706
|
+
np.testing.assert_array_equal(
|
|
707
|
+
pred_var.init_time.values, expected_init_times
|
|
708
|
+
)
|
|
709
|
+
np.testing.assert_array_equal(pred_var.time.values, expected_valid_times)
|
|
710
|
+
|
|
613
711
|
|
|
614
712
|
def assert_shape(x, shape: tuple):
|
|
615
713
|
"""Assert that the shape of ``x`` matches ``shape``."""
|
|
@@ -192,6 +192,23 @@ class TestTaskLoader(unittest.TestCase):
|
|
|
192
192
|
with self.assertRaises(InvalidSamplingStrategyError):
|
|
193
193
|
task = tl("2020-01-01", invalid_sampling_strategy)
|
|
194
194
|
|
|
195
|
+
def test_different_dtype_when_sampling_offgrid_data_at_specific_numpy_locs(self):
|
|
196
|
+
"""Test different dtype when sampling off-grid data at specific numpy locations."""
|
|
197
|
+
sampling_strat = np.array(
|
|
198
|
+
[np.linspace(0, 1, 10), np.linspace(0, 1, 10)], dtype=np.float16
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
tl = TaskLoader(
|
|
202
|
+
context=self.df,
|
|
203
|
+
target=self.df,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
assert sampling_strat.dtype != tl.context[0].index.get_level_values("x1").dtype
|
|
207
|
+
assert sampling_strat.dtype != tl.context[0].index.get_level_values("x2").dtype
|
|
208
|
+
|
|
209
|
+
with self.assertRaises(InvalidSamplingStrategyError):
|
|
210
|
+
task = tl("2020-01-01", sampling_strat, sampling_strat)
|
|
211
|
+
|
|
195
212
|
def test_wrong_links(self):
|
|
196
213
|
"""Test link indexes out of range."""
|
|
197
214
|
with self.assertRaises(ValueError):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|