PVNet 5.3.4__py3-none-any.whl → 5.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -46,6 +46,7 @@ class LateFusionModel(BaseModel):
46
46
  include_generation_history: bool = False,
47
47
  include_sun: bool = True,
48
48
  include_time: bool = False,
49
+ t0_embedding_dim: int = 0,
49
50
  location_id_mapping: dict[Any, int] | None = None,
50
51
  embedding_dim: int = 16,
51
52
  forecast_minutes: int = 30,
@@ -85,6 +86,8 @@ class LateFusionModel(BaseModel):
85
86
  include_generation_history: Include generation yield data.
86
87
  include_sun: Include sun azimuth and altitude data.
87
88
  include_time: Include sine and cosine of dates and times.
89
+ t0_embedding_dim: Shape of the embedding of the init-time (t0) of the forecast. Not used
90
+ if set to 0.
88
91
  location_id_mapping: A dictionary mapping the location ID to an integer. ID embedding is
89
92
  not used if this is not provided.
90
93
  embedding_dim: Number of embedding dimensions to use for location ID.
@@ -119,6 +122,7 @@ class LateFusionModel(BaseModel):
119
122
  self.include_pv = pv_encoder is not None
120
123
  self.include_sun = include_sun
121
124
  self.include_time = include_time
125
+ self.t0_embedding_dim = t0_embedding_dim
122
126
  self.location_id_mapping = location_id_mapping
123
127
  self.embedding_dim = embedding_dim
124
128
  self.add_image_embedding_channel = add_image_embedding_channel
@@ -246,6 +250,8 @@ class LateFusionModel(BaseModel):
246
250
  # Update num features
247
251
  fusion_input_features += 32
248
252
 
253
+ fusion_input_features += self.t0_embedding_dim
254
+
249
255
  if include_generation_history:
250
256
  # Update num features
251
257
  fusion_input_features += self.history_len + 1
@@ -321,6 +327,9 @@ class LateFusionModel(BaseModel):
321
327
  time = self.time_fc1(time)
322
328
  modes["time"] = time
323
329
 
330
+ if self.t0_embedding_dim>0:
331
+ modes["t0_embed"] = x["t0_embedding"]
332
+
324
333
  out = self.output_network(modes)
325
334
 
326
335
  if self.use_quantile_regression:
