dist-s1-enumerator 1.0.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- dist_s1_enumerator/__init__.py +57 -0
- dist_s1_enumerator/asf.py +328 -0
- dist_s1_enumerator/constants.py +50 -0
- dist_s1_enumerator/data/jpl_burst_geo.parquet +0 -0
- dist_s1_enumerator/data/mgrs.parquet +0 -0
- dist_s1_enumerator/data/mgrs_burst_lookup_table.parquet +0 -0
- dist_s1_enumerator/dist_enum.py +425 -0
- dist_s1_enumerator/dist_enum_inputs.py +138 -0
- dist_s1_enumerator/exceptions.py +2 -0
- dist_s1_enumerator/mgrs_burst_data.py +170 -0
- dist_s1_enumerator/param_models.py +100 -0
- dist_s1_enumerator/py.typed +0 -0
- dist_s1_enumerator/rtc_s1_io.py +142 -0
- dist_s1_enumerator/tabular_models.py +91 -0
- dist_s1_enumerator-1.0.8.dist-info/METADATA +295 -0
- dist_s1_enumerator-1.0.8.dist-info/RECORD +19 -0
- dist_s1_enumerator-1.0.8.dist-info/WHEEL +5 -0
- dist_s1_enumerator-1.0.8.dist-info/licenses/LICENSE +202 -0
- dist_s1_enumerator-1.0.8.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
from functools import lru_cache
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import geopandas as gpd
|
|
5
|
+
import pandas as pd
|
|
6
|
+
from shapely.geometry import Point, Polygon
|
|
7
|
+
|
|
8
|
+
from dist_s1_enumerator.exceptions import NoMGRSCoverage
|
|
9
|
+
from dist_s1_enumerator.tabular_models import burst_mgrs_lut_schema, burst_schema, mgrs_tile_schema, reorder_columns
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
DATA_DIR = Path(__file__).resolve().parent / 'data'
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def get_mgrs_burst_lut_path() -> Path:
|
|
16
|
+
parquet_path = DATA_DIR / 'mgrs_burst_lookup_table.parquet'
|
|
17
|
+
return parquet_path
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_mgrs_data_path() -> Path:
|
|
21
|
+
parquet_path = DATA_DIR / 'mgrs.parquet'
|
|
22
|
+
return parquet_path
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def get_burst_data_path() -> Path:
|
|
26
|
+
parquet_path = DATA_DIR / 'jpl_burst_geo.parquet'
|
|
27
|
+
return parquet_path
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def get_burst_table(burst_ids: list[str] | str | None = None) -> gpd.GeoDataFrame:
|
|
31
|
+
parquet_path = get_burst_data_path()
|
|
32
|
+
if burst_ids is None:
|
|
33
|
+
df = gpd.read_parquet(parquet_path)
|
|
34
|
+
else:
|
|
35
|
+
if isinstance(burst_ids, str):
|
|
36
|
+
burst_ids = [burst_ids]
|
|
37
|
+
filters = [('jpl_burst_id', 'in', burst_ids)]
|
|
38
|
+
df = gpd.read_parquet(parquet_path, filters=filters)
|
|
39
|
+
if df.empty:
|
|
40
|
+
burst_ids_str = ', '.join(map(str, burst_ids))
|
|
41
|
+
raise ValueError(f'No burst data found for {burst_ids_str}.')
|
|
42
|
+
burst_schema.validate(df)
|
|
43
|
+
df = reorder_columns(df, burst_schema)
|
|
44
|
+
return df.reset_index(drop=True)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@lru_cache
|
|
48
|
+
def get_mgrs_burst_lut() -> gpd.GeoDataFrame:
|
|
49
|
+
parquet_path = get_mgrs_burst_lut_path()
|
|
50
|
+
df = pd.read_parquet(parquet_path)
|
|
51
|
+
burst_mgrs_lut_schema.validate(df)
|
|
52
|
+
df = reorder_columns(df, burst_mgrs_lut_schema)
|
|
53
|
+
return df.reset_index(drop=True)
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def get_lut_by_mgrs_tile_ids(mgrs_tile_ids: str | list[str]) -> gpd.GeoDataFrame:
|
|
57
|
+
if isinstance(mgrs_tile_ids, str):
|
|
58
|
+
mgrs_tile_ids = [mgrs_tile_ids]
|
|
59
|
+
parquet_path = get_mgrs_burst_lut_path()
|
|
60
|
+
filters = [('mgrs_tile_id', 'in', mgrs_tile_ids)]
|
|
61
|
+
df_mgrs_burst_lut = pd.read_parquet(parquet_path, filters=filters)
|
|
62
|
+
if df_mgrs_burst_lut.empty:
|
|
63
|
+
mgrs_tile_ids_str = ', '.join(map(str, mgrs_tile_ids))
|
|
64
|
+
raise ValueError(f'No LUT data found for MGRS tile ids {mgrs_tile_ids_str}.')
|
|
65
|
+
burst_mgrs_lut_schema.validate(df_mgrs_burst_lut)
|
|
66
|
+
df_mgrs_burst_lut = reorder_columns(df_mgrs_burst_lut, burst_mgrs_lut_schema)
|
|
67
|
+
return df_mgrs_burst_lut.reset_index(drop=True)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@lru_cache
|
|
71
|
+
def get_mgrs_table() -> gpd.GeoDataFrame:
|
|
72
|
+
path = get_mgrs_data_path()
|
|
73
|
+
df_mgrs = gpd.read_parquet(path)
|
|
74
|
+
mgrs_tile_schema.validate(df_mgrs)
|
|
75
|
+
df_mgrs = reorder_columns(df_mgrs, mgrs_tile_schema)
|
|
76
|
+
return df_mgrs
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
def get_mgrs_tile_table_by_ids(mgrs_tile_ids: list[str]) -> gpd.GeoDataFrame:
|
|
80
|
+
df_mgrs = get_mgrs_table()
|
|
81
|
+
if isinstance(mgrs_tile_ids, str):
|
|
82
|
+
mgrs_tile_ids = [mgrs_tile_ids]
|
|
83
|
+
ind = df_mgrs.mgrs_tile_id.isin(mgrs_tile_ids)
|
|
84
|
+
if not ind.any():
|
|
85
|
+
mgrs_tile_ids_str = ', '.join(map(str, mgrs_tile_ids))
|
|
86
|
+
raise ValueError(f'No MGRS tile data found for {mgrs_tile_ids_str}.')
|
|
87
|
+
df_mgrs_subset = df_mgrs[ind].reset_index(drop=True)
|
|
88
|
+
return df_mgrs_subset
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def get_mgrs_tiles_overlapping_geometry(geometry: Polygon | Point) -> gpd.GeoDataFrame:
|
|
92
|
+
df_mgrs = get_mgrs_table()
|
|
93
|
+
ind = df_mgrs.intersects(geometry)
|
|
94
|
+
if not ind.any():
|
|
95
|
+
raise NoMGRSCoverage(
|
|
96
|
+
'We only have MGRS tiles that overlap with DIST-HLS products (this is slightly less than Sentinel-2). '
|
|
97
|
+
)
|
|
98
|
+
df_mgrs_overlapping = df_mgrs[ind].reset_index(drop=True)
|
|
99
|
+
mgrs_tile_schema.validate(df_mgrs_overlapping)
|
|
100
|
+
df_mgrs_overlapping = reorder_columns(df_mgrs_overlapping, mgrs_tile_schema)
|
|
101
|
+
return df_mgrs_overlapping
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
def get_burst_ids_in_mgrs_tiles(mgrs_tile_ids: list[str] | str, track_numbers: list[int] = None) -> list[str]:
|
|
105
|
+
"""Get all the burst ids in the provided MGRS tiles.
|
|
106
|
+
|
|
107
|
+
If track numbers are provided gets all the burst ids for the provided pass associated with the tracks
|
|
108
|
+
for each MGRS tile. Throws an error if there are multiple acq_group_id_within_mgrs_tile for a single MGRS tile.
|
|
109
|
+
"""
|
|
110
|
+
df_mgrs_burst_luts = get_lut_by_mgrs_tile_ids(mgrs_tile_ids)
|
|
111
|
+
if isinstance(mgrs_tile_ids, str):
|
|
112
|
+
mgrs_tile_ids = [mgrs_tile_ids]
|
|
113
|
+
if track_numbers is not None:
|
|
114
|
+
if len(track_numbers) > 2:
|
|
115
|
+
raise ValueError(
|
|
116
|
+
'More than 2 track numbers provided. When track numbers are provided, we select data from a single '
|
|
117
|
+
'pass so this is an invalid input.'
|
|
118
|
+
)
|
|
119
|
+
tile_data = []
|
|
120
|
+
for mgrs_tile_id in mgrs_tile_ids:
|
|
121
|
+
ind_temp = (df_mgrs_burst_luts.mgrs_tile_id == mgrs_tile_id) & (
|
|
122
|
+
df_mgrs_burst_luts.track_number.isin(track_numbers)
|
|
123
|
+
)
|
|
124
|
+
df_lut_temp = df_mgrs_burst_luts[ind_temp].reset_index(drop=True)
|
|
125
|
+
if df_lut_temp.empty:
|
|
126
|
+
mgrs_tile_ids_str = ', '.join(map(str, mgrs_tile_ids))
|
|
127
|
+
track_numbers_str = ', '.join(map(str, track_numbers))
|
|
128
|
+
available_track_numbers = (
|
|
129
|
+
df_mgrs_burst_luts[df_mgrs_burst_luts.mgrs_tile_id == mgrs_tile_id].track_number.unique().tolist()
|
|
130
|
+
)
|
|
131
|
+
available_track_numbers_str = ', '.join(map(str, available_track_numbers))
|
|
132
|
+
raise ValueError(
|
|
133
|
+
f'Mismatch - no LUT data found for MGRS tile ids {mgrs_tile_ids_str} '
|
|
134
|
+
f'and track numbers {track_numbers_str}. '
|
|
135
|
+
f'Available track numbers for tile {mgrs_tile_ids_str} are {available_track_numbers_str}.'
|
|
136
|
+
)
|
|
137
|
+
acq_ids = df_lut_temp.acq_group_id_within_mgrs_tile.unique().tolist()
|
|
138
|
+
if len(acq_ids) != 1:
|
|
139
|
+
track_numbers_str = ', '.join(map(str, track_numbers))
|
|
140
|
+
raise ValueError(
|
|
141
|
+
f'Multiple acq_group_id_within_mgrs_tile found for mgrs_tile_id {mgrs_tile_id} and '
|
|
142
|
+
f'track_numbers {track_numbers_str}.'
|
|
143
|
+
)
|
|
144
|
+
acq_id = acq_ids[0]
|
|
145
|
+
df_lut_pass = df_mgrs_burst_luts[df_mgrs_burst_luts.acq_group_id_within_mgrs_tile == acq_id].reset_index(
|
|
146
|
+
drop=True
|
|
147
|
+
)
|
|
148
|
+
tile_data.append(df_lut_pass)
|
|
149
|
+
df_mgrs_burst_luts = pd.concat(tile_data, axis=0)
|
|
150
|
+
# Remove duplicates if sequential track numbers are provided.
|
|
151
|
+
df_mgrs_burst_luts = df_mgrs_burst_luts.drop_duplicates().reset_index(drop=True)
|
|
152
|
+
|
|
153
|
+
df_mgrs_burst_luts = df_mgrs_burst_luts.drop_duplicates(subset=['jpl_burst_id', 'mgrs_tile_id'])
|
|
154
|
+
burst_ids = df_mgrs_burst_luts.jpl_burst_id.unique().tolist()
|
|
155
|
+
return burst_ids
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def get_burst_table_from_mgrs_tiles(mgrs_tile_ids: str | list[str]) -> list:
|
|
159
|
+
df_mgrs_burst_luts = get_lut_by_mgrs_tile_ids(mgrs_tile_ids)
|
|
160
|
+
burst_ids = df_mgrs_burst_luts.jpl_burst_id.unique().tolist()
|
|
161
|
+
df_burst = get_burst_table(burst_ids)
|
|
162
|
+
df_burst = pd.merge(
|
|
163
|
+
df_burst,
|
|
164
|
+
df_mgrs_burst_luts[['jpl_burst_id', 'track_number', 'acq_group_id_within_mgrs_tile', 'mgrs_tile_id']],
|
|
165
|
+
how='left',
|
|
166
|
+
on='jpl_burst_id',
|
|
167
|
+
)
|
|
168
|
+
burst_schema.validate(df_burst)
|
|
169
|
+
df_burst = reorder_columns(df_burst, burst_schema)
|
|
170
|
+
return df_burst.reset_index(drop=True)
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
from pydantic import BaseModel, ValidationInfo, field_validator
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
class LookbackStrategyParams(BaseModel):
|
|
5
|
+
"""Pydantic model for validating lookback strategy parameters."""
|
|
6
|
+
|
|
7
|
+
lookback_strategy: str
|
|
8
|
+
max_pre_imgs_per_burst: int | list[int] | tuple[int, ...]
|
|
9
|
+
delta_lookback_days: int | list[int] | tuple[int, ...]
|
|
10
|
+
min_pre_imgs_per_burst: int
|
|
11
|
+
delta_window_days: int
|
|
12
|
+
|
|
13
|
+
@field_validator('delta_window_days')
|
|
14
|
+
@classmethod
|
|
15
|
+
def validate_delta_window_days(cls, v: int) -> int:
|
|
16
|
+
"""Validate that delta_window_days is less than 365 days."""
|
|
17
|
+
if v > 365:
|
|
18
|
+
raise ValueError('delta_window_days must be less than 365 days.')
|
|
19
|
+
return v
|
|
20
|
+
|
|
21
|
+
@field_validator('lookback_strategy')
|
|
22
|
+
@classmethod
|
|
23
|
+
def validate_lookback_strategy(cls, v: str) -> str:
|
|
24
|
+
"""Validate that lookback_strategy is one of the supported values."""
|
|
25
|
+
allowed_strategies = ['immediate_lookback', 'multi_window']
|
|
26
|
+
if v not in allowed_strategies:
|
|
27
|
+
raise ValueError(f'lookback_strategy must be one of {allowed_strategies}, got {v}')
|
|
28
|
+
return v
|
|
29
|
+
|
|
30
|
+
@field_validator('max_pre_imgs_per_burst')
|
|
31
|
+
@classmethod
|
|
32
|
+
def validate_max_pre_imgs_per_burst(
|
|
33
|
+
cls, v: int | list[int] | tuple[int, ...], info: ValidationInfo
|
|
34
|
+
) -> int | tuple[int, ...]:
|
|
35
|
+
"""Validate max_pre_imgs_per_burst based on lookback_strategy."""
|
|
36
|
+
lookback_strategy = info.data.get('lookback_strategy')
|
|
37
|
+
|
|
38
|
+
if lookback_strategy == 'immediate_lookback':
|
|
39
|
+
if isinstance(v, list | tuple):
|
|
40
|
+
raise ValueError('max_pre_imgs_per_burst must be a single integer for immediate lookback strategy.')
|
|
41
|
+
|
|
42
|
+
elif lookback_strategy == 'multi_window':
|
|
43
|
+
if isinstance(v, int):
|
|
44
|
+
v = (v,) * 3
|
|
45
|
+
elif isinstance(v, list):
|
|
46
|
+
v = tuple(v)
|
|
47
|
+
|
|
48
|
+
return v
|
|
49
|
+
|
|
50
|
+
@field_validator('delta_lookback_days')
|
|
51
|
+
@classmethod
|
|
52
|
+
def validate_delta_lookback_days(
|
|
53
|
+
cls, v: int | list[int] | tuple[int, ...], info: ValidationInfo
|
|
54
|
+
) -> int | tuple[int, ...]:
|
|
55
|
+
"""Validate delta_lookback_days based on lookback_strategy and max_pre_imgs_per_burst."""
|
|
56
|
+
lookback_strategy = info.data.get('lookback_strategy')
|
|
57
|
+
max_pre_imgs_per_burst = info.data.get('max_pre_imgs_per_burst')
|
|
58
|
+
|
|
59
|
+
if lookback_strategy == 'immediate_lookback':
|
|
60
|
+
if v != 0:
|
|
61
|
+
raise ValueError('delta_lookback_days must be 0 for immediate lookback strategy.')
|
|
62
|
+
|
|
63
|
+
elif lookback_strategy == 'multi_window':
|
|
64
|
+
if isinstance(v, int):
|
|
65
|
+
if isinstance(max_pre_imgs_per_burst, list | tuple):
|
|
66
|
+
v = tuple(v * i for i in range(1, len(max_pre_imgs_per_burst) + 1))
|
|
67
|
+
else:
|
|
68
|
+
v = tuple(v * i for i in range(1, 3 + 1)) # Default to 3 if max_pre_imgs_per_burst is still an int
|
|
69
|
+
elif isinstance(v, list):
|
|
70
|
+
v = tuple(v)
|
|
71
|
+
|
|
72
|
+
if isinstance(max_pre_imgs_per_burst, list | tuple) and len(v) != len(max_pre_imgs_per_burst):
|
|
73
|
+
raise ValueError(
|
|
74
|
+
'max_pre_imgs_per_burst and delta_lookback_days must have the same length. '
|
|
75
|
+
'If max_pre_imgs_per_burst is a single integer, this is interpreted as the maximum '
|
|
76
|
+
'number of pre-images on 3 anniversary dates so ensure that `delta_lookback_days` '
|
|
77
|
+
'is a tuple of length 3 or an integer.'
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
return v
|
|
81
|
+
|
|
82
|
+
@field_validator('min_pre_imgs_per_burst')
|
|
83
|
+
@classmethod
|
|
84
|
+
def validate_min_pre_imgs_per_burst(cls, v: int, info: ValidationInfo) -> int:
|
|
85
|
+
"""Validate that all max_pre_imgs_per_burst values are greater than min_pre_imgs_per_burst."""
|
|
86
|
+
max_pre_imgs_per_burst = info.data.get('max_pre_imgs_per_burst')
|
|
87
|
+
lookback_strategy = info.data.get('lookback_strategy')
|
|
88
|
+
|
|
89
|
+
if lookback_strategy == 'immediate_lookback':
|
|
90
|
+
if isinstance(max_pre_imgs_per_burst, int) and max_pre_imgs_per_burst < v:
|
|
91
|
+
raise ValueError('max_pre_imgs_per_burst must be greater than min_pre_imgs_per_burst')
|
|
92
|
+
|
|
93
|
+
elif lookback_strategy == 'multi_window':
|
|
94
|
+
if isinstance(max_pre_imgs_per_burst, list | tuple):
|
|
95
|
+
if any(m < v for m in max_pre_imgs_per_burst):
|
|
96
|
+
raise ValueError('All values in max_pre_imgs_per_burst must be greater than min_pre_imgs_per_burst')
|
|
97
|
+
if isinstance(max_pre_imgs_per_burst, int) and max_pre_imgs_per_burst < v:
|
|
98
|
+
raise ValueError('max_pre_imgs_per_burst must be greater than min_pre_imgs_per_burst')
|
|
99
|
+
|
|
100
|
+
return v
|
|
File without changes
|
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
import concurrent.futures
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
|
|
4
|
+
import geopandas as gpd
|
|
5
|
+
import requests
|
|
6
|
+
from pandera.pandas import check_input
|
|
7
|
+
from rasterio.errors import RasterioIOError
|
|
8
|
+
from requests.exceptions import HTTPError, RequestException, Timeout
|
|
9
|
+
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
|
|
10
|
+
from tqdm.auto import tqdm
|
|
11
|
+
|
|
12
|
+
from dist_s1_enumerator.tabular_models import rtc_s1_schema
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def generate_rtc_s1_local_paths(
|
|
16
|
+
urls: list[str], data_dir: Path | str, track_token: list, date_tokens: list[str], mgrs_tokens: list[str]
|
|
17
|
+
) -> list[Path]:
|
|
18
|
+
data_dir = Path(data_dir)
|
|
19
|
+
data_dir.mkdir(parents=True, exist_ok=True)
|
|
20
|
+
|
|
21
|
+
n = len(urls)
|
|
22
|
+
bad_data = [
|
|
23
|
+
(input_name, len(data))
|
|
24
|
+
for (input_name, data) in zip(
|
|
25
|
+
['urls', 'date_tokens', 'mgrs_tokens', 'track_token'], [urls, date_tokens, mgrs_tokens, track_token]
|
|
26
|
+
)
|
|
27
|
+
if len(data) != n
|
|
28
|
+
]
|
|
29
|
+
if bad_data:
|
|
30
|
+
raise ValueError(f'Number of {bad_data[0][0]} (which is {bad_data[0][1]}) must match the number of URLs ({n}).')
|
|
31
|
+
|
|
32
|
+
dst_dirs = [
|
|
33
|
+
data_dir / mgrs_token / track_token / date_token
|
|
34
|
+
for (mgrs_token, track_token, date_token) in zip(mgrs_tokens, track_token, date_tokens)
|
|
35
|
+
]
|
|
36
|
+
[dst_dir.mkdir(parents=True, exist_ok=True) for dst_dir in dst_dirs]
|
|
37
|
+
|
|
38
|
+
local_paths = [dst_dir / url.split('/')[-1] for (dst_dir, url) in zip(dst_dirs, urls)]
|
|
39
|
+
return local_paths
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def append_local_paths(df_rtc_ts: gpd.GeoDataFrame, data_dir: Path | str) -> list[Path]:
|
|
43
|
+
copol_urls = df_rtc_ts['url_copol'].tolist()
|
|
44
|
+
crosspol_urls = df_rtc_ts['url_crosspol'].tolist()
|
|
45
|
+
track_tokens = df_rtc_ts['track_token'].tolist()
|
|
46
|
+
date_tokens = df_rtc_ts['acq_date_for_mgrs_pass'].tolist()
|
|
47
|
+
mgrs_tokens = df_rtc_ts['mgrs_tile_id'].tolist()
|
|
48
|
+
|
|
49
|
+
out_paths_copol = generate_rtc_s1_local_paths(copol_urls, data_dir, track_tokens, date_tokens, mgrs_tokens)
|
|
50
|
+
out_paths_crosspol = generate_rtc_s1_local_paths(crosspol_urls, data_dir, track_tokens, date_tokens, mgrs_tokens)
|
|
51
|
+
df_out = df_rtc_ts.copy()
|
|
52
|
+
df_out['loc_path_copol'] = out_paths_copol
|
|
53
|
+
df_out['loc_path_crosspol'] = out_paths_crosspol
|
|
54
|
+
return df_out
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def create_download_session(max_workers: int = 5) -> requests.Session:
|
|
58
|
+
"""Create a requests session with appropriate settings for downloads.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
max_workers: Number of concurrent download threads (used to size connection pool)
|
|
62
|
+
"""
|
|
63
|
+
session = requests.Session()
|
|
64
|
+
session.headers.update({'User-Agent': 'dist-s1-enumerator/1.0'})
|
|
65
|
+
|
|
66
|
+
# Size connection pool based on concurrent workers
|
|
67
|
+
pool_maxsize = max(max_workers * 2, 10)
|
|
68
|
+
pool_maxsize = min(pool_maxsize, 50)
|
|
69
|
+
|
|
70
|
+
adapter = requests.adapters.HTTPAdapter(
|
|
71
|
+
pool_connections=10,
|
|
72
|
+
pool_maxsize=pool_maxsize,
|
|
73
|
+
max_retries=0, # handle retries with tenacity
|
|
74
|
+
)
|
|
75
|
+
session.mount('http://', adapter)
|
|
76
|
+
session.mount('https://', adapter)
|
|
77
|
+
return session
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
@retry(
|
|
81
|
+
retry=retry_if_exception_type((ConnectionError, HTTPError, RasterioIOError, Timeout, RequestException)),
|
|
82
|
+
stop=stop_after_attempt(5),
|
|
83
|
+
wait=wait_exponential(multiplier=1, min=1, max=10),
|
|
84
|
+
reraise=True,
|
|
85
|
+
)
|
|
86
|
+
def localize_one_rtc(url: str, out_path: Path, session: requests.Session | None = None) -> Path:
|
|
87
|
+
"""Download a single RTC file with retry logic."""
|
|
88
|
+
if out_path.exists():
|
|
89
|
+
return out_path
|
|
90
|
+
|
|
91
|
+
if session is None:
|
|
92
|
+
session = create_download_session()
|
|
93
|
+
|
|
94
|
+
try:
|
|
95
|
+
with session.get(url, stream=True, timeout=30) as r:
|
|
96
|
+
r.raise_for_status()
|
|
97
|
+
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
98
|
+
with out_path.open('wb') as f:
|
|
99
|
+
for chunk in r.iter_content(chunk_size=16384):
|
|
100
|
+
if chunk: # filter out keep-alive chunks
|
|
101
|
+
f.write(chunk)
|
|
102
|
+
except Exception:
|
|
103
|
+
# Clean up partial file on failure
|
|
104
|
+
if out_path.exists():
|
|
105
|
+
out_path.unlink()
|
|
106
|
+
raise
|
|
107
|
+
return out_path
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
@check_input(rtc_s1_schema, 0)
|
|
111
|
+
def localize_rtc_s1_ts(
|
|
112
|
+
df_rtc_ts: gpd.GeoDataFrame,
|
|
113
|
+
data_dir: Path | str,
|
|
114
|
+
max_workers: int = 5,
|
|
115
|
+
tqdm_enabled: bool = True,
|
|
116
|
+
) -> gpd.GeoDataFrame:
|
|
117
|
+
df_out = append_local_paths(df_rtc_ts, data_dir)
|
|
118
|
+
urls = df_out['url_copol'].tolist() + df_out['url_crosspol'].tolist()
|
|
119
|
+
out_paths = df_out['loc_path_copol'].tolist() + df_out['loc_path_crosspol'].tolist()
|
|
120
|
+
|
|
121
|
+
# Create shared session for connection pooling, sized for concurrent workers
|
|
122
|
+
session = create_download_session(max_workers)
|
|
123
|
+
|
|
124
|
+
def localize_one_rtc_with_session(data: tuple) -> Path:
|
|
125
|
+
url, out_path = data
|
|
126
|
+
return localize_one_rtc(url, out_path, session)
|
|
127
|
+
|
|
128
|
+
disable_tqdm = not tqdm_enabled
|
|
129
|
+
with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
|
|
130
|
+
_ = list(
|
|
131
|
+
tqdm(
|
|
132
|
+
executor.map(localize_one_rtc_with_session, zip(urls, out_paths)),
|
|
133
|
+
total=len(urls),
|
|
134
|
+
disable=disable_tqdm,
|
|
135
|
+
desc='Downloading RTC-S1 burst data',
|
|
136
|
+
dynamic_ncols=True,
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
# For serialization
|
|
140
|
+
df_out['loc_path_copol'] = df_out['loc_path_copol'].astype(str)
|
|
141
|
+
df_out['loc_path_crosspol'] = df_out['loc_path_crosspol'].astype(str)
|
|
142
|
+
return df_out
|
|
@@ -0,0 +1,91 @@
|
|
|
1
|
+
import geopandas as gpd
|
|
2
|
+
from pandera.engines.pandas_engine import DateTime
|
|
3
|
+
from pandera.pandas import Column, DataFrameSchema
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
burst_schema = DataFrameSchema(
|
|
7
|
+
{
|
|
8
|
+
'jpl_burst_id': Column(str, required=True),
|
|
9
|
+
'track_number': Column(int, required=False),
|
|
10
|
+
'acq_group_id_within_mgrs_tile': Column(int, required=False),
|
|
11
|
+
'mgrs_tile_id': Column(str, required=False),
|
|
12
|
+
'geometry': Column('geometry', required=True),
|
|
13
|
+
}
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
mgrs_tile_schema = DataFrameSchema(
|
|
17
|
+
{
|
|
18
|
+
'mgrs_tile_id': Column(str, required=True),
|
|
19
|
+
'utm_epsg': Column(int, required=True),
|
|
20
|
+
'utm_wkt': Column(str, required=True),
|
|
21
|
+
'geometry': Column('geometry', required=True),
|
|
22
|
+
}
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
# Response schema from ASF DAAC API
|
|
26
|
+
rtc_s1_resp_schema = DataFrameSchema(
|
|
27
|
+
{
|
|
28
|
+
'opera_id': Column(str, required=True),
|
|
29
|
+
'jpl_burst_id': Column(str, required=True),
|
|
30
|
+
'acq_dt': Column(DateTime(tz='UTC'), required=True),
|
|
31
|
+
'acq_date_for_mgrs_pass': Column(str, required=False),
|
|
32
|
+
'polarizations': Column(str, required=True),
|
|
33
|
+
'track_number': Column(int, required=True),
|
|
34
|
+
# Integer number of 6 day periods since 2014-01-01
|
|
35
|
+
'pass_id': Column(int, required=True),
|
|
36
|
+
'url_crosspol': Column(str, required=True),
|
|
37
|
+
'url_copol': Column(str, required=True),
|
|
38
|
+
'geometry': Column('geometry', required=True),
|
|
39
|
+
}
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
# Schema for RTC-S1 metadata with MGRS tile and acq group id appended
|
|
43
|
+
# Note: a single burst product may be associated with multiple MGRS tiles and acq group_ids
|
|
44
|
+
rtc_s1_schema = rtc_s1_resp_schema.add_columns(
|
|
45
|
+
{
|
|
46
|
+
'mgrs_tile_id': Column(str, required=True),
|
|
47
|
+
'acq_group_id_within_mgrs_tile': Column(int, required=True),
|
|
48
|
+
'track_token': Column(str, required=True),
|
|
49
|
+
'geometry': Column('geometry', required=True),
|
|
50
|
+
}
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
# Schema for inputs to dist-s1 workflow
|
|
54
|
+
dist_s1_input_schema = rtc_s1_schema.add_columns(
|
|
55
|
+
{
|
|
56
|
+
'input_category': Column(str, required=True),
|
|
57
|
+
'product_id': Column(int, required=False),
|
|
58
|
+
'geometry': Column('geometry', required=True),
|
|
59
|
+
}
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
# Schema for localized inputs
|
|
63
|
+
dist_s1_loc_input_schema = dist_s1_input_schema.add_columns(
|
|
64
|
+
{
|
|
65
|
+
'loc_path_copol': Column(str, required=True),
|
|
66
|
+
'loc_path_crosspol': Column(str, required=True),
|
|
67
|
+
'geometry': Column('geometry', required=True),
|
|
68
|
+
}
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
burst_mgrs_lut_schema = DataFrameSchema(
|
|
72
|
+
{
|
|
73
|
+
'jpl_burst_id': Column(str, required=True),
|
|
74
|
+
'mgrs_tile_id': Column(str, required=True),
|
|
75
|
+
'track_number': Column(int, required=True),
|
|
76
|
+
'acq_group_id_within_mgrs_tile': Column(int, required=True),
|
|
77
|
+
'orbit_pass': Column(str, required=True),
|
|
78
|
+
'area_per_acq_group_km2': Column(int, required=True),
|
|
79
|
+
'n_bursts_per_acq_group': Column(int, required=True),
|
|
80
|
+
}
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def reorder_columns(df: gpd.GeoDataFrame, schema: DataFrameSchema) -> gpd.GeoDataFrame:
|
|
85
|
+
if not df.empty:
|
|
86
|
+
df = df[[col for col in schema.columns.keys() if col in df.columns]]
|
|
87
|
+
else:
|
|
88
|
+
df = gpd.GeoDataFrame(columns=schema.columns.keys())
|
|
89
|
+
if 'geometry' in schema.columns.keys():
|
|
90
|
+
df.set_crs(epsg=4326)
|
|
91
|
+
return df
|