satcube 0.1.17__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of satcube might be problematic. Click here for more details.

@@ -0,0 +1,1087 @@
1
+ import importlib.util
2
+ import pathlib
3
+ import shutil
4
+ from typing import Callable, List, Tuple, Union, Any
5
+
6
+ import ee
7
+ import fastcubo
8
+ import matplotlib.pyplot as plt
9
+ import numpy as np
10
+ import pandas as pd
11
+ import phicloudmask
12
+ import rasterio as rio
13
+ import requests
14
+ import satalign
15
+ import segmentation_models_pytorch as smp
16
+ import torch
17
+ import torch.nn.functional as F
18
+ from sklearn.linear_model import LinearRegression
19
+ import xarray as xr
20
+
21
+ from satcube.archive_dataclass import Sentinel2
22
+
23
+
24
+ def metadata_s2(
25
+ lon: float,
26
+ lat: float,
27
+ range_date: Tuple[str, str],
28
+ edge_size: int,
29
+ quiet: bool = False,
30
+ ) -> pd.DataFrame:
31
+ """Query the Sentinel-2 image collection.
32
+
33
+ Args:
34
+ lon (float): The longitude of the point.
35
+ lat (float): The latitude of the point.
36
+ range_date (Tuple[str, str]): The range of dates to query.
37
+ edge_size (int): The edge size of the image.
38
+
39
+ Returns:
40
+ pd.DataFrame: The table with the images to download.
41
+ """
42
+
43
+ if not quiet:
44
+ print(f"Querying Sentinel-2 image collection for {lon}, {lat}")
45
+
46
+ # Query the image collection
47
+ table = fastcubo.query_getPixels_imagecollection(
48
+ point=(lon, lat),
49
+ collection="COPERNICUS/S2_HARMONIZED",
50
+ bands=[
51
+ "B1",
52
+ "B2",
53
+ "B3",
54
+ "B4",
55
+ "B5",
56
+ "B6",
57
+ "B7",
58
+ "B8",
59
+ "B8A",
60
+ "B9",
61
+ "B10",
62
+ "B11",
63
+ "B12"
64
+ ], # We need all the bands to run cloud mask algorithms
65
+ data_range=(range_date[0], range_date[1]),
66
+ edge_size=edge_size,
67
+ resolution=10,
68
+ )
69
+
70
+ # Add the cloud cover to the table
71
+ ic_cc = (
72
+ ee.ImageCollection("GOOGLE/CLOUD_SCORE_PLUS/V1/S2_HARMONIZED")
73
+ .filterDate(range_date[0], range_date[1])
74
+ .filterBounds(ee.Geometry.Point(lon, lat))
75
+ )
76
+ ic_cc_pd = ic_cc.getRegion(
77
+ geometry=ee.Geometry.Point(lon, lat),
78
+ scale=edge_size
79
+ ).getInfo()
80
+ ic_cc_pd = pd.DataFrame(ic_cc_pd[1:], columns=ic_cc_pd[0])
81
+ ic_cc_pd["img_id"] = ic_cc_pd["id"].apply(lambda x: "COPERNICUS/S2_HARMONIZED" + "/" + x)
82
+ ic_cc_pd = ic_cc_pd.loc[:, ["img_id", "cs", "cs_cdf"]]
83
+
84
+ # Join the tables remove the right_on column
85
+ table = table.merge(
86
+ ic_cc_pd, left_on="img_id", right_on="img_id", how="right"
87
+ )
88
+
89
+ # Add the MGRS title
90
+ table["mgrs_title"] = table["img_id"].apply(
91
+ lambda x: pathlib.Path(x).stem.split("_")[2]
92
+ )
93
+
94
+ return table
95
+
96
+
97
+ def intermediate_process(
98
+ table: pd.DataFrame,
99
+ out_folder: str,
100
+ process_function: Callable,
101
+ process_function_args: dict,
102
+ force: bool = False,
103
+ ) -> pd.DataFrame:
104
+ """Apply a process to a folder of images.
105
+
106
+ Args:
107
+ table (pd.DataFrame): The table with the images to process.
108
+ out_folder (str): The output folder to save the results.
109
+ process_function (Callable): The function to apply to the images.
110
+ process_function_args (dict): The arguments to pass to the function.
111
+ force (bool, optional): If True, the process is done again.
112
+ Defaults to False.
113
+
114
+ Raises:
115
+ FileNotFoundError: If the input file does not exist.
116
+
117
+ Returns:
118
+ pathlib.Path: The output folder.
119
+ """
120
+ # Reset the table index
121
+ table = table.copy()
122
+ table.reset_index(drop=True, inplace=True)
123
+
124
+ # Create the output csv
125
+ out_table_name = (out_folder.parent) / (out_folder.name + ".csv")
126
+
127
+ # Check if the output file exists
128
+ if (not out_folder.exists()) or force:
129
+
130
+ # Create the output folder if it does not exist
131
+ out_folder.mkdir(parents=True, exist_ok=True)
132
+
133
+ # Apply the cloud removal for each image
134
+ out_table = process_function(
135
+ table=table, out_folder=out_folder, **process_function_args
136
+ )
137
+ out_table.to_csv(out_table_name, index=False)
138
+ else:
139
+ out_table = pd.read_csv(out_table_name)
140
+
141
+ return out_table
142
+
143
+
144
+ def cloudmasking_s2(
145
+ table: pd.DataFrame,
146
+ out_folder: pathlib.Path,
147
+ sensor: Sentinel2,
148
+ device: Union[str, torch.device],
149
+ quiet: bool = False,
150
+ ) -> pd.DataFrame:
151
+ """Generate a cloud mask for a Sentinel-2 dataset.
152
+
153
+ Args:
154
+ in_folder (pathlib.Path): The input folder.
155
+ out_folder (pathlib.Path): The output folder.
156
+ spectral_embedding_weights_path (str): The path to the
157
+ spectral embedding weights.
158
+ cloudmask_weights_path (str): The path to the cloud mask weights.
159
+ device (Union[str, torch.device]): The device to use.
160
+ quiet (bool, optional): If True, the function does not print. Defaults
161
+ to False.
162
+
163
+ Returns:
164
+ pd.DataFrame: The table with the cloud cover.
165
+ """
166
+
167
+ # Define Sentinel-2 descriptor
168
+ s2_descriptor = [
169
+ {
170
+ "name": "B01",
171
+ "band_type": "TOA Reflectance",
172
+ "min_wavelength": 425.0,
173
+ "max_wavelength": 461.0,
174
+ },
175
+ {
176
+ "name": "B02",
177
+ "band_type": "TOA Reflectance",
178
+ "min_wavelength": 446.0,
179
+ "max_wavelength": 542.0,
180
+ },
181
+ {
182
+ "name": "B03",
183
+ "band_type": "TOA Reflectance",
184
+ "min_wavelength": 537.5,
185
+ "max_wavelength": 582.5,
186
+ },
187
+ {
188
+ "name": "B04",
189
+ "band_type": "TOA Reflectance",
190
+ "min_wavelength": 645.5,
191
+ "max_wavelength": 684.5,
192
+ },
193
+ {
194
+ "name": "B05",
195
+ "band_type": "TOA Reflectance",
196
+ "min_wavelength": 694.0,
197
+ "max_wavelength": 714.0,
198
+ },
199
+ {
200
+ "name": "B06",
201
+ "band_type": "TOA Reflectance",
202
+ "min_wavelength": 731.0,
203
+ "max_wavelength": 749.0,
204
+ },
205
+ {
206
+ "name": "B07",
207
+ "band_type": "TOA Reflectance",
208
+ "min_wavelength": 767.0,
209
+ "max_wavelength": 795.0,
210
+ },
211
+ {
212
+ "name": "B08",
213
+ "band_type": "TOA Reflectance",
214
+ "min_wavelength": 763.5,
215
+ "max_wavelength": 904.5,
216
+ },
217
+ {
218
+ "name": "B8A",
219
+ "band_type": "TOA Reflectance",
220
+ "min_wavelength": 847.5,
221
+ "max_wavelength": 880.5,
222
+ },
223
+ {
224
+ "name": "B09",
225
+ "band_type": "TOA Reflectance",
226
+ "min_wavelength": 930.5,
227
+ "max_wavelength": 957.5,
228
+ },
229
+ {
230
+ "name": "B10",
231
+ "band_type": "TOA Reflectance",
232
+ "min_wavelength": 1337.0,
233
+ "max_wavelength": 1413.0,
234
+ },
235
+ {
236
+ "name": "B11",
237
+ "band_type": "TOA Reflectance",
238
+ "min_wavelength": 1541.0,
239
+ "max_wavelength": 1683.0,
240
+ },
241
+ {
242
+ "name": "B12",
243
+ "band_type": "TOA Reflectance",
244
+ "min_wavelength": 2074.0,
245
+ "max_wavelength": 2314.0,
246
+ },
247
+ ]
248
+ cloudsen12_style = {0: 0, 1: 0, 2: 0, 3: 0, 6: 0, 4: 1, 3: 2, 5: 3}
249
+ map_values = lambda x: cloudsen12_style.get(x, x)
250
+
251
+ # Load the weights of the embedding model
252
+ embedding_weights_path = sensor.weight_path / sensor.embedding_universal
253
+ embedding_weights = torch.load(embedding_weights_path)
254
+
255
+ # Load the weights of the cloud mask model 01
256
+ cloudmask01_weights_path = sensor.weight_path / sensor.cloud_model_universal
257
+ cloudmask01_weights = torch.load(cloudmask01_weights_path)
258
+
259
+ # Load the weights of the cloud mask model 02
260
+ cloudmask02_weights_path = sensor.weight_path / sensor.cloud_model_specific
261
+ cloudmask02_weights = torch.load(cloudmask02_weights_path)
262
+
263
+ # Model to device
264
+ segmodel01 = phicloudmask.CloudMask(descriptor=s2_descriptor, device=device)
265
+ segmodel01.embedding_model.load_state_dict(embedding_weights)
266
+ segmodel01.cloud_model.load_state_dict(cloudmask01_weights)
267
+ segmodel01.eval()
268
+ segmodel01.to(device)
269
+
270
+ # Auxiliar model
271
+ segmodel02 = smp.Unet(
272
+ encoder_name="mobilenet_v2", encoder_weights=None, in_channels=13, classes=4
273
+ )
274
+ segmodel02.load_state_dict(cloudmask02_weights)
275
+ segmodel02.eval()
276
+ segmodel02.to(device)
277
+
278
+ # Iterate over all the images
279
+ all_raw_files = [
280
+ path / name for path, name in zip(table["folder"], table["outname"])
281
+ ]
282
+ new_cloud_covers = []
283
+ for idx, file in enumerate(all_raw_files):
284
+ if not quiet:
285
+ print(f"Processing {file.name} [{idx + 1}/{len(all_raw_files)}]")
286
+
287
+ # Read the image
288
+ with rio.open(file) as src:
289
+ s2_raw = src.read()
290
+ metadata = src.profile
291
+ metadata["nodata"] = 65535
292
+ s2_raw_torch = torch.from_numpy(s2_raw[None] / 10000).float().to(device)
293
+
294
+ # Obtain the cloud mask
295
+ with torch.no_grad():
296
+ # Create the mask for the first model
297
+ cloud_probs_all = segmodel01(s2_raw_torch)
298
+ cloud_mask_all = cloud_probs_all.argmax(dim=0).cpu().numpy()
299
+ cloud_4class_all_01 = np.vectorize(map_values)(cloud_mask_all)
300
+
301
+ # Create the mask for the second model
302
+ cloud_probs_all_02 = segmodel02(s2_raw_torch).squeeze()
303
+ cloud_mask_all_02 = cloud_probs_all_02.argmax(dim=0).cpu().numpy()
304
+ cloud_4class_all_02 = cloud_mask_all_02
305
+
306
+ # Combine the two masks
307
+ cloud_4class_all = cloud_4class_all_01 + cloud_4class_all_02
308
+
309
+ # Apply the cloud mask
310
+ s2_clean = (s2_raw + 1) * (cloud_4class_all == 0)
311
+ s2_clean[s2_clean == 0] = 65535
312
+ s2_clean = s2_clean - 1
313
+ s2_clean[s2_clean == 65534] = 65535
314
+
315
+ # If more than 3 bands have zero values, then remove from all the bands
316
+ outmask = (s2_clean == 0).sum(0) > 3
317
+ s2_clean[:, outmask] = 65535
318
+
319
+ # Remove 60 meters bands
320
+ s2_clean = s2_clean[[1, 2, 3, 4, 5, 6, 7, 8, 11, 12]]
321
+
322
+ # Get the cloud cover between 0 and 100
323
+ cc_perc = (cloud_4class_all > 0).sum() / (
324
+ cloud_4class_all.shape[0] * cloud_4class_all.shape[1]
325
+ )
326
+
327
+ # Save the cloud cover
328
+ new_cloud_covers.append(cc_perc)
329
+
330
+ # Save the image
331
+ metadata["count"] = s2_clean.shape[0]
332
+ with rio.open(out_folder / file.name, "w", **metadata) as dst:
333
+ dst.write(s2_clean.astype(rio.uint16))
334
+
335
+ # Update the cloud cover
336
+ table["cloud_cover"] = new_cloud_covers
337
+
338
+ return table
339
+
340
+
341
+ def gapfilling_s2(
342
+ table: pd.DataFrame, out_folder: pathlib.Path, method: str, quiet: bool
343
+ ) -> pd.DataFrame:
344
+ """Remove gaps from a Sentinel-2 dataset.
345
+
346
+ Args:
347
+ table (pd.DataFrame): The table with the images to process.
348
+ out_folder (pathlib.Path): The output folder.
349
+ method (str): The method to fill the gaps.
350
+
351
+ Returns:
352
+ pathlib.Path: The path to the gap filled images.
353
+ """
354
+
355
+ # Load the data to fill the gaps
356
+ all_raw_files = [
357
+ path / name for path, name in zip(table["folder"], table["outname"])
358
+ ]
359
+ all_raw_dates = pd.to_datetime(table["img_date"])
360
+
361
+ match_error = []
362
+ for index, s2_img in enumerate(all_raw_files):
363
+
364
+ # Load the s2 image and mask
365
+ with rio.open(s2_img) as src:
366
+ s2_data = src.read() / 10000
367
+ s2_metadata = src.profile
368
+ s2_data[s2_data == 6.5535] = np.nan
369
+ s2_cloudmask = np.isnan(s2_data).mean(0)
370
+
371
+ if s2_cloudmask.sum() == 0:
372
+ # If there are no gaps, then just copy the image
373
+ if not quiet:
374
+ print(f"Processing {s2_img.name} [{index + 1}/{len(all_raw_files)}]")
375
+ shutil.copy(s2_img, out_folder / s2_img.name)
376
+ match_error.append(0)
377
+ else:
378
+ # closest n images in order to get the reference
379
+ idxs = np.argsort(np.abs(all_raw_dates - all_raw_dates[index]))
380
+
381
+ # Find the most appropriate image to fill the gap
382
+ TOTAL_TRIES = 5
383
+ counter = 0
384
+ for i in idxs:
385
+ # Load the reference image and mask
386
+ with rio.open(all_raw_files[i]) as src:
387
+ s2_data_ref = src.read() / 10000
388
+ s2_data_ref[s2_data_ref == 6.5535] = np.nan
389
+ s2_cloudmask_ref = np.isnan(s2_data_ref) * 1.0
390
+
391
+ # The reference image should have no gap
392
+ condition = np.sum((s2_cloudmask_ref + s2_cloudmask) == 2)
393
+ if condition == 0:
394
+
395
+ # Fill the gap
396
+ # There is three images:
397
+ # the image with gap (image1): This is the image that we want to fill the gap
398
+ # the reference image (image2): This is the image that we will use to fill the gap
399
+ # the final image (image3): The final image with the gap filled
400
+ image1, image2, image3 = s2_data, s2_data_ref, s2_data_ref
401
+ image1 = image1.copy()
402
+ image2 = image2.copy()
403
+ image3 = image3.copy()
404
+
405
+ ## Create a mask with the gaps
406
+ full_mask = ((s2_cloudmask + s2_cloudmask_ref) > 0) * 1.0
407
+
408
+ # mask -> 1: data, nan: gap
409
+ full_mask2 = full_mask.copy()
410
+ full_mask2[full_mask2 == 1] = np.nan
411
+ full_mask2[full_mask2 == 0] = 1
412
+
413
+ ## Mask image1 and image2
414
+ image1_masked = image1 * full_mask2
415
+ image2_masked = image2 * full_mask2
416
+
417
+ ## Apply histogram matching
418
+ new_image3 = np.zeros_like(image3)
419
+ for i in range(image3.shape[0]):
420
+ if method == "histogram_matching":
421
+ new_image3[i] = tripple_histogram_matching(
422
+ image1=image1_masked[i],
423
+ image2=image2_masked[i],
424
+ image3=image3[i],
425
+ )
426
+ elif method == "linear":
427
+ new_image3[i] = linear_interpolation(
428
+ image1=image1_masked[i],
429
+ image2=image2_masked[i],
430
+ image3=image3[i],
431
+ )
432
+
433
+ # Estimate metric (normalized difference)
434
+ a = new_image3[[2, 1, 0]].mean(0)
435
+ b = image1[[2, 1, 0]].mean(0)
436
+ metric = np.nanmean(np.abs(a - b) / (a + b))
437
+
438
+ # Prepare the final image
439
+ new_image3[np.isnan(new_image3)] = 0
440
+ image1[np.isnan(image1)] = 0
441
+ final_image = image1 + new_image3 * full_mask
442
+ final_image[final_image < 0] = 0
443
+
444
+ if counter == 0:
445
+ best_image = final_image
446
+ best_metric = metric
447
+ else:
448
+ if metric < best_metric:
449
+ best_image = final_image
450
+ best_metric = metric
451
+ else:
452
+ continue
453
+
454
+ # Try to fill the gap with the best image in 5 tries
455
+ counter += 1
456
+ if counter == TOTAL_TRIES:
457
+ break
458
+
459
+ # Compare the final_image with the image
460
+ best_image = (best_image * 10000).astype(np.uint16)
461
+ if not quiet:
462
+ print(
463
+ f"Processing {s2_img.name} [{index + 1}/{len(all_raw_files)}] with error {best_metric}"
464
+ )
465
+
466
+ # Save the image
467
+ with rio.open(out_folder / s2_img.name, "w", **s2_metadata) as dst:
468
+ dst.write(best_image)
469
+
470
+ # Save the match error
471
+ match_error.append(best_metric)
472
+
473
+ # Save the match error
474
+ table["match_error"] = match_error
475
+
476
+ return table
477
+
478
+
479
+ def aligned_s2(
480
+ table: pd.DataFrame, out_folder: str, quiet: bool = False
481
+ ) -> pd.DataFrame:
482
+ """Align all the images in the data cube.
483
+
484
+ Args:
485
+ table (pd.DataFrame): The table with the images to align.
486
+ out_folder (str): The output folder.
487
+
488
+ Returns:
489
+ pd.DataFrame: The table with the images aligned.
490
+ """
491
+
492
+ # Load the data to fill the gaps
493
+ all_raw_files = [
494
+ path / name for path, name in zip(table["folder"], table["outname"])
495
+ ]
496
+
497
+ # Create the reference image using the last 10 images
498
+ reference_files = all_raw_files[-10:]
499
+ for idx, file in enumerate(reference_files):
500
+ with rio.open(file) as src:
501
+ s2_mean = src.read() / 10000
502
+ if idx == 0:
503
+ s2_acc = np.zeros_like(s2_mean)
504
+ s2_acc = s2_acc + s2_mean
505
+ s2_mean = s2_acc / len(reference_files)
506
+
507
+ # Iterate over all the images
508
+ align_error = []
509
+ for idx, file in enumerate(all_raw_files):
510
+
511
+ if not quiet:
512
+ print(f"Processing {file.name} [{idx + 1}/{len(all_raw_files)}]")
513
+
514
+ # Load the s2 image and mask
515
+ with rio.open(file) as src:
516
+ s2_data = src.read() / 10000
517
+ metadata = src.profile
518
+
519
+ # Set the alignment model
520
+ syncmodel = satalign.PCC(
521
+ datacube=s2_data[None], # T x C x H x W
522
+ reference=s2_mean, # C x H x W
523
+ upsample_factor=200,
524
+ channel="luminance",
525
+ crop_center=s2_mean.shape[2] // 2,
526
+ )
527
+
528
+ # Run the alignment
529
+ news2cube, warps = syncmodel.run()
530
+
531
+ # Save the aligned image
532
+ news2cube = news2cube * 10000
533
+ with rio.open(out_folder / file.name, "w", **metadata) as dst:
534
+ dst.write(news2cube[0].astype(rio.uint16))
535
+
536
+ # Save the warps after alignment
537
+ error = (warps[0][0, 2], warps[0][1, 2])
538
+ error = np.sqrt(error[0] ** 2 + error[1] ** 2)
539
+ align_error.append(error)
540
+
541
+ # Add the alignment error
542
+ table["align_error"] = align_error
543
+
544
+ return table
545
+
546
+
547
+ def tripple_histogram_matching(
548
+ image1: np.ndarray, image2: np.ndarray, image3: np.ndarray
549
+ ) -> np.ndarray:
550
+ """Apply histogram matching to image3 using image1 and image2 as reference images.
551
+
552
+ Args:
553
+ image1 (np.ndarray): The first reference image.
554
+ image2 (np.ndarray): The second reference image.
555
+ image3 (np.ndarray): The image to be matched.
556
+
557
+ Returns:
558
+ np.ndarray: The matched image.
559
+ """
560
+
561
+ # remove nan values
562
+ image1_nonan = image1.flatten().copy()
563
+ image1_nonan = image1_nonan[~np.isnan(image1_nonan)]
564
+
565
+ image2_nonan = image2.flatten().copy()
566
+ image2_nonan = image2_nonan[~np.isnan(image2_nonan)]
567
+
568
+ image3_nonan = image3.flatten().copy()
569
+ image3_nonan = image3_nonan[~np.isnan(image3_nonan)]
570
+
571
+ # Calculate histograms
572
+ hist1, bins = np.histogram(image1_nonan, 128, [0, 2])
573
+ hist2, bins = np.histogram(image2_nonan, 128, [0, 2])
574
+ hist3, bins = np.histogram(image3_nonan, 128, [0, 2])
575
+
576
+ # Calculate the cumulative distribution function (CDF) of img1
577
+ cdf1 = hist1.cumsum() / hist1.sum()
578
+
579
+ # Calculate the CDF of img2
580
+ cdf2 = hist2.cumsum() / hist2.sum()
581
+
582
+ # Create a lookup table (LUT) to map the pixel values of img1 to img2
583
+ lut = np.interp(cdf2, cdf1, bins[:-1])
584
+
585
+ # Perform histogram matching
586
+ img3_matched = np.interp(image3.ravel(), bins[:-1], lut).reshape(image3.shape)
587
+
588
+ return img3_matched
589
+
590
+
591
+ def linear_interpolation(
592
+ image1: np.ndarray, image2: np.ndarray, image3: np.ndarray
593
+ ) -> np.ndarray:
594
+ """Apply linear interpolation to image3 using image1 and image2 as
595
+ reference images.
596
+
597
+ Args:
598
+ image1 (np.ndarray): The first reference image.
599
+ image2 (np.ndarray): The second reference image.
600
+ image3 (np.ndarray): The image to be matched.
601
+
602
+ Returns:
603
+ np.ndarray: The matched image.
604
+ """
605
+
606
+ # remove nan values
607
+ image1_nonan = image1.flatten().copy()
608
+ image1_nonan = image1_nonan[~np.isnan(image1_nonan)]
609
+
610
+ image2_nonan = image2.flatten().copy()
611
+ image2_nonan = image2_nonan[~np.isnan(image2_nonan)]
612
+
613
+ # Calculate the slope and intercept
614
+ linreg = LinearRegression()
615
+ linreg.fit(image2_nonan[:, np.newaxis], image1_nonan[:, np.newaxis])
616
+ slope = linreg.coef_[0]
617
+ intercept = linreg.intercept_
618
+
619
+ # Apply the linear interpolation
620
+ image3_matched = slope * image3 + intercept
621
+
622
+ return image3_matched
623
+
624
+
625
+ def display_images(
626
+ table: pd.DataFrame,
627
+ out_folder: pathlib.Path,
628
+ bands: List[int],
629
+ ratio: int,
630
+ ):
631
+ """Display a GIF from a dataset.
632
+
633
+ Args:
634
+ interp_file (dict): The dataset to display.
635
+ ratio (int, optional): The ratio to use. Defaults to 3000.
636
+ """
637
+
638
+ # Load the data
639
+ all_raw_files = [
640
+ path / name for path, name in zip(table["folder"], table["outname"])
641
+ ]
642
+ all_raw_files.sort()
643
+
644
+ # Create the GIF from combined_s
645
+ for index, file in enumerate(all_raw_files):
646
+ with rio.open(file) as src:
647
+ data = src.read() / ratio
648
+
649
+ # normalize the data according to the min and max values
650
+ data = np.clip(data, 0, 1)
651
+
652
+ fig, ax = plt.subplots(1, 1, figsize=(12, 12))
653
+ img = np.moveaxis(data[bands, :, :], 0, -1)
654
+ ax.imshow(img)
655
+ ax.axis("off")
656
+ title = "ID: " + str(index) + " - " + file.name
657
+ plt.title(title, fontsize=20)
658
+ plt.savefig(out_folder / f"temp_{index:03d}.png")
659
+ plt.close()
660
+ plt.clf()
661
+
662
+ return None
663
+
664
+
665
+ def load_evoland(
666
+ weights: str,
667
+ device: Union[str, torch.device] = "cpu",
668
+ ) -> Tuple[Any, Any]:
669
+ import onnxruntime as ort
670
+ # ONNX inference session options
671
+ so = ort.SessionOptions()
672
+ so.intra_op_num_threads = 10
673
+ so.inter_op_num_threads = 10
674
+ so.use_deterministic_compute = False
675
+
676
+ # Execute on cpu only
677
+ if device == "cpu":
678
+ ep_list = ["CPUExecutionProvider"]
679
+ elif device == "cuda":
680
+ ep_list = ["CUDAExecutionProvider"]
681
+ else:
682
+ raise ValueError("Invalid device")
683
+
684
+ ort_session = ort.InferenceSession(
685
+ weights,
686
+ sess_options=so,
687
+ providers=ep_list,
688
+ )
689
+ ort_session.set_providers(ep_list)
690
+ ro = ort.RunOptions()
691
+
692
+ return [ort_session, ro]
693
+
694
+
695
+ def super_s2(
696
+ table: pd.DataFrame,
697
+ out_folder: str,
698
+ device: str,
699
+ sensor: Sentinel2,
700
+ quiet: bool = False
701
+ ):
702
+ # super resolution requires the onnxruntime installed
703
+ spec = importlib.util.find_spec("onnxruntime")
704
+ if spec is None:
705
+ raise ImportError(
706
+ "onnxruntime is not installed. Please install it to use super resolution tools."
707
+ )
708
+
709
+ # Define the output file
710
+ ort_session, ro = load_evoland(
711
+ sensor.weight_path / sensor.super_model_specific,
712
+ device=device,
713
+ )
714
+
715
+ # Load the data
716
+ all_raw_files = [
717
+ path / name for path, name in zip(table["folder"], table["outname"])
718
+ ]
719
+
720
+ # Iterate over all the images
721
+ for idx, file in enumerate(all_raw_files):
722
+ if not quiet:
723
+ print(f"Processing {file.name} [{idx + 1}/{len(all_raw_files)}]")
724
+
725
+ # Read the image
726
+ with rio.open(file) as src:
727
+ data = src.read()
728
+ metadata = src.profile
729
+
730
+ # Apply the super resolution
731
+ sr = (
732
+ ort_session.run(
733
+ None, {"input": data[None].astype(np.float32)}, run_options=ro
734
+ )[0]
735
+ .squeeze()
736
+ .astype(np.float16)
737
+ )
738
+
739
+ # Update the metadata
740
+ metadata["width"] = sr.shape[1]
741
+ metadata["height"] = sr.shape[2]
742
+ metadata["transform"] = rio.Affine(
743
+ metadata["transform"].a / 2,
744
+ metadata["transform"].b,
745
+ metadata["transform"].c,
746
+ metadata["transform"].d,
747
+ metadata["transform"].e / 2,
748
+ metadata["transform"].f,
749
+ )
750
+
751
+ # Save the image
752
+ with rio.open(out_folder / file.name, "w", **metadata) as dst:
753
+ dst.write(sr.astype(rio.uint16))
754
+
755
+ return table
756
+
757
+
758
+ def monthly_composites_s2(
759
+ table: pd.DataFrame,
760
+ out_folder: pathlib.Path,
761
+ date_range: Tuple[str, str],
762
+ agg_method: str,
763
+ quiet: bool = False
764
+ ):
765
+
766
+ # Define the folder path using pathlib
767
+ all_raw_files = [
768
+ path / name for path, name in zip(table["folder"], table["outname"])
769
+ ]
770
+
771
+ # Load the first image to get the metadata
772
+ with rio.open(all_raw_files[0]) as src:
773
+ metadata = src.profile
774
+
775
+ # Prepare the metadata
776
+ all_raw_dates = pd.to_datetime(table["img_date"])
777
+ all_raw_date_min = pd.to_datetime(date_range[0])
778
+ all_raw_date_max = pd.to_datetime(date_range[1])
779
+ all_raw_dates_unique = pd.date_range(
780
+ all_raw_date_min, all_raw_date_max, freq="MS"
781
+ ) + pd.DateOffset(days=14)
782
+ all_raw_dates_unique = all_raw_dates_unique.strftime("%Y-%m-15")
783
+
784
+ # Aggregate the data considering the method and dates
785
+ new_table = []
786
+ for idx, date in enumerate(all_raw_dates_unique):
787
+ if not quiet:
788
+ print(f"Processing {date} [{idx + 1}/{len(all_raw_dates_unique)}]")
789
+
790
+ # Get the images to aggregate
791
+ idxs = all_raw_dates.dt.strftime("%Y-%m-15") == date
792
+ images = [all_raw_files[i] for i in np.where(idxs)[0]]
793
+
794
+ if len(images) == 0:
795
+ data = np.ones((metadata["count"], metadata["height"], metadata["width"]))
796
+ data = 65535 * data
797
+ nodata = 1
798
+ else:
799
+ # Read the images
800
+ container = []
801
+ for image in images:
802
+ with rio.open(image) as src:
803
+ data = src.read()
804
+ metadata = src.profile
805
+ container.append(data)
806
+
807
+ # Aggregate the data
808
+ if agg_method == "mean":
809
+ data = np.mean(container, axis=0)
810
+ elif agg_method == "median":
811
+ data = np.median(container, axis=0)
812
+ elif agg_method == "max":
813
+ data = np.max(container, axis=0)
814
+ elif agg_method == "min":
815
+ data = np.min(container, axis=0)
816
+ else:
817
+ raise ValueError("Invalid aggregation method")
818
+
819
+ nodata = 0
820
+
821
+ # Save the image
822
+ with rio.open(out_folder / f"{date}.tif", "w", **metadata) as dst:
823
+ dst.write(data.astype(rio.uint16))
824
+
825
+ meta_dict = {
826
+ "img_date": date,
827
+ "folder": out_folder,
828
+ "outname": f"{date}.tif",
829
+ "nodata": nodata,
830
+ }
831
+
832
+ new_table.append(meta_dict)
833
+
834
+ return pd.DataFrame(new_table)
835
+
836
+
837
+ def interpolate_s2(
838
+ table: pd.DataFrame,
839
+ out_folder: pathlib.Path,
840
+ quiet: bool = False
841
+ ) -> pd.DataFrame:
842
+ """Interpolate the missing values in a dataset.
843
+
844
+ Args:
845
+ table (pd.DataFrame): The table with the images to interpolate
846
+ out_folder (pathlib.Path): The output folder.
847
+ smooth_w (int, optional): The window length for the savgol
848
+ filter. Defaults to 3.
849
+ smooth_p (int, optional): The polynomial order for the savgol
850
+ filter. Defaults to 1.
851
+
852
+ Returns:
853
+ pd.DataFrame: The table with the images interpolated.
854
+ """
855
+ if not quiet:
856
+ print("Interpolating the missing values started...")
857
+
858
+ # Load the data
859
+ all_raw_files = [
860
+ path / name for path, name in zip(table["folder"], table["outname"])
861
+ ]
862
+ all_raw_files.sort()
863
+ all_raw_dates = pd.to_datetime(table["img_date"])
864
+
865
+ # Create a datacube
866
+ metadata = rio.open(all_raw_files[0]).profile
867
+ data_np = np.array([rio.open(file).read() for file in all_raw_files]) / 10000
868
+ data_np[data_np==6.5535] = np.nan
869
+ data_np = xr.DataArray(
870
+ data=data_np,
871
+ dims=("time", "band", "y", "x"),
872
+ coords={"time": all_raw_dates, "band": range(10)},
873
+ )
874
+
875
+ # Interpolate the missing values
876
+ data_np = data_np.interpolate_na(dim="time", method="linear")
877
+
878
+ # Save the images
879
+ for idx, file in enumerate(all_raw_files):
880
+ current_data = data_np[idx].values
881
+ date = pd.to_datetime(table["img_date"].iloc[idx]).strftime("%Y-%m-%d")
882
+ with rio.open(out_folder / f"{date}.tif", "w", **metadata) as dst:
883
+ dst.write((current_data * 10000).astype(np.uint16))
884
+
885
+ return table
886
+
887
+
888
+ def smooth_s2(
889
+ table: pd.DataFrame,
890
+ out_folder: pathlib.Path,
891
+ smooth_w: int,
892
+ smooth_p: int,
893
+ device: Union[str, torch.device],
894
+ quiet: bool
895
+ ) -> pd.DataFrame:
896
+ """Interpolate the missing values in a dataset.
897
+
898
+ Args:
899
+ table (pd.DataFrame): The table with the images to interpolate
900
+ out_folder (pathlib.Path): The output folder.
901
+ smooth_w (int, optional): The window length for the savgol
902
+ filter. Defaults to 3.
903
+ smooth_p (int, optional): The polynomial order for the savgol
904
+ filter. Defaults to 1.
905
+
906
+ Returns:
907
+ pd.DataFrame: The table with the images interpolated.
908
+ """
909
+
910
+ if not quiet:
911
+ print("Smoothing the values started...")
912
+
913
+ # Load the data
914
+ all_raw_files = [
915
+ path / name for path, name in zip(table["folder"], table["outname"])
916
+ ]
917
+ out_files = [out_folder / file.name for file in all_raw_files]
918
+
919
+ # Create a datacube
920
+ metadata = rio.open(all_raw_files[0]).profile
921
+ data_np = (np.array([rio.open(file).read() for file in all_raw_files]) / 10000).astype(np.float32)
922
+
923
+ # Create monthly composites
924
+ data_month = pd.to_datetime(table["img_date"]).dt.month
925
+ data_clim = []
926
+ for month in range(1, 13):
927
+ data_clim.append(data_np[data_month == month].mean(axis=0) / 10000)
928
+ data_clim = np.array(data_clim)
929
+
930
+ # Create the residuals
931
+ for idx, month in enumerate(data_month):
932
+ data_np[idx] = data_np[idx] - data_clim[month - 1]
933
+
934
+ # Smooth the residuals
935
+ data_np = torch.from_numpy(data_np).float().to(device)
936
+ try:
937
+ data_np = (
938
+ gaussian_smooth(
939
+ data_np, kernel_size=smooth_w, sigma=smooth_p
940
+ ).cpu().numpy()
941
+ )
942
+ except Exception as e:
943
+ print(e)
944
+ data_np = data_np.cpu().numpy()
945
+
946
+ # add the residuals to the climatology
947
+ for idx, month in enumerate(data_month):
948
+ data_np[idx] = data_np[idx] + data_clim[month - 1]
949
+
950
+ for idx, file in enumerate(out_files):
951
+ with rio.open(file, "w", **metadata) as dst:
952
+ dst.write((data_np[idx] * 10000).astype(np.uint16))
953
+
954
+ # Prepare the new table
955
+ new_table = pd.DataFrame(
956
+ {
957
+ "img_date": table["img_date"],
958
+ "outname": table["outname"],
959
+ }
960
+ )
961
+
962
+ return new_table
963
+
964
+
965
+
966
+ def gaussian_kernel1d(kernel_size: int, sigma: float):
967
+ """
968
+ Returns a 1D Gaussian kernel.
969
+ """
970
+ # Create a tensor with evenly spaced values centered at 0
971
+ x = torch.linspace(-(kernel_size // 2), kernel_size // 2, kernel_size)
972
+ # Calculate the Gaussian function
973
+ kernel = torch.exp(-(x**2) / (2 * sigma**2))
974
+ # Normalize the kernel to ensure the sum of all elements is 1
975
+ kernel = kernel / kernel.sum()
976
+ return kernel
977
+
978
+
979
+ def gaussian_smooth(tensor, kernel_size: int, sigma: float):
980
+ """
981
+ Apply Gaussian smoothing on the time dimension (first dimension) of the input tensor.
982
+
983
+ Args:
984
+ - tensor (torch.Tensor): Input tensor of shape (T, C, H, W) where T is the time dimension.
985
+ - kernel_size (int): Size of the Gaussian kernel.
986
+ - sigma (float): Standard deviation of the Gaussian kernel.
987
+
988
+ Returns:
989
+ - smoothed_tensor (torch.Tensor): Smoothed tensor.
990
+ """
991
+ # Get the Gaussian kernel
992
+ kernel = gaussian_kernel1d(kernel_size, sigma).to(tensor.device).view(1, 1, -1)
993
+
994
+ # Prepare the tensor for convolution: (B, C, T) where B = C*H*W, C=1, T=102
995
+ T, C, H, W = tensor.shape
996
+ tensor = tensor.view(T, -1).permute(1, 0).unsqueeze(1) # Shape: (C*H*W, 1, T)
997
+
998
+ # Apply convolution
999
+ padding = kernel_size // 2
1000
+ smoothed = F.conv1d(tensor, kernel, padding=padding, groups=1)
1001
+
1002
+ # Reshape back to original shape
1003
+ smoothed = smoothed.squeeze(1).permute(1, 0).view(T, C, H, W)
1004
+
1005
+ return smoothed
1006
+
1007
+
1008
+ def monthly_calendar(df, year1, year2, data_char="X", no_data_char="."):
1009
+
1010
+ # Initialize the matrix for the calendar
1011
+ df = df.copy()
1012
+ df["img_date"] = pd.to_datetime(df["img_date"])
1013
+
1014
+ years = range(year1, year2 + 1)
1015
+ months = range(1, 13)
1016
+
1017
+ # Create a dictionary to count the months with data
1018
+ year_month_data = {year: [no_data_char] * 12 for year in years}
1019
+
1020
+ # Populate the dictionary with data
1021
+ for year in years:
1022
+ for month in months:
1023
+ if not df[
1024
+ (df["img_date"].dt.year == year) & (df["img_date"].dt.month == month)
1025
+ ].empty:
1026
+ year_month_data[year][month - 1] = data_char
1027
+
1028
+ # Print the matrix month as numbers
1029
+ print("Y/M".center(10) + " ".join([str(x) for x in range(1, 13)]))
1030
+ for year in years:
1031
+ print(f"{year:<10}", end="")
1032
+ for month in year_month_data[year]:
1033
+ print(f" {month}", end="")
1034
+ print()
1035
+
1036
+
1037
+ def download_weights(
1038
+ path: Union[str, pathlib.Path],
1039
+ quiet: bool = False,
1040
+ ) -> pathlib.Path:
1041
+ """This function downloads the weights for the models.
1042
+
1043
+ Args:
1044
+ path (Union[str, pathlib.Path]): The path to save the weights.
1045
+ quiet (bool, optional): If True, the function will not print
1046
+ the progress. Defaults to False.
1047
+
1048
+ Returns:
1049
+ pathlib.Path: The path to the weights.
1050
+ """
1051
+
1052
+ if not quiet:
1053
+ print("Downloading the satcube weights...")
1054
+
1055
+
1056
+ URI = "https://github.com/JulioContrerasH/satcube/releases/download/weights-v1.0/"
1057
+ path = pathlib.Path(path)
1058
+ path.mkdir(parents=True, exist_ok=True)
1059
+
1060
+ # Download the weights
1061
+ s2_cloud_model_specific = "s2_cloud_model_specific.pt"
1062
+ s2_cloud_model_universal = "s2_cloud_model_universal.pt"
1063
+ s2_embedding_model_universal = "s2_embedding_model_universal.pt"
1064
+ s2_super_model_specific = "s2_super_model_specific.pt"
1065
+
1066
+ # Download the weights
1067
+ if not (path / s2_cloud_model_specific).exists():
1068
+ with requests.get(URI + s2_cloud_model_specific) as r:
1069
+ with open(path / s2_cloud_model_specific, "wb") as f:
1070
+ f.write(r.content)
1071
+
1072
+ if not (path / s2_cloud_model_universal).exists():
1073
+ with requests.get(URI + s2_cloud_model_universal, stream=True) as r:
1074
+ with open(path / s2_cloud_model_universal, "wb") as f:
1075
+ f.write(r.content)
1076
+
1077
+ if not (path / s2_embedding_model_universal).exists():
1078
+ with requests.get(URI + s2_embedding_model_universal, stream=True) as r:
1079
+ with open(path / s2_embedding_model_universal, "wb") as f:
1080
+ f.write(r.content)
1081
+
1082
+ if not (path / s2_super_model_specific).exists():
1083
+ with requests.get(URI + s2_super_model_specific, stream=True) as r:
1084
+ with open(path / s2_super_model_specific, "wb") as f:
1085
+ f.write(r.content)
1086
+
1087
+ return path