deepsensor 0.3.8__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {deepsensor-0.3.8 → deepsensor-0.4.0}/PKG-INFO +3 -3
  2. {deepsensor-0.3.8 → deepsensor-0.4.0}/README.md +2 -2
  3. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/data/loader.py +17 -13
  4. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/data/sources.py +3 -3
  5. deepsensor-0.4.0/deepsensor/eval/__init__.py +1 -0
  6. deepsensor-0.4.0/deepsensor/eval/metrics.py +24 -0
  7. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/model/model.py +74 -5
  8. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/model/pred.py +86 -17
  9. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor.egg-info/PKG-INFO +3 -3
  10. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor.egg-info/SOURCES.txt +2 -0
  11. {deepsensor-0.3.8 → deepsensor-0.4.0}/setup.cfg +1 -1
  12. {deepsensor-0.3.8 → deepsensor-0.4.0}/tests/test_model.py +74 -6
  13. {deepsensor-0.3.8 → deepsensor-0.4.0}/tests/test_task_loader.py +17 -0
  14. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/__init__.py +0 -0
  15. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/active_learning/__init__.py +0 -0
  16. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/active_learning/acquisition_fns.py +0 -0
  17. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/active_learning/algorithms.py +0 -0
  18. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/config.py +0 -0
  19. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/data/__init__.py +0 -0
  20. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/data/processor.py +0 -0
  21. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/data/task.py +0 -0
  22. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/data/utils.py +0 -0
  23. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/errors.py +0 -0
  24. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/model/__init__.py +0 -0
  25. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/model/convnp.py +0 -0
  26. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/model/defaults.py +0 -0
  27. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/model/nps.py +0 -0
  28. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/plot.py +0 -0
  29. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/py.typed +0 -0
  30. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/tensorflow/__init__.py +0 -0
  31. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/torch/__init__.py +0 -0
  32. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/train/__init__.py +0 -0
  33. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor/train/train.py +0 -0
  34. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor.egg-info/dependency_links.txt +0 -0
  35. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor.egg-info/not-zip-safe +0 -0
  36. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor.egg-info/requires.txt +0 -0
  37. {deepsensor-0.3.8 → deepsensor-0.4.0}/deepsensor.egg-info/top_level.txt +0 -0
  38. {deepsensor-0.3.8 → deepsensor-0.4.0}/pyproject.toml +0 -0
  39. {deepsensor-0.3.8 → deepsensor-0.4.0}/setup.py +0 -0
  40. {deepsensor-0.3.8 → deepsensor-0.4.0}/tests/__init__.py +0 -0
  41. {deepsensor-0.3.8 → deepsensor-0.4.0}/tests/test_active_learning.py +0 -0
  42. {deepsensor-0.3.8 → deepsensor-0.4.0}/tests/test_data_processor.py +0 -0
  43. {deepsensor-0.3.8 → deepsensor-0.4.0}/tests/test_plotting.py +0 -0
  44. {deepsensor-0.3.8 → deepsensor-0.4.0}/tests/test_task.py +0 -0
  45. {deepsensor-0.3.8 → deepsensor-0.4.0}/tests/test_training.py +0 -0
  46. {deepsensor-0.3.8 → deepsensor-0.4.0}/tests/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: deepsensor
3
- Version: 0.3.8
3
+ Version: 0.4.0
4
4
  Summary: A Python package for modelling xarray and pandas data with neural processes.
5
5
  Home-page: https://github.com/alan-turing-institute/deepsensor
6
6
  Author: Tom R. Andersson
@@ -44,7 +44,7 @@ data with neural processes</p>
44
44
 
45
45
  -----------
46
46
 
