satcube 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of satcube might be problematic. Click here for more details.

satcube/__init__.py CHANGED
@@ -1,9 +1,7 @@
1
- from satcube.cloud_detection import cloud_masking
2
1
  from satcube.download import download
3
- from satcube.align import align
4
- import importlib.metadata
5
2
  from satcube.objects import SatCubeMetadata
3
+ import importlib.metadata
6
4
 
7
- __all__ = ["cloud_masking", "download", "align", "SatCubeMetadata"]
5
+ __all__ = ["download", "SatCubeMetadata"]
8
6
  # __version__ = importlib.metadata.version("satcube")
9
7
 
satcube/align.py CHANGED
@@ -1,73 +1,91 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import pathlib
4
- from typing import List, Tuple
5
- import pickle
4
+ from typing import Tuple
6
5
  import pandas as pd
7
6
  import satalign
8
- import shutil
9
-
10
7
  import numpy as np
11
8
  import rasterio as rio
12
- import xarray as xr
13
- from affine import Affine
14
9
  from concurrent.futures import ThreadPoolExecutor, as_completed
15
10
  from tqdm import tqdm
16
11
 
12
+ def _process_image(
13
+ image: np.ndarray,
14
+ reference: np.ndarray,
15
+ profile: dict,
16
+ output_path: pathlib.Path,
17
+ ) -> Tuple[float, float]:
18
+
19
+ image_float = image.astype(np.float32) / 10000
20
+ image_float = image_float[np.newaxis, ...]
21
+
22
+ image, M = satalign.LGM(
23
+ datacube=image_float,
24
+ reference=reference
25
+ ).run_multicore()
26
+
27
+ image = (image * 10000).astype(np.uint16).squeeze()
17
28
 
18
- def process_row(row: pd.Series, reference: np.ndarray, input_dir: pathlib.Path, output_dir: pathlib.Path) -> None:
29
+ with rio.open(output_path, "w", **profile) as dst:
30
+ dst.write(image)
31
+
32
+ return M[0][0, 2], M[0][1, 2]
33
+
34
+ def _process_row(
35
+ row: pd.Series,
36
+ reference: np.ndarray,
37
+ input_dir: pathlib.Path,
38
+ output_dir: pathlib.Path
39
+ ) -> Tuple[str, float, float]:
40
+
19
41
  row_path = input_dir / (row["id"] + ".tif")
20
42
  output_path = output_dir / (row["id"] + ".tif")
21
- with rio.open(row_path) as src:
22
- row_image = src.read()
23
- profile_image = src.profile
24
43
 
25
- row_image_float = row_image.astype(np.float32) / 10000
26
- row_image_float = row_image_float[np.newaxis, ...]
44
+ with rio.open(row_path) as src:
45
+ image = src.read()
46
+ profile = src.profile
27
47
 
