PVNet 5.3.5__py3-none-any.whl → 5.3.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pvnet/utils.py CHANGED
@@ -101,66 +101,60 @@ def validate_batch_against_config(
101
101
  logger.info("Performing batch shape validation against model config.")
102
102
 
103
103
  # NWP validation
104
- if hasattr(model, "nwp_encoders_dict"):
104
+ if model.include_nwp:
105
105
  if "nwp" not in batch:
106
- raise ValueError(
107
- "Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
108
- )
106
+ raise ValueError("Model uses NWP data but 'nwp' missing from batch.")
109
107
 
110
- for source, nwp_data in batch["nwp"].items():
111
- if source in model.nwp_encoders_dict:
112
- enc = model.nwp_encoders_dict[source]
113
- expected_channels = enc.in_channels
114
- if model.add_image_embedding_channel:
115
- expected_channels -= 1
116
-
117
- expected = (
118
- nwp_data["nwp"].shape[0],
119
- enc.sequence_length,
120
- expected_channels,
121
- enc.image_size_pixels,
122
- enc.image_size_pixels,
108
+ for source in model.nwp_encoders_dict:
109
+ if source not in batch["nwp"]:
110
+ raise ValueError(
111
+ f"Model uses NWP source '{source}' but it is missing from batch['nwp']."
123
112
  )
124
- if tuple(nwp_data["nwp"].shape) != expected:
125
- actual_shape = tuple(nwp_data["nwp"].shape)
126
- raise ValueError(
127
- f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
128
- )
129
113
 
130
- # Satellite validation
131
- if hasattr(model, "sat_encoder"):
132
- if "satellite_actual" not in batch:
133
- raise ValueError(
134
- "Model configured with 'sat_encoder' but 'satellite_actual' missing from batch."
114
+ enc = model.nwp_encoders_dict[source]
115
+ expected_channels = enc.in_channels - int(model.add_image_embedding_channel)
116
+
117
+ expected_shape = (
118
+ batch["nwp"][source]["nwp"].shape[0],
119
+ enc.sequence_length,
120
+ expected_channels,
121
+ enc.image_size_pixels,
122
+ enc.image_size_pixels,
135
123
  )
124
+ actual_shape = tuple(batch["nwp"][source]["nwp"].shape)
125
+ if actual_shape != expected_shape:
126
+ raise ValueError(
127
+ f"NWP.{source} shape mismatch: expected {expected_shape}, got {actual_shape}"
128
+ )
129
+
130
+ # Satellite validation
131
+ if model.include_sat:
132
+ if (sat_data := batch.get("satellite_actual")) is None:
133
+ raise ValueError("Model uses sat data but 'satellite_actual' missing from batch.")
136
134
 
137
135
  enc = model.sat_encoder
138
- expected_channels = enc.in_channels
139
- if model.add_image_embedding_channel:
140
- expected_channels -= 1
141
-
142
- expected = (
143
- batch["satellite_actual"].shape[0],
144
- enc.sequence_length,
145
- expected_channels,
146
- enc.image_size_pixels,
147
- enc.image_size_pixels,
148
- )
149
- if tuple(batch["satellite_actual"].shape) != expected:
150
- actual_shape = tuple(batch["satellite_actual"].shape)
151
- raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
136
+ exp_ch = enc.in_channels - int(model.add_image_embedding_channel)
137
+ _, actual_seq, actual_ch, h, w = sat_data.shape
138
+
139
+ if actual_ch != exp_ch or h != enc.image_size_pixels or w != enc.image_size_pixels:
140
+ msg = (
141
+ f"Sat mismatch: Exp {exp_ch}ch, {enc.image_size_pixels}px. "
142
+ f"Got {actual_ch}ch, {h}x{w}px"
143
+ )
144
+ raise ValueError(msg)
145
+
146
+ if actual_seq < enc.sequence_length:
147
+ raise ValueError(f"Sat too short: exp {enc.sequence_length}, got {actual_seq}")
152
148
 
153
- # generation validation
154
149
  key = "generation"
155
150
  if key in batch:
156
151
  total_minutes = model.history_minutes + model.forecast_minutes
157
- interval = model.interval_minutes
158
- expected_len = total_minutes // interval + 1
159
- expected = (batch[key].shape[0], expected_len)
160
- if tuple(batch[key].shape) != expected:
161
- actual_shape = tuple(batch[key].shape)
152
+ expected_len = total_minutes // model.interval_minutes + 1
153
+ expected_shape = (batch[key].shape[0], expected_len)
154
+ actual_shape = tuple(batch[key].shape)
155
+ if actual_shape != expected_shape:
162
156
  raise ValueError(
163
- f"{key.upper()} shape mismatch: expected {expected}, got {actual_shape}"
157
+ f"Generation data shape mismatch: expected {expected_shape}, got {actual_shape}"
164
158
  )
165
159
 
166
160
  logger.info("Batch shape validation successful!")
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.3.5
3
+ Version: 5.3.7
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
@@ -2,7 +2,7 @@ pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
2
  pvnet/datamodule.py,sha256=wc1RQfFhgW9Hxyw7vrpFERhOd2FmjDsO1x49J2erOYk,5750
3
3
  pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
4
4
  pvnet/optimizers.py,sha256=DZ74KcFQV226zwu7-qAzofTMTYeIyScox4Kqbq30WWY,6440
5
- pvnet/utils.py,sha256=L3MDF5m1Ez_btAZZ8t-T5wXLzFmyj7UZtorA91DEpFw,6003
5
+ pvnet/utils.py,sha256=Nnc-DaqeJJwv8WtIXc_mpJXkCxXSHN8RScQT1mZ28GA,5880
6
6
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
7
7
  pvnet/models/base_model.py,sha256=V-vBqtzZc_c8Ho5hVo_ikq2wzZ7hsAIM7I4vhzGDfNc,16051
8
8
  pvnet/models/ensemble.py,sha256=USpNQ0O5eiffapLPE9T6gR-uK9f_3E4pX3DK7Lmkn2U,2228
@@ -22,8 +22,8 @@ pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
22
22
  pvnet/training/lightning_module.py,sha256=hmvne9DQauWpG61sRK-t8MTZRVwdywaEFCs0VFVRuMs,13522
23
23
  pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
24
24
  pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
25
- pvnet-5.3.5.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
- pvnet-5.3.5.dist-info/METADATA,sha256=rIlZGmFiIzkMpG_5U-6SrsdDW6fIke667JAG79g3KN4,16479
27
- pvnet-5.3.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
- pvnet-5.3.5.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
- pvnet-5.3.5.dist-info/RECORD,,
25
+ pvnet-5.3.7.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
+ pvnet-5.3.7.dist-info/METADATA,sha256=Ce405a8g81_se67GP7MMHpiO2wjT63nc2HcGz8h1kaY,16479
27
+ pvnet-5.3.7.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
28
+ pvnet-5.3.7.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
+ pvnet-5.3.7.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5