47
- [![release](https://img.shields.io/badge/release-v0.3.8-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
47
+ [![release](https://img.shields.io/badge/release-v0.4.0-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
48
48
  [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://alan-turing-institute.github.io/deepsensor/)
49
49
  ![Tests](https://github.com/alan-turing-institute/deepsensor/actions/workflows/tests.yml/badge.svg)
50
50
  [![Coverage Status](https://coveralls.io/repos/github/alan-turing-institute/deepsensor/badge.svg?branch=main)](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
@@ -240,7 +240,7 @@ if you would like to join this list!
240
240
  <table>
241
241
  <tbody>
242
242
  <tr>
243
- <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
243
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
244
244
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
245
245
  <td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
246
246
  <td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
@@ -11,7 +11,7 @@ data with neural processes</p>
11
11
 
12
12
  -----------
13
13
 
14
- [![release](https://img.shields.io/badge/release-v0.3.8-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
14
+ [![release](https://img.shields.io/badge/release-v0.4.0-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
15
15
  [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://alan-turing-institute.github.io/deepsensor/)
16
16
  ![Tests](https://github.com/alan-turing-institute/deepsensor/actions/workflows/tests.yml/badge.svg)
17
17
  [![Coverage Status](https://coveralls.io/repos/github/alan-turing-institute/deepsensor/badge.svg?branch=main)](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
@@ -207,7 +207,7 @@ if you would like to join this list!
207
207
  <table>
208
208
  <tbody>
209
209
  <tr>
210
- <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
210
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
211
211
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
212
212
  <td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
213
213
  <td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
@@ -678,11 +678,11 @@ class TaskLoader:
678
678
  seed: Optional[int] = None,
679
679
  ) -> (np.ndarray, np.ndarray):
680
680
  """
681
- Sample a DataArray according to a given strategy.
681
+ Sample a DataFrame according to a given strategy.
682
682
 
683
683
  Args:
684
684
  df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
685
- DataArray to sample, assumed to be time-sliced for the task
685
+ Dataframe to sample, assumed to be time-sliced for the task
686
686
  already.
687
687
  sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
688
688
  Sampling strategy, either "all" or an integer for random grid
@@ -720,20 +720,24 @@ class TaskLoader:
720
720
  X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
721
721
  Y_c = df.values.T
722
722
  elif isinstance(sampling_strat, np.ndarray):
723
+ if df.index.get_level_values("x1").dtype != sampling_strat.dtype:
724
+ raise InvalidSamplingStrategyError(
725
+ "Passed a numpy coordinate array to sample pandas DataFrame, "
726
+ "but the coordinate array has a different dtype than the DataFrame. "
727
+ f"Got {sampling_strat.dtype} but expected {df.index.get_level_values('x1').dtype}."
728
+ )
723
729
  X_c = sampling_strat.astype(self.dtype)
724
- x1match = np.in1d(df.index.get_level_values("x1"), X_c[0])
725
- x2match = np.in1d(df.index.get_level_values("x2"), X_c[1])
726
- num_matches = np.sum(x1match & x2match)
727
-
728
- # Check that we got all the samples we asked for
729
- if num_matches != X_c.shape[1]:
730
+ try:
731
+ Y_c = df.loc[pd.IndexSlice[:, X_c[0], X_c[1]]].values.T
732
+ except KeyError:
730
733
  raise InvalidSamplingStrategyError(
731
- f"Passed a numpy coordinate array to sample pandas DataFrame, "
732
- f"but the DataFrame did not contain all the requested samples. "
733
- f"Requested {X_c.shape[1]} samples but only got {num_matches}."
734
+ "Passed a numpy coordinate array to sample pandas DataFrame, "
735
+ "but the DataFrame did not contain all the requested samples.\n"
736
+ f"Indexes: {df.index}\n"
737
+ f"Sampling coords: {X_c}\n"
738
+ "If this is unexpected, check that your numpy sampling array matches "
739
+ "the DataFrame index values *exactly*."
734
740
  )
735
-
736
- Y_c = df[x1match & x2match].values.T
737
741
  else:
738
742
  raise InvalidSamplingStrategyError(
739
743
  f"Unknown sampling strategy {sampling_strat}"
@@ -277,7 +277,7 @@ def get_era5_reanalysis_data(
277
277
  if num_processes == 1:
278
278
  # Just download in one go
279
279
  if verbose:
280
- print("Downloading ERA5 data in without parallelisation... ")
280
+ print("Downloading ERA5 data without parallelisation... ")
281
281
  era5_da = _get_era5_reanalysis_data_parallel(
282
282
  date_range=date_range,
283
283
  var_IDs=var_IDs,
@@ -432,7 +432,7 @@ def get_gldas_land_mask(
432
432
  with urllib.request.urlopen(req) as response:
433
433
  with open(fname, "wb") as f:
434
434
  f.write(response.read())
435
- da = xr.open_dataset(fname)["GLDAS_mask"].isel(time=0).drop("time").load()
435
+ da = xr.open_dataset(fname)["GLDAS_mask"].isel(time=0).drop_vars("time").load()
436
436
 
437
437
  if isinstance(extent, str):
438
438
  extent = extent_str_to_tuple(extent)
@@ -577,7 +577,7 @@ def get_earthenv_auxiliary_data(
577
577
  # Read data
578
578
  da = xr.open_dataset(fname).to_array().squeeze().load()
579
579
  da = da.rename({"y": "lat", "x": "lon"})
580
- da = da.drop(["band", "spatial_ref", "variable"])
580
+ da = da.drop_vars(["band", "spatial_ref", "variable"])
581
581
  da.name = var_ID
582
582
  da = da.sel(lat=slice(lat_max, lat_min), lon=slice(lon_min, lon_max))
583
583
  da_dict[var_ID] = da
@@ -0,0 +1 @@
1
+ from .metrics import *
@@ -0,0 +1,24 @@
1
+ import xarray as xr
2
+ from deepsensor.model.pred import Prediction
3
+
4
+
5
+ def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset:
6
+ """
7
+ Compute errors between predictions and targets.
8
+
9
+ Args:
10
+ pred: Prediction object.
11
+ target: Target data.
12
+
13
+ Returns:
14
+ xr.Dataset: Dataset of pointwise differences between predictions and targets
15
+ at the same valid time in the predictions. Note, the difference is positive
16
+ when the prediction is greater than the target.
17
+ """
18
+ errors = {}
19
+ for var_ID, pred_var in pred.items():
20
+ target_var = target[var_ID]
21
+ error = pred_var["mean"] - target_var.sel(time=pred_var.time)
22
+ error.name = f"{var_ID}"
23
+ errors[var_ID] = error
24
+ return xr.Dataset(errors)
@@ -348,6 +348,34 @@ class DeepSensorModel(ProbabilisticModel):
348
348
  if ar_sample and n_samples < 1:
349
349
  raise ValueError("Must pass `n_samples` > 0 to use `ar_sample`.")
350
350
 
351
+ target_delta_t = self.task_loader.target_delta_t
352
+ dts = [pd.Timedelta(dt) for dt in target_delta_t]
353
+ dts_all_zero = all([dt == pd.Timedelta(seconds=0) for dt in dts])
354
+ if target_delta_t is not None and dts_all_zero:
355
+ forecasting_mode = False
356
+ lead_times = None
357
+ elif target_delta_t is not None and not dts_all_zero:
358
+ target_var_IDs_set = set(self.task_loader.target_var_IDs)
359
+ msg = f"""
360
+ Got more than one set of target variables in target sets,
361
+ but predictions can only be made with one set of target variables
362
+ to simplify implementation.
363
+ Got {target_var_IDs_set}.
364
+ """
365
+ assert len(target_var_IDs_set) == 1, msg
366
+ # Repeat lead_tim for each variable in each target set
367
+ lead_times = []
368
+ for target_set_idx, dt in enumerate(target_delta_t):
369
+ target_set_dim = self.task_loader.target_dims[target_set_idx]
370
+ lead_times += [
371
+ pd.Timedelta(dt, unit=self.task_loader.time_freq)
372
+ for _ in range(target_set_dim)
373
+ ]
374
+ forecasting_mode = True
375
+ else:
376
+ forecasting_mode = False
377
+ lead_times = None
378
+
351
379
  if type(tasks) is Task:
352
380
  tasks = [tasks]
353
381
 
@@ -355,12 +383,14 @@ class DeepSensorModel(ProbabilisticModel):
355
383
  B.set_random_seed(seed)
356
384
  np.random.seed(seed)
357
385
 
358
- dates = [task["time"] for task in tasks]
386
+ init_dates = [task["time"] for task in tasks]
359
387
 
360
388
  # Flatten tuple of tuples to single list
361
389
  target_var_IDs = [
362
390
  var_ID for set in self.task_loader.target_var_IDs for var_ID in set
363
391
  ]
392
+ if lead_times is not None:
393
+ assert len(lead_times) == len(target_var_IDs)
364
394
 
365
395
  # TODO consider removing this logic, can we just depend on the dim names in X_t?
366
396
  if not unnormalise:
@@ -385,7 +415,7 @@ class DeepSensorModel(ProbabilisticModel):
385
415
  elif isinstance(X_t, (xr.DataArray, xr.Dataset)):
386
416
  # Remove time dimension if present
387
417
  if "time" in X_t.coords:
388
- X_t = X_t.isel(time=0).drop("time")
418
+ X_t = X_t.isel(time=0).drop_vars("time")
389
419
 
390
420
  if mode == "off-grid" and append_indexes is not None:
391
421
  # Check append_indexes are all same length as X_t
@@ -450,11 +480,13 @@ class DeepSensorModel(ProbabilisticModel):
450
480
  pred = Prediction(
451
481
  target_var_IDs,
452
482
  pred_params_to_store,
453
- dates,
483
+ init_dates,
454
484
  X_t,
455
485
  X_t_mask,
456
486
  coord_names,
457
487
  n_samples=n_samples,
488
+ forecasting_mode=forecasting_mode,
489
+ lead_times=lead_times,
458
490
  )
459
491
 
460
492
  def unnormalise_pred_array(arr, **kwargs):
@@ -605,14 +637,22 @@ class DeepSensorModel(ProbabilisticModel):
605
637
  # Assign predictions to Prediction object
606
638
  for param, arr in prediction_arrs.items():
607
639
  if param != "mixture_probs":
608
- pred.assign(param, task["time"], arr)
640
+ pred.assign(param, task["time"], arr, lead_times=lead_times)
609
641
  elif param == "mixture_probs":
610
642
  assert arr.shape[0] == self.N_mixture_components, (
611
643
  f"Number of mixture components ({arr.shape[0]}) does not match "
612
644
  f"model attribute N_mixture_components ({self.N_mixture_components})."
613
645
  )
614
646
  for component_i, probs in enumerate(arr):
615
- pred.assign(f"{param}_{component_i}", task["time"], probs)
647
+ pred.assign(
648
+ f"{param}_{component_i}",
649
+ task["time"],
650
+ probs,
651
+ lead_times=lead_times,
652
+ )
653
+
654
+ if forecasting_mode:
655
+ pred = add_valid_time_coord_to_pred(pred)
616
656
 
617
657
  if verbose:
618
658
  dur = time.time() - tic
@@ -621,6 +661,35 @@ class DeepSensorModel(ProbabilisticModel):
621
661
  return pred
622
662
 
623
663
 
664
+ def add_valid_time_coord_to_pred(pred: Prediction) -> Prediction:
665
+ """
666
+ Add a valid time coordinate "time" to a Prediction object based on the
667
+ initialisation times "init_time" and lead times "lead_time".
668
+
669
+ Args:
670
+ pred (:class:`~.model.pred.Prediction`):
671
+ Prediction object to add valid time coordinate to.
672
+
673
+ Returns:
674
+ :class:`~.model.pred.Prediction`:
675
+ Prediction object with valid time coordinate added.
676
+ """
677
+ for var_ID in pred.keys():
678
+ if isinstance(pred[var_ID], pd.DataFrame):
679
+ x = pred[var_ID].reset_index()
680
+ pred[var_ID]["time"] = (x["lead_time"] + x["init_time"]).values
681
+ print(f"{x}")
682
+ print(f"{x.dtypes}")
683
+ elif isinstance(pred[var_ID], xr.Dataset):
684
+ x = pred[var_ID]
685
+ pred[var_ID] = pred[var_ID].assign_coords(
686
+ time=x["lead_time"] + x["init_time"]
687
+ )
688
+ else:
689
+ raise ValueError(f"Unsupported prediction type {type(pred[var_ID])}.")
690
+ return pred
691
+
692
+
624
693
  def main(): # pragma: no cover
625
694
  import deepsensor.tensorflow
626
695
  from deepsensor.data.loader import TaskLoader
@@ -4,6 +4,8 @@ import numpy as np
4
4
  import pandas as pd
5
5
  import xarray as xr
6
6
 
7
+ Timestamp = Union[str, pd.Timestamp, np.datetime64]
8
+
7
9
 
8
10
  class Prediction(dict):
9
11
  """
@@ -32,13 +34,20 @@ class Prediction(dict):
32
34
  n_samples (int)
33
35
  Number of joint samples to draw from the model. If 0, will not
34
36
  draw samples. Default 0.
37
+ forecasting_mode (bool)
38
+ If True, stored forecast predictions with an init_time and lead_time dimension,
39
+ and a valid_time coordinate. If False, stores prediction at t=0 only
40
+ (i.e. spatial interpolation), with only a single time dimension. Default False.
41
+ lead_times (List[pd.Timedelta], optional)
42
+ List of lead times to store in predictions. Must be provided if
43
+ forecasting_mode is True. Default None.
35
44
  """
36
45
 
37
46
  def __init__(
38
47
  self,
39
48
  target_var_IDs: List[str],
40
49
  pred_params: List[str],
41
- dates: List[Union[str, pd.Timestamp]],
50
+ dates: List[Timestamp],
42
51
  X_t: Union[
43
52
  xr.Dataset,
44
53
  xr.DataArray,
@@ -50,6 +59,8 @@ class Prediction(dict):
50
59
  X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
51
60
  coord_names: dict = None,
52
61
  n_samples: int = 0,
62
+ forecasting_mode: bool = False,
63
+ lead_times: Optional[List[pd.Timedelta]] = None,
53
64
  ):
54
65
  self.target_var_IDs = target_var_IDs
55
66
  self.X_t_mask = X_t_mask
@@ -58,6 +69,13 @@ class Prediction(dict):
58
69
  self.x1_name = coord_names["x1"]
59
70
  self.x2_name = coord_names["x2"]
60
71
 
72
+ self.forecasting_mode = forecasting_mode
73
+ if forecasting_mode:
74
+ assert (
75
+ lead_times is not None
76
+ ), "If forecasting_mode is True, lead_times must be provided."
77
+ self.lead_times = lead_times
78
+
61
79
  self.mode = infer_prediction_modality_from_X_t(X_t)
62
80
 
63
81
  self.pred_params = pred_params
@@ -67,15 +85,25 @@ class Prediction(dict):
67
85
  *[f"sample_{i}" for i in range(n_samples)],
68
86
  ]
69
87
 
88
+ # Create empty xarray/pandas objects to store predictions
70
89
  if self.mode == "on-grid":
71
90
  for var_ID in self.target_var_IDs:
72
- # Create empty xarray/pandas objects to store predictions
91
+ if self.forecasting_mode:
92
+ prepend_dims = ["lead_time"]
93
+ prepend_coords = {"lead_time": lead_times}
94
+ else:
95
+ prepend_dims = None
96
+ prepend_coords = None
73
97
  self[var_ID] = create_empty_spatiotemporal_xarray(
74
98
  X_t,
75
99
  dates,
76
100
  data_vars=self.pred_params,
77
101
  coord_names=coord_names,
102
+ prepend_dims=prepend_dims,
103
+ prepend_coords=prepend_coords,
78
104
  )
105
+ if self.forecasting_mode:
106
+ self[var_ID] = self[var_ID].rename(time="init_time")
79
107
  if self.X_t_mask is None:
80
108
  # Create 2D boolean array of True values to simplify indexing
81
109
  self.X_t_mask = (
@@ -86,8 +114,18 @@ class Prediction(dict):
86
114
  )
87
115
  elif self.mode == "off-grid":
88
116
  # Repeat target locs for each date to create multiindex
89
- idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
90
- index = pd.MultiIndex.from_tuples(idxs, names=["time", *X_t.index.names])
117
+ if self.forecasting_mode:
118
+ index_names = ["lead_time", "init_time", *X_t.index.names]
119
+ idxs = [
120
+ (lt, date, *idxs)
121
+ for lt in lead_times
122
+ for date in dates
123
+ for idxs in X_t.index
124
+ ]
125
+ else:
126
+ index_names = ["time", *X_t.index.names]
127
+ idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
128
+ index = pd.MultiIndex.from_tuples(idxs, names=index_names)
91
129
  for var_ID in self.target_var_IDs:
92
130
  self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params)
93
131
 
@@ -106,6 +144,7 @@ class Prediction(dict):
106
144
  prediction_parameter: str,
107
145
  date: Union[str, pd.Timestamp],
108
146
  data: np.ndarray,
147
+ lead_times: Optional[List[pd.Timedelta]] = None,
109
148
  ):
110
149
  """
111
150
 
@@ -117,11 +156,29 @@ class Prediction(dict):
117
156
  data (np.ndarray)
118
157
  If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets).
119
158
  If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2).
159
+ lead_time (pd.Timedelta, optional)
160
+ Lead time of the forecast. Required if forecasting_mode is True. Default None.
120
161
  """
162
+ if self.forecasting_mode:
163
+ assert (
164
+ lead_times is not None
165
+ ), "If forecasting_mode is True, lead_times must be provided."
166
+
167
+ msg = f"""
168
+ If forecasting_mode is True, lead_times must be of equal length to the number of
169
+ variables in the data (the first dimension). Got {lead_times=} of length
170
+ {len(lead_times)} lead times and data shape {data.shape}.
171
+ """
172
+ assert len(lead_times) == data.shape[0], msg
173
+
121
174
  if self.mode == "on-grid":
122
175
  if prediction_parameter != "samples":
123
- for var_ID, pred in zip(self.target_var_IDs, data):
124
- self[var_ID][prediction_parameter].loc[date].data[
176
+ for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
177
+ if self.forecasting_mode:
178
+ index = (lead_times[i], date)
179
+ else:
180
+ index = date
181
+ self[var_ID][prediction_parameter].loc[index].data[
125
182
  self.X_t_mask.data
126
183
  ] = pred.ravel()
127
184
  elif prediction_parameter == "samples":
@@ -130,28 +187,44 @@ class Prediction(dict):
130
187
  f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}."
131
188
  )
132
189
  for sample_i, sample in enumerate(data):
133
- for var_ID, pred in zip(self.target_var_IDs, sample):
134
- self[var_ID][f"sample_{sample_i}"].loc[date].data[
190
+ for i, (var_ID, pred) in enumerate(
191
+ zip(self.target_var_IDs, sample)
192
+ ):
193
+ if self.forecasting_mode:
194
+ index = (lead_times[i], date)
195
+ else:
196
+ index = date
197
+ self[var_ID][f"sample_{sample_i}"].loc[index].data[
135
198
  self.X_t_mask.data
136
199
  ] = pred.ravel()
137
200
 
138
201
  elif self.mode == "off-grid":
139
202
  if prediction_parameter != "samples":
140
- for var_ID, pred in zip(self.target_var_IDs, data):
141
- self[var_ID][prediction_parameter].loc[date] = pred
203
+ for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
204
+ if self.forecasting_mode:
205
+ index = (lead_times[i], date)
206
+ else:
207
+ index = date
208
+ self[var_ID][prediction_parameter].loc[index] = pred
142
209
  elif prediction_parameter == "samples":
143
210
  assert len(data.shape) == 3, (
144
211
  f"If prediction_parameter is 'samples', and mode is 'off-grid', data must"
145
212
  f"have shape (N_samples, N_var, N_targets). Got {data.shape}."
146
213
  )
147
214
  for sample_i, sample in enumerate(data):
148
- for var_ID, pred in zip(self.target_var_IDs, sample):
149
- self[var_ID][f"sample_{sample_i}"].loc[date] = pred
215
+ for i, (var_ID, pred) in enumerate(
216
+ zip(self.target_var_IDs, sample)
217
+ ):
218
+ if self.forecasting_mode:
219
+ index = (lead_times[i], date)
220
+ else:
221
+ index = date
222
+ self[var_ID][f"sample_{sample_i}"].loc[index] = pred
150
223
 
151
224
 
152
225
  def create_empty_spatiotemporal_xarray(
153
226
  X: Union[xr.Dataset, xr.DataArray],
154
- dates: List,
227
+ dates: List[Timestamp],
155
228
  coord_names: dict = None,
156
229
  data_vars: List[str] = None,
157
230
  prepend_dims: Optional[List[str]] = None,
@@ -231,10 +304,6 @@ def create_empty_spatiotemporal_xarray(
231
304
  # Convert time coord to pandas timestamps
232
305
  pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values))
233
306
 
234
- # TODO: Convert init time to forecast time?
235
- # pred_ds = pred_ds.assign_coords(
236
- # time=pred_ds['time'] + pd.Timedelta(days=task_loader.target_delta_t[0]))
237
-
238
307
  return pred_ds
239
308
 
240
309
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: deepsensor
3
- Version: 0.3.8
3
+ Version: 0.4.0
4
4
  Summary: A Python package for modelling xarray and pandas data with neural processes.
5
5
  Home-page: https://github.com/alan-turing-institute/deepsensor
6
6
  Author: Tom R. Andersson
@@ -44,7 +44,7 @@ data with neural processes</p>
44
44
 
45
45
  -----------
46
46
 
47
- [![release](https://img.shields.io/badge/release-v0.3.8-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
47
+ [![release](https://img.shields.io/badge/release-v0.4.0-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
48
48
  [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://alan-turing-institute.github.io/deepsensor/)
49
49
  ![Tests](https://github.com/alan-turing-institute/deepsensor/actions/workflows/tests.yml/badge.svg)
50
50
  [![Coverage Status](https://coveralls.io/repos/github/alan-turing-institute/deepsensor/badge.svg?branch=main)](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
@@ -240,7 +240,7 @@ if you would like to join this list!
240
240
  <table>
241
241
  <tbody>
242
242
  <tr>
243
- <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
243
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
244
244
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
245
245
  <td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
246
246
  <td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
@@ -22,6 +22,8 @@ deepsensor/data/processor.py
22
22
  deepsensor/data/sources.py
23
23
  deepsensor/data/task.py
24
24
  deepsensor/data/utils.py
25
+ deepsensor/eval/__init__.py
26
+ deepsensor/eval/metrics.py
25
27
  deepsensor/model/__init__.py
26
28
  deepsensor/model/convnp.py
27
29
  deepsensor/model/defaults.py
@@ -1,6 +1,6 @@
1
1
  [metadata]
2
2
  name = deepsensor
3
- version = 0.3.8
3
+ version = 0.4.0
4
4
  author = Tom R. Andersson
5
5
  author_email = tomandersson3@gmail.com
6
6
  description = A Python package for modelling xarray and pandas data with neural processes.
@@ -18,6 +18,7 @@ from deepsensor.data.processor import DataProcessor
18
18
  from deepsensor.data.loader import TaskLoader
19
19
  from deepsensor.model.convnp import ConvNP
20
20
  from deepsensor.train.train import Trainer
21
+ from deepsensor.eval.metrics import compute_errors
21
22
 
22
23
  from tests.utils import gen_random_data_xr, gen_random_data_pandas
23
24
 
@@ -55,8 +56,15 @@ class TestModel(unittest.TestCase):
55
56
  def setUpClass(cls):
56
57
  # super().__init__(*args, **kwargs)
57
58
  # It's safe to share data between tests because the TaskLoader does not modify data
59
+ cls.var_ID = "2m_temp"
58
60
  cls.da = _gen_data_xr()
61
+ cls.da.name = cls.var_ID
59
62
  cls.df = _gen_data_pandas()
63
+ cls.df.name = cls.var_ID
64
+ # Various tests assume we have a single target set with a single variable.
65
+ # If a test requires multiple target sets or variables, this is set up in the test.
66
+ assert isinstance(cls.da, xr.DataArray)
67
+ assert isinstance(cls.df, pd.Series)
60
68
 
61
69
  cls.dp = DataProcessor()
62
70
  _ = cls.dp([cls.da, cls.df]) # Compute normalisation parameters
@@ -417,10 +425,10 @@ class TestModel(unittest.TestCase):
417
425
  task = tl("2020-01-01")
418
426
  pred = model.predict(task, X_t=da_raw)
419
427
 
420
- assert np.array_equal(
428
+ np.testing.assert_array_equal(
421
429
  pred["dummy_data"]["mean"]["latitude"], da_raw["latitude"]
422
430
  )
423
- assert np.array_equal(
431
+ np.testing.assert_array_equal(
424
432
  pred["dummy_data"]["mean"]["longitude"], da_raw["longitude"]
425
433
  )
426
434
 
@@ -493,14 +501,14 @@ class TestModel(unittest.TestCase):
493
501
  # Check that nothing breaks and the correct parameters are returned
494
502
  pred = model.predict(task, X_t=X_t, pred_params=pred_params)
495
503
  for pred_param in pred_params:
496
- assert pred_param in pred["var"]
504
+ assert pred_param in pred[self.var_ID]
497
505
 
498
506
  # Test mixture probs special case
499
507
  pred_params = ["mixture_probs"]
500
508
  pred = model.predict(task, X_t=self.da, pred_params=pred_params)
501
509
  for component in range(model.N_mixture_components):
502
510
  pred_param = f"mixture_probs_{component}"
503
- assert pred_param in pred["var"]
511
+ assert pred_param in pred[self.var_ID]
504
512
 
505
513
  def test_highlevel_predict_with_pred_params_xarray(self):
506
514
  """
@@ -528,14 +536,14 @@ class TestModel(unittest.TestCase):
528
536
  # Check that nothing breaks and the correct parameters are returned
529
537
  pred = model.predict(task, X_t=self.da, pred_params=pred_params)
530
538
  for pred_param in pred_params:
531
- assert pred_param in pred["var"]
539
+ assert pred_param in pred[self.var_ID]
532
540
 
533
541
  # Test mixture probs special case
534
542
  pred_params = ["mixture_probs"]
535
543
  pred = model.predict(task, X_t=self.da, pred_params=pred_params)
536
544
  for component in range(model.N_mixture_components):
537
545
  pred_param = f"mixture_probs_{component}"
538
- assert pred_param in pred["var"]
546
+ assert pred_param in pred[self.var_ID]
539
547
 
540
548
  def test_highlevel_predict_with_invalid_pred_params(self):
541
549
  """Test that passing ``pred_params`` to ``.predict`` works."""
@@ -640,6 +648,66 @@ class TestModel(unittest.TestCase):
640
648
  ar_sample=True,
641
649
  )
642
650
 
651
+ def test_forecasting_model_predict_return_valid_times(self):
652
+ """Test that the times returned by a forecasting model are valid."""
653
+ init_dates = ["2020-01-01", "2020-01-02"]
654
+ expected_init_times = np.array(init_dates).astype(np.datetime64)
655
+
656
+ lead_times_days = [1, 2, 3]
657
+ expected_lead_times = np.array(
658
+ [np.timedelta64(lt, "D") for lt in lead_times_days]
659
+ )
660
+
661
+ expected_valid_times = np.array(
662
+ expected_lead_times[:, None] + expected_init_times[None, :]
663
+ )
664
+
665
+ tl = TaskLoader(
666
+ context=self.da,
667
+ target=[
668
+ self.da,
669
+ ]
670
+ * len(lead_times_days),
671
+ target_delta_t=lead_times_days,
672
+ time_freq="D",
673
+ )
674
+ model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
675
+ tasks = tl(init_dates, context_sampling=10)
676
+
677
+ X_ts = [
678
+ # Gridded predictions (xarray)
679
+ self.da,
680
+ # Off-grid prediction (pandas)
681
+ np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]),
682
+ ]
683
+ for X_t in X_ts:
684
+ pred = model.predict(tasks, X_t=X_t)
685
+
686
+ pred_var = pred[self.var_ID]
687
+
688
+ if isinstance(pred_var, xr.Dataset):
689
+ # Check we can compute errors using the valid time coord ('time')
690
+ errors = compute_errors(pred, self.da.to_dataset())
691
+ for var_ID in errors.keys():
692
+ assert tuple(errors[var_ID].dims) == (
693
+ "lead_time",
694
+ "init_time",
695
+ "x1",
696
+ "x2",
697
+ )
698
+ assert errors[var_ID].shape == pred[var_ID]["mean"].shape
699
+ elif isinstance(pred_var, pd.DataFrame):
700
+ # Makes coordinate checking easier by avoiding repeat values
701
+ pred_var = pred_var.to_xarray().isel(x1=0, x2=0)
702
+
703
+ np.testing.assert_array_equal(
704
+ pred_var.lead_time.values, expected_lead_times
705
+ )
706
+ np.testing.assert_array_equal(
707
+ pred_var.init_time.values, expected_init_times
708
+ )
709
+ np.testing.assert_array_equal(pred_var.time.values, expected_valid_times)
710
+
643
711
 
644
712
  def assert_shape(x, shape: tuple):
645
713
  """Assert that the shape of ``x`` matches ``shape``."""
@@ -192,6 +192,23 @@ class TestTaskLoader(unittest.TestCase):
192
192
  with self.assertRaises(InvalidSamplingStrategyError):
193
193
  task = tl("2020-01-01", invalid_sampling_strategy)
194
194
 
195
+ def test_different_dtype_when_sampling_offgrid_data_at_specific_numpy_locs(self):
196
+ """Test different dtype when sampling off-grid data at specific numpy locations."""
197
+ sampling_strat = np.array(
198
+ [np.linspace(0, 1, 10), np.linspace(0, 1, 10)], dtype=np.float16
199
+ )
200
+
201
+ tl = TaskLoader(
202
+ context=self.df,
203
+ target=self.df,
204
+ )
205
+
206
+ assert sampling_strat.dtype != tl.context[0].index.get_level_values("x1").dtype
207
+ assert sampling_strat.dtype != tl.context[0].index.get_level_values("x2").dtype
208
+
209
+ with self.assertRaises(InvalidSamplingStrategyError):
210
+ task = tl("2020-01-01", sampling_strat, sampling_strat)
211
+
195
212
  def test_wrong_links(self):
196
213
  """Test link indexes out of range."""
197
214
  with self.assertRaises(ValueError):
File without changes
File without changes
File without changes
File without changes