28
- pcc_model = satalign.LGM(
29
- datacube = row_image_float,
30
- reference = reference
48
+ dx_px, dy_px = _process_image(
49
+ image=image,
50
+ reference=reference,
51
+ profile=profile,
52
+ output_path=output_path
31
53
  )
32
- image, _ = pcc_model.run_multicore()
33
- image = (image * 10000).astype(np.uint16).squeeze()
34
-
35
- with rio.open(output_path, "w", **profile_image) as dst:
36
- dst.write(image)
54
+
55
+ return row["id"], dx_px, dy_px
37
56
 
38
- def align(
57
+ def align_fn(
58
+ metadata: pd.DataFrame | None = None,
39
59
  input_dir: str | pathlib.Path = "raw",
40
60
  output_dir: str | pathlib.Path = "aligned",
41
61
  nworks: int = 4,
42
62
  cache: bool = False
43
- ) -> None:
63
+ ) -> pd.DataFrame | None:
44
64
 
45
65
  input_dir = pathlib.Path(input_dir).expanduser().resolve()
46
66
  output_dir = pathlib.Path(output_dir).expanduser().resolve()
47
67
  output_dir.mkdir(parents=True, exist_ok=True)
48
-
49
- metadata_path = input_dir / "metadata.csv"
50
-
51
- if not metadata_path.exists():
68
+
69
+ if metadata is None:
52
70
  raise FileNotFoundError(
53
- f"Metadata file not found: {metadata_path}. "
71
+ f"Add metadata file to do alignment."
54
72
  "Please run the download step first."
55
73
  )
56
- else:
57
- metadata = pd.read_csv(metadata_path)
58
-
59
- if cache:
60
- exist_files = [file.stem for file in output_dir.glob("*.tif")]
61
- metadata = metadata[~metadata["id"].isin(exist_files)]
62
74
 
63
- if metadata.empty:
64
- return
65
75
 
66
76
  id_reference = metadata.sort_values(
67
- by=["cs_cdf", "date"],
77
+ by=["cs_cdf"],
68
78
  ascending=False,
69
79
  ).iloc[0]["id"]
70
80
 
81
+ df = metadata.copy()
82
+
83
+ if cache:
84
+ exist_files = [file.stem for file in output_dir.glob("*.tif")]
85
+ df = df[~df["id"].isin(exist_files)]
86
+ if df.empty:
87
+ return metadata
88
+
71
89
  reference_path = input_dir / (id_reference + ".tif")
72
90
 
73
91
  with rio.open(reference_path) as ref_src:
@@ -75,11 +93,20 @@ def align(
75
93
 
76
94
  reference_float = reference.astype(np.float32) / 10000
77
95
 
96
+ results = []
97
+
78
98
  with ThreadPoolExecutor(max_workers=nworks) as executor:
79
99
  futures = {
80
- executor.submit(process_row, row, reference_float, input_dir, output_dir)
81
- for _, row in metadata.iterrows()
100
+ executor.submit(
101
+ _process_row,
102
+ row=row,
103
+ reference=reference_float,
104
+ input_dir=input_dir,
105
+ output_dir=output_dir
106
+ ): row["id"]
107
+ for _, row in df.iterrows()
82
108
  }
109
+
83
110
  for future in tqdm(
84
111
  as_completed(futures),
85
112
  total=len(futures),
@@ -88,11 +115,25 @@ def align(
88
115
  leave=True
89
116
  ):
90
117
  try:
91
- future.result()
118
+ img_id, dx_px, dy_px = future.result()
119
+ results.append({"id": img_id,
120
+ "dx_px": dx_px,
121
+ "dy_px": dy_px})
92
122
  except Exception as e:
93
- print(f"Error processing image: {e}")
94
-
95
- metadata = input_dir / "metadata.csv"
96
- if metadata.exists():
97
- metadata_dst = output_dir / "metadata.csv"
98
- shutil.copy(metadata, metadata_dst)
123
+ print(f"Error processing image: {e} {futures[future]}")
124
+
125
+ shift_df = pd.DataFrame(results)
126
+
127
+ metadata = metadata.drop(
128
+ columns=["dx_px","dy_px"],
129
+ errors="ignore"
130
+ )
131
+
132
+ metadata = metadata.merge(
133
+ shift_df,
134
+ on="id",
135
+ how="left",
136
+ suffixes=('', '')
137
+ )
138
+
139
+ return metadata
@@ -0,0 +1,23 @@
1
+ import torch
2
+
3
+ class LandsatCloudDetector(torch.nn.Module):
4
+ def __init__(self):
5
+ super().__init__()
6
+
7
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
8
+ # Define bit flags for clouds based on the
9
+ # Landsat QA band documentation
10
+ cloud_flags = (1 << 3) | (1 << 4) | (1 << 1)
11
+
12
+ ## Get the QA band
13
+ qa_band = x[6]
14
+ mask_band = x[:6].mean(axis=0)
15
+ mask_band[~torch.isnan(mask_band)] = 1
16
+
17
+ ## Create a cloud mask
18
+ cloud_mask = torch.bitwise_and(qa_band.int(), cloud_flags) == 0
19
+ cloud_mask = cloud_mask.float()
20
+ cloud_mask[cloud_mask == 0] = torch.nan
21
+ cloud_mask[cloud_mask == 0] = 1
22
+ final_mask = cloud_mask * mask_band
23
+ return final_mask
@@ -0,0 +1,39 @@
1
+ import pathlib
2
+ from datetime import datetime
3
+ from typing import List, Optional
4
+
5
+ import pydantic
6
+
7
+
8
+ class Sensor(pydantic.BaseModel):
9
+ start_date: str
10
+ end_date: str
11
+ edge_size: int
12
+ bands: List[str]
13
+
14
+
15
+ class Sentinel2(Sensor):
16
+ weight_path: pathlib.Path
17
+ start_date: Optional[str] = "2015-06-27"
18
+ end_date: Optional[str] = datetime.now().strftime("%Y-%m-%d")
19
+ resolution: Optional[int] = 10
20
+ edge_size: Optional[int] = 384
21
+ embedding_universal: Optional[str] = "s2_embedding_model_universal.pt"
22
+ cloud_model_universal: str = "s2_cloud_model_universal.pt"
23
+ cloud_model_specific: str = "s2_cloud_model_specific.pt"
24
+ super_model_specific: str = "s2_super_model_specific.pt"
25
+ bands: List[str] = [
26
+ "B01",
27
+ "B02",
28
+ "B03",
29
+ "B04",
30
+ "B05",
31
+ "B06",
32
+ "B07",
33
+ "B08",
34
+ "B8A",
35
+ "B09",
36
+ "B10",
37
+ "B11",
38
+ "B12",
39
+ ]