deepsensor 0.3.8__tar.gz → 0.4.1__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {deepsensor-0.3.8 → deepsensor-0.4.1}/PKG-INFO +3 -3
- {deepsensor-0.3.8 → deepsensor-0.4.1}/README.md +2 -2
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/data/loader.py +17 -13
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/data/sources.py +3 -3
- deepsensor-0.4.1/deepsensor/eval/__init__.py +1 -0
- deepsensor-0.4.1/deepsensor/eval/metrics.py +24 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/model/model.py +76 -5
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/model/pred.py +86 -17
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/plot.py +4 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor.egg-info/PKG-INFO +3 -3
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor.egg-info/SOURCES.txt +2 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/setup.cfg +1 -1
- {deepsensor-0.3.8 → deepsensor-0.4.1}/tests/test_model.py +74 -6
- {deepsensor-0.3.8 → deepsensor-0.4.1}/tests/test_task_loader.py +17 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/__init__.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/active_learning/__init__.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/active_learning/acquisition_fns.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/active_learning/algorithms.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/config.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/data/__init__.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/data/processor.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/data/task.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/data/utils.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/errors.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/model/__init__.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/model/convnp.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/model/defaults.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/model/nps.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/py.typed +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/tensorflow/__init__.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/torch/__init__.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/train/__init__.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor/train/train.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor.egg-info/dependency_links.txt +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor.egg-info/not-zip-safe +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor.egg-info/requires.txt +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/deepsensor.egg-info/top_level.txt +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/pyproject.toml +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/setup.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/tests/__init__.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/tests/test_active_learning.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/tests/test_data_processor.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/tests/test_plotting.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/tests/test_task.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/tests/test_training.py +0 -0
- {deepsensor-0.3.8 → deepsensor-0.4.1}/tests/utils.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: deepsensor
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: A Python package for modelling xarray and pandas data with neural processes.
|
|
5
5
|
Home-page: https://github.com/alan-turing-institute/deepsensor
|
|
6
6
|
Author: Tom R. Andersson
|
|
@@ -44,7 +44,7 @@ data with neural processes</p>
|
|
|
44
44
|
|
|
45
45
|
-----------
|
|
46
46
|
|
|
47
|
-
[](https://github.com/alan-turing-institute/deepsensor/releases)
|
|
48
48
|
[](https://alan-turing-institute.github.io/deepsensor/)
|
|
49
49
|

|
|
50
50
|
[](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
|
|
@@ -240,7 +240,7 @@ if you would like to join this list!
|
|
|
240
240
|
<table>
|
|
241
241
|
<tbody>
|
|
242
242
|
<tr>
|
|
243
|
-
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
|
|
243
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
|
|
244
244
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
|
|
245
245
|
<td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
|
|
246
246
|
<td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
|
|
@@ -11,7 +11,7 @@ data with neural processes</p>
|
|
|
11
11
|
|
|
12
12
|
-----------
|
|
13
13
|
|
|
14
|
-
[](https://github.com/alan-turing-institute/deepsensor/releases)
|
|
15
15
|
[](https://alan-turing-institute.github.io/deepsensor/)
|
|
16
16
|

|
|
17
17
|
[](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
|
|
@@ -207,7 +207,7 @@ if you would like to join this list!
|
|
|
207
207
|
<table>
|
|
208
208
|
<tbody>
|
|
209
209
|
<tr>
|
|
210
|
-
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
|
|
210
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
|
|
211
211
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
|
|
212
212
|
<td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
|
|
213
213
|
<td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
|
|
@@ -678,11 +678,11 @@ class TaskLoader:
|
|
|
678
678
|
seed: Optional[int] = None,
|
|
679
679
|
) -> (np.ndarray, np.ndarray):
|
|
680
680
|
"""
|
|
681
|
-
Sample a
|
|
681
|
+
Sample a DataFrame according to a given strategy.
|
|
682
682
|
|
|
683
683
|
Args:
|
|
684
684
|
df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
|
|
685
|
-
|
|
685
|
+
Dataframe to sample, assumed to be time-sliced for the task
|
|
686
686
|
already.
|
|
687
687
|
sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
|
|
688
688
|
Sampling strategy, either "all" or an integer for random grid
|
|
@@ -720,20 +720,24 @@ class TaskLoader:
|
|
|
720
720
|
X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
|
|
721
721
|
Y_c = df.values.T
|
|
722
722
|
elif isinstance(sampling_strat, np.ndarray):
|
|
723
|
+
if df.index.get_level_values("x1").dtype != sampling_strat.dtype:
|
|
724
|
+
raise InvalidSamplingStrategyError(
|
|
725
|
+
"Passed a numpy coordinate array to sample pandas DataFrame, "
|
|
726
|
+
"but the coordinate array has a different dtype than the DataFrame. "
|
|
727
|
+
f"Got {sampling_strat.dtype} but expected {df.index.get_level_values('x1').dtype}."
|
|
728
|
+
)
|
|
723
729
|
X_c = sampling_strat.astype(self.dtype)
|
|
724
|
-
|
|
725
|
-
|
|
726
|
-
|
|
727
|
-
|
|
728
|
-
# Check that we got all the samples we asked for
|
|
729
|
-
if num_matches != X_c.shape[1]:
|
|
730
|
+
try:
|
|
731
|
+
Y_c = df.loc[pd.IndexSlice[:, X_c[0], X_c[1]]].values.T
|
|
732
|
+
except KeyError:
|
|
730
733
|
raise InvalidSamplingStrategyError(
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
f"
|
|
734
|
+
"Passed a numpy coordinate array to sample pandas DataFrame, "
|
|
735
|
+
"but the DataFrame did not contain all the requested samples.\n"
|
|
736
|
+
f"Indexes: {df.index}\n"
|
|
737
|
+
f"Sampling coords: {X_c}\n"
|
|
738
|
+
"If this is unexpected, check that your numpy sampling array matches "
|
|
739
|
+
"the DataFrame index values *exactly*."
|
|
734
740
|
)
|
|
735
|
-
|
|
736
|
-
Y_c = df[x1match & x2match].values.T
|
|
737
741
|
else:
|
|
738
742
|
raise InvalidSamplingStrategyError(
|
|
739
743
|
f"Unknown sampling strategy {sampling_strat}"
|
|
@@ -277,7 +277,7 @@ def get_era5_reanalysis_data(
|
|
|
277
277
|
if num_processes == 1:
|
|
278
278
|
# Just download in one go
|
|
279
279
|
if verbose:
|
|
280
|
-
print("Downloading ERA5 data
|
|
280
|
+
print("Downloading ERA5 data without parallelisation... ")
|
|
281
281
|
era5_da = _get_era5_reanalysis_data_parallel(
|
|
282
282
|
date_range=date_range,
|
|
283
283
|
var_IDs=var_IDs,
|
|
@@ -432,7 +432,7 @@ def get_gldas_land_mask(
|
|
|
432
432
|
with urllib.request.urlopen(req) as response:
|
|
433
433
|
with open(fname, "wb") as f:
|
|
434
434
|
f.write(response.read())
|
|
435
|
-
da = xr.open_dataset(fname)["GLDAS_mask"].isel(time=0).
|
|
435
|
+
da = xr.open_dataset(fname)["GLDAS_mask"].isel(time=0).drop_vars("time").load()
|
|
436
436
|
|
|
437
437
|
if isinstance(extent, str):
|
|
438
438
|
extent = extent_str_to_tuple(extent)
|
|
@@ -577,7 +577,7 @@ def get_earthenv_auxiliary_data(
|
|
|
577
577
|
# Read data
|
|
578
578
|
da = xr.open_dataset(fname).to_array().squeeze().load()
|
|
579
579
|
da = da.rename({"y": "lat", "x": "lon"})
|
|
580
|
-
da = da.
|
|
580
|
+
da = da.drop_vars(["band", "spatial_ref", "variable"])
|
|
581
581
|
da.name = var_ID
|
|
582
582
|
da = da.sel(lat=slice(lat_max, lat_min), lon=slice(lon_min, lon_max))
|
|
583
583
|
da_dict[var_ID] = da
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .metrics import *
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
import xarray as xr
|
|
2
|
+
from deepsensor.model.pred import Prediction
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset:
|
|
6
|
+
"""
|
|
7
|
+
Compute errors between predictions and targets.
|
|
8
|
+
|
|
9
|
+
Args:
|
|
10
|
+
pred: Prediction object.
|
|
11
|
+
target: Target data.
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
xr.Dataset: Dataset of pointwise differences between predictions and targets
|
|
15
|
+
at the same valid time in the predictions. Note, the difference is positive
|
|
16
|
+
when the prediction is greater than the target.
|
|
17
|
+
"""
|
|
18
|
+
errors = {}
|
|
19
|
+
for var_ID, pred_var in pred.items():
|
|
20
|
+
target_var = target[var_ID]
|
|
21
|
+
error = pred_var["mean"] - target_var.sel(time=pred_var.time)
|
|
22
|
+
error.name = f"{var_ID}"
|
|
23
|
+
errors[var_ID] = error
|
|
24
|
+
return xr.Dataset(errors)
|
|
@@ -348,6 +348,34 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
348
348
|
if ar_sample and n_samples < 1:
|
|
349
349
|
raise ValueError("Must pass `n_samples` > 0 to use `ar_sample`.")
|
|
350
350
|
|
|
351
|
+
target_delta_t = self.task_loader.target_delta_t
|
|
352
|
+
dts = [pd.Timedelta(dt) for dt in target_delta_t]
|
|
353
|
+
dts_all_zero = all([dt == pd.Timedelta(seconds=0) for dt in dts])
|
|
354
|
+
if target_delta_t is not None and dts_all_zero:
|
|
355
|
+
forecasting_mode = False
|
|
356
|
+
lead_times = None
|
|
357
|
+
elif target_delta_t is not None and not dts_all_zero:
|
|
358
|
+
target_var_IDs_set = set(self.task_loader.target_var_IDs)
|
|
359
|
+
msg = f"""
|
|
360
|
+
Got more than one set of target variables in target sets,
|
|
361
|
+
but predictions can only be made with one set of target variables
|
|
362
|
+
to simplify implementation.
|
|
363
|
+
Got {target_var_IDs_set}.
|
|
364
|
+
"""
|
|
365
|
+
assert len(target_var_IDs_set) == 1, msg
|
|
366
|
+
# Repeat lead_tim for each variable in each target set
|
|
367
|
+
lead_times = []
|
|
368
|
+
for target_set_idx, dt in enumerate(target_delta_t):
|
|
369
|
+
target_set_dim = self.task_loader.target_dims[target_set_idx]
|
|
370
|
+
lead_times += [
|
|
371
|
+
pd.Timedelta(dt, unit=self.task_loader.time_freq)
|
|
372
|
+
for _ in range(target_set_dim)
|
|
373
|
+
]
|
|
374
|
+
forecasting_mode = True
|
|
375
|
+
else:
|
|
376
|
+
forecasting_mode = False
|
|
377
|
+
lead_times = None
|
|
378
|
+
|
|
351
379
|
if type(tasks) is Task:
|
|
352
380
|
tasks = [tasks]
|
|
353
381
|
|
|
@@ -355,12 +383,14 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
355
383
|
B.set_random_seed(seed)
|
|
356
384
|
np.random.seed(seed)
|
|
357
385
|
|
|
358
|
-
|
|
386
|
+
init_dates = [task["time"] for task in tasks]
|
|
359
387
|
|
|
360
388
|
# Flatten tuple of tuples to single list
|
|
361
389
|
target_var_IDs = [
|
|
362
390
|
var_ID for set in self.task_loader.target_var_IDs for var_ID in set
|
|
363
391
|
]
|
|
392
|
+
if lead_times is not None:
|
|
393
|
+
assert len(lead_times) == len(target_var_IDs)
|
|
364
394
|
|
|
365
395
|
# TODO consider removing this logic, can we just depend on the dim names in X_t?
|
|
366
396
|
if not unnormalise:
|
|
@@ -385,7 +415,7 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
385
415
|
elif isinstance(X_t, (xr.DataArray, xr.Dataset)):
|
|
386
416
|
# Remove time dimension if present
|
|
387
417
|
if "time" in X_t.coords:
|
|
388
|
-
X_t = X_t.isel(time=0).
|
|
418
|
+
X_t = X_t.isel(time=0).drop_vars("time")
|
|
389
419
|
|
|
390
420
|
if mode == "off-grid" and append_indexes is not None:
|
|
391
421
|
# Check append_indexes are all same length as X_t
|
|
@@ -450,11 +480,13 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
450
480
|
pred = Prediction(
|
|
451
481
|
target_var_IDs,
|
|
452
482
|
pred_params_to_store,
|
|
453
|
-
|
|
483
|
+
init_dates,
|
|
454
484
|
X_t,
|
|
455
485
|
X_t_mask,
|
|
456
486
|
coord_names,
|
|
457
487
|
n_samples=n_samples,
|
|
488
|
+
forecasting_mode=forecasting_mode,
|
|
489
|
+
lead_times=lead_times,
|
|
458
490
|
)
|
|
459
491
|
|
|
460
492
|
def unnormalise_pred_array(arr, **kwargs):
|
|
@@ -605,14 +637,22 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
605
637
|
# Assign predictions to Prediction object
|
|
606
638
|
for param, arr in prediction_arrs.items():
|
|
607
639
|
if param != "mixture_probs":
|
|
608
|
-
pred.assign(param, task["time"], arr)
|
|
640
|
+
pred.assign(param, task["time"], arr, lead_times=lead_times)
|
|
609
641
|
elif param == "mixture_probs":
|
|
610
642
|
assert arr.shape[0] == self.N_mixture_components, (
|
|
611
643
|
f"Number of mixture components ({arr.shape[0]}) does not match "
|
|
612
644
|
f"model attribute N_mixture_components ({self.N_mixture_components})."
|
|
613
645
|
)
|
|
614
646
|
for component_i, probs in enumerate(arr):
|
|
615
|
-
pred.assign(
|
|
647
|
+
pred.assign(
|
|
648
|
+
f"{param}_{component_i}",
|
|
649
|
+
task["time"],
|
|
650
|
+
probs,
|
|
651
|
+
lead_times=lead_times,
|
|
652
|
+
)
|
|
653
|
+
|
|
654
|
+
if forecasting_mode:
|
|
655
|
+
pred = add_valid_time_coord_to_pred_and_move_time_dims(pred)
|
|
616
656
|
|
|
617
657
|
if verbose:
|
|
618
658
|
dur = time.time() - tic
|
|
@@ -621,6 +661,37 @@ class DeepSensorModel(ProbabilisticModel):
|
|
|
621
661
|
return pred
|
|
622
662
|
|
|
623
663
|
|
|
664
|
+
def add_valid_time_coord_to_pred_and_move_time_dims(pred: Prediction) -> Prediction:
|
|
665
|
+
"""
|
|
666
|
+
Add a valid time coordinate "time" to a Prediction object based on the
|
|
667
|
+
initialisation times "init_time" and lead times "lead_time", and
|
|
668
|
+
reorder the time dims from ("lead_time", "init_time") to ("init_time", "lead_time").
|
|
669
|
+
|
|
670
|
+
Args:
|
|
671
|
+
pred (:class:`~.model.pred.Prediction`):
|
|
672
|
+
Prediction object to add valid time coordinate to.
|
|
673
|
+
|
|
674
|
+
Returns:
|
|
675
|
+
:class:`~.model.pred.Prediction`:
|
|
676
|
+
Prediction object with valid time coordinate added.
|
|
677
|
+
"""
|
|
678
|
+
for var_ID in pred.keys():
|
|
679
|
+
if isinstance(pred[var_ID], pd.DataFrame):
|
|
680
|
+
x = pred[var_ID].reset_index()
|
|
681
|
+
pred[var_ID]["time"] = (x["lead_time"] + x["init_time"]).values
|
|
682
|
+
pred[var_ID] = pred[var_ID].swaplevel("init_time", "lead_time")
|
|
683
|
+
pred[var_ID] = pred[var_ID].sort_index()
|
|
684
|
+
elif isinstance(pred[var_ID], xr.Dataset):
|
|
685
|
+
x = pred[var_ID]
|
|
686
|
+
pred[var_ID] = pred[var_ID].assign_coords(
|
|
687
|
+
time=x["lead_time"] + x["init_time"]
|
|
688
|
+
)
|
|
689
|
+
pred[var_ID] = pred[var_ID].transpose("init_time", "lead_time", ...)
|
|
690
|
+
else:
|
|
691
|
+
raise ValueError(f"Unsupported prediction type {type(pred[var_ID])}.")
|
|
692
|
+
return pred
|
|
693
|
+
|
|
694
|
+
|
|
624
695
|
def main(): # pragma: no cover
|
|
625
696
|
import deepsensor.tensorflow
|
|
626
697
|
from deepsensor.data.loader import TaskLoader
|
|
@@ -4,6 +4,8 @@ import numpy as np
|
|
|
4
4
|
import pandas as pd
|
|
5
5
|
import xarray as xr
|
|
6
6
|
|
|
7
|
+
Timestamp = Union[str, pd.Timestamp, np.datetime64]
|
|
8
|
+
|
|
7
9
|
|
|
8
10
|
class Prediction(dict):
|
|
9
11
|
"""
|
|
@@ -32,13 +34,20 @@ class Prediction(dict):
|
|
|
32
34
|
n_samples (int)
|
|
33
35
|
Number of joint samples to draw from the model. If 0, will not
|
|
34
36
|
draw samples. Default 0.
|
|
37
|
+
forecasting_mode (bool)
|
|
38
|
+
If True, stored forecast predictions with an init_time and lead_time dimension,
|
|
39
|
+
and a valid_time coordinate. If False, stores prediction at t=0 only
|
|
40
|
+
(i.e. spatial interpolation), with only a single time dimension. Default False.
|
|
41
|
+
lead_times (List[pd.Timedelta], optional)
|
|
42
|
+
List of lead times to store in predictions. Must be provided if
|
|
43
|
+
forecasting_mode is True. Default None.
|
|
35
44
|
"""
|
|
36
45
|
|
|
37
46
|
def __init__(
|
|
38
47
|
self,
|
|
39
48
|
target_var_IDs: List[str],
|
|
40
49
|
pred_params: List[str],
|
|
41
|
-
dates: List[
|
|
50
|
+
dates: List[Timestamp],
|
|
42
51
|
X_t: Union[
|
|
43
52
|
xr.Dataset,
|
|
44
53
|
xr.DataArray,
|
|
@@ -50,6 +59,8 @@ class Prediction(dict):
|
|
|
50
59
|
X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
|
|
51
60
|
coord_names: dict = None,
|
|
52
61
|
n_samples: int = 0,
|
|
62
|
+
forecasting_mode: bool = False,
|
|
63
|
+
lead_times: Optional[List[pd.Timedelta]] = None,
|
|
53
64
|
):
|
|
54
65
|
self.target_var_IDs = target_var_IDs
|
|
55
66
|
self.X_t_mask = X_t_mask
|
|
@@ -58,6 +69,13 @@ class Prediction(dict):
|
|
|
58
69
|
self.x1_name = coord_names["x1"]
|
|
59
70
|
self.x2_name = coord_names["x2"]
|
|
60
71
|
|
|
72
|
+
self.forecasting_mode = forecasting_mode
|
|
73
|
+
if forecasting_mode:
|
|
74
|
+
assert (
|
|
75
|
+
lead_times is not None
|
|
76
|
+
), "If forecasting_mode is True, lead_times must be provided."
|
|
77
|
+
self.lead_times = lead_times
|
|
78
|
+
|
|
61
79
|
self.mode = infer_prediction_modality_from_X_t(X_t)
|
|
62
80
|
|
|
63
81
|
self.pred_params = pred_params
|
|
@@ -67,15 +85,25 @@ class Prediction(dict):
|
|
|
67
85
|
*[f"sample_{i}" for i in range(n_samples)],
|
|
68
86
|
]
|
|
69
87
|
|
|
88
|
+
# Create empty xarray/pandas objects to store predictions
|
|
70
89
|
if self.mode == "on-grid":
|
|
71
90
|
for var_ID in self.target_var_IDs:
|
|
72
|
-
|
|
91
|
+
if self.forecasting_mode:
|
|
92
|
+
prepend_dims = ["lead_time"]
|
|
93
|
+
prepend_coords = {"lead_time": lead_times}
|
|
94
|
+
else:
|
|
95
|
+
prepend_dims = None
|
|
96
|
+
prepend_coords = None
|
|
73
97
|
self[var_ID] = create_empty_spatiotemporal_xarray(
|
|
74
98
|
X_t,
|
|
75
99
|
dates,
|
|
76
100
|
data_vars=self.pred_params,
|
|
77
101
|
coord_names=coord_names,
|
|
102
|
+
prepend_dims=prepend_dims,
|
|
103
|
+
prepend_coords=prepend_coords,
|
|
78
104
|
)
|
|
105
|
+
if self.forecasting_mode:
|
|
106
|
+
self[var_ID] = self[var_ID].rename(time="init_time")
|
|
79
107
|
if self.X_t_mask is None:
|
|
80
108
|
# Create 2D boolean array of True values to simplify indexing
|
|
81
109
|
self.X_t_mask = (
|
|
@@ -86,8 +114,18 @@ class Prediction(dict):
|
|
|
86
114
|
)
|
|
87
115
|
elif self.mode == "off-grid":
|
|
88
116
|
# Repeat target locs for each date to create multiindex
|
|
89
|
-
|
|
90
|
-
|
|
117
|
+
if self.forecasting_mode:
|
|
118
|
+
index_names = ["lead_time", "init_time", *X_t.index.names]
|
|
119
|
+
idxs = [
|
|
120
|
+
(lt, date, *idxs)
|
|
121
|
+
for lt in lead_times
|
|
122
|
+
for date in dates
|
|
123
|
+
for idxs in X_t.index
|
|
124
|
+
]
|
|
125
|
+
else:
|
|
126
|
+
index_names = ["time", *X_t.index.names]
|
|
127
|
+
idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
|
|
128
|
+
index = pd.MultiIndex.from_tuples(idxs, names=index_names)
|
|
91
129
|
for var_ID in self.target_var_IDs:
|
|
92
130
|
self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params)
|
|
93
131
|
|
|
@@ -106,6 +144,7 @@ class Prediction(dict):
|
|
|
106
144
|
prediction_parameter: str,
|
|
107
145
|
date: Union[str, pd.Timestamp],
|
|
108
146
|
data: np.ndarray,
|
|
147
|
+
lead_times: Optional[List[pd.Timedelta]] = None,
|
|
109
148
|
):
|
|
110
149
|
"""
|
|
111
150
|
|
|
@@ -117,11 +156,29 @@ class Prediction(dict):
|
|
|
117
156
|
data (np.ndarray)
|
|
118
157
|
If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets).
|
|
119
158
|
If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2).
|
|
159
|
+
lead_time (pd.Timedelta, optional)
|
|
160
|
+
Lead time of the forecast. Required if forecasting_mode is True. Default None.
|
|
120
161
|
"""
|
|
162
|
+
if self.forecasting_mode:
|
|
163
|
+
assert (
|
|
164
|
+
lead_times is not None
|
|
165
|
+
), "If forecasting_mode is True, lead_times must be provided."
|
|
166
|
+
|
|
167
|
+
msg = f"""
|
|
168
|
+
If forecasting_mode is True, lead_times must be of equal length to the number of
|
|
169
|
+
variables in the data (the first dimension). Got {lead_times=} of length
|
|
170
|
+
{len(lead_times)} lead times and data shape {data.shape}.
|
|
171
|
+
"""
|
|
172
|
+
assert len(lead_times) == data.shape[0], msg
|
|
173
|
+
|
|
121
174
|
if self.mode == "on-grid":
|
|
122
175
|
if prediction_parameter != "samples":
|
|
123
|
-
for var_ID, pred in zip(self.target_var_IDs, data):
|
|
124
|
-
self
|
|
176
|
+
for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
|
|
177
|
+
if self.forecasting_mode:
|
|
178
|
+
index = (lead_times[i], date)
|
|
179
|
+
else:
|
|
180
|
+
index = date
|
|
181
|
+
self[var_ID][prediction_parameter].loc[index].data[
|
|
125
182
|
self.X_t_mask.data
|
|
126
183
|
] = pred.ravel()
|
|
127
184
|
elif prediction_parameter == "samples":
|
|
@@ -130,28 +187,44 @@ class Prediction(dict):
|
|
|
130
187
|
f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}."
|
|
131
188
|
)
|
|
132
189
|
for sample_i, sample in enumerate(data):
|
|
133
|
-
for var_ID, pred in
|
|
134
|
-
self
|
|
190
|
+
for i, (var_ID, pred) in enumerate(
|
|
191
|
+
zip(self.target_var_IDs, sample)
|
|
192
|
+
):
|
|
193
|
+
if self.forecasting_mode:
|
|
194
|
+
index = (lead_times[i], date)
|
|
195
|
+
else:
|
|
196
|
+
index = date
|
|
197
|
+
self[var_ID][f"sample_{sample_i}"].loc[index].data[
|
|
135
198
|
self.X_t_mask.data
|
|
136
199
|
] = pred.ravel()
|
|
137
200
|
|
|
138
201
|
elif self.mode == "off-grid":
|
|
139
202
|
if prediction_parameter != "samples":
|
|
140
|
-
for var_ID, pred in zip(self.target_var_IDs, data):
|
|
141
|
-
self
|
|
203
|
+
for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
|
|
204
|
+
if self.forecasting_mode:
|
|
205
|
+
index = (lead_times[i], date)
|
|
206
|
+
else:
|
|
207
|
+
index = date
|
|
208
|
+
self[var_ID][prediction_parameter].loc[index] = pred
|
|
142
209
|
elif prediction_parameter == "samples":
|
|
143
210
|
assert len(data.shape) == 3, (
|
|
144
211
|
f"If prediction_parameter is 'samples', and mode is 'off-grid', data must"
|
|
145
212
|
f"have shape (N_samples, N_var, N_targets). Got {data.shape}."
|
|
146
213
|
)
|
|
147
214
|
for sample_i, sample in enumerate(data):
|
|
148
|
-
for var_ID, pred in
|
|
149
|
-
self
|
|
215
|
+
for i, (var_ID, pred) in enumerate(
|
|
216
|
+
zip(self.target_var_IDs, sample)
|
|
217
|
+
):
|
|
218
|
+
if self.forecasting_mode:
|
|
219
|
+
index = (lead_times[i], date)
|
|
220
|
+
else:
|
|
221
|
+
index = date
|
|
222
|
+
self[var_ID][f"sample_{sample_i}"].loc[index] = pred
|
|
150
223
|
|
|
151
224
|
|
|
152
225
|
def create_empty_spatiotemporal_xarray(
|
|
153
226
|
X: Union[xr.Dataset, xr.DataArray],
|
|
154
|
-
dates: List,
|
|
227
|
+
dates: List[Timestamp],
|
|
155
228
|
coord_names: dict = None,
|
|
156
229
|
data_vars: List[str] = None,
|
|
157
230
|
prepend_dims: Optional[List[str]] = None,
|
|
@@ -231,10 +304,6 @@ def create_empty_spatiotemporal_xarray(
|
|
|
231
304
|
# Convert time coord to pandas timestamps
|
|
232
305
|
pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values))
|
|
233
306
|
|
|
234
|
-
# TODO: Convert init time to forecast time?
|
|
235
|
-
# pred_ds = pred_ds.assign_coords(
|
|
236
|
-
# time=pred_ds['time'] + pd.Timedelta(days=task_loader.target_delta_t[0]))
|
|
237
|
-
|
|
238
307
|
return pred_ds
|
|
239
308
|
|
|
240
309
|
|
|
@@ -954,6 +954,8 @@ def prediction(
|
|
|
954
954
|
ax = axes[row_i, col_i]
|
|
955
955
|
|
|
956
956
|
if pred.mode == "on-grid":
|
|
957
|
+
if "init_time" in pred[0].indexes:
|
|
958
|
+
raise ValueError("Plotting forecasts not currently supported.")
|
|
957
959
|
if param == "std":
|
|
958
960
|
vmin = 0
|
|
959
961
|
else:
|
|
@@ -1000,6 +1002,8 @@ def prediction(
|
|
|
1000
1002
|
# )
|
|
1001
1003
|
|
|
1002
1004
|
elif pred.mode == "off-grid":
|
|
1005
|
+
if "init_time" in pred[0].index.names:
|
|
1006
|
+
raise ValueError("Plotting forecasts not currently supported.")
|
|
1003
1007
|
import seaborn as sns
|
|
1004
1008
|
|
|
1005
1009
|
hue = (
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: deepsensor
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 0.4.1
|
|
4
4
|
Summary: A Python package for modelling xarray and pandas data with neural processes.
|
|
5
5
|
Home-page: https://github.com/alan-turing-institute/deepsensor
|
|
6
6
|
Author: Tom R. Andersson
|
|
@@ -44,7 +44,7 @@ data with neural processes</p>
|
|
|
44
44
|
|
|
45
45
|
-----------
|
|
46
46
|
|
|
47
|
-
[](https://github.com/alan-turing-institute/deepsensor/releases)
|
|
48
48
|
[](https://alan-turing-institute.github.io/deepsensor/)
|
|
49
49
|

|
|
50
50
|
[](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
|
|
@@ -240,7 +240,7 @@ if you would like to join this list!
|
|
|
240
240
|
<table>
|
|
241
241
|
<tbody>
|
|
242
242
|
<tr>
|
|
243
|
-
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
|
|
243
|
+
<td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
|
|
244
244
|
<td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
|
|
245
245
|
<td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
|
|
246
246
|
<td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
|
|
@@ -22,6 +22,8 @@ deepsensor/data/processor.py
|
|
|
22
22
|
deepsensor/data/sources.py
|
|
23
23
|
deepsensor/data/task.py
|
|
24
24
|
deepsensor/data/utils.py
|
|
25
|
+
deepsensor/eval/__init__.py
|
|
26
|
+
deepsensor/eval/metrics.py
|
|
25
27
|
deepsensor/model/__init__.py
|
|
26
28
|
deepsensor/model/convnp.py
|
|
27
29
|
deepsensor/model/defaults.py
|
|
@@ -18,6 +18,7 @@ from deepsensor.data.processor import DataProcessor
|
|
|
18
18
|
from deepsensor.data.loader import TaskLoader
|
|
19
19
|
from deepsensor.model.convnp import ConvNP
|
|
20
20
|
from deepsensor.train.train import Trainer
|
|
21
|
+
from deepsensor.eval.metrics import compute_errors
|
|
21
22
|
|
|
22
23
|
from tests.utils import gen_random_data_xr, gen_random_data_pandas
|
|
23
24
|
|
|
@@ -55,8 +56,15 @@ class TestModel(unittest.TestCase):
|
|
|
55
56
|
def setUpClass(cls):
|
|
56
57
|
# super().__init__(*args, **kwargs)
|
|
57
58
|
# It's safe to share data between tests because the TaskLoader does not modify data
|
|
59
|
+
cls.var_ID = "2m_temp"
|
|
58
60
|
cls.da = _gen_data_xr()
|
|
61
|
+
cls.da.name = cls.var_ID
|
|
59
62
|
cls.df = _gen_data_pandas()
|
|
63
|
+
cls.df.name = cls.var_ID
|
|
64
|
+
# Various tests assume we have a single target set with a single variable.
|
|
65
|
+
# If a test requires multiple target sets or variables, this is set up in the test.
|
|
66
|
+
assert isinstance(cls.da, xr.DataArray)
|
|
67
|
+
assert isinstance(cls.df, pd.Series)
|
|
60
68
|
|
|
61
69
|
cls.dp = DataProcessor()
|
|
62
70
|
_ = cls.dp([cls.da, cls.df]) # Compute normalisation parameters
|
|
@@ -417,10 +425,10 @@ class TestModel(unittest.TestCase):
|
|
|
417
425
|
task = tl("2020-01-01")
|
|
418
426
|
pred = model.predict(task, X_t=da_raw)
|
|
419
427
|
|
|
420
|
-
|
|
428
|
+
np.testing.assert_array_equal(
|
|
421
429
|
pred["dummy_data"]["mean"]["latitude"], da_raw["latitude"]
|
|
422
430
|
)
|
|
423
|
-
|
|
431
|
+
np.testing.assert_array_equal(
|
|
424
432
|
pred["dummy_data"]["mean"]["longitude"], da_raw["longitude"]
|
|
425
433
|
)
|
|
426
434
|
|
|
@@ -493,14 +501,14 @@ class TestModel(unittest.TestCase):
|
|
|
493
501
|
# Check that nothing breaks and the correct parameters are returned
|
|
494
502
|
pred = model.predict(task, X_t=X_t, pred_params=pred_params)
|
|
495
503
|
for pred_param in pred_params:
|
|
496
|
-
assert pred_param in pred[
|
|
504
|
+
assert pred_param in pred[self.var_ID]
|
|
497
505
|
|
|
498
506
|
# Test mixture probs special case
|
|
499
507
|
pred_params = ["mixture_probs"]
|
|
500
508
|
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
|
|
501
509
|
for component in range(model.N_mixture_components):
|
|
502
510
|
pred_param = f"mixture_probs_{component}"
|
|
503
|
-
assert pred_param in pred[
|
|
511
|
+
assert pred_param in pred[self.var_ID]
|
|
504
512
|
|
|
505
513
|
def test_highlevel_predict_with_pred_params_xarray(self):
|
|
506
514
|
"""
|
|
@@ -528,14 +536,14 @@ class TestModel(unittest.TestCase):
|
|
|
528
536
|
# Check that nothing breaks and the correct parameters are returned
|
|
529
537
|
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
|
|
530
538
|
for pred_param in pred_params:
|
|
531
|
-
assert pred_param in pred[
|
|
539
|
+
assert pred_param in pred[self.var_ID]
|
|
532
540
|
|
|
533
541
|
# Test mixture probs special case
|
|
534
542
|
pred_params = ["mixture_probs"]
|
|
535
543
|
pred = model.predict(task, X_t=self.da, pred_params=pred_params)
|
|
536
544
|
for component in range(model.N_mixture_components):
|
|
537
545
|
pred_param = f"mixture_probs_{component}"
|
|
538
|
-
assert pred_param in pred[
|
|
546
|
+
assert pred_param in pred[self.var_ID]
|
|
539
547
|
|
|
540
548
|
def test_highlevel_predict_with_invalid_pred_params(self):
|
|
541
549
|
"""Test that passing ``pred_params`` to ``.predict`` works."""
|
|
@@ -640,6 +648,66 @@ class TestModel(unittest.TestCase):
|
|
|
640
648
|
ar_sample=True,
|
|
641
649
|
)
|
|
642
650
|
|
|
651
|
+
def test_forecasting_model_predict_return_valid_times(self):
|
|
652
|
+
"""Test that the times returned by a forecasting model are valid."""
|
|
653
|
+
init_dates = ["2020-01-01", "2020-01-02"]
|
|
654
|
+
expected_init_times = np.array(init_dates).astype(np.datetime64)
|
|
655
|
+
|
|
656
|
+
lead_times_days = [1, 2, 3]
|
|
657
|
+
expected_lead_times = np.array(
|
|
658
|
+
[np.timedelta64(lt, "D") for lt in lead_times_days]
|
|
659
|
+
)
|
|
660
|
+
|
|
661
|
+
expected_valid_times = np.array(
|
|
662
|
+
expected_init_times[:, None] + expected_lead_times[None, :]
|
|
663
|
+
)
|
|
664
|
+
|
|
665
|
+
tl = TaskLoader(
|
|
666
|
+
context=self.da,
|
|
667
|
+
target=[
|
|
668
|
+
self.da,
|
|
669
|
+
]
|
|
670
|
+
* len(lead_times_days),
|
|
671
|
+
target_delta_t=lead_times_days,
|
|
672
|
+
time_freq="D",
|
|
673
|
+
)
|
|
674
|
+
model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
|
|
675
|
+
tasks = tl(init_dates, context_sampling=10)
|
|
676
|
+
|
|
677
|
+
X_ts = [
|
|
678
|
+
# Gridded predictions (xarray)
|
|
679
|
+
self.da,
|
|
680
|
+
# Off-grid prediction (pandas)
|
|
681
|
+
np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]),
|
|
682
|
+
]
|
|
683
|
+
for X_t in X_ts:
|
|
684
|
+
pred = model.predict(tasks, X_t=X_t)
|
|
685
|
+
|
|
686
|
+
pred_var = pred[self.var_ID]
|
|
687
|
+
|
|
688
|
+
if isinstance(pred_var, xr.Dataset):
|
|
689
|
+
# Check we can compute errors using the valid time coord ('time')
|
|
690
|
+
errors = compute_errors(pred, self.da.to_dataset())
|
|
691
|
+
for var_ID in errors.keys():
|
|
692
|
+
assert tuple(errors[var_ID].dims) == (
|
|
693
|
+
"init_time",
|
|
694
|
+
"lead_time",
|
|
695
|
+
"x1",
|
|
696
|
+
"x2",
|
|
697
|
+
)
|
|
698
|
+
assert errors[var_ID].shape == pred[var_ID]["mean"].shape
|
|
699
|
+
elif isinstance(pred_var, pd.DataFrame):
|
|
700
|
+
# Makes coordinate checking easier by avoiding repeat values
|
|
701
|
+
pred_var = pred_var.to_xarray().isel(x1=0, x2=0)
|
|
702
|
+
|
|
703
|
+
np.testing.assert_array_equal(
|
|
704
|
+
pred_var.lead_time.values, expected_lead_times
|
|
705
|
+
)
|
|
706
|
+
np.testing.assert_array_equal(
|
|
707
|
+
pred_var.init_time.values, expected_init_times
|
|
708
|
+
)
|
|
709
|
+
np.testing.assert_array_equal(pred_var.time.values, expected_valid_times)
|
|
710
|
+
|
|
643
711
|
|
|
644
712
|
def assert_shape(x, shape: tuple):
|
|
645
713
|
"""Assert that the shape of ``x`` matches ``shape``."""
|
|
@@ -192,6 +192,23 @@ class TestTaskLoader(unittest.TestCase):
|
|
|
192
192
|
with self.assertRaises(InvalidSamplingStrategyError):
|
|
193
193
|
task = tl("2020-01-01", invalid_sampling_strategy)
|
|
194
194
|
|
|
195
|
+
def test_different_dtype_when_sampling_offgrid_data_at_specific_numpy_locs(self):
|
|
196
|
+
"""Test different dtype when sampling off-grid data at specific numpy locations."""
|
|
197
|
+
sampling_strat = np.array(
|
|
198
|
+
[np.linspace(0, 1, 10), np.linspace(0, 1, 10)], dtype=np.float16
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
tl = TaskLoader(
|
|
202
|
+
context=self.df,
|
|
203
|
+
target=self.df,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
assert sampling_strat.dtype != tl.context[0].index.get_level_values("x1").dtype
|
|
207
|
+
assert sampling_strat.dtype != tl.context[0].index.get_level_values("x2").dtype
|
|
208
|
+
|
|
209
|
+
with self.assertRaises(InvalidSamplingStrategyError):
|
|
210
|
+
task = tl("2020-01-01", sampling_strat, sampling_strat)
|
|
211
|
+
|
|
195
212
|
def test_wrong_links(self):
|
|
196
213
|
"""Test link indexes out of range."""
|
|
197
214
|
with self.assertRaises(ValueError):
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|