pvnet/utils.py CHANGED
@@ -101,66 +101,64 @@ def validate_batch_against_config(
101
101
  logger.info("Performing batch shape validation against model config.")
102
102
 
103
103
  # NWP validation
104
- if hasattr(model, "nwp_encoders_dict"):
104
+ if model.include_nwp:
105
105
  if "nwp" not in batch:
106
- raise ValueError(
107
- "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
108
- )
106
+ raise ValueError("Model uses NWP data but 'nwp' missing from batch.")
109
107
 
110
- for source, nwp_data in batch["nwp"].items():
111
- if source in model.nwp_encoders_dict:
112
- enc = model.nwp_encoders_dict[source]
113
- expected_channels = enc.in_channels
114
- if model.add_image_embedding_channel:
115
- expected_channels -= 1
116
-
117
- expected = (
118
- nwp_data["nwp"].shape[0],
119
- enc.sequence_length,
120
- expected_channels,
121
- enc.image_size_pixels,
122
- enc.image_size_pixels,
108
+ for source in model.nwp_encoders_dict:
109
+ if source not in batch["nwp"]:
110
+ raise ValueError(
111
+ f"Model uses NWP source '{source}' but it is missing from batch['nwp']."
112
+ )
113
+
114
+ enc = model.nwp_encoders_dict[source]
115
+ expected_channels = enc.in_channels - int(model.add_image_embedding_channel)
116
+
117
+ expected_shape = (
118
+ batch["nwp"][source]["nwp"].shape[0],
119
+ enc.sequence_length,
120
+ expected_channels,
121
+ enc.image_size_pixels,
122
+ enc.image_size_pixels,
123
+ )
124
+ actual_shape = tuple(batch["nwp"][source]["nwp"].shape)
125
+ if actual_shape != expected_shape:
126
+ raise ValueError(
127
+ f"NWP.{source} shape mismatch: expected {expected_shape}, got {actual_shape}"
123
128
  )
124
- if tuple(nwp_data["nwp"].shape) != expected:
125
- actual_shape = tuple(nwp_data["nwp"].shape)
126
- raise ValueError(
127
- f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
128
- )
129
129
 
130
130
  # Satellite validation
131
- if hasattr(model, "sat_encoder"):
131
+ if model.include_sat:
132
132
  if "satellite_actual" not in batch:
133
133
  raise ValueError(
134
- "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
134
+ "Model uses satellite data but 'satellite_actual' missing from batch."
135
135
  )
136
136
 
137
137
  enc = model.sat_encoder
138
- expected_channels = enc.in_channels
139
- if model.add_image_embedding_channel:
140
- expected_channels -= 1
138
+ expected_channels = enc.in_channels - int(model.add_image_embedding_channel)
141
139
 
142
- expected = (
140
+ expected_shape = (
143
141
  batch["satellite_actual"].shape[0],
144
142
  enc.sequence_length,
145
143
  expected_channels,
146
144
  enc.image_size_pixels,
147
145
  enc.image_size_pixels,
148
146
  )
149
- if tuple(batch["satellite_actual"].shape) != expected:
150
- actual_shape = tuple(batch["satellite_actual"].shape)
151
- raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
147
+ actual_shape = tuple(batch["satellite_actual"].shape)
148
+ if actual_shape != expected_shape:
149
+ raise ValueError(
150
+ f"Satellite shape mismatch: expected {expected_shape}, got {actual_shape}"
151
+ )
152
152
 
153
- # generation validation
154
153
  key = "generation"
155
154
  if key in batch:
156
155
  total_minutes = model.history_minutes + model.forecast_minutes
157
- interval = model.interval_minutes
158
- expected_len = total_minutes // interval + 1
159
- expected = (batch[key].shape[0], expected_len)
160
- if tuple(batch[key].shape) != expected:
161
- actual_shape = tuple(batch[key].shape)
156
+ expected_len = total_minutes // model.interval_minutes + 1
157
+ expected_shape = (batch[key].shape[0], expected_len)
158
+ actual_shape = tuple(batch[key].shape)
159
+ if actual_shape != expected_shape:
162
160
  raise ValueError(
163
- f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
161
+ f"Generation data shape mismatch: expected {expected_shape}, got {actual_shape}"
164
162
  )
165
163
 
166
164
  logger.info("Batch shape validation successful!")
@@ -1,12 +1,12 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.3.4
3
+ Version: 5.3.6
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
7
7
  Description-Content-Type: text/markdown
8
8
  License-File: LICENSE
9
- Requires-Dist: ocf-data-sampler>=0.6.0
9
+ Requires-Dist: ocf-data-sampler>=1.0.9
10
10
  Requires-Dist: numpy
11
11
  Requires-Dist: pandas
12
12
  Requires-Dist: matplotlib
@@ -2,13 +2,13 @@ pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
2
  pvnet/datamodule.py,sha256=wc1RQfFhgW9Hxyw7vrpFERhOd2FmjDsO1x49J2erOYk,5750
3
3
  pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
4
4
  pvnet/optimizers.py,sha256=DZ74KcFQV226zwu7-qAzofTMTYeIyScox4Kqbq30WWY,6440
5
- pvnet/utils.py,sha256=L3MDF5m1Ez_btAZZ8t-T5wXLzFmyj7UZtorA91DEpFw,6003
5
+ pvnet/utils.py,sha256=aVcalRAUO7TIa6AepRGmt0zPx9e1h2Xed34uCF_yg50,5887
6
6
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
7
7
  pvnet/models/base_model.py,sha256=V-vBqtzZc_c8Ho5hVo_ikq2wzZ7hsAIM7I4vhzGDfNc,16051
8
8
  pvnet/models/ensemble.py,sha256=USpNQ0O5eiffapLPE9T6gR-uK9f_3E4pX3DK7Lmkn2U,2228
9
9
  pvnet/models/late_fusion/__init__.py,sha256=Jf0B-E0_5IvSBFoj1wvnPtwYDxs4pRIFm5qHv--Bbps,26
10
10
  pvnet/models/late_fusion/basic_blocks.py,sha256=_cYGVyAIyEJS4wd-DEAXQXu0br66guZJn3ugoebWqZ0,1479
11
- pvnet/models/late_fusion/late_fusion.py,sha256=kQUnyqMykmwc0GdoFhNXYStJPrjr3hFSvUNe8FumVx4,15260
11
+ pvnet/models/late_fusion/late_fusion.py,sha256=r05RJvw2-ZQgWJobOGq1g4rlMJQjGM0UzG3syA4T0qo,15617
12
12
  pvnet/models/late_fusion/encoders/__init__.py,sha256=bLBQdnCeLYhwISW0t88ZZBz-ebS94m7ZwBcsofWMHR4,51
13
13
  pvnet/models/late_fusion/encoders/basic_blocks.py,sha256=DGkFFIZv4S4FLTaAIOrAngAFBpgZQHfkGM4dzezZLk4,3044
14
14
  pvnet/models/late_fusion/encoders/encoders3d.py,sha256=9fmqVHO73F-jN62w065cgEQI_icNFC2nQH6ZEGvTHxU,7116
@@ -22,8 +22,8 @@ pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
22
22
  pvnet/training/lightning_module.py,sha256=hmvne9DQauWpG61sRK-t8MTZRVwdywaEFCs0VFVRuMs,13522
23
23
  pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
24
24
  pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
25
- pvnet-5.3.4.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
- pvnet-5.3.4.dist-info/METADATA,sha256=9t2g_zlRqOHNgwvHm1IFoJh_Z458Waxkt2Ge1q-1IWc,16479
27
- pvnet-5.3.4.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
- pvnet-5.3.4.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
- pvnet-5.3.4.dist-info/RECORD,,
25
+ pvnet-5.3.6.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
+ pvnet-5.3.6.dist-info/METADATA,sha256=emx0MAvTIzqiocttPDRoQp-7QTScaKD5ANzC95wWkKo,16479
27
+ pvnet-5.3.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
+ pvnet-5.3.6.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
+ pvnet-5.3.6.dist-info/RECORD,,
File without changes