da-stdk 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
da_stdk/__init__.py ADDED
@@ -0,0 +1,5 @@
1
+ """
2
+ DA-STDK: Data-adaptive spatio-temporal distributional prediction.
3
+ Cluster-adaptive bases and conformal calibration for prediction intervals.
4
+ """
5
+ __version__ = "0.1.0"
@@ -0,0 +1,25 @@
1
+ """
2
+ Data I/O for STNF-XAttn
3
+ """
4
+ from .kaust_loader import (
5
+ load_kaust_csv,
6
+ sample_observed_sites,
7
+ KAUSTWindowDataset,
8
+ create_dataloaders,
9
+ prepare_test_context,
10
+ predictions_to_csv
11
+ )
12
+ from .obs_sampling import create_spatial_obs_prob_fn, sample_observations
13
+ from .splits import split_train_valid
14
+
15
+ __all__ = [
16
+ 'load_kaust_csv',
17
+ 'sample_observed_sites',
18
+ 'KAUSTWindowDataset',
19
+ 'create_dataloaders',
20
+ 'prepare_test_context',
21
+ 'predictions_to_csv',
22
+ 'create_spatial_obs_prob_fn',
23
+ 'sample_observations',
24
+ 'split_train_valid',
25
+ ]
@@ -0,0 +1,660 @@
1
+ """
2
+ KAUST CSV data loader
3
+
4
+ Features:
5
+ 1. Load train.csv, test.csv (x, y, t, z format)
6
+ 2. Create site indices from (x, y) coordinates (train+test combined)
7
+ 3. Reconstruct time series matrix (T, S)
8
+ 4. Sample observed sites (Uniform/Biased)
9
+ 5. Sliding window Dataset (L-step context, H-step forecast)
10
+ """
11
+ import numpy as np
12
+ import pandas as pd
13
+ import torch
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from typing import Tuple, Dict, Optional, List
16
+ from pathlib import Path
17
+
18
+
19
+ def load_kaust_csv_single(
20
+ data_path: str,
21
+ normalize: bool = True
22
+ ) -> Tuple[np.ndarray, np.ndarray, Dict]:
23
+ """
24
+ Load KAUST CSV file (single file)
25
+
26
+ Args:
27
+ data_path: CSV file path
28
+ normalize: Whether to normalize z values
29
+
30
+ Returns:
31
+ z_data: (T, S) - Complete time series
32
+ coords: (S, 2) - Site coordinates [x, y], already in [0,1]
33
+ metadata: dict - Normalization statistics, etc.
34
+ """
35
+ # Load CSV
36
+ df = pd.read_csv(data_path)
37
+ print(f"[INFO] Loaded data: {len(df)} rows")
38
+
39
+ # 1. Create site indices
40
+ all_coords = df[['x', 'y']].drop_duplicates().reset_index(drop=True)
41
+ S = len(all_coords)
42
+ print(f"[INFO] Total sites: {S}")
43
+
44
+ # Site mapping: (x, y) → index
45
+ site_to_idx = {
46
+ (row['x'], row['y']): idx
47
+ for idx, row in all_coords.iterrows()
48
+ }
49
+
50
+ # Coordinate array: (S, 2), already in [0,1]^2
51
+ coords = all_coords[['x', 'y']].values.astype(np.float32)
52
+
53
+ # 2. Time indices
54
+ t_vals = df['t'].values
55
+ T = int(t_vals.max())
56
+ print(f"[INFO] Time range: 1 ~ {T}")
57
+
58
+ # 3. Reconstruct time series matrix: (T, S)
59
+ z_data = np.full((T, S), np.nan, dtype=np.float32)
60
+ for _, row in df.iterrows():
61
+ t_idx = int(row['t']) - 1 # 0-based indexing
62
+ site_idx = site_to_idx[(row['x'], row['y'])]
63
+ z_data[t_idx, site_idx] = row['z']
64
+
65
+ # 4. Normalize (z values only)
66
+ metadata = {}
67
+ if normalize:
68
+ z_flat = z_data[~np.isnan(z_data)]
69
+ z_mean = z_flat.mean()
70
+ z_std = z_flat.std()
71
+ z_data = (z_data - z_mean) / z_std
72
+ metadata['z_mean'] = z_mean
73
+ metadata['z_std'] = z_std
74
+ print(f"[INFO] Normalized z: mean={z_mean:.4f}, std={z_std:.4f}")
75
+
76
+ return z_data, coords, metadata
77
+
78
+
79
+ def load_kaust_csv(
80
+ train_path: str,
81
+ test_path: str,
82
+ normalize: bool = True
83
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, Dict]:
84
+ """
85
+ Load and preprocess KAUST CSV files
86
+
87
+ Args:
88
+ train_path: train.csv file path
89
+ test_path: test.csv file path
90
+ normalize: Whether to normalize z values
91
+
92
+ Returns:
93
+ z_train: (T_tr, S) - Training time series
94
+ z_test: (T_te, S) - Test time series (initialized with NaN)
95
+ coords: (S, 2) - Site coordinates [x, y]
96
+ site_to_idx: dict - (x, y) → site index mapping
97
+ metadata: dict - Normalization statistics, etc.
98
+ """
99
+ # Load CSV
100
+ df_train = pd.read_csv(train_path)
101
+ df_test = pd.read_csv(test_path)
102
+
103
+ print(f"[INFO] Loaded train: {len(df_train)} rows")
104
+ print(f"[INFO] Loaded test: {len(df_test)} rows")
105
+
106
+ # 1. Create site indices (train + test combined)
107
+ # Define sites by unique combinations of (x, y) coordinates
108
+ all_coords = pd.concat([
109
+ df_train[['x', 'y']],
110
+ df_test[['x', 'y']]
111
+ ]).drop_duplicates().reset_index(drop=True)
112
+
113
+ S = len(all_coords)
114
+ print(f"[INFO] Total sites: {S}")
115
+
116
+ # Site mapping: (x, y) → index
117
+ site_to_idx = {
118
+ (row['x'], row['y']): idx
119
+ for idx, row in all_coords.iterrows()
120
+ }
121
+
122
+ # Coordinate array: (S, 2)
123
+ coords = all_coords[['x', 'y']].values.astype(np.float32)
124
+
125
+ # 2. Time indices (assuming t starts from 1)
126
+ t_train = df_train['t'].values
127
+ t_test = df_test['t'].values
128
+
129
+ T_tr = t_train.max()
130
+ T_te_end = t_test.max()
131
+ T_te_start = t_test.min()
132
+
133
+ print(f"[INFO] Train time range: 1 ~ {T_tr}")
134
+ print(f"[INFO] Test time range: {T_te_start} ~ {T_te_end}")
135
+
136
+ # 3. Reconstruct time series matrix
137
+ # z_train: (T_tr, S)
138
+ z_train = np.full((T_tr, S), np.nan, dtype=np.float32)
139
+ for _, row in df_train.iterrows():
140
+ t_idx = int(row['t']) - 1 # 0-based indexing
141
+ site_idx = site_to_idx[(row['x'], row['y'])]
142
+ z_train[t_idx, site_idx] = row['z']
143
+
144
+ # z_test: (T_te, S) - Initialized with NaN (prediction target)
145
+ T_te = T_te_end - T_te_start + 1
146
+ z_test = np.full((T_te, S), np.nan, dtype=np.float32)
147
+ # test.csv doesn't have z values, so keep NaN
148
+
149
+ # 4. Normalize (based on train data)
150
+ metadata = {}
151
+ if normalize:
152
+ z_train_valid = z_train[~np.isnan(z_train)]
153
+ z_mean = z_train_valid.mean()
154
+ z_std = z_train_valid.std() + 1e-8
155
+
156
+ z_train = (z_train - z_mean) / z_std
157
+
158
+ metadata['z_mean'] = float(z_mean)
159
+ metadata['z_std'] = float(z_std)
160
+ print(f"[INFO] Normalized: mean={z_mean:.4f}, std={z_std:.4f}")
161
+ else:
162
+ metadata['z_mean'] = 0.0
163
+ metadata['z_std'] = 1.0
164
+
165
+ # 5. Metadata
166
+ metadata.update({
167
+ 'S': S,
168
+ 'T_tr': T_tr,
169
+ 'T_te': T_te,
170
+ 'T_te_start': T_te_start,
171
+ 'coords': coords,
172
+ 'site_to_idx': site_to_idx
173
+ })
174
+
175
+ return z_train, z_test, coords, site_to_idx, metadata
176
+
177
+
178
+ def load_test_ground_truth_from_full(
179
+ full_csv_path: str,
180
+ site_to_idx: dict,
181
+ T_te_start: int,
182
+ T_te: int,
183
+ ) -> np.ndarray:
184
+ """
185
+ Load test-period z values from the full CSV (e.g. 2b_8.csv) so we can
186
+ evaluate on the provider's test set. Uses site_to_idx from load_kaust_csv
187
+ so site order matches.
188
+
189
+ Args:
190
+ full_csv_path: path to full CSV (x, y, t, z) with all time steps
191
+ site_to_idx: (x, y) -> index from load_kaust_csv
192
+ T_te_start: first test time step (1-based, e.g. 91)
193
+ T_te: number of test time steps
194
+
195
+ Returns:
196
+ z_test_gt: (T_te, S) float32 array
197
+ """
198
+ df = pd.read_csv(full_csv_path)
199
+ S = len(site_to_idx)
200
+ z_test_gt = np.full((T_te, S), np.nan, dtype=np.float32)
201
+ for _, row in df.iterrows():
202
+ t_val = int(row['t'])
203
+ if t_val < T_te_start or t_val > T_te_start + T_te - 1:
204
+ continue
205
+ t_idx = t_val - T_te_start # 0-based
206
+ key = (float(row['x']), float(row['y']))
207
+ if key not in site_to_idx:
208
+ continue
209
+ site_idx = site_to_idx[key]
210
+ z_test_gt[t_idx, site_idx] = row['z']
211
+ return z_test_gt
212
+
213
+
214
+ def load_kaust_csv_with_test_gt(
215
+ train_path: str,
216
+ test_path: str,
217
+ full_csv_path: str,
218
+ normalize: bool = False,
219
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict]:
220
+ """
221
+ Load provider train/test split and fill test z from full CSV for evaluation.
222
+
223
+ Returns:
224
+ z_full: (T_tr + T_te, S) - train then test time steps
225
+ coords: (S, 2)
226
+ metadata: includes z_mean, z_std (from train only), T_tr, T_te, T_te_start
227
+ """
228
+ z_train, z_test_empty, coords, site_to_idx, meta = load_kaust_csv(
229
+ train_path, test_path, normalize=False
230
+ )
231
+ T_tr = meta['T_tr']
232
+ T_te = meta['T_te']
233
+ T_te_start = meta['T_te_start']
234
+
235
+ z_test_gt = load_test_ground_truth_from_full(
236
+ full_csv_path, site_to_idx, T_te_start, T_te
237
+ )
238
+ z_full = np.concatenate([z_train, z_test_gt], axis=0).astype(np.float32)
239
+
240
+ if normalize:
241
+ z_train_valid = z_train[~np.isnan(z_train)]
242
+ z_mean = float(z_train_valid.mean())
243
+ z_std = float(z_train_valid.std() + 1e-8)
244
+ z_full = (z_full - z_mean) / z_std
245
+ meta['z_mean'] = z_mean
246
+ meta['z_std'] = z_std
247
+ else:
248
+ meta['z_mean'] = 0.0
249
+ meta['z_std'] = 1.0
250
+
251
+ return z_full, coords, meta
252
+
253
+
254
+ def sample_observed_sites(
255
+ coords: np.ndarray,
256
+ obs_fraction: float,
257
+ sampling_method: str = 'uniform',
258
+ bias_sigma: float = 0.15,
259
+ bias_temp: float = 1.0,
260
+ seed: Optional[int] = None
261
+ ) -> np.ndarray:
262
+ """
263
+ Sample observed sites
264
+
265
+ Args:
266
+ coords: (S, 2) - Site coordinates
267
+ obs_fraction: Observation ratio (0~1)
268
+ sampling_method: 'uniform' or 'biased'
269
+ bias_sigma: Biased sampling distance scale
270
+ bias_temp: Biased sampling temperature
271
+ seed: Random seed
272
+
273
+ Returns:
274
+ obs_indices: (n_obs,) - Observed site index array
275
+ """
276
+ if seed is not None:
277
+ np.random.seed(seed)
278
+
279
+ S = len(coords)
280
+ n_obs = max(1, int(S * obs_fraction))
281
+
282
+ if sampling_method == 'uniform':
283
+ # Uniform sampling
284
+ obs_indices = np.random.choice(S, size=n_obs, replace=False)
285
+ print(f"[INFO] Sampled {n_obs}/{S} sites (uniform)")
286
+
287
+ elif sampling_method == 'biased':
288
+ # Biased sampling (weighted near origin). Uses Gaussian weights;
289
+ # this is NOT the same as KAUST experiment / paper clustered formula
290
+ # p(s) ∝ (1+10||s||)^{-2} used in train_st_interp and visualize_obs_density.
291
+ distances = np.sqrt(coords[:, 0]**2 + coords[:, 1]**2)
292
+ weights = np.exp(- (distances**2) / (2 * bias_sigma**2))
293
+
294
+ # Temperature scaling
295
+ weights = weights ** (1.0 / bias_temp)
296
+
297
+ # Normalize
298
+ probs = weights / weights.sum()
299
+
300
+ # Sampling
301
+ obs_indices = np.random.choice(S, size=n_obs, replace=False, p=probs)
302
+
303
+ avg_dist = distances[obs_indices].mean()
304
+ print(f"[INFO] Sampled {n_obs}/{S} sites (biased, avg_dist={avg_dist:.4f})")
305
+
306
+ else:
307
+ raise ValueError(f"Unknown sampling method: {sampling_method}")
308
+
309
+ return np.sort(obs_indices)
310
+
311
+
312
+ class KAUSTWindowDataset(Dataset):
313
+ """
314
+ Sliding window Dataset
315
+
316
+ During training:
317
+ - Input: Observed site data from [t0-L, t0) interval
318
+ - Target: All site data from [t0, t0+H) interval
319
+
320
+ Args:
321
+ z_full: (T, S) - Complete time series (train only)
322
+ coords: (S, 2) - Site coordinates
323
+ obs_indices: (n_obs,) - Observed site indices
324
+ L: context length
325
+ H: forecast horizon
326
+ stride: Sliding window stride (default 1)
327
+ t0_min: Minimum t0 (use L if None)
328
+ t0_max: Maximum t0 (use T-H+1 if None)
329
+ use_coords_cov: Use (x, y) as covariates
330
+ use_time_cov: Use t as covariates
331
+ time_encoding: Time encoding method {linear, sinusoidal}
332
+ """
333
+ def __init__(
334
+ self,
335
+ z_full: np.ndarray,
336
+ coords: np.ndarray,
337
+ obs_indices: np.ndarray,
338
+ L: int,
339
+ H: int,
340
+ stride: int = 1,
341
+ t0_min: int = None,
342
+ t0_max: int = None,
343
+ use_coords_cov: bool = False,
344
+ use_time_cov: bool = False,
345
+ time_encoding: str = 'linear'
346
+ ):
347
+ self.z_full = z_full # (T, S)
348
+ self.coords = coords # (S, 2)
349
+ self.obs_indices = obs_indices # (n_obs,)
350
+ self.L = L
351
+ self.H = H
352
+ self.stride = stride
353
+ self.use_coords_cov = use_coords_cov
354
+ self.use_time_cov = use_time_cov
355
+ self.time_encoding = time_encoding
356
+
357
+ self.T, self.S = z_full.shape
358
+ self.n_obs = len(obs_indices)
359
+
360
+ # Calculate covariates dimension
361
+ self.p_covariates = 0
362
+ if use_coords_cov:
363
+ self.p_covariates += 2 # (x, y)
364
+ if use_time_cov:
365
+ if time_encoding == 'sinusoidal':
366
+ self.p_covariates += 2 # (sin(t), cos(t))
367
+ else: # linear
368
+ self.p_covariates += 1 # t
369
+
370
+ # Valid window start points
371
+ # t0-L >= 0 and t0+H <= T
372
+ if t0_min is None:
373
+ t0_min = L
374
+ if t0_max is None:
375
+ t0_max = self.T - H + 1
376
+
377
+ self.valid_t0 = list(range(t0_min, t0_max, stride))
378
+
379
+ cov_info = f", p_cov={self.p_covariates}" if self.p_covariates > 0 else ""
380
+ print(f"[INFO] Dataset: {len(self.valid_t0)} windows (L={L}, H={H}, stride={stride}, t0=[{t0_min}, {t0_max}){cov_info})")
381
+
382
+ def __len__(self):
383
+ return len(self.valid_t0)
384
+
385
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
386
+ t0 = self.valid_t0[idx]
387
+
388
+ # 1. Context: Observed sites from [t0-L, t0)
389
+ y_hist_obs = self.z_full[t0-self.L:t0, self.obs_indices] # (L, n_obs)
390
+
391
+ # 2. Target: Only observed sites from [t0, t0+H)
392
+ y_fut = self.z_full[t0:t0+self.H, self.obs_indices] # (H, n_obs)
393
+
394
+ # 3. Coordinates
395
+ obs_coords = self.coords[self.obs_indices] # (n_obs, 2)
396
+ target_coords = self.coords[self.obs_indices] # (n_obs, 2) - Same!
397
+
398
+ # 4. Create covariates
399
+ result = {
400
+ 'obs_coords': torch.from_numpy(obs_coords).float(), # (n_obs, 2)
401
+ 'target_coords': torch.from_numpy(target_coords).float(), # (n_obs, 2)
402
+ 'y_hist_obs': torch.from_numpy(y_hist_obs).float().unsqueeze(-1), # (L, n_obs, 1)
403
+ 'y_fut': torch.from_numpy(y_fut).float().unsqueeze(-1), # (H, n_obs, 1)
404
+ 't0': t0
405
+ }
406
+
407
+ # Covariates for history (observed sites)
408
+ if self.p_covariates > 0:
409
+ X_hist_list = []
410
+
411
+ # Coordinate covariates: (n_obs, 2)
412
+ if self.use_coords_cov:
413
+ # Expand (n_obs, 2) to (L, n_obs, 2)
414
+ coords_cov = np.tile(obs_coords[np.newaxis, :, :], (self.L, 1, 1))
415
+ X_hist_list.append(coords_cov)
416
+
417
+ # Time covariates
418
+ if self.use_time_cov:
419
+ # Time indices: [t0-L, t0) → normalized time
420
+ t_indices = np.arange(t0 - self.L, t0).astype(np.float32)
421
+ t_normalized = t_indices / self.T # Normalize to [0, 1] range
422
+
423
+ if self.time_encoding == 'sinusoidal':
424
+ # sin/cos encoding
425
+ t_sin = np.sin(2 * np.pi * t_normalized) # (L,)
426
+ t_cos = np.cos(2 * np.pi * t_normalized) # (L,)
427
+ # (L, n_obs, 2)
428
+ t_cov = np.stack([
429
+ np.tile(t_sin[:, np.newaxis], (1, self.n_obs)),
430
+ np.tile(t_cos[:, np.newaxis], (1, self.n_obs))
431
+ ], axis=-1)
432
+ else: # linear
433
+ # (L, n_obs, 1)
434
+ t_cov = np.tile(t_normalized[:, np.newaxis, np.newaxis], (1, self.n_obs, 1))
435
+
436
+ X_hist_list.append(t_cov)
437
+
438
+ # Concatenate: (L, n_obs, p)
439
+ X_hist_obs = np.concatenate(X_hist_list, axis=-1)
440
+ result['X_hist_obs'] = torch.from_numpy(X_hist_obs).float()
441
+
442
+ # Covariates for future (target sites)
443
+ if self.p_covariates > 0:
444
+ X_fut_list = []
445
+
446
+ # Coordinate covariates
447
+ if self.use_coords_cov:
448
+ # (n_obs, 2) - Target has same coordinates
449
+ X_fut_list.append(target_coords)
450
+
451
+ # Time covariates for future
452
+ if self.use_time_cov:
453
+ # Use only first time point of future (t0)
454
+ t_future = float(t0) / self.T # Normalize
455
+
456
+ if self.time_encoding == 'sinusoidal':
457
+ # sin/cos encoding: (n_obs, 2)
458
+ t_sin = np.sin(2 * np.pi * t_future)
459
+ t_cos = np.cos(2 * np.pi * t_future)
460
+ t_fut_cov = np.tile(np.array([[t_sin, t_cos]]), (self.n_obs, 1))
461
+ else: # linear
462
+ # (n_obs, 1)
463
+ t_fut_cov = np.full((self.n_obs, 1), t_future, dtype=np.float32)
464
+
465
+ X_fut_list.append(t_fut_cov)
466
+
467
+ # Concatenate: (n_obs, p)
468
+ if len(X_fut_list) > 0:
469
+ X_fut_target = np.concatenate(X_fut_list, axis=-1)
470
+ result['X_fut_target'] = torch.from_numpy(X_fut_target).float()
471
+
472
+ return result
473
+
474
+
475
+ def create_dataloaders(
476
+ z_train: np.ndarray,
477
+ coords: np.ndarray,
478
+ obs_indices: np.ndarray,
479
+ config: dict,
480
+ val_ratio: float = 0.2
481
+ ) -> Tuple[DataLoader, DataLoader]:
482
+ """
483
+ Create Train/Val DataLoaders (split by Target)
484
+
485
+ Context is taken from entire z_train,
486
+ but Target (prediction interval) is split into train/valid
487
+
488
+ Example: T=90, L=24, H=10, val_ratio=0.2
489
+ - Train: t0 = [24, 72), target = [24, 82)
490
+ - Valid: t0 = [72, 80], target = [72, 90)
491
+
492
+ Args:
493
+ z_train: (T_tr, S) - Training time series
494
+ coords: (S, 2) - Site coordinates
495
+ obs_indices: (n_obs,) - Observed sites
496
+ config: kaust_data.yaml configuration
497
+ val_ratio: Validation ratio
498
+
499
+ Returns:
500
+ train_loader, val_loader
501
+ """
502
+ L = config['L']
503
+ H = config['H']
504
+ batch_size = config['batch_size']
505
+ num_workers = config.get('num_workers', 0)
506
+
507
+ # Extract covariates settings (if present)
508
+ use_coords_cov = config.get('use_coords_cov', False)
509
+ use_time_cov = config.get('use_time_cov', False)
510
+ time_encoding = config.get('time_encoding', 'linear')
511
+
512
+ T_tr = z_train.shape[0]
513
+
514
+ # Split Train/Val by Target
515
+ # Maximum t0: T_tr - H (since target is [t0, t0+H))
516
+ t0_max = T_tr - H # T=90, H=10 → t0_max = 80
517
+ t0_split = int(t0_max * (1 - val_ratio)) # 0.8 → 64
518
+
519
+ # Create datasets (share entire z_train, only t0 range differs)
520
+ train_dataset = KAUSTWindowDataset(
521
+ z_train, coords, obs_indices, L, H, stride=1,
522
+ t0_min=L, t0_max=t0_split, # t0 = [L, t0_split)
523
+ use_coords_cov=use_coords_cov,
524
+ use_time_cov=use_time_cov,
525
+ time_encoding=time_encoding
526
+ )
527
+
528
+ val_dataset = KAUSTWindowDataset(
529
+ z_train, coords, obs_indices, L, H, stride=1, # stride=1 for temporal split
530
+ t0_min=t0_split, t0_max=t0_max + 1, # t0 = [t0_split, t0_max]
531
+ use_coords_cov=use_coords_cov,
532
+ use_time_cov=use_time_cov,
533
+ time_encoding=time_encoding
534
+ )
535
+
536
+ # DataLoader
537
+ train_loader = DataLoader(
538
+ train_dataset,
539
+ batch_size=batch_size,
540
+ shuffle=True,
541
+ num_workers=num_workers,
542
+ pin_memory=True
543
+ )
544
+
545
+ val_loader = DataLoader(
546
+ val_dataset,
547
+ batch_size=batch_size,
548
+ shuffle=False,
549
+ num_workers=num_workers,
550
+ pin_memory=True
551
+ )
552
+
553
+ print(f"[INFO] Train: {len(train_dataset)} windows, Val: {len(val_dataset)} windows")
554
+
555
+ return train_loader, val_loader
556
+
557
+
558
+ def prepare_test_context(
559
+ z_train: np.ndarray,
560
+ coords: np.ndarray,
561
+ obs_indices: np.ndarray,
562
+ L: int
563
+ ) -> Dict[str, torch.Tensor]:
564
+ """
565
+ Prepare context for test prediction
566
+
567
+ Use last L time points as context
568
+
569
+ Args:
570
+ z_train: (T_tr, S)
571
+ coords: (S, 2)
572
+ obs_indices: (n_obs,)
573
+ L: context length
574
+
575
+ Returns:
576
+ context: dict with obs_coords, target_coords, y_hist_obs
577
+ """
578
+ T_tr, S = z_train.shape
579
+
580
+ # Last L time points
581
+ y_hist_obs = z_train[-L:, obs_indices] # (L, n_obs)
582
+
583
+ obs_coords = coords[obs_indices] # (n_obs, 2)
584
+ target_coords = coords # (S, 2)
585
+
586
+ return {
587
+ 'obs_coords': torch.from_numpy(obs_coords).float().unsqueeze(0), # (1, n_obs, 2)
588
+ 'target_coords': torch.from_numpy(target_coords).float().unsqueeze(0), # (1, S, 2)
589
+ 'y_hist_obs': torch.from_numpy(y_hist_obs).float().unsqueeze(0).unsqueeze(-1) # (1, L, n_obs, 1)
590
+ }
591
+
592
+
593
+ def predictions_to_csv(
594
+ y_pred: np.ndarray,
595
+ test_csv_path: str,
596
+ output_path: str,
597
+ site_to_idx: dict,
598
+ z_mean: float,
599
+ z_std: float,
600
+ denormalize: bool = True
601
+ ):
602
+ """
603
+ Save prediction results to CSV for submission
604
+
605
+ Args:
606
+ y_pred: (H, S) - Predictions
607
+ test_csv_path: Original test.csv path (for row order reference)
608
+ output_path: Output CSV path
609
+ site_to_idx: (x, y) → site index mapping
610
+ z_mean, z_std: Normalization statistics
611
+ denormalize: Whether to denormalize
612
+ """
613
+ # Load test.csv
614
+ df_test = pd.read_csv(test_csv_path)
615
+
616
+ # Denormalize
617
+ if denormalize:
618
+ y_pred = y_pred * z_std + z_mean
619
+
620
+ # Map predictions
621
+ z_hat_list = []
622
+ for _, row in df_test.iterrows():
623
+ t = int(row['t'])
624
+ site_idx = site_to_idx[(row['x'], row['y'])]
625
+
626
+ # Convert t to relative index in test interval
627
+ # Here, simply assume first test time point as 0
628
+ t_rel = t - df_test['t'].min()
629
+
630
+ if t_rel < len(y_pred):
631
+ z_hat = y_pred[t_rel, site_idx]
632
+ else:
633
+ z_hat = np.nan
634
+
635
+ z_hat_list.append(z_hat)
636
+
637
+ # Save CSV
638
+ df_output = pd.DataFrame({'z': z_hat_list})
639
+ df_output.to_csv(output_path, index=False)
640
+ print(f"[INFO] Saved predictions to {output_path}")
641
+
642
+
643
+ if __name__ == '__main__':
644
+ # Test code
645
+ train_path = 'data/2b/2b_7_train.csv'
646
+ test_path = 'data/2b/2b_7_test.csv'
647
+
648
+ # Load
649
+ z_train, z_test, coords, site_to_idx, metadata = load_kaust_csv(
650
+ train_path, test_path, normalize=True
651
+ )
652
+
653
+ # Sample observed sites
654
+ obs_indices = sample_observed_sites(
655
+ coords, obs_fraction=0.1, sampling_method='uniform', seed=42
656
+ )
657
+
658
+ print(f"Observed sites: {obs_indices[:10]}...")
659
+ print(f"z_train shape: {z_train.shape}")
660
+ print(f"coords shape: {coords.shape}")