da-stdk 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- da_stdk/__init__.py +5 -0
- da_stdk/dataio/__init__.py +25 -0
- da_stdk/dataio/kaust_loader.py +660 -0
- da_stdk/dataio/obs_sampling.py +84 -0
- da_stdk/dataio/splits.py +61 -0
- da_stdk/losses/__init__.py +34 -0
- da_stdk/losses/crps.py +180 -0
- da_stdk/losses/non_crossing.py +133 -0
- da_stdk/models/__init__.py +6 -0
- da_stdk/models/st_interp.py +928 -0
- da_stdk/training/__init__.py +7 -0
- da_stdk/training/config.py +160 -0
- da_stdk/training/trainer.py +480 -0
- da_stdk/utils/__init__.py +14 -0
- da_stdk/utils/conformal.py +171 -0
- da_stdk/utils/ema.py +105 -0
- da_stdk/utils/metrics.py +163 -0
- da_stdk/utils/seed.py +27 -0
- da_stdk/viz/__init__.py +52 -0
- da_stdk/viz/basis.py +242 -0
- da_stdk/viz/kaust_analysis.py +1011 -0
- da_stdk/viz/obs.py +120 -0
- da_stdk/viz/obs_density.py +133 -0
- da_stdk/viz/predictions.py +153 -0
- da_stdk/viz/spatial.py +324 -0
- da_stdk/viz/temporal.py +363 -0
- da_stdk/viz/training.py +63 -0
- da_stdk-0.1.0.dist-info/METADATA +145 -0
- da_stdk-0.1.0.dist-info/RECORD +30 -0
- da_stdk-0.1.0.dist-info/WHEEL +4 -0
da_stdk/__init__.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Data I/O for STNF-XAttn
|
|
3
|
+
"""
|
|
4
|
+
from .kaust_loader import (
|
|
5
|
+
load_kaust_csv,
|
|
6
|
+
sample_observed_sites,
|
|
7
|
+
KAUSTWindowDataset,
|
|
8
|
+
create_dataloaders,
|
|
9
|
+
prepare_test_context,
|
|
10
|
+
predictions_to_csv
|
|
11
|
+
)
|
|
12
|
+
from .obs_sampling import create_spatial_obs_prob_fn, sample_observations
|
|
13
|
+
from .splits import split_train_valid
|
|
14
|
+
|
|
15
|
+
__all__ = [
|
|
16
|
+
'load_kaust_csv',
|
|
17
|
+
'sample_observed_sites',
|
|
18
|
+
'KAUSTWindowDataset',
|
|
19
|
+
'create_dataloaders',
|
|
20
|
+
'prepare_test_context',
|
|
21
|
+
'predictions_to_csv',
|
|
22
|
+
'create_spatial_obs_prob_fn',
|
|
23
|
+
'sample_observations',
|
|
24
|
+
'split_train_valid',
|
|
25
|
+
]
|
|
@@ -0,0 +1,660 @@
|
|
|
1
|
+
"""
|
|
2
|
+
KAUST CSV data loader
|
|
3
|
+
|
|
4
|
+
Features:
|
|
5
|
+
1. Load train.csv, test.csv (x, y, t, z format)
|
|
6
|
+
2. Create site indices from (x, y) coordinates (train+test combined)
|
|
7
|
+
3. Reconstruct time series matrix (T, S)
|
|
8
|
+
4. Sample observed sites (Uniform/Biased)
|
|
9
|
+
5. Sliding window Dataset (L-step context, H-step forecast)
|
|
10
|
+
"""
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
import torch
|
|
14
|
+
from torch.utils.data import Dataset, DataLoader
|
|
15
|
+
from typing import Tuple, Dict, Optional, List
|
|
16
|
+
from pathlib import Path
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def load_kaust_csv_single(
|
|
20
|
+
data_path: str,
|
|
21
|
+
normalize: bool = True
|
|
22
|
+
) -> Tuple[np.ndarray, np.ndarray, Dict]:
|
|
23
|
+
"""
|
|
24
|
+
Load KAUST CSV file (single file)
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
data_path: CSV file path
|
|
28
|
+
normalize: Whether to normalize z values
|
|
29
|
+
|
|
30
|
+
Returns:
|
|
31
|
+
z_data: (T, S) - Complete time series
|
|
32
|
+
coords: (S, 2) - Site coordinates [x, y], already in [0,1]
|
|
33
|
+
metadata: dict - Normalization statistics, etc.
|
|
34
|
+
"""
|
|
35
|
+
# Load CSV
|
|
36
|
+
df = pd.read_csv(data_path)
|
|
37
|
+
print(f"[INFO] Loaded data: {len(df)} rows")
|
|
38
|
+
|
|
39
|
+
# 1. Create site indices
|
|
40
|
+
all_coords = df[['x', 'y']].drop_duplicates().reset_index(drop=True)
|
|
41
|
+
S = len(all_coords)
|
|
42
|
+
print(f"[INFO] Total sites: {S}")
|
|
43
|
+
|
|
44
|
+
# Site mapping: (x, y) → index
|
|
45
|
+
site_to_idx = {
|
|
46
|
+
(row['x'], row['y']): idx
|
|
47
|
+
for idx, row in all_coords.iterrows()
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
# Coordinate array: (S, 2), already in [0,1]^2
|
|
51
|
+
coords = all_coords[['x', 'y']].values.astype(np.float32)
|
|
52
|
+
|
|
53
|
+
# 2. Time indices
|
|
54
|
+
t_vals = df['t'].values
|
|
55
|
+
T = int(t_vals.max())
|
|
56
|
+
print(f"[INFO] Time range: 1 ~ {T}")
|
|
57
|
+
|
|
58
|
+
# 3. Reconstruct time series matrix: (T, S)
|
|
59
|
+
z_data = np.full((T, S), np.nan, dtype=np.float32)
|
|
60
|
+
for _, row in df.iterrows():
|
|
61
|
+
t_idx = int(row['t']) - 1 # 0-based indexing
|
|
62
|
+
site_idx = site_to_idx[(row['x'], row['y'])]
|
|
63
|
+
z_data[t_idx, site_idx] = row['z']
|
|
64
|
+
|
|
65
|
+
# 4. Normalize (z values only)
|
|
66
|
+
metadata = {}
|
|
67
|
+
if normalize:
|
|
68
|
+
z_flat = z_data[~np.isnan(z_data)]
|
|
69
|
+
z_mean = z_flat.mean()
|
|
70
|
+
z_std = z_flat.std()
|
|
71
|
+
z_data = (z_data - z_mean) / z_std
|
|
72
|
+
metadata['z_mean'] = z_mean
|
|
73
|
+
metadata['z_std'] = z_std
|
|
74
|
+
print(f"[INFO] Normalized z: mean={z_mean:.4f}, std={z_std:.4f}")
|
|
75
|
+
|
|
76
|
+
return z_data, coords, metadata
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def load_kaust_csv(
|
|
80
|
+
train_path: str,
|
|
81
|
+
test_path: str,
|
|
82
|
+
normalize: bool = True
|
|
83
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict]:
|
|
84
|
+
"""
|
|
85
|
+
Load and preprocess KAUST CSV files
|
|
86
|
+
|
|
87
|
+
Args:
|
|
88
|
+
train_path: train.csv file path
|
|
89
|
+
test_path: test.csv file path
|
|
90
|
+
normalize: Whether to normalize z values
|
|
91
|
+
|
|
92
|
+
Returns:
|
|
93
|
+
z_train: (T_tr, S) - Training time series
|
|
94
|
+
z_test: (T_te, S) - Test time series (initialized with NaN)
|
|
95
|
+
coords: (S, 2) - Site coordinates [x, y]
|
|
96
|
+
site_to_idx: dict - (x, y) → site index mapping
|
|
97
|
+
metadata: dict - Normalization statistics, etc.
|
|
98
|
+
"""
|
|
99
|
+
# Load CSV
|
|
100
|
+
df_train = pd.read_csv(train_path)
|
|
101
|
+
df_test = pd.read_csv(test_path)
|
|
102
|
+
|
|
103
|
+
print(f"[INFO] Loaded train: {len(df_train)} rows")
|
|
104
|
+
print(f"[INFO] Loaded test: {len(df_test)} rows")
|
|
105
|
+
|
|
106
|
+
# 1. Create site indices (train + test combined)
|
|
107
|
+
# Define sites by unique combinations of (x, y) coordinates
|
|
108
|
+
all_coords = pd.concat([
|
|
109
|
+
df_train[['x', 'y']],
|
|
110
|
+
df_test[['x', 'y']]
|
|
111
|
+
]).drop_duplicates().reset_index(drop=True)
|
|
112
|
+
|
|
113
|
+
S = len(all_coords)
|
|
114
|
+
print(f"[INFO] Total sites: {S}")
|
|
115
|
+
|
|
116
|
+
# Site mapping: (x, y) → index
|
|
117
|
+
site_to_idx = {
|
|
118
|
+
(row['x'], row['y']): idx
|
|
119
|
+
for idx, row in all_coords.iterrows()
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
# Coordinate array: (S, 2)
|
|
123
|
+
coords = all_coords[['x', 'y']].values.astype(np.float32)
|
|
124
|
+
|
|
125
|
+
# 2. Time indices (assuming t starts from 1)
|
|
126
|
+
t_train = df_train['t'].values
|
|
127
|
+
t_test = df_test['t'].values
|
|
128
|
+
|
|
129
|
+
T_tr = t_train.max()
|
|
130
|
+
T_te_end = t_test.max()
|
|
131
|
+
T_te_start = t_test.min()
|
|
132
|
+
|
|
133
|
+
print(f"[INFO] Train time range: 1 ~ {T_tr}")
|
|
134
|
+
print(f"[INFO] Test time range: {T_te_start} ~ {T_te_end}")
|
|
135
|
+
|
|
136
|
+
# 3. Reconstruct time series matrix
|
|
137
|
+
# z_train: (T_tr, S)
|
|
138
|
+
z_train = np.full((T_tr, S), np.nan, dtype=np.float32)
|
|
139
|
+
for _, row in df_train.iterrows():
|
|
140
|
+
t_idx = int(row['t']) - 1 # 0-based indexing
|
|
141
|
+
site_idx = site_to_idx[(row['x'], row['y'])]
|
|
142
|
+
z_train[t_idx, site_idx] = row['z']
|
|
143
|
+
|
|
144
|
+
# z_test: (T_te, S) - Initialized with NaN (prediction target)
|
|
145
|
+
T_te = T_te_end - T_te_start + 1
|
|
146
|
+
z_test = np.full((T_te, S), np.nan, dtype=np.float32)
|
|
147
|
+
# test.csv doesn't have z values, so keep NaN
|
|
148
|
+
|
|
149
|
+
# 4. Normalize (based on train data)
|
|
150
|
+
metadata = {}
|
|
151
|
+
if normalize:
|
|
152
|
+
z_train_valid = z_train[~np.isnan(z_train)]
|
|
153
|
+
z_mean = z_train_valid.mean()
|
|
154
|
+
z_std = z_train_valid.std() + 1e-8
|
|
155
|
+
|
|
156
|
+
z_train = (z_train - z_mean) / z_std
|
|
157
|
+
|
|
158
|
+
metadata['z_mean'] = float(z_mean)
|
|
159
|
+
metadata['z_std'] = float(z_std)
|
|
160
|
+
print(f"[INFO] Normalized: mean={z_mean:.4f}, std={z_std:.4f}")
|
|
161
|
+
else:
|
|
162
|
+
metadata['z_mean'] = 0.0
|
|
163
|
+
metadata['z_std'] = 1.0
|
|
164
|
+
|
|
165
|
+
# 5. Metadata
|
|
166
|
+
metadata.update({
|
|
167
|
+
'S': S,
|
|
168
|
+
'T_tr': T_tr,
|
|
169
|
+
'T_te': T_te,
|
|
170
|
+
'T_te_start': T_te_start,
|
|
171
|
+
'coords': coords,
|
|
172
|
+
'site_to_idx': site_to_idx
|
|
173
|
+
})
|
|
174
|
+
|
|
175
|
+
return z_train, z_test, coords, site_to_idx, metadata
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def load_test_ground_truth_from_full(
|
|
179
|
+
full_csv_path: str,
|
|
180
|
+
site_to_idx: dict,
|
|
181
|
+
T_te_start: int,
|
|
182
|
+
T_te: int,
|
|
183
|
+
) -> np.ndarray:
|
|
184
|
+
"""
|
|
185
|
+
Load test-period z values from the full CSV (e.g. 2b_8.csv) so we can
|
|
186
|
+
evaluate on the provider's test set. Uses site_to_idx from load_kaust_csv
|
|
187
|
+
so site order matches.
|
|
188
|
+
|
|
189
|
+
Args:
|
|
190
|
+
full_csv_path: path to full CSV (x, y, t, z) with all time steps
|
|
191
|
+
site_to_idx: (x, y) -> index from load_kaust_csv
|
|
192
|
+
T_te_start: first test time step (1-based, e.g. 91)
|
|
193
|
+
T_te: number of test time steps
|
|
194
|
+
|
|
195
|
+
Returns:
|
|
196
|
+
z_test_gt: (T_te, S) float32 array
|
|
197
|
+
"""
|
|
198
|
+
df = pd.read_csv(full_csv_path)
|
|
199
|
+
S = len(site_to_idx)
|
|
200
|
+
z_test_gt = np.full((T_te, S), np.nan, dtype=np.float32)
|
|
201
|
+
for _, row in df.iterrows():
|
|
202
|
+
t_val = int(row['t'])
|
|
203
|
+
if t_val < T_te_start or t_val > T_te_start + T_te - 1:
|
|
204
|
+
continue
|
|
205
|
+
t_idx = t_val - T_te_start # 0-based
|
|
206
|
+
key = (float(row['x']), float(row['y']))
|
|
207
|
+
if key not in site_to_idx:
|
|
208
|
+
continue
|
|
209
|
+
site_idx = site_to_idx[key]
|
|
210
|
+
z_test_gt[t_idx, site_idx] = row['z']
|
|
211
|
+
return z_test_gt
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
def load_kaust_csv_with_test_gt(
|
|
215
|
+
train_path: str,
|
|
216
|
+
test_path: str,
|
|
217
|
+
full_csv_path: str,
|
|
218
|
+
normalize: bool = False,
|
|
219
|
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict]:
|
|
220
|
+
"""
|
|
221
|
+
Load provider train/test split and fill test z from full CSV for evaluation.
|
|
222
|
+
|
|
223
|
+
Returns:
|
|
224
|
+
z_full: (T_tr + T_te, S) - train then test time steps
|
|
225
|
+
coords: (S, 2)
|
|
226
|
+
metadata: includes z_mean, z_std (from train only), T_tr, T_te, T_te_start
|
|
227
|
+
"""
|
|
228
|
+
z_train, z_test_empty, coords, site_to_idx, meta = load_kaust_csv(
|
|
229
|
+
train_path, test_path, normalize=False
|
|
230
|
+
)
|
|
231
|
+
T_tr = meta['T_tr']
|
|
232
|
+
T_te = meta['T_te']
|
|
233
|
+
T_te_start = meta['T_te_start']
|
|
234
|
+
|
|
235
|
+
z_test_gt = load_test_ground_truth_from_full(
|
|
236
|
+
full_csv_path, site_to_idx, T_te_start, T_te
|
|
237
|
+
)
|
|
238
|
+
z_full = np.concatenate([z_train, z_test_gt], axis=0).astype(np.float32)
|
|
239
|
+
|
|
240
|
+
if normalize:
|
|
241
|
+
z_train_valid = z_train[~np.isnan(z_train)]
|
|
242
|
+
z_mean = float(z_train_valid.mean())
|
|
243
|
+
z_std = float(z_train_valid.std() + 1e-8)
|
|
244
|
+
z_full = (z_full - z_mean) / z_std
|
|
245
|
+
meta['z_mean'] = z_mean
|
|
246
|
+
meta['z_std'] = z_std
|
|
247
|
+
else:
|
|
248
|
+
meta['z_mean'] = 0.0
|
|
249
|
+
meta['z_std'] = 1.0
|
|
250
|
+
|
|
251
|
+
return z_full, coords, meta
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def sample_observed_sites(
|
|
255
|
+
coords: np.ndarray,
|
|
256
|
+
obs_fraction: float,
|
|
257
|
+
sampling_method: str = 'uniform',
|
|
258
|
+
bias_sigma: float = 0.15,
|
|
259
|
+
bias_temp: float = 1.0,
|
|
260
|
+
seed: Optional[int] = None
|
|
261
|
+
) -> np.ndarray:
|
|
262
|
+
"""
|
|
263
|
+
Sample observed sites
|
|
264
|
+
|
|
265
|
+
Args:
|
|
266
|
+
coords: (S, 2) - Site coordinates
|
|
267
|
+
obs_fraction: Observation ratio (0~1)
|
|
268
|
+
sampling_method: 'uniform' or 'biased'
|
|
269
|
+
bias_sigma: Biased sampling distance scale
|
|
270
|
+
bias_temp: Biased sampling temperature
|
|
271
|
+
seed: Random seed
|
|
272
|
+
|
|
273
|
+
Returns:
|
|
274
|
+
obs_indices: (n_obs,) - Observed site index array
|
|
275
|
+
"""
|
|
276
|
+
if seed is not None:
|
|
277
|
+
np.random.seed(seed)
|
|
278
|
+
|
|
279
|
+
S = len(coords)
|
|
280
|
+
n_obs = max(1, int(S * obs_fraction))
|
|
281
|
+
|
|
282
|
+
if sampling_method == 'uniform':
|
|
283
|
+
# Uniform sampling
|
|
284
|
+
obs_indices = np.random.choice(S, size=n_obs, replace=False)
|
|
285
|
+
print(f"[INFO] Sampled {n_obs}/{S} sites (uniform)")
|
|
286
|
+
|
|
287
|
+
elif sampling_method == 'biased':
|
|
288
|
+
# Biased sampling (weighted near origin). Uses Gaussian weights;
|
|
289
|
+
# this is NOT the same as KAUST experiment / paper clustered formula
|
|
290
|
+
# p(s) ∝ (1+10||s||)^{-2} used in train_st_interp and visualize_obs_density.
|
|
291
|
+
distances = np.sqrt(coords[:, 0]**2 + coords[:, 1]**2)
|
|
292
|
+
weights = np.exp(- (distances**2) / (2 * bias_sigma**2))
|
|
293
|
+
|
|
294
|
+
# Temperature scaling
|
|
295
|
+
weights = weights ** (1.0 / bias_temp)
|
|
296
|
+
|
|
297
|
+
# Normalize
|
|
298
|
+
probs = weights / weights.sum()
|
|
299
|
+
|
|
300
|
+
# Sampling
|
|
301
|
+
obs_indices = np.random.choice(S, size=n_obs, replace=False, p=probs)
|
|
302
|
+
|
|
303
|
+
avg_dist = distances[obs_indices].mean()
|
|
304
|
+
print(f"[INFO] Sampled {n_obs}/{S} sites (biased, avg_dist={avg_dist:.4f})")
|
|
305
|
+
|
|
306
|
+
else:
|
|
307
|
+
raise ValueError(f"Unknown sampling method: {sampling_method}")
|
|
308
|
+
|
|
309
|
+
return np.sort(obs_indices)
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
class KAUSTWindowDataset(Dataset):
|
|
313
|
+
"""
|
|
314
|
+
Sliding window Dataset
|
|
315
|
+
|
|
316
|
+
During training:
|
|
317
|
+
- Input: Observed site data from [t0-L, t0) interval
|
|
318
|
+
- Target: All site data from [t0, t0+H) interval
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
z_full: (T, S) - Complete time series (train only)
|
|
322
|
+
coords: (S, 2) - Site coordinates
|
|
323
|
+
obs_indices: (n_obs,) - Observed site indices
|
|
324
|
+
L: context length
|
|
325
|
+
H: forecast horizon
|
|
326
|
+
stride: Sliding window stride (default 1)
|
|
327
|
+
t0_min: Minimum t0 (use L if None)
|
|
328
|
+
t0_max: Maximum t0 (use T-H+1 if None)
|
|
329
|
+
use_coords_cov: Use (x, y) as covariates
|
|
330
|
+
use_time_cov: Use t as covariates
|
|
331
|
+
time_encoding: Time encoding method {linear, sinusoidal}
|
|
332
|
+
"""
|
|
333
|
+
def __init__(
|
|
334
|
+
self,
|
|
335
|
+
z_full: np.ndarray,
|
|
336
|
+
coords: np.ndarray,
|
|
337
|
+
obs_indices: np.ndarray,
|
|
338
|
+
L: int,
|
|
339
|
+
H: int,
|
|
340
|
+
stride: int = 1,
|
|
341
|
+
t0_min: int = None,
|
|
342
|
+
t0_max: int = None,
|
|
343
|
+
use_coords_cov: bool = False,
|
|
344
|
+
use_time_cov: bool = False,
|
|
345
|
+
time_encoding: str = 'linear'
|
|
346
|
+
):
|
|
347
|
+
self.z_full = z_full # (T, S)
|
|
348
|
+
self.coords = coords # (S, 2)
|
|
349
|
+
self.obs_indices = obs_indices # (n_obs,)
|
|
350
|
+
self.L = L
|
|
351
|
+
self.H = H
|
|
352
|
+
self.stride = stride
|
|
353
|
+
self.use_coords_cov = use_coords_cov
|
|
354
|
+
self.use_time_cov = use_time_cov
|
|
355
|
+
self.time_encoding = time_encoding
|
|
356
|
+
|
|
357
|
+
self.T, self.S = z_full.shape
|
|
358
|
+
self.n_obs = len(obs_indices)
|
|
359
|
+
|
|
360
|
+
# Calculate covariates dimension
|
|
361
|
+
self.p_covariates = 0
|
|
362
|
+
if use_coords_cov:
|
|
363
|
+
self.p_covariates += 2 # (x, y)
|
|
364
|
+
if use_time_cov:
|
|
365
|
+
if time_encoding == 'sinusoidal':
|
|
366
|
+
self.p_covariates += 2 # (sin(t), cos(t))
|
|
367
|
+
else: # linear
|
|
368
|
+
self.p_covariates += 1 # t
|
|
369
|
+
|
|
370
|
+
# Valid window start points
|
|
371
|
+
# t0-L >= 0 and t0+H <= T
|
|
372
|
+
if t0_min is None:
|
|
373
|
+
t0_min = L
|
|
374
|
+
if t0_max is None:
|
|
375
|
+
t0_max = self.T - H + 1
|
|
376
|
+
|
|
377
|
+
self.valid_t0 = list(range(t0_min, t0_max, stride))
|
|
378
|
+
|
|
379
|
+
cov_info = f", p_cov={self.p_covariates}" if self.p_covariates > 0 else ""
|
|
380
|
+
print(f"[INFO] Dataset: {len(self.valid_t0)} windows (L={L}, H={H}, stride={stride}, t0=[{t0_min}, {t0_max}){cov_info})")
|
|
381
|
+
|
|
382
|
+
def __len__(self):
|
|
383
|
+
return len(self.valid_t0)
|
|
384
|
+
|
|
385
|
+
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
|
|
386
|
+
t0 = self.valid_t0[idx]
|
|
387
|
+
|
|
388
|
+
# 1. Context: Observed sites from [t0-L, t0)
|
|
389
|
+
y_hist_obs = self.z_full[t0-self.L:t0, self.obs_indices] # (L, n_obs)
|
|
390
|
+
|
|
391
|
+
# 2. Target: Only observed sites from [t0, t0+H)
|
|
392
|
+
y_fut = self.z_full[t0:t0+self.H, self.obs_indices] # (H, n_obs)
|
|
393
|
+
|
|
394
|
+
# 3. Coordinates
|
|
395
|
+
obs_coords = self.coords[self.obs_indices] # (n_obs, 2)
|
|
396
|
+
target_coords = self.coords[self.obs_indices] # (n_obs, 2) - Same!
|
|
397
|
+
|
|
398
|
+
# 4. Create covariates
|
|
399
|
+
result = {
|
|
400
|
+
'obs_coords': torch.from_numpy(obs_coords).float(), # (n_obs, 2)
|
|
401
|
+
'target_coords': torch.from_numpy(target_coords).float(), # (n_obs, 2)
|
|
402
|
+
'y_hist_obs': torch.from_numpy(y_hist_obs).float().unsqueeze(-1), # (L, n_obs, 1)
|
|
403
|
+
'y_fut': torch.from_numpy(y_fut).float().unsqueeze(-1), # (H, n_obs, 1)
|
|
404
|
+
't0': t0
|
|
405
|
+
}
|
|
406
|
+
|
|
407
|
+
# Covariates for history (observed sites)
|
|
408
|
+
if self.p_covariates > 0:
|
|
409
|
+
X_hist_list = []
|
|
410
|
+
|
|
411
|
+
# Coordinate covariates: (n_obs, 2)
|
|
412
|
+
if self.use_coords_cov:
|
|
413
|
+
# Expand (n_obs, 2) to (L, n_obs, 2)
|
|
414
|
+
coords_cov = np.tile(obs_coords[np.newaxis, :, :], (self.L, 1, 1))
|
|
415
|
+
X_hist_list.append(coords_cov)
|
|
416
|
+
|
|
417
|
+
# Time covariates
|
|
418
|
+
if self.use_time_cov:
|
|
419
|
+
# Time indices: [t0-L, t0) → normalized time
|
|
420
|
+
t_indices = np.arange(t0 - self.L, t0).astype(np.float32)
|
|
421
|
+
t_normalized = t_indices / self.T # Normalize to [0, 1] range
|
|
422
|
+
|
|
423
|
+
if self.time_encoding == 'sinusoidal':
|
|
424
|
+
# sin/cos encoding
|
|
425
|
+
t_sin = np.sin(2 * np.pi * t_normalized) # (L,)
|
|
426
|
+
t_cos = np.cos(2 * np.pi * t_normalized) # (L,)
|
|
427
|
+
# (L, n_obs, 2)
|
|
428
|
+
t_cov = np.stack([
|
|
429
|
+
np.tile(t_sin[:, np.newaxis], (1, self.n_obs)),
|
|
430
|
+
np.tile(t_cos[:, np.newaxis], (1, self.n_obs))
|
|
431
|
+
], axis=-1)
|
|
432
|
+
else: # linear
|
|
433
|
+
# (L, n_obs, 1)
|
|
434
|
+
t_cov = np.tile(t_normalized[:, np.newaxis, np.newaxis], (1, self.n_obs, 1))
|
|
435
|
+
|
|
436
|
+
X_hist_list.append(t_cov)
|
|
437
|
+
|
|
438
|
+
# Concatenate: (L, n_obs, p)
|
|
439
|
+
X_hist_obs = np.concatenate(X_hist_list, axis=-1)
|
|
440
|
+
result['X_hist_obs'] = torch.from_numpy(X_hist_obs).float()
|
|
441
|
+
|
|
442
|
+
# Covariates for future (target sites)
|
|
443
|
+
if self.p_covariates > 0:
|
|
444
|
+
X_fut_list = []
|
|
445
|
+
|
|
446
|
+
# Coordinate covariates
|
|
447
|
+
if self.use_coords_cov:
|
|
448
|
+
# (n_obs, 2) - Target has same coordinates
|
|
449
|
+
X_fut_list.append(target_coords)
|
|
450
|
+
|
|
451
|
+
# Time covariates for future
|
|
452
|
+
if self.use_time_cov:
|
|
453
|
+
# Use only first time point of future (t0)
|
|
454
|
+
t_future = float(t0) / self.T # Normalize
|
|
455
|
+
|
|
456
|
+
if self.time_encoding == 'sinusoidal':
|
|
457
|
+
# sin/cos encoding: (n_obs, 2)
|
|
458
|
+
t_sin = np.sin(2 * np.pi * t_future)
|
|
459
|
+
t_cos = np.cos(2 * np.pi * t_future)
|
|
460
|
+
t_fut_cov = np.tile(np.array([[t_sin, t_cos]]), (self.n_obs, 1))
|
|
461
|
+
else: # linear
|
|
462
|
+
# (n_obs, 1)
|
|
463
|
+
t_fut_cov = np.full((self.n_obs, 1), t_future, dtype=np.float32)
|
|
464
|
+
|
|
465
|
+
X_fut_list.append(t_fut_cov)
|
|
466
|
+
|
|
467
|
+
# Concatenate: (n_obs, p)
|
|
468
|
+
if len(X_fut_list) > 0:
|
|
469
|
+
X_fut_target = np.concatenate(X_fut_list, axis=-1)
|
|
470
|
+
result['X_fut_target'] = torch.from_numpy(X_fut_target).float()
|
|
471
|
+
|
|
472
|
+
return result
|
|
473
|
+
|
|
474
|
+
|
|
475
|
+
def create_dataloaders(
|
|
476
|
+
z_train: np.ndarray,
|
|
477
|
+
coords: np.ndarray,
|
|
478
|
+
obs_indices: np.ndarray,
|
|
479
|
+
config: dict,
|
|
480
|
+
val_ratio: float = 0.2
|
|
481
|
+
) -> Tuple[DataLoader, DataLoader]:
|
|
482
|
+
"""
|
|
483
|
+
Create Train/Val DataLoaders (split by Target)
|
|
484
|
+
|
|
485
|
+
Context is taken from entire z_train,
|
|
486
|
+
but Target (prediction interval) is split into train/valid
|
|
487
|
+
|
|
488
|
+
Example: T=90, L=24, H=10, val_ratio=0.2
|
|
489
|
+
- Train: t0 = [24, 72), target = [24, 82)
|
|
490
|
+
- Valid: t0 = [72, 80], target = [72, 90)
|
|
491
|
+
|
|
492
|
+
Args:
|
|
493
|
+
z_train: (T_tr, S) - Training time series
|
|
494
|
+
coords: (S, 2) - Site coordinates
|
|
495
|
+
obs_indices: (n_obs,) - Observed sites
|
|
496
|
+
config: kaust_data.yaml configuration
|
|
497
|
+
val_ratio: Validation ratio
|
|
498
|
+
|
|
499
|
+
Returns:
|
|
500
|
+
train_loader, val_loader
|
|
501
|
+
"""
|
|
502
|
+
L = config['L']
|
|
503
|
+
H = config['H']
|
|
504
|
+
batch_size = config['batch_size']
|
|
505
|
+
num_workers = config.get('num_workers', 0)
|
|
506
|
+
|
|
507
|
+
# Extract covariates settings (if present)
|
|
508
|
+
use_coords_cov = config.get('use_coords_cov', False)
|
|
509
|
+
use_time_cov = config.get('use_time_cov', False)
|
|
510
|
+
time_encoding = config.get('time_encoding', 'linear')
|
|
511
|
+
|
|
512
|
+
T_tr = z_train.shape[0]
|
|
513
|
+
|
|
514
|
+
# Split Train/Val by Target
|
|
515
|
+
# Maximum t0: T_tr - H (since target is [t0, t0+H))
|
|
516
|
+
t0_max = T_tr - H # T=90, H=10 → t0_max = 80
|
|
517
|
+
t0_split = int(t0_max * (1 - val_ratio)) # 0.8 → 64
|
|
518
|
+
|
|
519
|
+
# Create datasets (share entire z_train, only t0 range differs)
|
|
520
|
+
train_dataset = KAUSTWindowDataset(
|
|
521
|
+
z_train, coords, obs_indices, L, H, stride=1,
|
|
522
|
+
t0_min=L, t0_max=t0_split, # t0 = [L, t0_split)
|
|
523
|
+
use_coords_cov=use_coords_cov,
|
|
524
|
+
use_time_cov=use_time_cov,
|
|
525
|
+
time_encoding=time_encoding
|
|
526
|
+
)
|
|
527
|
+
|
|
528
|
+
val_dataset = KAUSTWindowDataset(
|
|
529
|
+
z_train, coords, obs_indices, L, H, stride=1, # stride=1 for temporal split
|
|
530
|
+
t0_min=t0_split, t0_max=t0_max + 1, # t0 = [t0_split, t0_max]
|
|
531
|
+
use_coords_cov=use_coords_cov,
|
|
532
|
+
use_time_cov=use_time_cov,
|
|
533
|
+
time_encoding=time_encoding
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# DataLoader
|
|
537
|
+
train_loader = DataLoader(
|
|
538
|
+
train_dataset,
|
|
539
|
+
batch_size=batch_size,
|
|
540
|
+
shuffle=True,
|
|
541
|
+
num_workers=num_workers,
|
|
542
|
+
pin_memory=True
|
|
543
|
+
)
|
|
544
|
+
|
|
545
|
+
val_loader = DataLoader(
|
|
546
|
+
val_dataset,
|
|
547
|
+
batch_size=batch_size,
|
|
548
|
+
shuffle=False,
|
|
549
|
+
num_workers=num_workers,
|
|
550
|
+
pin_memory=True
|
|
551
|
+
)
|
|
552
|
+
|
|
553
|
+
print(f"[INFO] Train: {len(train_dataset)} windows, Val: {len(val_dataset)} windows")
|
|
554
|
+
|
|
555
|
+
return train_loader, val_loader
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
def prepare_test_context(
|
|
559
|
+
z_train: np.ndarray,
|
|
560
|
+
coords: np.ndarray,
|
|
561
|
+
obs_indices: np.ndarray,
|
|
562
|
+
L: int
|
|
563
|
+
) -> Dict[str, torch.Tensor]:
|
|
564
|
+
"""
|
|
565
|
+
Prepare context for test prediction
|
|
566
|
+
|
|
567
|
+
Use last L time points as context
|
|
568
|
+
|
|
569
|
+
Args:
|
|
570
|
+
z_train: (T_tr, S)
|
|
571
|
+
coords: (S, 2)
|
|
572
|
+
obs_indices: (n_obs,)
|
|
573
|
+
L: context length
|
|
574
|
+
|
|
575
|
+
Returns:
|
|
576
|
+
context: dict with obs_coords, target_coords, y_hist_obs
|
|
577
|
+
"""
|
|
578
|
+
T_tr, S = z_train.shape
|
|
579
|
+
|
|
580
|
+
# Last L time points
|
|
581
|
+
y_hist_obs = z_train[-L:, obs_indices] # (L, n_obs)
|
|
582
|
+
|
|
583
|
+
obs_coords = coords[obs_indices] # (n_obs, 2)
|
|
584
|
+
target_coords = coords # (S, 2)
|
|
585
|
+
|
|
586
|
+
return {
|
|
587
|
+
'obs_coords': torch.from_numpy(obs_coords).float().unsqueeze(0), # (1, n_obs, 2)
|
|
588
|
+
'target_coords': torch.from_numpy(target_coords).float().unsqueeze(0), # (1, S, 2)
|
|
589
|
+
'y_hist_obs': torch.from_numpy(y_hist_obs).float().unsqueeze(0).unsqueeze(-1) # (1, L, n_obs, 1)
|
|
590
|
+
}
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
def predictions_to_csv(
|
|
594
|
+
y_pred: np.ndarray,
|
|
595
|
+
test_csv_path: str,
|
|
596
|
+
output_path: str,
|
|
597
|
+
site_to_idx: dict,
|
|
598
|
+
z_mean: float,
|
|
599
|
+
z_std: float,
|
|
600
|
+
denormalize: bool = True
|
|
601
|
+
):
|
|
602
|
+
"""
|
|
603
|
+
Save prediction results to CSV for submission
|
|
604
|
+
|
|
605
|
+
Args:
|
|
606
|
+
y_pred: (H, S) - Predictions
|
|
607
|
+
test_csv_path: Original test.csv path (for row order reference)
|
|
608
|
+
output_path: Output CSV path
|
|
609
|
+
site_to_idx: (x, y) → site index mapping
|
|
610
|
+
z_mean, z_std: Normalization statistics
|
|
611
|
+
denormalize: Whether to denormalize
|
|
612
|
+
"""
|
|
613
|
+
# Load test.csv
|
|
614
|
+
df_test = pd.read_csv(test_csv_path)
|
|
615
|
+
|
|
616
|
+
# Denormalize
|
|
617
|
+
if denormalize:
|
|
618
|
+
y_pred = y_pred * z_std + z_mean
|
|
619
|
+
|
|
620
|
+
# Map predictions
|
|
621
|
+
z_hat_list = []
|
|
622
|
+
for _, row in df_test.iterrows():
|
|
623
|
+
t = int(row['t'])
|
|
624
|
+
site_idx = site_to_idx[(row['x'], row['y'])]
|
|
625
|
+
|
|
626
|
+
# Convert t to relative index in test interval
|
|
627
|
+
# Here, simply assume first test time point as 0
|
|
628
|
+
t_rel = t - df_test['t'].min()
|
|
629
|
+
|
|
630
|
+
if t_rel < len(y_pred):
|
|
631
|
+
z_hat = y_pred[t_rel, site_idx]
|
|
632
|
+
else:
|
|
633
|
+
z_hat = np.nan
|
|
634
|
+
|
|
635
|
+
z_hat_list.append(z_hat)
|
|
636
|
+
|
|
637
|
+
# Save CSV
|
|
638
|
+
df_output = pd.DataFrame({'z': z_hat_list})
|
|
639
|
+
df_output.to_csv(output_path, index=False)
|
|
640
|
+
print(f"[INFO] Saved predictions to {output_path}")
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
if __name__ == '__main__':
|
|
644
|
+
# Test code
|
|
645
|
+
train_path = 'data/2b/2b_7_train.csv'
|
|
646
|
+
test_path = 'data/2b/2b_7_test.csv'
|
|
647
|
+
|
|
648
|
+
# Load
|
|
649
|
+
z_train, z_test, coords, site_to_idx, metadata = load_kaust_csv(
|
|
650
|
+
train_path, test_path, normalize=True
|
|
651
|
+
)
|
|
652
|
+
|
|
653
|
+
# Sample observed sites
|
|
654
|
+
obs_indices = sample_observed_sites(
|
|
655
|
+
coords, obs_fraction=0.1, sampling_method='uniform', seed=42
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
print(f"Observed sites: {obs_indices[:10]}...")
|
|
659
|
+
print(f"z_train shape: {z_train.shape}")
|
|
660
|
+
print(f"coords shape: {coords.shape}")
|