satcube 0.1.0__py3-none-any.whl → 0.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of satcube might be problematic. Click here for more details.

@@ -1,24 +1,160 @@
1
+ """Predict cloud masks for Sentinel-2 GeoTIFFs with the SEN2CloudEnsemble model.
2
+
3
+ The callable :pyfunc:`cloud_masking` accepts **either** a single ``.tif`` file
4
+ or a directory tree; in both cases it writes a masked copy of every image (and,
5
+ optionally, the binary mask) to *output*.
6
+
7
+ Example
8
+ -------
9
+ >>> from satcube.cloud_detection import cloud_masking
10
+ >>> cloud_masking("~/s2/input", "~/s2/output", device="cuda")
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import time
16
+ from pathlib import Path
17
+ from typing import List
18
+
19
+ import mlstac
20
+ import numpy as np
21
+ import rasterio as rio
1
22
  import torch
2
23
 
3
- class LandsatCloudDetector(torch.nn.Module):
4
- def __init__(self):
5
- super().__init__()
24
+ from satcube.utils import DeviceManager, _reset_gpu
25
+
26
+
27
+ def cloud_masking(
28
+ input: str | Path, # noqa: A002 (shadowing built-in is OK here)
29
+ output: str | Path,
30
+ *,
31
+ tile: int = 512,
32
+ pad: int = 64,
33
+ save_mask: bool = False,
34
+ device: str = "cpu",
35
+ max_pix_cpu: float = 7.0e7,
36
+ ) -> List[Path]:
37
+ """Write cloud-masked Sentinel-2 images.
38
+
39
+ Parameters
40
+ ----------
41
+ input
42
+ Path to a single ``.tif`` file **or** a directory containing them.
43
+ output
44
+ Destination directory (created if missing).
45
+ tile, pad
46
+ Tile size and padding (pixels) when tiling is required.
47
+ save_mask
48
+ If *True*, store the binary mask alongside the masked image.
49
+ device
50
+ Torch device for inference, e.g. ``"cpu"`` or ``"cuda:0"``.
51
+ max_pix_cpu
52
+ Tile images larger than this when running on CPU.
53
+
54
+ Returns
55
+ -------
56
+ list[pathlib.Path]
57
+ Paths to the generated masked images.
58
+ """
59
+ t_start = time.perf_counter()
60
+
61
+ src = Path(input).expanduser().resolve()
62
+ dst_dir = Path(output).expanduser().resolve()
63
+ dst_dir.mkdir(parents=True, exist_ok=True)
64
+
65
+ # Collect files to process -------------------------------------------------
66
+ tif_paths: list[Path]
67
+ if src.is_dir():
68
+ tif_paths = [p for p in src.rglob("*.tif")]
69
+ elif src.is_file() and src.suffix.lower() == ".tif":
70
+ tif_paths = [src]
71
+ src = src.parent # for relative-path bookkeeping below
72
+ else:
73
+ raise ValueError(f"Input must be a .tif or directory, got: {src}")
74
+
75
+ if not tif_paths:
76
+ print(f"[cloud_masking] No .tif files found in {src}")
77
+ return []
78
+
79
+ experiment = mlstac.load("SEN2CloudEnsemble")
80
+ dm = DeviceManager(experiment, init_device=device)
81
+
82
+ masked_paths: list[Path] = []
83
+
84
+ # -------------------------------------------------------------------------
85
+ for idx, tif_path in enumerate(tif_paths, 1):
86
+ rel = tif_path.relative_to(src)
87
+ out_dir = dst_dir / rel.parent
88
+ out_dir.mkdir(parents=True, exist_ok=True)
89
+
90
+ mask_path = out_dir / f"{tif_path.stem}_cloudmask.tif"
91
+ masked_path = out_dir / f"{tif_path.stem}_masked.tif"
92
+
93
+ with rio.open(tif_path) as src_img:
94
+ profile = src_img.profile
95
+ h, w = src_img.height, src_img.width
96
+
97
+ mask_prof = profile.copy()
98
+ mask_prof.update(driver="GTiff", count=1, dtype="uint8", nodata=255)
99
+
100
+ do_tiling = (dm.device == "cuda") or (h * w > max_pix_cpu)
101
+ full_mask = np.full((h, w), 255, np.uint8)
102
+
103
+ t0 = time.perf_counter()
104
+
105
+ # ----------------------- inference -----------------------------------
106
+ if not do_tiling: # full frame
107
+ with rio.open(tif_path) as src_img, torch.inference_mode():
108
+ img = src_img.read().astype(np.float32) / 1e4
109
+ h32, w32 = (h + 31) // 32 * 32, (w + 31) // 32 * 32
110
+ pad_b, pad_r = h32 - h, w32 - w
111
+ tensor = torch.from_numpy(img).unsqueeze(0)
112
+ if pad_b or pad_r:
113
+ tensor = torch.nn.functional.pad(tensor, (0, pad_r, 0, pad_b))
114
+ mask = dm.model(tensor.to(dm.device)).squeeze(0)
115
+ full_mask[:] = mask[..., :h, :w].cpu().numpy().astype(np.uint8)
116
+ else: # tiled
117
+ with rio.open(tif_path) as src_img, torch.inference_mode():
118
+ for y0 in range(0, h, tile):
119
+ for x0 in range(0, w, tile):
120
+ y0r, x0r = max(0, y0 - pad), max(0, x0 - pad)
121
+ y1r, x1r = min(h, y0 + tile + pad), min(w, x0 + tile + pad)
122
+ win = rio.windows.Window(x0r, y0r, x1r - x0r, y1r - y0r)
123
+
124
+ patch = src_img.read(window=win).astype(np.float32) / 1e4
125
+ tensor = torch.from_numpy(patch).unsqueeze(0).to(dm.device)
126
+ mask = dm.model(tensor).squeeze(0).cpu().numpy().astype(np.uint8)
127
+
128
+ y_in0 = pad if y0r else 0
129
+ x_in0 = pad if x0r else 0
130
+ y_in1 = mask.shape[0] - (pad if y1r < h else 0)
131
+ x_in1 = mask.shape[1] - (pad if x1r < w else 0)
132
+ core = mask[y_in0:y_in1, x_in0:x_in1]
133
+ full_mask[y0 : y0 + core.shape[0], x0 : x0 + core.shape[1]] = core
134
+
135
+ # ----------------------- output --------------------------------------
136
+ if save_mask:
137
+ with rio.open(mask_path, "w", **mask_prof) as dst:
138
+ dst.write(full_mask, 1)
139
+
140
+ with rio.open(tif_path) as src_img:
141
+ data = src_img.read()
142
+ img_prof = src_img.profile.copy()
143
+
144
+ masked = data.copy()
145
+ masked[:, full_mask != 0] = 65535
146
+ img_prof.update(dtype="uint16", nodata=65535)
6
147
 
7
- def forward(self, x: torch.Tensor) -> torch.Tensor:
8
- # Define bit flags for clouds based on the
9
- # Landsat QA band documentation
10
- cloud_flags = (1 << 3) | (1 << 4) | (1 << 1)
148
+ with rio.open(masked_path, "w", **img_prof) as dst:
149
+ dst.write(masked)
11
150
 
12
- ## Get the QA band
13
- qa_band = x[6]
14
- mask_band = x[:6].mean(axis=0)
15
- mask_band[~torch.isnan(mask_band)] = 1
151
+ masked_paths.append(masked_path)
152
+ dt = time.perf_counter() - t0
153
+ print(f"[{idx}/{len(tif_paths)}] {rel} → done in {dt:.1f}s")
16
154
 
17
- ## Create a cloud mask
18
- cloud_mask = torch.bitwise_and(qa_band.int(), cloud_flags) == 0
19
- cloud_mask = cloud_mask.float()
20
- cloud_mask[cloud_mask == 0] = torch.nan
21
- cloud_mask[cloud_mask == 0] = 1
22
- final_mask = cloud_mask * mask_band
23
- return final_mask
155
+ if dm.device == "cuda":
156
+ _reset_gpu()
24
157
 
158
+ total_time = time.perf_counter() - t_start
159
+ print(f"Processed {len(masked_paths)} image(s) in {total_time:.1f}s.")
160
+ return masked_paths
@@ -0,0 +1,24 @@
1
+ import torch
2
+
3
+ class LandsatCloudDetector(torch.nn.Module):
4
+ def __init__(self):
5
+ super().__init__()
6
+
7
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
8
+ # Define bit flags for clouds based on the
9
+ # Landsat QA band documentation
10
+ cloud_flags = (1 << 3) | (1 << 4) | (1 << 1)
11
+
12
+ ## Get the QA band
13
+ qa_band = x[6]
14
+ mask_band = x[:6].mean(axis=0)
15
+ mask_band[~torch.isnan(mask_band)] = 1
16
+
17
+ ## Create a cloud mask
18
+ cloud_mask = torch.bitwise_and(qa_band.int(), cloud_flags) == 0
19
+ cloud_mask = cloud_mask.float()
20
+ cloud_mask[cloud_mask == 0] = torch.nan
21
+ cloud_mask[cloud_mask == 0] = 1
22
+ final_mask = cloud_mask * mask_band
23
+ return final_mask
24
+
satcube/download.py ADDED
@@ -0,0 +1,65 @@
1
+ import pathlib
2
+ import ee
3
+ import cubexpress
4
+ import pandas as pd
5
+
6
+
7
+ def download_data(
8
+ *, # keyword-only
9
+ lon: float,
10
+ lat: float,
11
+ cloud_max: int = 40,
12
+ edge_size: int = 2_048,
13
+ start: str,
14
+ end: str,
15
+ output: str = "raw",
16
+ scale: int = 10,
17
+ nworks: int = 4,
18
+ mosaic: bool = True,
19
+ auto_init_gee: bool = True,
20
+ ) -> pd.DataFrame:
21
+ """
22
+ Download a Sentinel cube for (lon, lat) and return its metadata.
23
+
24
+ Parameters
25
+ ----------
26
+ lon, lat Center point in degrees.
27
+ cloud_max Max cloud cover (%).
28
+ edge_size Square side length (m).
29
+ start, end YYYY-MM-DD date range.
30
+ output Folder for GeoTIFFs.
31
+ scale Pixel size (m).
32
+ nworks Parallel workers.
33
+ mosaic Merge scenes per date.
34
+ auto_init_gee Call ee.Initialize() if needed.
35
+
36
+ Returns
37
+ -------
38
+ pandas.DataFrame
39
+ Scene catalogue used for the request.
40
+ """
41
+ # EE ready
42
+ if auto_init_gee:
43
+ try:
44
+ ee.Initialize()
45
+ except ee.EEException:
46
+ ee.Authenticate(); ee.Initialize()
47
+
48
+ # Filter scenes
49
+ df = cubexpress.cloud_table(
50
+ lon=lon,
51
+ lat=lat,
52
+ edge_size=edge_size,
53
+ scale=scale,
54
+ cloud_max=cloud_max,
55
+ start=start,
56
+ end=end,
57
+ )
58
+
59
+ # Build requests + ensure dir
60
+ requests = cubexpress.table_to_requestset(df, mosaic=mosaic)
61
+ pathlib.Path(output).mkdir(parents=True, exist_ok=True)
62
+
63
+ # Download cube
64
+ cubexpress.get_cube(requests, output, nworks)
65
+ return df
@@ -0,0 +1,82 @@
1
+ import ee
2
+ import cubexpress
3
+ import pathlib
4
+ from typing import Optional
5
+ from datetime import datetime
6
+
7
+ def download_data(
8
+ lon: float,
9
+ lat: float,
10
+ cs_cdf: Optional[float] = 0.6,
11
+ buffer_size: Optional[int] = 1280,
12
+ start_date: Optional[str] = "2015-01-01",
13
+ end_date: Optional[str] = datetime.today().strftime('%Y-%m-%d'),
14
+ outfolder: Optional[str] = "raw/"
15
+ ) -> pathlib.Path:
16
+ """
17
+ Download Sentinel-2 imagery data using cubexpress and Earth Engine API.
18
+
19
+ Args:
20
+ lon (float): Longitude of the point of interest.
21
+ lat (float): Latitude of the point of interest.
22
+ cs_cdf (Optional[float]): Cloud mask threshold (default 0.6).
23
+ buffer_size (Optional[int]): Buffer size for image extraction (default 1280).
24
+ start_date (Optional[str]): Start date for image filtering (default "2015-01-01").
25
+ end_date (Optional[str]): End date for image filtering (default today’s date).
26
+ outfolder (Optional[str]): Output folder to save images (default "raw/").
27
+
28
+ Returns:
29
+ pathlib.Path: Path to the folder where the data is stored.
30
+ """
31
+
32
+ # Initialize Earth Engine
33
+ ee.Initialize(project="ee-julius013199")
34
+
35
+ # Define point of interest
36
+ point = ee.Geometry.Point([lon, lat])
37
+
38
+ # Filter image collection by location and date
39
+ collection = ee.ImageCollection("COPERNICUS/S2_SR_HARMONIZED") \
40
+ .filterBounds(point) \
41
+ .filterDate(start_date, end_date)
42
+
43
+ # Get image IDs
44
+ image_ids = collection.aggregate_array('system:id').getInfo()
45
+
46
+ # Cloud mask function
47
+ def cloud_mask(image) -> ee.Image:
48
+ """Apply cloud mask to the image."""
49
+ return image.select('MSK_CLDPRB').lt(20)
50
+
51
+ # Apply cloud mask
52
+ collection = collection.map(cloud_mask)
53
+
54
+ # Generate geotransform for cubexpress
55
+ geotransform = cubexpress.lonlat2rt(lon=lon, lat=lat, edge_size=buffer_size, scale=10)
56
+
57
+ # Prepare requests for cubexpress
58
+ requests = [
59
+ cubexpress.Request(
60
+ id=f"s2test_{i}",
61
+ raster_transform=geotransform,
62
+ bands=["B4", "B3", "B2"], # RGB bands
63
+ image=ee.Image(image_id).divide(10000) # Adjust image scaling
64
+ )
65
+ for i, image_id in enumerate(image_ids)
66
+ ]
67
+
68
+ # Create request set
69
+ cube_requests = cubexpress.RequestSet(requestset=requests)
70
+
71
+ # Set output folder
72
+ output_path = pathlib.Path(outfolder)
73
+
74
+ # Download the data
75
+ cubexpress.getcube(
76
+ request=cube_requests,
77
+ output_path=output_path,
78
+ nworkers=4,
79
+ max_deep_level=5
80
+ )
81
+
82
+ return output_path
satcube/main.py CHANGED
@@ -7,7 +7,7 @@ import pandas as pd
7
7
  import torch
8
8
 
9
9
  from satcube.dataclass import Sensor
10
- from satcube.utils import (aligned_s2, cloudmasking_s2, display_images,
10
+ from satcube.utils_old import (aligned_s2, cloudmasking_s2, display_images,
11
11
  gapfilling_s2, intermediate_process, interpolate_s2,
12
12
  metadata_s2, monthly_composites_s2, smooth_s2, super_s2)
13
13
 
@@ -252,7 +252,7 @@ class SatCube:
252
252
  out_table["folder"] = out_folder
253
253
 
254
254
  return out_table
255
-
255
+
256
256
  def monthly_composites_s2(
257
257
  self,
258
258
  table: Optional[pd.DataFrame],