PVNet 5.3.6__py3-none-any.whl → 5.3.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
pvnet/utils.py CHANGED
@@ -102,53 +102,43 @@ def validate_batch_against_config(
102
102
 
103
103
  # NWP validation
104
104
  if model.include_nwp:
105
- if "nwp" not in batch:
105
+ if (nwp_dict := batch.get("nwp")) is None:
106
106
  raise ValueError("Model uses NWP data but 'nwp' missing from batch.")
107
107
 
108
- for source in model.nwp_encoders_dict:
109
- if source not in batch["nwp"]:
110
- raise ValueError(
111
- f"Model uses NWP source '{source}' but it is missing from batch['nwp']."
112
- )
108
+ for source, enc in model.nwp_encoders_dict.items():
109
+ if (src_data := nwp_dict.get(source)) is None:
110
+ raise ValueError(f"NWP source '{source}' missing from batch['nwp'].")
113
111
 
114
- enc = model.nwp_encoders_dict[source]
115
- expected_channels = enc.in_channels - int(model.add_image_embedding_channel)
116
-
117
- expected_shape = (
118
- batch["nwp"][source]["nwp"].shape[0],
119
- enc.sequence_length,
120
- expected_channels,
121
- enc.image_size_pixels,
122
- enc.image_size_pixels,
123
- )
124
- actual_shape = tuple(batch["nwp"][source]["nwp"].shape)
125
- if actual_shape != expected_shape:
126
- raise ValueError(
127
- f"NWP.{source} shape mismatch: expected {expected_shape}, got {actual_shape}"
112
+ nwp_tensor = src_data["nwp"]
113
+ exp_ch = enc.in_channels - int(model.add_image_embedding_channel)
114
+ _, actual_seq, actual_ch, h, w = nwp_tensor.shape
115
+
116
+ if (actual_seq != enc.sequence_length or actual_ch != exp_ch or
117
+ h != enc.image_size_pixels or w != enc.image_size_pixels):
118
+ msg = (
119
+ f"NWP.{source} mismatch: Exp {enc.sequence_length}seq, {exp_ch}ch. "
120
+ f"Got {actual_seq}seq, {actual_ch}ch"
128
121
  )
122
+ raise ValueError(msg)
129
123
 
130
124
  # Satellite validation
131
125
  if model.include_sat:
132
- if "satellite_actual" not in batch:
133
- raise ValueError(
134
- "Model uses satellite data but 'satellite_actual' missing from batch."
135
- )
126
+ if (sat_data := batch.get("satellite_actual")) is None:
127
+ raise ValueError("Model uses sat data but 'satellite_actual' missing from batch.")
136
128
 
137
129
  enc = model.sat_encoder
138
- expected_channels = enc.in_channels - int(model.add_image_embedding_channel)
139
-
140
- expected_shape = (
141
- batch["satellite_actual"].shape[0],
142
- enc.sequence_length,
143
- expected_channels,
144
- enc.image_size_pixels,
145
- enc.image_size_pixels,
146
- )
147
- actual_shape = tuple(batch["satellite_actual"].shape)
148
- if actual_shape != expected_shape:
149
- raise ValueError(
150
- f"Satellite shape mismatch: expected {expected_shape}, got {actual_shape}"
130
+ exp_ch = enc.in_channels - int(model.add_image_embedding_channel)
131
+ _, actual_seq, actual_ch, h, w = sat_data.shape
132
+
133
+ if actual_ch != exp_ch or h != enc.image_size_pixels or w != enc.image_size_pixels:
134
+ msg = (
135
+ f"Sat mismatch: Exp {exp_ch}ch, {enc.image_size_pixels}px. "
136
+ f"Got {actual_ch}ch, {h}x{w}px"
151
137
  )
138
+ raise ValueError(msg)
139
+
140
+ if actual_seq < enc.sequence_length:
141
+ raise ValueError(f"Sat too short: exp {enc.sequence_length}, got {actual_seq}")
152
142
 
153
143
  key = "generation"
154
144
  if key in batch:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: PVNet
3
- Version: 5.3.6
3
+ Version: 5.3.8
4
4
  Summary: PVNet
5
5
  Author-email: Peter Dudfield <info@openclimatefix.org>
6
6
  Requires-Python: <3.14,>=3.11
@@ -2,7 +2,7 @@ pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
2
2
  pvnet/datamodule.py,sha256=wc1RQfFhgW9Hxyw7vrpFERhOd2FmjDsO1x49J2erOYk,5750
3
3
  pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
4
4
  pvnet/optimizers.py,sha256=DZ74KcFQV226zwu7-qAzofTMTYeIyScox4Kqbq30WWY,6440
5
- pvnet/utils.py,sha256=aVcalRAUO7TIa6AepRGmt0zPx9e1h2Xed34uCF_yg50,5887
5
+ pvnet/utils.py,sha256=lalHiFzqsJyje3pJp0_rhRJ7w3e5av6W8dYn9I5O7oc,5765
6
6
  pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
7
7
  pvnet/models/base_model.py,sha256=V-vBqtzZc_c8Ho5hVo_ikq2wzZ7hsAIM7I4vhzGDfNc,16051
8
8
  pvnet/models/ensemble.py,sha256=USpNQ0O5eiffapLPE9T6gR-uK9f_3E4pX3DK7Lmkn2U,2228
@@ -22,8 +22,8 @@ pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
22
22
  pvnet/training/lightning_module.py,sha256=hmvne9DQauWpG61sRK-t8MTZRVwdywaEFCs0VFVRuMs,13522
23
23
  pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
24
24
  pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
25
- pvnet-5.3.6.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
- pvnet-5.3.6.dist-info/METADATA,sha256=emx0MAvTIzqiocttPDRoQp-7QTScaKD5ANzC95wWkKo,16479
27
- pvnet-5.3.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
28
- pvnet-5.3.6.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
- pvnet-5.3.6.dist-info/RECORD,,
25
+ pvnet-5.3.8.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
26
+ pvnet-5.3.8.dist-info/METADATA,sha256=dbj7rAWRN7ZAdNvPr2prEHGwGpA5wNA66UvFrAAJKI4,16479
27
+ pvnet-5.3.8.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
28
+ pvnet-5.3.8.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
29
+ pvnet-5.3.8.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (80.9.0)
2
+ Generator: setuptools (80.10.2)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5