PVNet 5.3.5__py3-none-any.whl → 5.3.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pvnet/utils.py +42 -48
- {pvnet-5.3.5.dist-info → pvnet-5.3.7.dist-info}/METADATA +1 -1
- {pvnet-5.3.5.dist-info → pvnet-5.3.7.dist-info}/RECORD +6 -6
- {pvnet-5.3.5.dist-info → pvnet-5.3.7.dist-info}/WHEEL +1 -1
- {pvnet-5.3.5.dist-info → pvnet-5.3.7.dist-info}/licenses/LICENSE +0 -0
- {pvnet-5.3.5.dist-info → pvnet-5.3.7.dist-info}/top_level.txt +0 -0
pvnet/utils.py
CHANGED
|
@@ -101,66 +101,60 @@ def validate_batch_against_config(
|
|
|
101
101
|
logger.info("Performing batch shape validation against model config.")
|
|
102
102
|
|
|
103
103
|
# NWP validation
|
|
104
|
-
if
|
|
104
|
+
if model.include_nwp:
|
|
105
105
|
if "nwp" not in batch:
|
|
106
|
-
raise ValueError(
|
|
107
|
-
"Model configured with 'nwp_encoders_dict' but 'nwp' data missing from batch."
|
|
108
|
-
)
|
|
106
|
+
raise ValueError("Model uses NWP data but 'nwp' missing from batch.")
|
|
109
107
|
|
|
110
|
-
for source
|
|
111
|
-
if source in
|
|
112
|
-
|
|
113
|
-
|
|
114
|
-
if model.add_image_embedding_channel:
|
|
115
|
-
expected_channels -= 1
|
|
116
|
-
|
|
117
|
-
expected = (
|
|
118
|
-
nwp_data["nwp"].shape[0],
|
|
119
|
-
enc.sequence_length,
|
|
120
|
-
expected_channels,
|
|
121
|
-
enc.image_size_pixels,
|
|
122
|
-
enc.image_size_pixels,
|
|
108
|
+
for source in model.nwp_encoders_dict:
|
|
109
|
+
if source not in batch["nwp"]:
|
|
110
|
+
raise ValueError(
|
|
111
|
+
f"Model uses NWP source '{source}' but it is missing from batch['nwp']."
|
|
123
112
|
)
|
|
124
|
-
if tuple(nwp_data["nwp"].shape) != expected:
|
|
125
|
-
actual_shape = tuple(nwp_data["nwp"].shape)
|
|
126
|
-
raise ValueError(
|
|
127
|
-
f"NWP.{source} shape mismatch: expected {expected}, got {actual_shape}"
|
|
128
|
-
)
|
|
129
113
|
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
"
|
|
114
|
+
enc = model.nwp_encoders_dict[source]
|
|
115
|
+
expected_channels = enc.in_channels - int(model.add_image_embedding_channel)
|
|
116
|
+
|
|
117
|
+
expected_shape = (
|
|
118
|
+
batch["nwp"][source]["nwp"].shape[0],
|
|
119
|
+
enc.sequence_length,
|
|
120
|
+
expected_channels,
|
|
121
|
+
enc.image_size_pixels,
|
|
122
|
+
enc.image_size_pixels,
|
|
135
123
|
)
|
|
124
|
+
actual_shape = tuple(batch["nwp"][source]["nwp"].shape)
|
|
125
|
+
if actual_shape != expected_shape:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"NWP.{source} shape mismatch: expected {expected_shape}, got {actual_shape}"
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
# Satellite validation
|
|
131
|
+
if model.include_sat:
|
|
132
|
+
if (sat_data := batch.get("satellite_actual")) is None:
|
|
133
|
+
raise ValueError("Model uses sat data but 'satellite_actual' missing from batch.")
|
|
136
134
|
|
|
137
135
|
enc = model.sat_encoder
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
actual_shape = tuple(batch["satellite_actual"].shape)
|
|
151
|
-
raise ValueError(f"Satellite shape mismatch: expected {expected}, got {actual_shape}")
|
|
136
|
+
exp_ch = enc.in_channels - int(model.add_image_embedding_channel)
|
|
137
|
+
_, actual_seq, actual_ch, h, w = sat_data.shape
|
|
138
|
+
|
|
139
|
+
if actual_ch != exp_ch or h != enc.image_size_pixels or w != enc.image_size_pixels:
|
|
140
|
+
msg = (
|
|
141
|
+
f"Sat mismatch: Exp {exp_ch}ch, {enc.image_size_pixels}px. "
|
|
142
|
+
f"Got {actual_ch}ch, {h}x{w}px"
|
|
143
|
+
)
|
|
144
|
+
raise ValueError(msg)
|
|
145
|
+
|
|
146
|
+
if actual_seq < enc.sequence_length:
|
|
147
|
+
raise ValueError(f"Sat too short: exp {enc.sequence_length}, got {actual_seq}")
|
|
152
148
|
|
|
153
|
-
# generation validation
|
|
154
149
|
key = "generation"
|
|
155
150
|
if key in batch:
|
|
156
151
|
total_minutes = model.history_minutes + model.forecast_minutes
|
|
157
|
-
|
|
158
|
-
|
|
159
|
-
|
|
160
|
-
if
|
|
161
|
-
actual_shape = tuple(batch[key].shape)
|
|
152
|
+
expected_len = total_minutes // model.interval_minutes + 1
|
|
153
|
+
expected_shape = (batch[key].shape[0], expected_len)
|
|
154
|
+
actual_shape = tuple(batch[key].shape)
|
|
155
|
+
if actual_shape != expected_shape:
|
|
162
156
|
raise ValueError(
|
|
163
|
-
f"
|
|
157
|
+
f"Generation data shape mismatch: expected {expected_shape}, got {actual_shape}"
|
|
164
158
|
)
|
|
165
159
|
|
|
166
160
|
logger.info("Batch shape validation successful!")
|
|
@@ -2,7 +2,7 @@ pvnet/__init__.py,sha256=TAZm88TJ5ieL1XjEyRg1LciIGuSScEucdAruQLfM92I,25
|
|
|
2
2
|
pvnet/datamodule.py,sha256=wc1RQfFhgW9Hxyw7vrpFERhOd2FmjDsO1x49J2erOYk,5750
|
|
3
3
|
pvnet/load_model.py,sha256=P1QODX_mJRnKZ_kIll9BlOjK_A1W4YM3QG-mZd-2Mcc,3852
|
|
4
4
|
pvnet/optimizers.py,sha256=DZ74KcFQV226zwu7-qAzofTMTYeIyScox4Kqbq30WWY,6440
|
|
5
|
-
pvnet/utils.py,sha256=
|
|
5
|
+
pvnet/utils.py,sha256=Nnc-DaqeJJwv8WtIXc_mpJXkCxXSHN8RScQT1mZ28GA,5880
|
|
6
6
|
pvnet/models/__init__.py,sha256=owzZ9xkD0DRTT51mT2Dx_p96oJjwDz57xo_MaMIEosk,145
|
|
7
7
|
pvnet/models/base_model.py,sha256=V-vBqtzZc_c8Ho5hVo_ikq2wzZ7hsAIM7I4vhzGDfNc,16051
|
|
8
8
|
pvnet/models/ensemble.py,sha256=USpNQ0O5eiffapLPE9T6gR-uK9f_3E4pX3DK7Lmkn2U,2228
|
|
@@ -22,8 +22,8 @@ pvnet/training/__init__.py,sha256=FKxmPZ59Vuj5_mXomN4saJ3En5M-aDMxSs6OttTQOcg,49
|
|
|
22
22
|
pvnet/training/lightning_module.py,sha256=hmvne9DQauWpG61sRK-t8MTZRVwdywaEFCs0VFVRuMs,13522
|
|
23
23
|
pvnet/training/plots.py,sha256=7JtjA9zIotuoKZ2l0fbS-FZDB48TcIk_-XLA2EWVMv4,2448
|
|
24
24
|
pvnet/training/train.py,sha256=Sry2wYgggUmtIB-k_umFts7xMr2roEL76NCu9ySbLUY,4107
|
|
25
|
-
pvnet-5.3.
|
|
26
|
-
pvnet-5.3.
|
|
27
|
-
pvnet-5.3.
|
|
28
|
-
pvnet-5.3.
|
|
29
|
-
pvnet-5.3.
|
|
25
|
+
pvnet-5.3.7.dist-info/licenses/LICENSE,sha256=tKUnlSmcLBWMJWkHx3UjZGdrjs9LidGwLo0jsBUBAwU,1077
|
|
26
|
+
pvnet-5.3.7.dist-info/METADATA,sha256=Ce405a8g81_se67GP7MMHpiO2wjT63nc2HcGz8h1kaY,16479
|
|
27
|
+
pvnet-5.3.7.dist-info/WHEEL,sha256=wUyA8OaulRlbfwMtmQsvNngGrxQHAvkKcvRmdizlJi0,92
|
|
28
|
+
pvnet-5.3.7.dist-info/top_level.txt,sha256=4mg6WjeW05SR7pg3-Q4JRE2yAoutHYpspOsiUzYVNv0,6
|
|
29
|
+
pvnet-5.3.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|