deepsensor 0.3.7__tar.gz → 0.4.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. {deepsensor-0.3.7 → deepsensor-0.4.0}/PKG-INFO +3 -3
  2. {deepsensor-0.3.7 → deepsensor-0.4.0}/README.md +2 -2
  3. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/loader.py +17 -13
  4. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/processor.py +21 -26
  5. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/sources.py +3 -3
  6. deepsensor-0.4.0/deepsensor/eval/__init__.py +1 -0
  7. deepsensor-0.4.0/deepsensor/eval/metrics.py +24 -0
  8. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/convnp.py +78 -4
  9. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/model.py +74 -5
  10. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/pred.py +86 -17
  11. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/PKG-INFO +3 -3
  12. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/SOURCES.txt +2 -0
  13. {deepsensor-0.3.7 → deepsensor-0.4.0}/setup.cfg +1 -1
  14. {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_active_learning.py +0 -3
  15. {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_model.py +143 -45
  16. {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_task_loader.py +17 -0
  17. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/__init__.py +0 -0
  18. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/active_learning/__init__.py +0 -0
  19. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/active_learning/acquisition_fns.py +0 -0
  20. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/active_learning/algorithms.py +0 -0
  21. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/config.py +0 -0
  22. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/__init__.py +0 -0
  23. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/task.py +0 -0
  24. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/data/utils.py +0 -0
  25. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/errors.py +0 -0
  26. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/__init__.py +0 -0
  27. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/defaults.py +0 -0
  28. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/model/nps.py +0 -0
  29. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/plot.py +0 -0
  30. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/py.typed +0 -0
  31. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/tensorflow/__init__.py +0 -0
  32. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/torch/__init__.py +0 -0
  33. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/train/__init__.py +0 -0
  34. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor/train/train.py +0 -0
  35. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/dependency_links.txt +0 -0
  36. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/not-zip-safe +0 -0
  37. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/requires.txt +0 -0
  38. {deepsensor-0.3.7 → deepsensor-0.4.0}/deepsensor.egg-info/top_level.txt +0 -0
  39. {deepsensor-0.3.7 → deepsensor-0.4.0}/pyproject.toml +0 -0
  40. {deepsensor-0.3.7 → deepsensor-0.4.0}/setup.py +0 -0
  41. {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/__init__.py +0 -0
  42. {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_data_processor.py +0 -0
  43. {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_plotting.py +0 -0
  44. {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_task.py +0 -0
  45. {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/test_training.py +0 -0
  46. {deepsensor-0.3.7 → deepsensor-0.4.0}/tests/utils.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: deepsensor
3
- Version: 0.3.7
3
+ Version: 0.4.0
4
4
  Summary: A Python package for modelling xarray and pandas data with neural processes.
5
5
  Home-page: https://github.com/alan-turing-institute/deepsensor
6
6
  Author: Tom R. Andersson
@@ -44,7 +44,7 @@ data with neural processes</p>
44
44
 
45
45
  -----------
46
46
 
47
- [![release](https://img.shields.io/badge/release-v0.3.7-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
47
+ [![release](https://img.shields.io/badge/release-v0.4.0-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
48
48
  [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://alan-turing-institute.github.io/deepsensor/)
49
49
  ![Tests](https://github.com/alan-turing-institute/deepsensor/actions/workflows/tests.yml/badge.svg)
50
50
  [![Coverage Status](https://coveralls.io/repos/github/alan-turing-institute/deepsensor/badge.svg?branch=main)](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
@@ -240,7 +240,7 @@ if you would like to join this list!
240
240
  <table>
241
241
  <tbody>
242
242
  <tr>
243
- <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
243
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
244
244
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
245
245
  <td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
246
246
  <td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
@@ -11,7 +11,7 @@ data with neural processes</p>
11
11
 
12
12
  -----------
13
13
 
14
- [![release](https://img.shields.io/badge/release-v0.3.7-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
14
+ [![release](https://img.shields.io/badge/release-v0.4.0-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
15
15
  [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://alan-turing-institute.github.io/deepsensor/)
16
16
  ![Tests](https://github.com/alan-turing-institute/deepsensor/actions/workflows/tests.yml/badge.svg)
17
17
  [![Coverage Status](https://coveralls.io/repos/github/alan-turing-institute/deepsensor/badge.svg?branch=main)](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
@@ -207,7 +207,7 @@ if you would like to join this list!
207
207
  <table>
208
208
  <tbody>
209
209
  <tr>
210
- <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
210
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
211
211
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
212
212
  <td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
213
213
  <td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
@@ -678,11 +678,11 @@ class TaskLoader:
678
678
  seed: Optional[int] = None,
679
679
  ) -> (np.ndarray, np.ndarray):
680
680
  """
681
- Sample a DataArray according to a given strategy.
681
+ Sample a DataFrame according to a given strategy.
682
682
 
683
683
  Args:
684
684
  df (:class:`pandas.DataFrame` | :class:`pandas.Series`):
685
- DataArray to sample, assumed to be time-sliced for the task
685
+ Dataframe to sample, assumed to be time-sliced for the task
686
686
  already.
687
687
  sampling_strat (str | int | float | :class:`numpy:numpy.ndarray`):
688
688
  Sampling strategy, either "all" or an integer for random grid
@@ -720,20 +720,24 @@ class TaskLoader:
720
720
  X_c = df.reset_index()[["x1", "x2"]].values.T.astype(self.dtype)
721
721
  Y_c = df.values.T
722
722
  elif isinstance(sampling_strat, np.ndarray):
723
+ if df.index.get_level_values("x1").dtype != sampling_strat.dtype:
724
+ raise InvalidSamplingStrategyError(
725
+ "Passed a numpy coordinate array to sample pandas DataFrame, "
726
+ "but the coordinate array has a different dtype than the DataFrame. "
727
+ f"Got {sampling_strat.dtype} but expected {df.index.get_level_values('x1').dtype}."
728
+ )
723
729
  X_c = sampling_strat.astype(self.dtype)
724
- x1match = np.in1d(df.index.get_level_values("x1"), X_c[0])
725
- x2match = np.in1d(df.index.get_level_values("x2"), X_c[1])
726
- num_matches = np.sum(x1match & x2match)
727
-
728
- # Check that we got all the samples we asked for
729
- if num_matches != X_c.shape[1]:
730
+ try:
731
+ Y_c = df.loc[pd.IndexSlice[:, X_c[0], X_c[1]]].values.T
732
+ except KeyError:
730
733
  raise InvalidSamplingStrategyError(
731
- f"Passed a numpy coordinate array to sample pandas DataFrame, "
732
- f"but the DataFrame did not contain all the requested samples. "
733
- f"Requested {X_c.shape[1]} samples but only got {num_matches}."
734
+ "Passed a numpy coordinate array to sample pandas DataFrame, "
735
+ "but the DataFrame did not contain all the requested samples.\n"
736
+ f"Indexes: {df.index}\n"
737
+ f"Sampling coords: {X_c}\n"
738
+ "If this is unexpected, check that your numpy sampling array matches "
739
+ "the DataFrame index values *exactly*."
734
740
  )
735
-
736
- Y_c = df[x1match & x2match].values.T
737
741
  else:
738
742
  raise InvalidSamplingStrategyError(
739
743
  f"Unknown sampling strategy {sampling_strat}"
@@ -97,7 +97,7 @@ class DataProcessor:
97
97
  self.verbose = verbose
98
98
 
99
99
  # List of valid normalisation method names
100
- self.valid_methods = ["mean_std", "min_max"]
100
+ self.valid_methods = ["mean_std", "min_max", "positive_semidefinite"]
101
101
 
102
102
  def save(self, folder: str):
103
103
  """Save DataProcessor config to JSON in `folder`"""
@@ -293,6 +293,8 @@ class DataProcessor:
293
293
  params = {"mean": float(data.mean()), "std": float(data.std())}
294
294
  elif method == "min_max":
295
295
  params = {"min": float(data.min()), "max": float(data.max())}
296
+ elif method == "positive_semidefinite":
297
+ params = {"min": float(data.min()), "std": float(data.std())}
296
298
  if self.verbose:
297
299
  print(f"Done. {var_ID} {method} params={params}")
298
300
  self.add_to_config(
@@ -498,33 +500,25 @@ class DataProcessor:
498
500
 
499
501
  params = self.get_config(var_ID, data, method)
500
502
 
503
+ # Linear transformation:
504
+ # - Inverse normalisation: y_unnorm = m * y_norm + c
505
+ # - Inverse normalisation: y_norm = (1/m) * y_unnorm - c/m
501
506
  if method == "mean_std":
502
- std = params["std"]
503
- mean = params["mean"]
504
- if unnorm:
505
- scale = std
506
- offset = mean
507
- else:
508
- scale = 1 / std
509
- offset = -mean / std
510
- data = data * scale
511
- if add_offset:
512
- data = data + offset
513
- return data
514
-
507
+ m = params["std"]
508
+ c = params["mean"]
515
509
  elif method == "min_max":
516
- minimum = params["min"]
517
- maximum = params["max"]
518
- if unnorm:
519
- scale = (maximum - minimum) / 2
520
- offset = (maximum + minimum) / 2
521
- else:
522
- scale = 2 / (maximum - minimum)
523
- offset = -(maximum + minimum) / (maximum - minimum)
524
- data = data * scale
525
- if add_offset:
526
- data = data + offset
527
- return data
510
+ m = (params["max"] - params["min"]) / 2
511
+ c = (params["max"] + params["min"]) / 2
512
+ elif method == "positive_semidefinite":
513
+ m = params["std"]
514
+ c = params["min"]
515
+ if not unnorm:
516
+ c = -c / m
517
+ m = 1 / m
518
+ data = data * m
519
+ if add_offset:
520
+ data = data + c
521
+ return data
528
522
 
529
523
  def map(
530
524
  self,
@@ -610,6 +604,7 @@ class DataProcessor:
610
604
  method (str, optional): Normalisation method. Options include:
611
605
  - "mean_std": Normalise to mean=0 and std=1 (default)
612
606
  - "min_max": Normalise to min=-1 and max=1
607
+ - "positive_semidefinite": Normalise to min=0 and std=1
613
608
 
614
609
  Returns:
615
610
  :class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame` | List[:class:`xarray.DataArray` | :class:`xarray.Dataset` | :class:`pandas.DataFrame`]:
@@ -277,7 +277,7 @@ def get_era5_reanalysis_data(
277
277
  if num_processes == 1:
278
278
  # Just download in one go
279
279
  if verbose:
280
- print("Downloading ERA5 data in without parallelisation... ")
280
+ print("Downloading ERA5 data without parallelisation... ")
281
281
  era5_da = _get_era5_reanalysis_data_parallel(
282
282
  date_range=date_range,
283
283
  var_IDs=var_IDs,
@@ -432,7 +432,7 @@ def get_gldas_land_mask(
432
432
  with urllib.request.urlopen(req) as response:
433
433
  with open(fname, "wb") as f:
434
434
  f.write(response.read())
435
- da = xr.open_dataset(fname)["GLDAS_mask"].isel(time=0).drop("time").load()
435
+ da = xr.open_dataset(fname)["GLDAS_mask"].isel(time=0).drop_vars("time").load()
436
436
 
437
437
  if isinstance(extent, str):
438
438
  extent = extent_str_to_tuple(extent)
@@ -577,7 +577,7 @@ def get_earthenv_auxiliary_data(
577
577
  # Read data
578
578
  da = xr.open_dataset(fname).to_array().squeeze().load()
579
579
  da = da.rename({"y": "lat", "x": "lon"})
580
- da = da.drop(["band", "spatial_ref", "variable"])
580
+ da = da.drop_vars(["band", "spatial_ref", "variable"])
581
581
  da.name = var_ID
582
582
  da = da.sel(lat=slice(lat_max, lat_min), lon=slice(lon_min, lon_max))
583
583
  da_dict[var_ID] = da
@@ -0,0 +1 @@
1
+ from .metrics import *
@@ -0,0 +1,24 @@
1
+ import xarray as xr
2
+ from deepsensor.model.pred import Prediction
3
+
4
+
5
+ def compute_errors(pred: Prediction, target: xr.Dataset) -> xr.Dataset:
6
+ """
7
+ Compute errors between predictions and targets.
8
+
9
+ Args:
10
+ pred: Prediction object.
11
+ target: Target data.
12
+
13
+ Returns:
14
+ xr.Dataset: Dataset of pointwise differences between predictions and targets
15
+ at the same valid time in the predictions. Note, the difference is positive
16
+ when the prediction is greater than the target.
17
+ """
18
+ errors = {}
19
+ for var_ID, pred_var in pred.items():
20
+ target_var = target[var_ID]
21
+ error = pred_var["mean"] - target_var.sel(time=pred_var.time)
22
+ error.name = f"{var_ID}"
23
+ errors[var_ID] = error
24
+ return xr.Dataset(errors)
@@ -539,10 +539,10 @@ class ConvNP(DeepSensorModel):
539
539
  def alpha(
540
540
  self, dist: AbstractMultiOutputDistribution
541
541
  ) -> Union[np.ndarray, List[np.ndarray]]:
542
- if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]:
542
+ if self.config["likelihood"] not in ["spikes-beta"]:
543
543
  raise NotImplementedError(
544
544
  f"ConvNP.alpha method not supported for likelihood {self.config['likelihood']}. "
545
- f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'."
545
+ f"Valid likelihoods: 'spikes-beta'."
546
546
  )
547
547
  alpha = dist.slab.alpha
548
548
  alpha = self._cast_numpy_and_squeeze(alpha)
@@ -576,10 +576,10 @@ class ConvNP(DeepSensorModel):
576
576
  def beta(
577
577
  self, dist: AbstractMultiOutputDistribution
578
578
  ) -> Union[np.ndarray, List[np.ndarray]]:
579
- if self.config["likelihood"] not in ["spikes-beta", "bernoulli-gamma"]:
579
+ if self.config["likelihood"] not in ["spikes-beta"]:
580
580
  raise NotImplementedError(
581
581
  f"ConvNP.beta method not supported for likelihood {self.config['likelihood']}. "
582
- f"Try changing the likelihood to a mixture model, e.g. 'spikes-beta' or 'bernoulli-gamma'."
582
+ f"Valid likelihoods: 'spikes-beta'."
583
583
  )
584
584
  beta = dist.slab.beta
585
585
  beta = self._cast_numpy_and_squeeze(beta)
@@ -608,6 +608,80 @@ class ConvNP(DeepSensorModel):
608
608
  dist = self(task)
609
609
  return self.beta(dist)
610
610
 
611
+ @dispatch
612
+ def k(
613
+ self, dist: AbstractMultiOutputDistribution
614
+ ) -> Union[np.ndarray, List[np.ndarray]]:
615
+ if self.config["likelihood"] not in ["bernoulli-gamma"]:
616
+ raise NotImplementedError(
617
+ f"ConvNP.k method not supported for likelihood {self.config['likelihood']}. "
618
+ f"Valid likelihoods: 'bernoulli-gamma'."
619
+ )
620
+ k = dist.slab.k
621
+ k = self._cast_numpy_and_squeeze(k)
622
+ return self._maybe_concat_multi_targets(k)
623
+
624
+ @dispatch
625
+ def k(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
626
+ """
627
+ k parameter values of model's distribution at target locations in task.
628
+
629
+ Returned numpy arrays have shape ``(N_features, *N_targets)``.
630
+
631
+ .. note::
632
+ This method only works for models that return a distribution with
633
+ a ``dist.slab.k`` attribute, e.g. models with a Beta or
634
+ Bernoulli-Gamma likelihood, where it returns the k values of
635
+ the slab component of the mixture model.
636
+
637
+ Args:
638
+ task (:class:`~.data.task.Task`):
639
+ The task containing the context and target data.
640
+
641
+ Returns:
642
+ :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
643
+ k values.
644
+ """
645
+ dist = self(task)
646
+ return self.k(dist)
647
+
648
+ @dispatch
649
+ def scale(
650
+ self, dist: AbstractMultiOutputDistribution
651
+ ) -> Union[np.ndarray, List[np.ndarray]]:
652
+ if self.config["likelihood"] not in ["bernoulli-gamma"]:
653
+ raise NotImplementedError(
654
+ f"ConvNP.scale method not supported for likelihood {self.config['likelihood']}. "
655
+ f"Valid likelihoods: 'bernoulli-gamma'."
656
+ )
657
+ scale = dist.slab.scale
658
+ scale = self._cast_numpy_and_squeeze(scale)
659
+ return self._maybe_concat_multi_targets(scale)
660
+
661
+ @dispatch
662
+ def scale(self, task: Task) -> Union[np.ndarray, List[np.ndarray]]:
663
+ """
664
+ Scale parameter values of model's distribution at target locations in task.
665
+
666
+ Returned numpy arrays have shape ``(N_features, *N_targets)``.
667
+
668
+ .. note::
669
+ This method only works for models that return a distribution with
670
+ a ``dist.slab.scale`` attribute, e.g. models with a Beta or
671
+ Bernoulli-Gamma likelihood, where it returns the scale values of
672
+ the slab component of the mixture model.
673
+
674
+ Args:
675
+ task (:class:`~.data.task.Task`):
676
+ The task containing the context and target data.
677
+
678
+ Returns:
679
+ :class:`numpy:numpy.ndarray` | List[:class:`numpy:numpy.ndarray`]:
680
+ Scale values.
681
+ """
682
+ dist = self(task)
683
+ return self.scale(dist)
684
+
611
685
  @dispatch
612
686
  def mixture_probs(self, dist: AbstractMultiOutputDistribution):
613
687
  if self.N_mixture_components == 1:
@@ -348,6 +348,34 @@ class DeepSensorModel(ProbabilisticModel):
348
348
  if ar_sample and n_samples < 1:
349
349
  raise ValueError("Must pass `n_samples` > 0 to use `ar_sample`.")
350
350
 
351
+ target_delta_t = self.task_loader.target_delta_t
352
+ dts = [pd.Timedelta(dt) for dt in target_delta_t]
353
+ dts_all_zero = all([dt == pd.Timedelta(seconds=0) for dt in dts])
354
+ if target_delta_t is not None and dts_all_zero:
355
+ forecasting_mode = False
356
+ lead_times = None
357
+ elif target_delta_t is not None and not dts_all_zero:
358
+ target_var_IDs_set = set(self.task_loader.target_var_IDs)
359
+ msg = f"""
360
+ Got more than one set of target variables in target sets,
361
+ but predictions can only be made with one set of target variables
362
+ to simplify implementation.
363
+ Got {target_var_IDs_set}.
364
+ """
365
+ assert len(target_var_IDs_set) == 1, msg
366
+ # Repeat lead_tim for each variable in each target set
367
+ lead_times = []
368
+ for target_set_idx, dt in enumerate(target_delta_t):
369
+ target_set_dim = self.task_loader.target_dims[target_set_idx]
370
+ lead_times += [
371
+ pd.Timedelta(dt, unit=self.task_loader.time_freq)
372
+ for _ in range(target_set_dim)
373
+ ]
374
+ forecasting_mode = True
375
+ else:
376
+ forecasting_mode = False
377
+ lead_times = None
378
+
351
379
  if type(tasks) is Task:
352
380
  tasks = [tasks]
353
381
 
@@ -355,12 +383,14 @@ class DeepSensorModel(ProbabilisticModel):
355
383
  B.set_random_seed(seed)
356
384
  np.random.seed(seed)
357
385
 
358
- dates = [task["time"] for task in tasks]
386
+ init_dates = [task["time"] for task in tasks]
359
387
 
360
388
  # Flatten tuple of tuples to single list
361
389
  target_var_IDs = [
362
390
  var_ID for set in self.task_loader.target_var_IDs for var_ID in set
363
391
  ]
392
+ if lead_times is not None:
393
+ assert len(lead_times) == len(target_var_IDs)
364
394
 
365
395
  # TODO consider removing this logic, can we just depend on the dim names in X_t?
366
396
  if not unnormalise:
@@ -385,7 +415,7 @@ class DeepSensorModel(ProbabilisticModel):
385
415
  elif isinstance(X_t, (xr.DataArray, xr.Dataset)):
386
416
  # Remove time dimension if present
387
417
  if "time" in X_t.coords:
388
- X_t = X_t.isel(time=0).drop("time")
418
+ X_t = X_t.isel(time=0).drop_vars("time")
389
419
 
390
420
  if mode == "off-grid" and append_indexes is not None:
391
421
  # Check append_indexes are all same length as X_t
@@ -450,11 +480,13 @@ class DeepSensorModel(ProbabilisticModel):
450
480
  pred = Prediction(
451
481
  target_var_IDs,
452
482
  pred_params_to_store,
453
- dates,
483
+ init_dates,
454
484
  X_t,
455
485
  X_t_mask,
456
486
  coord_names,
457
487
  n_samples=n_samples,
488
+ forecasting_mode=forecasting_mode,
489
+ lead_times=lead_times,
458
490
  )
459
491
 
460
492
  def unnormalise_pred_array(arr, **kwargs):
@@ -605,14 +637,22 @@ class DeepSensorModel(ProbabilisticModel):
605
637
  # Assign predictions to Prediction object
606
638
  for param, arr in prediction_arrs.items():
607
639
  if param != "mixture_probs":
608
- pred.assign(param, task["time"], arr)
640
+ pred.assign(param, task["time"], arr, lead_times=lead_times)
609
641
  elif param == "mixture_probs":
610
642
  assert arr.shape[0] == self.N_mixture_components, (
611
643
  f"Number of mixture components ({arr.shape[0]}) does not match "
612
644
  f"model attribute N_mixture_components ({self.N_mixture_components})."
613
645
  )
614
646
  for component_i, probs in enumerate(arr):
615
- pred.assign(f"{param}_{component_i}", task["time"], probs)
647
+ pred.assign(
648
+ f"{param}_{component_i}",
649
+ task["time"],
650
+ probs,
651
+ lead_times=lead_times,
652
+ )
653
+
654
+ if forecasting_mode:
655
+ pred = add_valid_time_coord_to_pred(pred)
616
656
 
617
657
  if verbose:
618
658
  dur = time.time() - tic
@@ -621,6 +661,35 @@ class DeepSensorModel(ProbabilisticModel):
621
661
  return pred
622
662
 
623
663
 
664
+ def add_valid_time_coord_to_pred(pred: Prediction) -> Prediction:
665
+ """
666
+ Add a valid time coordinate "time" to a Prediction object based on the
667
+ initialisation times "init_time" and lead times "lead_time".
668
+
669
+ Args:
670
+ pred (:class:`~.model.pred.Prediction`):
671
+ Prediction object to add valid time coordinate to.
672
+
673
+ Returns:
674
+ :class:`~.model.pred.Prediction`:
675
+ Prediction object with valid time coordinate added.
676
+ """
677
+ for var_ID in pred.keys():
678
+ if isinstance(pred[var_ID], pd.DataFrame):
679
+ x = pred[var_ID].reset_index()
680
+ pred[var_ID]["time"] = (x["lead_time"] + x["init_time"]).values
681
+ print(f"{x}")
682
+ print(f"{x.dtypes}")
683
+ elif isinstance(pred[var_ID], xr.Dataset):
684
+ x = pred[var_ID]
685
+ pred[var_ID] = pred[var_ID].assign_coords(
686
+ time=x["lead_time"] + x["init_time"]
687
+ )
688
+ else:
689
+ raise ValueError(f"Unsupported prediction type {type(pred[var_ID])}.")
690
+ return pred
691
+
692
+
624
693
  def main(): # pragma: no cover
625
694
  import deepsensor.tensorflow
626
695
  from deepsensor.data.loader import TaskLoader
@@ -4,6 +4,8 @@ import numpy as np
4
4
  import pandas as pd
5
5
  import xarray as xr
6
6
 
7
+ Timestamp = Union[str, pd.Timestamp, np.datetime64]
8
+
7
9
 
8
10
  class Prediction(dict):
9
11
  """
@@ -32,13 +34,20 @@ class Prediction(dict):
32
34
  n_samples (int)
33
35
  Number of joint samples to draw from the model. If 0, will not
34
36
  draw samples. Default 0.
37
+ forecasting_mode (bool)
38
+ If True, stored forecast predictions with an init_time and lead_time dimension,
39
+ and a valid_time coordinate. If False, stores prediction at t=0 only
40
+ (i.e. spatial interpolation), with only a single time dimension. Default False.
41
+ lead_times (List[pd.Timedelta], optional)
42
+ List of lead times to store in predictions. Must be provided if
43
+ forecasting_mode is True. Default None.
35
44
  """
36
45
 
37
46
  def __init__(
38
47
  self,
39
48
  target_var_IDs: List[str],
40
49
  pred_params: List[str],
41
- dates: List[Union[str, pd.Timestamp]],
50
+ dates: List[Timestamp],
42
51
  X_t: Union[
43
52
  xr.Dataset,
44
53
  xr.DataArray,
@@ -50,6 +59,8 @@ class Prediction(dict):
50
59
  X_t_mask: Optional[Union[xr.Dataset, xr.DataArray]] = None,
51
60
  coord_names: dict = None,
52
61
  n_samples: int = 0,
62
+ forecasting_mode: bool = False,
63
+ lead_times: Optional[List[pd.Timedelta]] = None,
53
64
  ):
54
65
  self.target_var_IDs = target_var_IDs
55
66
  self.X_t_mask = X_t_mask
@@ -58,6 +69,13 @@ class Prediction(dict):
58
69
  self.x1_name = coord_names["x1"]
59
70
  self.x2_name = coord_names["x2"]
60
71
 
72
+ self.forecasting_mode = forecasting_mode
73
+ if forecasting_mode:
74
+ assert (
75
+ lead_times is not None
76
+ ), "If forecasting_mode is True, lead_times must be provided."
77
+ self.lead_times = lead_times
78
+
61
79
  self.mode = infer_prediction_modality_from_X_t(X_t)
62
80
 
63
81
  self.pred_params = pred_params
@@ -67,15 +85,25 @@ class Prediction(dict):
67
85
  *[f"sample_{i}" for i in range(n_samples)],
68
86
  ]
69
87
 
88
+ # Create empty xarray/pandas objects to store predictions
70
89
  if self.mode == "on-grid":
71
90
  for var_ID in self.target_var_IDs:
72
- # Create empty xarray/pandas objects to store predictions
91
+ if self.forecasting_mode:
92
+ prepend_dims = ["lead_time"]
93
+ prepend_coords = {"lead_time": lead_times}
94
+ else:
95
+ prepend_dims = None
96
+ prepend_coords = None
73
97
  self[var_ID] = create_empty_spatiotemporal_xarray(
74
98
  X_t,
75
99
  dates,
76
100
  data_vars=self.pred_params,
77
101
  coord_names=coord_names,
102
+ prepend_dims=prepend_dims,
103
+ prepend_coords=prepend_coords,
78
104
  )
105
+ if self.forecasting_mode:
106
+ self[var_ID] = self[var_ID].rename(time="init_time")
79
107
  if self.X_t_mask is None:
80
108
  # Create 2D boolean array of True values to simplify indexing
81
109
  self.X_t_mask = (
@@ -86,8 +114,18 @@ class Prediction(dict):
86
114
  )
87
115
  elif self.mode == "off-grid":
88
116
  # Repeat target locs for each date to create multiindex
89
- idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
90
- index = pd.MultiIndex.from_tuples(idxs, names=["time", *X_t.index.names])
117
+ if self.forecasting_mode:
118
+ index_names = ["lead_time", "init_time", *X_t.index.names]
119
+ idxs = [
120
+ (lt, date, *idxs)
121
+ for lt in lead_times
122
+ for date in dates
123
+ for idxs in X_t.index
124
+ ]
125
+ else:
126
+ index_names = ["time", *X_t.index.names]
127
+ idxs = [(date, *idxs) for date in dates for idxs in X_t.index]
128
+ index = pd.MultiIndex.from_tuples(idxs, names=index_names)
91
129
  for var_ID in self.target_var_IDs:
92
130
  self[var_ID] = pd.DataFrame(index=index, columns=self.pred_params)
93
131
 
@@ -106,6 +144,7 @@ class Prediction(dict):
106
144
  prediction_parameter: str,
107
145
  date: Union[str, pd.Timestamp],
108
146
  data: np.ndarray,
147
+ lead_times: Optional[List[pd.Timedelta]] = None,
109
148
  ):
110
149
  """
111
150
 
@@ -117,11 +156,29 @@ class Prediction(dict):
117
156
  data (np.ndarray)
118
157
  If off-grid: Shape (N_var, N_targets) or (N_samples, N_var, N_targets).
119
158
  If on-grid: Shape (N_var, N_x1, N_x2) or (N_samples, N_var, N_x1, N_x2).
159
+ lead_time (pd.Timedelta, optional)
160
+ Lead time of the forecast. Required if forecasting_mode is True. Default None.
120
161
  """
162
+ if self.forecasting_mode:
163
+ assert (
164
+ lead_times is not None
165
+ ), "If forecasting_mode is True, lead_times must be provided."
166
+
167
+ msg = f"""
168
+ If forecasting_mode is True, lead_times must be of equal length to the number of
169
+ variables in the data (the first dimension). Got {lead_times=} of length
170
+ {len(lead_times)} lead times and data shape {data.shape}.
171
+ """
172
+ assert len(lead_times) == data.shape[0], msg
173
+
121
174
  if self.mode == "on-grid":
122
175
  if prediction_parameter != "samples":
123
- for var_ID, pred in zip(self.target_var_IDs, data):
124
- self[var_ID][prediction_parameter].loc[date].data[
176
+ for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
177
+ if self.forecasting_mode:
178
+ index = (lead_times[i], date)
179
+ else:
180
+ index = date
181
+ self[var_ID][prediction_parameter].loc[index].data[
125
182
  self.X_t_mask.data
126
183
  ] = pred.ravel()
127
184
  elif prediction_parameter == "samples":
@@ -130,28 +187,44 @@ class Prediction(dict):
130
187
  f"have shape (N_samples, N_var, N_x1, N_x2). Got {data.shape}."
131
188
  )
132
189
  for sample_i, sample in enumerate(data):
133
- for var_ID, pred in zip(self.target_var_IDs, sample):
134
- self[var_ID][f"sample_{sample_i}"].loc[date].data[
190
+ for i, (var_ID, pred) in enumerate(
191
+ zip(self.target_var_IDs, sample)
192
+ ):
193
+ if self.forecasting_mode:
194
+ index = (lead_times[i], date)
195
+ else:
196
+ index = date
197
+ self[var_ID][f"sample_{sample_i}"].loc[index].data[
135
198
  self.X_t_mask.data
136
199
  ] = pred.ravel()
137
200
 
138
201
  elif self.mode == "off-grid":
139
202
  if prediction_parameter != "samples":
140
- for var_ID, pred in zip(self.target_var_IDs, data):
141
- self[var_ID][prediction_parameter].loc[date] = pred
203
+ for i, (var_ID, pred) in enumerate(zip(self.target_var_IDs, data)):
204
+ if self.forecasting_mode:
205
+ index = (lead_times[i], date)
206
+ else:
207
+ index = date
208
+ self[var_ID][prediction_parameter].loc[index] = pred
142
209
  elif prediction_parameter == "samples":
143
210
  assert len(data.shape) == 3, (
144
211
  f"If prediction_parameter is 'samples', and mode is 'off-grid', data must"
145
212
  f"have shape (N_samples, N_var, N_targets). Got {data.shape}."
146
213
  )
147
214
  for sample_i, sample in enumerate(data):
148
- for var_ID, pred in zip(self.target_var_IDs, sample):
149
- self[var_ID][f"sample_{sample_i}"].loc[date] = pred
215
+ for i, (var_ID, pred) in enumerate(
216
+ zip(self.target_var_IDs, sample)
217
+ ):
218
+ if self.forecasting_mode:
219
+ index = (lead_times[i], date)
220
+ else:
221
+ index = date
222
+ self[var_ID][f"sample_{sample_i}"].loc[index] = pred
150
223
 
151
224
 
152
225
  def create_empty_spatiotemporal_xarray(
153
226
  X: Union[xr.Dataset, xr.DataArray],
154
- dates: List,
227
+ dates: List[Timestamp],
155
228
  coord_names: dict = None,
156
229
  data_vars: List[str] = None,
157
230
  prepend_dims: Optional[List[str]] = None,
@@ -231,10 +304,6 @@ def create_empty_spatiotemporal_xarray(
231
304
  # Convert time coord to pandas timestamps
232
305
  pred_ds = pred_ds.assign_coords(time=pd.to_datetime(pred_ds.time.values))
233
306
 
234
- # TODO: Convert init time to forecast time?
235
- # pred_ds = pred_ds.assign_coords(
236
- # time=pred_ds['time'] + pd.Timedelta(days=task_loader.target_delta_t[0]))
237
-
238
307
  return pred_ds
239
308
 
240
309
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: deepsensor
3
- Version: 0.3.7
3
+ Version: 0.4.0
4
4
  Summary: A Python package for modelling xarray and pandas data with neural processes.
5
5
  Home-page: https://github.com/alan-turing-institute/deepsensor
6
6
  Author: Tom R. Andersson
@@ -44,7 +44,7 @@ data with neural processes</p>
44
44
 
45
45
  -----------
46
46
 
47
- [![release](https://img.shields.io/badge/release-v0.3.7-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
47
+ [![release](https://img.shields.io/badge/release-v0.4.0-green?logo=github)](https://github.com/alan-turing-institute/deepsensor/releases)
48
48
  [![Latest Docs](https://img.shields.io/badge/docs-latest-blue.svg)](https://alan-turing-institute.github.io/deepsensor/)
49
49
  ![Tests](https://github.com/alan-turing-institute/deepsensor/actions/workflows/tests.yml/badge.svg)
50
50
  [![Coverage Status](https://coveralls.io/repos/github/alan-turing-institute/deepsensor/badge.svg?branch=main)](https://coveralls.io/github/alan-turing-institute/deepsensor?branch=main)
@@ -240,7 +240,7 @@ if you would like to join this list!
240
240
  <table>
241
241
  <tbody>
242
242
  <tr>
243
- <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a></td>
243
+ <td align="center" valign="top" width="14.28%"><a href="https://github.com/acocac"><img src="https://avatars.githubusercontent.com/u/13321552?v=4?s=100" width="100px;" alt="Alejandro ©"/><br /><sub><b>Alejandro ©</b></sub></a><br /><a href="#userTesting-acocac" title="User Testing">📓</a> <a href="#bug-acocac" title="Bug reports">🐛</a> <a href="#mentoring-acocac" title="Mentoring">🧑‍🏫</a> <a href="#ideas-acocac" title="Ideas, Planning, & Feedback">🤔</a> <a href="#research-acocac" title="Research">🔬</a> <a href="#code-acocac" title="Code">💻</a> <a href="#test-acocac" title="Tests">⚠️</a></td>
244
244
  <td align="center" valign="top" width="14.28%"><a href="https://github.com/annavaughan"><img src="https://avatars.githubusercontent.com/u/45528489?v=4?s=100" width="100px;" alt="Anna Vaughan"/><br /><sub><b>Anna Vaughan</b></sub></a><br /><a href="#research-annavaughan" title="Research">🔬</a></td>
245
245
  <td align="center" valign="top" width="14.28%"><a href="http://davidwilby.dev"><img src="https://avatars.githubusercontent.com/u/24752124?v=4?s=100" width="100px;" alt="David Wilby"/><br /><sub><b>David Wilby</b></sub></a><br /><a href="#doc-davidwilby" title="Documentation">📖</a> <a href="#test-davidwilby" title="Tests">⚠️</a> <a href="#maintenance-davidwilby" title="Maintenance">🚧</a></td>
246
246
  <td align="center" valign="top" width="14.28%"><a href="http://inconsistentrecords.co.uk"><img src="https://avatars.githubusercontent.com/u/731727?v=4?s=100" width="100px;" alt="Jim Circadian"/><br /><sub><b>Jim Circadian</b></sub></a><br /><a href="#ideas-JimCircadian" title="Ideas, Planning, & Feedback">🤔</a> <a href="#projectManagement-JimCircadian" title="Project Management">📆</a> <a href="#maintenance-JimCircadian" title="Maintenance">🚧</a></td>
@@ -22,6 +22,8 @@ deepsensor/data/processor.py
22
22
  deepsensor/data/sources.py
23
23
  deepsensor/data/task.py
24
24
  deepsensor/data/utils.py
25
+ deepsensor/eval/__init__.py
26
+ deepsensor/eval/metrics.py
25
27
  deepsensor/model/__init__.py
26
28
  deepsensor/model/convnp.py
27
29
  deepsensor/model/defaults.py
@@ -1,6 +1,6 @@
1
1
  [metadata]
2
2
  name = deepsensor
3
- version = 0.3.7
3
+ version = 0.4.0
4
4
  author = Tom R. Andersson
5
5
  author_email = tomandersson3@gmail.com
6
6
  description = A Python package for modelling xarray and pandas data with neural processes.
@@ -26,9 +26,6 @@ from deepsensor.data.processor import DataProcessor, xarray_to_coord_array_norma
26
26
  from deepsensor.model.convnp import ConvNP
27
27
 
28
28
 
29
- # from deepsensor.active_learning.acquisition_fns import
30
-
31
-
32
29
  class TestActiveLearning(unittest.TestCase):
33
30
 
34
31
  @classmethod
@@ -18,6 +18,7 @@ from deepsensor.data.processor import DataProcessor
18
18
  from deepsensor.data.loader import TaskLoader
19
19
  from deepsensor.model.convnp import ConvNP
20
20
  from deepsensor.train.train import Trainer
21
+ from deepsensor.eval.metrics import compute_errors
21
22
 
22
23
  from tests.utils import gen_random_data_xr, gen_random_data_pandas
23
24
 
@@ -55,8 +56,15 @@ class TestModel(unittest.TestCase):
55
56
  def setUpClass(cls):
56
57
  # super().__init__(*args, **kwargs)
57
58
  # It's safe to share data between tests because the TaskLoader does not modify data
59
+ cls.var_ID = "2m_temp"
58
60
  cls.da = _gen_data_xr()
61
+ cls.da.name = cls.var_ID
59
62
  cls.df = _gen_data_pandas()
63
+ cls.df.name = cls.var_ID
64
+ # Various tests assume we have a single target set with a single variable.
65
+ # If a test requires multiple target sets or variables, this is set up in the test.
66
+ assert isinstance(cls.da, xr.DataArray)
67
+ assert isinstance(cls.df, pd.Series)
60
68
 
61
69
  cls.dp = DataProcessor()
62
70
  _ = cls.dp([cls.da, cls.df]) # Compute normalisation parameters
@@ -193,7 +201,7 @@ class TestModel(unittest.TestCase):
193
201
  n_targets * dim_y_combined * n_target_dims,
194
202
  ),
195
203
  )
196
- if likelihood in ["cnp-spikes-beta"]:
204
+ if likelihood in ["cnp-spikes-beta", "bernoulli-gamma"]:
197
205
  mixture_probs = model.mixture_probs(task)
198
206
  if isinstance(mixture_probs, (list, tuple)):
199
207
  for p, dim_y in zip(mixture_probs, tl.target_dims):
@@ -215,6 +223,7 @@ class TestModel(unittest.TestCase):
215
223
  ),
216
224
  )
217
225
 
226
+ if likelihood in ["cnp-spikes-beta"]:
218
227
  x = model.alpha(task)
219
228
  if isinstance(x, (list, tuple)):
220
229
  for p, dim_y in zip(x, tl.target_dims):
@@ -229,6 +238,21 @@ class TestModel(unittest.TestCase):
229
238
  else:
230
239
  assert_shape(x, (dim_y_combined, *expected_obs_shape))
231
240
 
241
+ if likelihood in ["bernoulli-gamma"]:
242
+ x = model.k(task)
243
+ if isinstance(x, (list, tuple)):
244
+ for p, dim_y in zip(x, tl.target_dims):
245
+ assert_shape(p, (dim_y, *expected_obs_shape))
246
+ else:
247
+ assert_shape(x, (dim_y_combined, *expected_obs_shape))
248
+
249
+ x = model.scale(task)
250
+ if isinstance(x, (list, tuple)):
251
+ for p, dim_y in zip(x, tl.target_dims):
252
+ assert_shape(p, (dim_y, *expected_obs_shape))
253
+ else:
254
+ assert_shape(x, (dim_y_combined, *expected_obs_shape))
255
+
232
256
  # Scalars
233
257
  if likelihood in ["cnp", "gnp"]:
234
258
  # Methods for Gaussian likelihoods only
@@ -401,10 +425,10 @@ class TestModel(unittest.TestCase):
401
425
  task = tl("2020-01-01")
402
426
  pred = model.predict(task, X_t=da_raw)
403
427
 
404
- assert np.array_equal(
428
+ np.testing.assert_array_equal(
405
429
  pred["dummy_data"]["mean"]["latitude"], da_raw["latitude"]
406
430
  )
407
- assert np.array_equal(
431
+ np.testing.assert_array_equal(
408
432
  pred["dummy_data"]["mean"]["longitude"], da_raw["longitude"]
409
433
  )
410
434
 
@@ -451,61 +475,75 @@ class TestModel(unittest.TestCase):
451
475
  def test_highlevel_predict_with_pred_params_pandas(self):
452
476
  """
453
477
  Test that passing ``pred_params`` to ``.predict`` works with
454
- a spikes-beta likelihood for prediction to pandas.
478
+ mixture model likelihoods for off-grid prediction to pandas.
455
479
  """
456
480
  tl = TaskLoader(context=self.da, target=self.da)
457
- model = ConvNP(
458
- self.dp,
459
- tl,
460
- unet_channels=(5, 5, 5),
461
- verbose=False,
462
- likelihood="cnp-spikes-beta",
463
- )
464
- task = tl("2020-01-01", context_sampling=10, target_sampling=10)
465
481
 
466
- # Off-grid prediction
467
- X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]])
482
+ likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"]
483
+ expected_pred_params = [
484
+ ["mean", "std", "variance", "alpha", "beta"],
485
+ ["mean", "std", "variance", "k", "scale"],
486
+ ]
487
+
488
+ for likelihood, pred_params in zip(likelihoods, expected_pred_params):
489
+ model = ConvNP(
490
+ self.dp,
491
+ tl,
492
+ unet_channels=(5, 5, 5),
493
+ verbose=False,
494
+ likelihood=likelihood,
495
+ )
496
+ task = tl("2020-01-01", context_sampling=10)
497
+
498
+ # Off-grid prediction
499
+ X_t = np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]])
468
500
 
469
- # Check that nothing breaks and the correct parameters are returned
470
- pred_params = ["mean", "std", "variance", "alpha", "beta"]
471
- pred = model.predict(task, X_t=X_t, pred_params=pred_params)
472
- for pred_param in pred_params:
473
- assert pred_param in pred["var"]
501
+ # Check that nothing breaks and the correct parameters are returned
502
+ pred = model.predict(task, X_t=X_t, pred_params=pred_params)
503
+ for pred_param in pred_params:
504
+ assert pred_param in pred[self.var_ID]
474
505
 
475
- # Test mixture probs special case
476
- pred_params = ["mixture_probs"]
477
- pred = model.predict(task, X_t=self.da, pred_params=pred_params)
478
- for component in range(model.N_mixture_components):
479
- pred_param = f"mixture_probs_{component}"
480
- assert pred_param in pred["var"]
506
+ # Test mixture probs special case
507
+ pred_params = ["mixture_probs"]
508
+ pred = model.predict(task, X_t=self.da, pred_params=pred_params)
509
+ for component in range(model.N_mixture_components):
510
+ pred_param = f"mixture_probs_{component}"
511
+ assert pred_param in pred[self.var_ID]
481
512
 
482
513
  def test_highlevel_predict_with_pred_params_xarray(self):
483
514
  """
484
515
  Test that passing ``pred_params`` to ``.predict`` works with
485
- a spikes-beta likelihood for prediction to xarray.
516
+ mixture model likelihoods for gridded prediction to xarray.
486
517
  """
487
518
  tl = TaskLoader(context=self.da, target=self.da)
488
- model = ConvNP(
489
- self.dp,
490
- tl,
491
- unet_channels=(5, 5, 5),
492
- verbose=False,
493
- likelihood="cnp-spikes-beta",
494
- )
495
- task = tl("2020-01-01", context_sampling=10, target_sampling=10)
496
519
 
497
- # Check that nothing breaks and the correct parameters are returned
498
- pred_params = ["mean", "std", "variance", "alpha", "beta"]
499
- pred = model.predict(task, X_t=self.da, pred_params=pred_params)
500
- for pred_param in pred_params:
501
- assert pred_param in pred["var"]
520
+ likelihoods = ["cnp-spikes-beta", "bernoulli-gamma"]
521
+ expected_pred_params = [
522
+ ["mean", "std", "variance", "alpha", "beta"],
523
+ ["mean", "std", "variance", "k", "scale"],
524
+ ]
525
+
526
+ for likelihood, pred_params in zip(likelihoods, expected_pred_params):
527
+ model = ConvNP(
528
+ self.dp,
529
+ tl,
530
+ unet_channels=(5, 5, 5),
531
+ verbose=False,
532
+ likelihood=likelihood,
533
+ )
534
+ task = tl("2020-01-01", context_sampling=10)
535
+
536
+ # Check that nothing breaks and the correct parameters are returned
537
+ pred = model.predict(task, X_t=self.da, pred_params=pred_params)
538
+ for pred_param in pred_params:
539
+ assert pred_param in pred[self.var_ID]
502
540
 
503
- # Test mixture probs special case
504
- pred_params = ["mixture_probs"]
505
- pred = model.predict(task, X_t=self.da, pred_params=pred_params)
506
- for component in range(model.N_mixture_components):
507
- pred_param = f"mixture_probs_{component}"
508
- assert pred_param in pred["var"]
541
+ # Test mixture probs special case
542
+ pred_params = ["mixture_probs"]
543
+ pred = model.predict(task, X_t=self.da, pred_params=pred_params)
544
+ for component in range(model.N_mixture_components):
545
+ pred_param = f"mixture_probs_{component}"
546
+ assert pred_param in pred[self.var_ID]
509
547
 
510
548
  def test_highlevel_predict_with_invalid_pred_params(self):
511
549
  """Test that passing ``pred_params`` to ``.predict`` works."""
@@ -610,6 +648,66 @@ class TestModel(unittest.TestCase):
610
648
  ar_sample=True,
611
649
  )
612
650
 
651
+ def test_forecasting_model_predict_return_valid_times(self):
652
+ """Test that the times returned by a forecasting model are valid."""
653
+ init_dates = ["2020-01-01", "2020-01-02"]
654
+ expected_init_times = np.array(init_dates).astype(np.datetime64)
655
+
656
+ lead_times_days = [1, 2, 3]
657
+ expected_lead_times = np.array(
658
+ [np.timedelta64(lt, "D") for lt in lead_times_days]
659
+ )
660
+
661
+ expected_valid_times = np.array(
662
+ expected_lead_times[:, None] + expected_init_times[None, :]
663
+ )
664
+
665
+ tl = TaskLoader(
666
+ context=self.da,
667
+ target=[
668
+ self.da,
669
+ ]
670
+ * len(lead_times_days),
671
+ target_delta_t=lead_times_days,
672
+ time_freq="D",
673
+ )
674
+ model = ConvNP(self.dp, tl, unet_channels=(5, 5, 5), verbose=False)
675
+ tasks = tl(init_dates, context_sampling=10)
676
+
677
+ X_ts = [
678
+ # Gridded predictions (xarray)
679
+ self.da,
680
+ # Off-grid prediction (pandas)
681
+ np.array([[0.0, 0.5, 1.0], [0.0, 0.5, 1.0]]),
682
+ ]
683
+ for X_t in X_ts:
684
+ pred = model.predict(tasks, X_t=X_t)
685
+
686
+ pred_var = pred[self.var_ID]
687
+
688
+ if isinstance(pred_var, xr.Dataset):
689
+ # Check we can compute errors using the valid time coord ('time')
690
+ errors = compute_errors(pred, self.da.to_dataset())
691
+ for var_ID in errors.keys():
692
+ assert tuple(errors[var_ID].dims) == (
693
+ "lead_time",
694
+ "init_time",
695
+ "x1",
696
+ "x2",
697
+ )
698
+ assert errors[var_ID].shape == pred[var_ID]["mean"].shape
699
+ elif isinstance(pred_var, pd.DataFrame):
700
+ # Makes coordinate checking easier by avoiding repeat values
701
+ pred_var = pred_var.to_xarray().isel(x1=0, x2=0)
702
+
703
+ np.testing.assert_array_equal(
704
+ pred_var.lead_time.values, expected_lead_times
705
+ )
706
+ np.testing.assert_array_equal(
707
+ pred_var.init_time.values, expected_init_times
708
+ )
709
+ np.testing.assert_array_equal(pred_var.time.values, expected_valid_times)
710
+
613
711
 
614
712
  def assert_shape(x, shape: tuple):
615
713
  """Assert that the shape of ``x`` matches ``shape``."""
@@ -192,6 +192,23 @@ class TestTaskLoader(unittest.TestCase):
192
192
  with self.assertRaises(InvalidSamplingStrategyError):
193
193
  task = tl("2020-01-01", invalid_sampling_strategy)
194
194
 
195
+ def test_different_dtype_when_sampling_offgrid_data_at_specific_numpy_locs(self):
196
+ """Test different dtype when sampling off-grid data at specific numpy locations."""
197
+ sampling_strat = np.array(
198
+ [np.linspace(0, 1, 10), np.linspace(0, 1, 10)], dtype=np.float16
199
+ )
200
+
201
+ tl = TaskLoader(
202
+ context=self.df,
203
+ target=self.df,
204
+ )
205
+
206
+ assert sampling_strat.dtype != tl.context[0].index.get_level_values("x1").dtype
207
+ assert sampling_strat.dtype != tl.context[0].index.get_level_values("x2").dtype
208
+
209
+ with self.assertRaises(InvalidSamplingStrategyError):
210
+ task = tl("2020-01-01", sampling_strat, sampling_strat)
211
+
195
212
  def test_wrong_links(self):
196
213
  """Test link indexes out of range."""
197
214
  with self.assertRaises(ValueError):
File without changes
File without changes
File without changes
File without changes