PVNet_summation 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of PVNet_summation might be problematic. Click here for more details.

@@ -0,0 +1 @@
1
+ """PVNet_summation"""
@@ -0,0 +1,2 @@
1
+ """Data module"""
2
+ from .datamodule import PresavedDataModule, StreamedDataModule
@@ -0,0 +1,213 @@
1
+ """Pytorch lightning datamodules for loading pre-saved samples and predictions."""
2
+
3
+ from glob import glob
4
+ from typing import TypeAlias
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ import torch
9
+ from lightning.pytorch import LightningDataModule
10
+ from ocf_data_sampler.load.gsp import open_gsp
11
+ from ocf_data_sampler.numpy_sample.common_types import NumpyBatch, NumpySample
12
+ from ocf_data_sampler.torch_datasets.datasets.pvnet_uk import PVNetUKConcurrentDataset
13
+ from ocf_data_sampler.utils import minutes
14
+ from torch.utils.data import DataLoader, Dataset, default_collate
15
+ from typing_extensions import override
16
+
17
+ SumNumpySample: TypeAlias = dict[str, np.ndarray | NumpyBatch]
18
+ SumTensorBatch: TypeAlias = dict[str, torch.Tensor]
19
+
20
+
21
+ class StreamedDataset(PVNetUKConcurrentDataset):
22
+ """A torch dataset for creating concurrent PVNet inputs and national targets."""
23
+
24
+ def __init__(
25
+ self,
26
+ config_filename: str,
27
+ start_time: str | None = None,
28
+ end_time: str | None = None,
29
+ ) -> None:
30
+ """A torch dataset for creating concurrent PVNet inputs and national targets.
31
+
32
+ Args:
33
+ config_filename: Path to the configuration file
34
+ start_time: Limit the init-times to be after this
35
+ end_time: Limit the init-times to be before this
36
+ """
37
+ super().__init__(config_filename, start_time, end_time, gsp_ids=None)
38
+
39
+ # Load and nornmalise the national GSP data to use as target values
40
+ national_gsp_data = (
41
+ open_gsp(
42
+ zarr_path=self.config.input_data.gsp.zarr_path,
43
+ boundaries_version=self.config.input_data.gsp.boundaries_version
44
+ )
45
+ .sel(gsp_id=0)
46
+ .compute()
47
+ )
48
+ self.national_gsp_data = national_gsp_data / national_gsp_data.effective_capacity_mwp
49
+
50
+
51
+ def _get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
52
+ """Generate a concurrent PVNet sample for given init-time.
53
+
54
+ Args:
55
+ t0: init-time for sample
56
+ """
57
+
58
+ pvnet_inputs: NumpySample = super()._get_sample(t0)
59
+
60
+ location_capacities = pvnet_inputs["gsp_effective_capacity_mwp"]
61
+
62
+ valid_times = pd.date_range(
63
+ t0+minutes(self.config.input_data.gsp.time_resolution_minutes),
64
+ t0+minutes(self.config.input_data.gsp.interval_end_minutes),
65
+ freq=minutes(self.config.input_data.gsp.time_resolution_minutes)
66
+ )
67
+
68
+ total_outturns = self.national_gsp_data.sel(time_utc=valid_times).values
69
+ total_capacity = self.national_gsp_data.sel(time_utc=t0).effective_capacity_mwp.item()
70
+
71
+ relative_capacities = location_capacities / total_capacity
72
+
73
+ return {
74
+ # NumpyBatch object with batch size = num_locations
75
+ "pvnet_inputs": pvnet_inputs,
76
+ # Shape: [time]
77
+ "target": total_outturns,
78
+ # Shape: [time]
79
+ "valid_times": valid_times.values.astype(int),
80
+ # Shape:
81
+ "last_outturn": self.national_gsp_data.sel(time_utc=t0).values,
82
+ # Shape: [num_locations]
83
+ "relative_capacity": relative_capacities,
84
+ }
85
+
86
+ @override
87
+ def __getitem__(self, idx: int) -> SumNumpySample:
88
+ return super().__getitem__(idx)
89
+
90
+ @override
91
+ def get_sample(self, t0: pd.Timestamp) -> SumNumpySample:
92
+ return super().get_sample(t0)
93
+
94
+
95
+ class StreamedDataModule(LightningDataModule):
96
+ """Datamodule for training pvnet_summation."""
97
+
98
+ def __init__(
99
+ self,
100
+ configuration: str,
101
+ train_period: list[str | None] = [None, None],
102
+ val_period: list[str | None] = [None, None],
103
+ num_workers: int = 0,
104
+ prefetch_factor: int | None = None,
105
+ persistent_workers: bool = False,
106
+ ):
107
+ """Datamodule for creating concurrent PVNet inputs and national targets.
108
+
109
+ Args:
110
+ configuration: Path to ocf-data-sampler configuration file.
111
+ train_period: Date range filter for train dataloader.
112
+ val_period: Date range filter for val dataloader.
113
+ num_workers: Number of workers to use in multiprocess batch loading.
114
+ prefetch_factor: Number of data will be prefetched at the end of each worker process.
115
+ persistent_workers: If True, the data loader will not shut down the worker processes
116
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
117
+ instances alive.
118
+ """
119
+ super().__init__()
120
+ self.configuration = configuration
121
+ self.train_period = train_period
122
+ self.val_period = val_period
123
+
124
+ self._dataloader_kwargs = dict(
125
+ batch_size=None,
126
+ batch_sampler=None,
127
+ num_workers=num_workers,
128
+ collate_fn=None,
129
+ pin_memory=False,
130
+ drop_last=False,
131
+ timeout=0,
132
+ worker_init_fn=None,
133
+ prefetch_factor=prefetch_factor,
134
+ persistent_workers=persistent_workers,
135
+ )
136
+
137
+ def train_dataloader(self, shuffle: bool = False) -> DataLoader:
138
+ """Construct train dataloader"""
139
+ dataset = StreamedDataset(self.configuration, *self.train_period)
140
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
141
+
142
+ def val_dataloader(self, shuffle: bool = False) -> DataLoader:
143
+ """Construct val dataloader"""
144
+ dataset = StreamedDataset(self.configuration, *self.val_period)
145
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
146
+
147
+
148
+ class PresavedDataset(Dataset):
149
+ """Dataset for loading pre-saved PVNet predictions from disk"""
150
+
151
+ def __init__(self, sample_dir: str):
152
+ """"Dataset for loading pre-saved PVNet predictions from disk.
153
+
154
+ Args:
155
+ sample_dir: The directory containing the saved samples
156
+ """
157
+ self.sample_filepaths = sorted(glob(f"{sample_dir}/*.pt"))
158
+
159
+ def __len__(self) -> int:
160
+ return len(self.sample_filepaths)
161
+
162
+ def __getitem__(self, idx: int) -> dict:
163
+ return torch.load(self.sample_filepaths[idx], weights_only=True)
164
+
165
+
166
+ class PresavedDataModule(LightningDataModule):
167
+ """Datamodule for loading pre-saved PVNet predictions."""
168
+
169
+ def __init__(
170
+ self,
171
+ sample_dir: str,
172
+ batch_size: int = 16,
173
+ num_workers: int = 0,
174
+ prefetch_factor: int | None = None,
175
+ persistent_workers: bool = False,
176
+ ):
177
+ """Datamodule for loading pre-saved PVNet predictions.
178
+
179
+ Args:
180
+ sample_dir: Path to the directory of pre-saved samples.
181
+ batch_size: Batch size.
182
+ num_workers: Number of workers to use in multiprocess batch loading.
183
+ prefetch_factor: Number of data will be prefetched at the end of each worker process.
184
+ persistent_workers: If True, the data loader will not shut down the worker processes
185
+ after a dataset has been consumed once. This allows to maintain the workers Dataset
186
+ instances alive.
187
+ """
188
+ super().__init__()
189
+ self.sample_dir = sample_dir
190
+
191
+ self._dataloader_kwargs = dict(
192
+ batch_size=batch_size,
193
+ sampler=None,
194
+ batch_sampler=None,
195
+ num_workers=num_workers,
196
+ collate_fn=None if batch_size is None else default_collate,
197
+ pin_memory=False,
198
+ drop_last=False,
199
+ timeout=0,
200
+ worker_init_fn=None,
201
+ prefetch_factor=prefetch_factor,
202
+ persistent_workers=persistent_workers,
203
+ )
204
+
205
+ def train_dataloader(self, shuffle: bool = True) -> DataLoader:
206
+ """Construct train dataloader"""
207
+ dataset = PresavedDataset(f"{self.sample_dir}/train")
208
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
209
+
210
+ def val_dataloader(self, shuffle: bool = False) -> DataLoader:
211
+ """Construct val dataloader"""
212
+ dataset = PresavedDataset(f"{self.sample_dir}/val")
213
+ return DataLoader(dataset, shuffle=shuffle, **self._dataloader_kwargs)
@@ -0,0 +1,70 @@
1
+ """Load a model from its checkpoint directory"""
2
+
3
+ import glob
4
+ import os
5
+
6
+ import hydra
7
+ import torch
8
+ import yaml
9
+
10
+ from pvnet_summation.utils import (
11
+ DATAMODULE_CONFIG_NAME,
12
+ FULL_CONFIG_NAME,
13
+ MODEL_CONFIG_NAME,
14
+ )
15
+
16
+
17
+ def get_model_from_checkpoints(
18
+ checkpoint_dir_path: str,
19
+ val_best: bool = True,
20
+ ) -> tuple[torch.nn.Module, dict, str | None, str | None]:
21
+ """Load a model from its checkpoint directory
22
+
23
+ Returns:
24
+ tuple:
25
+ model: nn.Module of pretrained model.
26
+ model_config: path to model config used to train the model.
27
+ datamodule_config: path to datamodule used to create samples e.g train/test split info.
28
+ experiment_configs: path to the full experimental config.
29
+
30
+ """
31
+
32
+ # Load lightning training module
33
+ with open(f"{checkpoint_dir_path}/{MODEL_CONFIG_NAME}") as cfg:
34
+ model_config = yaml.load(cfg, Loader=yaml.FullLoader)
35
+
36
+ lightning_module = hydra.utils.instantiate(model_config)
37
+
38
+ if val_best:
39
+ # Only one epoch (best) saved per model
40
+ files = glob.glob(f"{checkpoint_dir_path}/epoch*.ckpt")
41
+ if len(files) != 1:
42
+ raise ValueError(
43
+ f"Found {len(files)} checkpoints @ {checkpoint_dir_path}/epoch*.ckpt. Expected one."
44
+ )
45
+
46
+ checkpoint = torch.load(files[0], map_location="cpu", weights_only=True)
47
+ else:
48
+ checkpoint = torch.load(
49
+ f"{checkpoint_dir_path}/last.ckpt",
50
+ map_location="cpu",
51
+ weights_only=True,
52
+ )
53
+
54
+ lightning_module.load_state_dict(state_dict=checkpoint["state_dict"])
55
+
56
+ # Extract the model from the lightning module
57
+ model = lightning_module.model
58
+ model_config = model_config["model"]
59
+
60
+ # Check for datamodule config
61
+ # This only exists if the model was trained with presaved samples
62
+ datamodule_config = f"{checkpoint_dir_path}/{DATAMODULE_CONFIG_NAME}"
63
+ datamodule_config = datamodule_config if os.path.isfile(datamodule_config) else None
64
+
65
+ # Check for experiment config
66
+ # For backwards compatibility - this might not always exist
67
+ experiment_config = f"{checkpoint_dir_path}/{FULL_CONFIG_NAME}"
68
+ experiment_config = experiment_config if os.path.isfile(experiment_config) else None
69
+
70
+ return model, model_config, datamodule_config, experiment_config
@@ -0,0 +1,3 @@
1
+ """Models for PVNet summation"""
2
+ from .base_model import BaseModel
3
+ from .dense_model import DenseModel
@@ -0,0 +1,345 @@
1
+ """Base model for all PVNet submodels"""
2
+ import logging
3
+ import os
4
+ import shutil
5
+ import time
6
+ from importlib.metadata import version
7
+ from pathlib import Path
8
+
9
+ import hydra
10
+ import torch
11
+ import yaml
12
+ from huggingface_hub import ModelCard, ModelCardData, snapshot_download
13
+ from huggingface_hub.hf_api import HfApi
14
+ from safetensors.torch import load_file, save_file
15
+
16
+ from pvnet_summation.data.datamodule import SumTensorBatch
17
+ from pvnet_summation.utils import (
18
+ DATAMODULE_CONFIG_NAME,
19
+ FULL_CONFIG_NAME,
20
+ MODEL_CARD_NAME,
21
+ MODEL_CONFIG_NAME,
22
+ PYTORCH_WEIGHTS_NAME,
23
+ )
24
+
25
+
26
+ def santize_datamodule(config: dict) -> dict:
27
+ """Create new datamodule config which only keeps the details required for inference"""
28
+ return {"pvnet_model": config["pvnet_model"]}
29
+
30
+
31
+ def download_from_hf(
32
+ repo_id: str,
33
+ filename: str | list[str],
34
+ revision: str,
35
+ cache_dir: str | None,
36
+ force_download: bool,
37
+ max_retries: int = 5,
38
+ wait_time: int = 10,
39
+ ) -> str | list[str]:
40
+ """Tries to download one or more files from HuggingFace up to max_retries times.
41
+
42
+ Args:
43
+ repo_id: HuggingFace repo ID
44
+ filename: Name of the file(s) to download
45
+ revision: Specific model revision
46
+ cache_dir: Cache directory
47
+ force_download: Whether to force a new download
48
+ max_retries: Maximum number of retry attempts
49
+ wait_time: Wait time (in seconds) before retrying
50
+
51
+ Returns:
52
+ The local file path of the downloaded file(s)
53
+ """
54
+ for attempt in range(1, max_retries + 1):
55
+ try:
56
+ save_dir = snapshot_download(
57
+ repo_id=repo_id,
58
+ allow_patterns=filename,
59
+ revision=revision,
60
+ cache_dir=cache_dir,
61
+ force_download=force_download,
62
+ )
63
+
64
+ if isinstance(filename, list):
65
+ return [f"{save_dir}/{f}" for f in filename]
66
+ else:
67
+ return f"{save_dir}/{filename}"
68
+
69
+ except Exception as e:
70
+ if attempt == max_retries:
71
+ raise Exception(
72
+ f"Failed to download {filename} from {repo_id} after {max_retries} attempts."
73
+ ) from e
74
+ logging.warning(
75
+ (
76
+ f"Attempt {attempt}/{max_retries} failed to download {filename} "
77
+ f"from {repo_id}. Retrying in {wait_time} seconds..."
78
+ )
79
+ )
80
+ time.sleep(wait_time)
81
+
82
+
83
+ class HuggingfaceMixin:
84
+ """Mixin for saving and loading model to and from huggingface"""
85
+
86
+ @classmethod
87
+ def from_pretrained(
88
+ cls,
89
+ model_id: str,
90
+ revision: str,
91
+ cache_dir: str | None = None,
92
+ force_download: bool = False,
93
+ strict: bool = True,
94
+ ) -> "BaseModel":
95
+ """Load Pytorch pretrained weights and return the loaded model."""
96
+
97
+ if os.path.isdir(model_id):
98
+ print("Loading model from local directory")
99
+ model_file = f"{model_id}/{PYTORCH_WEIGHTS_NAME}"
100
+ config_file = f"{model_id}/{MODEL_CONFIG_NAME}"
101
+ else:
102
+ print("Loading model from huggingface repo")
103
+
104
+ model_file, config_file = download_from_hf(
105
+ repo_id=model_id,
106
+ filename=[PYTORCH_WEIGHTS_NAME, MODEL_CONFIG_NAME],
107
+ revision=revision,
108
+ cache_dir=cache_dir,
109
+ force_download=force_download,
110
+ max_retries=5,
111
+ wait_time=10,
112
+ )
113
+
114
+ with open(config_file, "r") as f:
115
+ model = hydra.utils.instantiate(yaml.safe_load(f))
116
+
117
+ state_dict = load_file(model_file)
118
+ model.load_state_dict(state_dict, strict=strict) # type: ignore
119
+ model.eval() # type: ignore
120
+
121
+ return model
122
+
123
+ @classmethod
124
+ def get_datamodule_config(
125
+ cls,
126
+ model_id: str,
127
+ revision: str,
128
+ cache_dir: str | None = None,
129
+ force_download: bool = False,
130
+ ) -> str:
131
+ """Load data config file."""
132
+ if os.path.isdir(model_id):
133
+ print("Loading datamodule config from local directory")
134
+ datamodule_config_file = os.path.join(model_id, DATAMODULE_CONFIG_NAME)
135
+ else:
136
+ print("Loading datamodule config from huggingface repo")
137
+ datamodule_config_file = download_from_hf(
138
+ repo_id=model_id,
139
+ filename=DATAMODULE_CONFIG_NAME,
140
+ revision=revision,
141
+ cache_dir=cache_dir,
142
+ force_download=force_download,
143
+ max_retries=5,
144
+ wait_time=10,
145
+ )
146
+
147
+ return datamodule_config_file
148
+
149
+ def _save_model_weights(self, save_directory: str) -> None:
150
+ """Save weights from a Pytorch model to a local directory."""
151
+ save_file(self.state_dict(), f"{save_directory}/{PYTORCH_WEIGHTS_NAME}")
152
+
153
+ def save_pretrained(
154
+ self,
155
+ save_directory: str,
156
+ model_config: dict,
157
+ wandb_repo: str,
158
+ wandb_id: str,
159
+ card_template_path: str,
160
+ datamodule_config_path,
161
+ experiment_config_path: str | None = None,
162
+ hf_repo_id: str | None = None,
163
+ push_to_hub: bool = False,
164
+ ) -> None:
165
+ """Save weights in local directory or upload to huggingface hub.
166
+
167
+ Args:
168
+ save_directory:
169
+ Path to directory in which the model weights and configuration will be saved.
170
+ model_config (`dict`):
171
+ Model configuration specified as a key/value dictionary.
172
+ wandb_repo: Identifier of the repo on wandb.
173
+ wandb_id: Identifier of the model on wandb.
174
+ datamodule_config_path:
175
+ The path to the datamodule config.
176
+ card_template_path: Path to the HuggingFace model card template. Defaults to card in
177
+ PVNet library if set to None.
178
+ experiment_config_path:
179
+ The path to the full experimental config.
180
+ hf_repo_id:
181
+ ID of your repository on the Hub. Used only if `push_to_hub=True`. Will default to
182
+ the folder name if not provided.
183
+ push_to_hub (`bool`, *optional*, defaults to `False`):
184
+ Whether or not to push your model to the HuggingFace Hub after saving it.
185
+ """
186
+
187
+ save_directory = Path(save_directory)
188
+ save_directory.mkdir(parents=True, exist_ok=True)
189
+
190
+ # Save model weights/files
191
+ self._save_model_weights(save_directory)
192
+
193
+ # Save the model config
194
+ if isinstance(model_config, dict):
195
+ with open(save_directory / MODEL_CONFIG_NAME, "w") as outfile:
196
+ yaml.dump(model_config, outfile, sort_keys=False, default_flow_style=False)
197
+
198
+ # Sanitize and save the datamodule config
199
+ with open(datamodule_config_path) as cfg:
200
+ datamodule_config = yaml.load(cfg, Loader=yaml.FullLoader)
201
+
202
+ datamodule_config = santize_datamodule(datamodule_config)
203
+
204
+ with open(save_directory / DATAMODULE_CONFIG_NAME, "w") as outfile:
205
+ yaml.dump(datamodule_config, outfile, sort_keys=False, default_flow_style=False)
206
+
207
+ # Save the full experimental config
208
+ if experiment_config_path is not None:
209
+ shutil.copyfile(experiment_config_path, save_directory / FULL_CONFIG_NAME)
210
+
211
+ card = self.create_hugging_face_model_card(card_template_path, wandb_repo, wandb_id)
212
+
213
+ (save_directory / MODEL_CARD_NAME).write_text(str(card))
214
+
215
+ if push_to_hub:
216
+ api = HfApi()
217
+
218
+ api.upload_folder(
219
+ repo_id=hf_repo_id,
220
+ folder_path=save_directory,
221
+ repo_type="model",
222
+ commit_message=f"Upload model - {wandb_id}",
223
+ )
224
+
225
+ # Print the most recent commit hash
226
+ c = api.list_repo_commits(repo_id=hf_repo_id, repo_type="model")[0]
227
+
228
+ message = (
229
+ f"The latest commit is now: \n"
230
+ f" date: {c.created_at} \n"
231
+ f" commit hash: {c.commit_id}\n"
232
+ f" by: {c.authors}\n"
233
+ f" title: {c.title}\n"
234
+ )
235
+
236
+ print(message)
237
+
238
+ @staticmethod
239
+ def create_hugging_face_model_card(
240
+ card_template_path: str,
241
+ wandb_repo: str,
242
+ wandb_id: str,
243
+ ) -> ModelCard:
244
+ """
245
+ Creates Hugging Face model card
246
+
247
+ Args:
248
+ card_template_path: Path to the HuggingFace model card template
249
+ wandb_repo: Identifier of the repo on wandb.
250
+ wandb_id: Identifier of the model on wandb.
251
+
252
+ Returns:
253
+ card: ModelCard - Hugging Face model card object
254
+ """
255
+
256
+ # Creating and saving model card.
257
+ card_data = ModelCardData(language="en", license="mit", library_name="pytorch")
258
+
259
+ link = f"https://wandb.ai/{wandb_repo}/runs/{wandb_id}"
260
+ wandb_link = f" - [{link}]({link})\n"
261
+
262
+ # Find package versions for OCF packages
263
+ packages_to_display = ["pvnet_summation", "ocf-data-sampler"]
264
+ packages_and_versions = {package: version(package) for package in packages_to_display}
265
+
266
+
267
+ package_versions_markdown = ""
268
+ for package, v in packages_and_versions.items():
269
+ package_versions_markdown += f" - {package}=={v}\n"
270
+
271
+ return ModelCard.from_template(
272
+ card_data,
273
+ template_path=card_template_path,
274
+ wandb_link=wandb_link,
275
+ package_versions=package_versions_markdown,
276
+ )
277
+
278
+
279
+ class BaseModel(torch.nn.Module, HuggingfaceMixin):
280
+ """Abstract base class for PVNet-summation submodels"""
281
+
282
+ def __init__(
283
+ self,
284
+ output_quantiles: list[float] | None,
285
+ num_input_locations: int,
286
+ input_quantiles: list[float] | None,
287
+ history_minutes: int,
288
+ forecast_minutes: int,
289
+ interval_minutes: int,
290
+ ):
291
+ """Abtstract base class for PVNet-summation submodels.
292
+
293
+ """
294
+ super().__init__()
295
+
296
+ self.output_quantiles = output_quantiles
297
+
298
+ self.num_input_locations = num_input_locations
299
+ self.input_quantiles = input_quantiles
300
+
301
+ self.history_minutes = history_minutes
302
+ self.forecast_minutes = forecast_minutes
303
+ self.interval_minutes = interval_minutes
304
+
305
+ # Number of timestemps for 30 minutely data
306
+ self.history_len = history_minutes // interval_minutes
307
+ self.forecast_len = (forecast_minutes) // interval_minutes
308
+
309
+ # Store whether the model should use quantile regression or simply predict the mean
310
+ self.use_quantile_regression = self.output_quantiles is not None
311
+
312
+ # Store the number of ouput features that the model should predict for
313
+ if self.use_quantile_regression:
314
+ self.num_output_features = self.forecast_len * len(self.output_quantiles)
315
+ else:
316
+ self.num_output_features = self.forecast_len
317
+
318
+ # Store the expected input shape
319
+ if input_quantiles is None:
320
+ self.input_shape = (self.num_input_locations, self.forecast_len)
321
+ else:
322
+ self.input_shape = (self.num_input_locations, self.forecast_len, len(input_quantiles))
323
+
324
+ def _quantiles_to_prediction(self, y_quantiles: torch.Tensor) -> torch.Tensor:
325
+ """Convert network prediction into a point prediction.
326
+
327
+ Args:
328
+ y_quantiles: Quantile prediction of network
329
+
330
+ Returns:
331
+ torch.Tensor: Point prediction
332
+ """
333
+ # y_quantiles Shape: [batch_size, seq_length, num_quantiles]
334
+ idx = self.output_quantiles.index(0.5)
335
+ return y_quantiles[..., idx]
336
+
337
+ def sum_of_locations(self, x: SumTensorBatch) -> torch.Tensor:
338
+ """Compute the sum of the location-level predictions"""
339
+ if self.input_quantiles is None:
340
+ y_hat = x["pvnet_outputs"]
341
+ else:
342
+ idx = self.input_quantiles.index(0.5)
343
+ y_hat = x["pvnet_outputs"][..., idx]
344
+
345
+ return (y_hat * x["relative_capacity"].unsqueeze(-1)).sum(dim=1)