dist-s1-enumerator 1.0.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,170 @@
1
+ from functools import lru_cache
2
+ from pathlib import Path
3
+
4
+ import geopandas as gpd
5
+ import pandas as pd
6
+ from shapely.geometry import Point, Polygon
7
+
8
+ from dist_s1_enumerator.exceptions import NoMGRSCoverage
9
+ from dist_s1_enumerator.tabular_models import burst_mgrs_lut_schema, burst_schema, mgrs_tile_schema, reorder_columns
10
+
11
+
12
+ DATA_DIR = Path(__file__).resolve().parent / 'data'
13
+
14
+
15
+ def get_mgrs_burst_lut_path() -> Path:
16
+ parquet_path = DATA_DIR / 'mgrs_burst_lookup_table.parquet'
17
+ return parquet_path
18
+
19
+
20
+ def get_mgrs_data_path() -> Path:
21
+ parquet_path = DATA_DIR / 'mgrs.parquet'
22
+ return parquet_path
23
+
24
+
25
+ def get_burst_data_path() -> Path:
26
+ parquet_path = DATA_DIR / 'jpl_burst_geo.parquet'
27
+ return parquet_path
28
+
29
+
30
+ def get_burst_table(burst_ids: list[str] | str | None = None) -> gpd.GeoDataFrame:
31
+ parquet_path = get_burst_data_path()
32
+ if burst_ids is None:
33
+ df = gpd.read_parquet(parquet_path)
34
+ else:
35
+ if isinstance(burst_ids, str):
36
+ burst_ids = [burst_ids]
37
+ filters = [('jpl_burst_id', 'in', burst_ids)]
38
+ df = gpd.read_parquet(parquet_path, filters=filters)
39
+ if df.empty:
40
+ burst_ids_str = ', '.join(map(str, burst_ids))
41
+ raise ValueError(f'No burst data found for {burst_ids_str}.')
42
+ burst_schema.validate(df)
43
+ df = reorder_columns(df, burst_schema)
44
+ return df.reset_index(drop=True)
45
+
46
+
47
+ @lru_cache
48
+ def get_mgrs_burst_lut() -> gpd.GeoDataFrame:
49
+ parquet_path = get_mgrs_burst_lut_path()
50
+ df = pd.read_parquet(parquet_path)
51
+ burst_mgrs_lut_schema.validate(df)
52
+ df = reorder_columns(df, burst_mgrs_lut_schema)
53
+ return df.reset_index(drop=True)
54
+
55
+
56
+ def get_lut_by_mgrs_tile_ids(mgrs_tile_ids: str | list[str]) -> gpd.GeoDataFrame:
57
+ if isinstance(mgrs_tile_ids, str):
58
+ mgrs_tile_ids = [mgrs_tile_ids]
59
+ parquet_path = get_mgrs_burst_lut_path()
60
+ filters = [('mgrs_tile_id', 'in', mgrs_tile_ids)]
61
+ df_mgrs_burst_lut = pd.read_parquet(parquet_path, filters=filters)
62
+ if df_mgrs_burst_lut.empty:
63
+ mgrs_tile_ids_str = ', '.join(map(str, mgrs_tile_ids))
64
+ raise ValueError(f'No LUT data found for MGRS tile ids {mgrs_tile_ids_str}.')
65
+ burst_mgrs_lut_schema.validate(df_mgrs_burst_lut)
66
+ df_mgrs_burst_lut = reorder_columns(df_mgrs_burst_lut, burst_mgrs_lut_schema)
67
+ return df_mgrs_burst_lut.reset_index(drop=True)
68
+
69
+
70
+ @lru_cache
71
+ def get_mgrs_table() -> gpd.GeoDataFrame:
72
+ path = get_mgrs_data_path()
73
+ df_mgrs = gpd.read_parquet(path)
74
+ mgrs_tile_schema.validate(df_mgrs)
75
+ df_mgrs = reorder_columns(df_mgrs, mgrs_tile_schema)
76
+ return df_mgrs
77
+
78
+
79
+ def get_mgrs_tile_table_by_ids(mgrs_tile_ids: list[str]) -> gpd.GeoDataFrame:
80
+ df_mgrs = get_mgrs_table()
81
+ if isinstance(mgrs_tile_ids, str):
82
+ mgrs_tile_ids = [mgrs_tile_ids]
83
+ ind = df_mgrs.mgrs_tile_id.isin(mgrs_tile_ids)
84
+ if not ind.any():
85
+ mgrs_tile_ids_str = ', '.join(map(str, mgrs_tile_ids))
86
+ raise ValueError(f'No MGRS tile data found for {mgrs_tile_ids_str}.')
87
+ df_mgrs_subset = df_mgrs[ind].reset_index(drop=True)
88
+ return df_mgrs_subset
89
+
90
+
91
+ def get_mgrs_tiles_overlapping_geometry(geometry: Polygon | Point) -> gpd.GeoDataFrame:
92
+ df_mgrs = get_mgrs_table()
93
+ ind = df_mgrs.intersects(geometry)
94
+ if not ind.any():
95
+ raise NoMGRSCoverage(
96
+ 'We only have MGRS tiles that overlap with DIST-HLS products (this is slightly less than Sentinel-2). '
97
+ )
98
+ df_mgrs_overlapping = df_mgrs[ind].reset_index(drop=True)
99
+ mgrs_tile_schema.validate(df_mgrs_overlapping)
100
+ df_mgrs_overlapping = reorder_columns(df_mgrs_overlapping, mgrs_tile_schema)
101
+ return df_mgrs_overlapping
102
+
103
+
104
+ def get_burst_ids_in_mgrs_tiles(mgrs_tile_ids: list[str] | str, track_numbers: list[int] = None) -> list[str]:
105
+ """Get all the burst ids in the provided MGRS tiles.
106
+
107
+ If track numbers are provided gets all the burst ids for the provided pass associated with the tracks
108
+ for each MGRS tile. Throws an error if there are multiple acq_group_id_within_mgrs_tile for a single MGRS tile.
109
+ """
110
+ df_mgrs_burst_luts = get_lut_by_mgrs_tile_ids(mgrs_tile_ids)
111
+ if isinstance(mgrs_tile_ids, str):
112
+ mgrs_tile_ids = [mgrs_tile_ids]
113
+ if track_numbers is not None:
114
+ if len(track_numbers) > 2:
115
+ raise ValueError(
116
+ 'More than 2 track numbers provided. When track numbers are provided, we select data from a single '
117
+ 'pass so this is an invalid input.'
118
+ )
119
+ tile_data = []
120
+ for mgrs_tile_id in mgrs_tile_ids:
121
+ ind_temp = (df_mgrs_burst_luts.mgrs_tile_id == mgrs_tile_id) & (
122
+ df_mgrs_burst_luts.track_number.isin(track_numbers)
123
+ )
124
+ df_lut_temp = df_mgrs_burst_luts[ind_temp].reset_index(drop=True)
125
+ if df_lut_temp.empty:
126
+ mgrs_tile_ids_str = ', '.join(map(str, mgrs_tile_ids))
127
+ track_numbers_str = ', '.join(map(str, track_numbers))
128
+ available_track_numbers = (
129
+ df_mgrs_burst_luts[df_mgrs_burst_luts.mgrs_tile_id == mgrs_tile_id].track_number.unique().tolist()
130
+ )
131
+ available_track_numbers_str = ', '.join(map(str, available_track_numbers))
132
+ raise ValueError(
133
+ f'Mismatch - no LUT data found for MGRS tile ids {mgrs_tile_ids_str} '
134
+ f'and track numbers {track_numbers_str}. '
135
+ f'Available track numbers for tile {mgrs_tile_ids_str} are {available_track_numbers_str}.'
136
+ )
137
+ acq_ids = df_lut_temp.acq_group_id_within_mgrs_tile.unique().tolist()
138
+ if len(acq_ids) != 1:
139
+ track_numbers_str = ', '.join(map(str, track_numbers))
140
+ raise ValueError(
141
+ f'Multiple acq_group_id_within_mgrs_tile found for mgrs_tile_id {mgrs_tile_id} and '
142
+ f'track_numbers {track_numbers_str}.'
143
+ )
144
+ acq_id = acq_ids[0]
145
+ df_lut_pass = df_mgrs_burst_luts[df_mgrs_burst_luts.acq_group_id_within_mgrs_tile == acq_id].reset_index(
146
+ drop=True
147
+ )
148
+ tile_data.append(df_lut_pass)
149
+ df_mgrs_burst_luts = pd.concat(tile_data, axis=0)
150
+ # Remove duplicates if sequential track numbers are provided.
151
+ df_mgrs_burst_luts = df_mgrs_burst_luts.drop_duplicates().reset_index(drop=True)
152
+
153
+ df_mgrs_burst_luts = df_mgrs_burst_luts.drop_duplicates(subset=['jpl_burst_id', 'mgrs_tile_id'])
154
+ burst_ids = df_mgrs_burst_luts.jpl_burst_id.unique().tolist()
155
+ return burst_ids
156
+
157
+
158
+ def get_burst_table_from_mgrs_tiles(mgrs_tile_ids: str | list[str]) -> list:
159
+ df_mgrs_burst_luts = get_lut_by_mgrs_tile_ids(mgrs_tile_ids)
160
+ burst_ids = df_mgrs_burst_luts.jpl_burst_id.unique().tolist()
161
+ df_burst = get_burst_table(burst_ids)
162
+ df_burst = pd.merge(
163
+ df_burst,
164
+ df_mgrs_burst_luts[['jpl_burst_id', 'track_number', 'acq_group_id_within_mgrs_tile', 'mgrs_tile_id']],
165
+ how='left',
166
+ on='jpl_burst_id',
167
+ )
168
+ burst_schema.validate(df_burst)
169
+ df_burst = reorder_columns(df_burst, burst_schema)
170
+ return df_burst.reset_index(drop=True)
@@ -0,0 +1,100 @@
1
+ from pydantic import BaseModel, ValidationInfo, field_validator
2
+
3
+
4
+ class LookbackStrategyParams(BaseModel):
5
+ """Pydantic model for validating lookback strategy parameters."""
6
+
7
+ lookback_strategy: str
8
+ max_pre_imgs_per_burst: int | list[int] | tuple[int, ...]
9
+ delta_lookback_days: int | list[int] | tuple[int, ...]
10
+ min_pre_imgs_per_burst: int
11
+ delta_window_days: int
12
+
13
+ @field_validator('delta_window_days')
14
+ @classmethod
15
+ def validate_delta_window_days(cls, v: int) -> int:
16
+ """Validate that delta_window_days is less than 365 days."""
17
+ if v > 365:
18
+ raise ValueError('delta_window_days must be less than 365 days.')
19
+ return v
20
+
21
+ @field_validator('lookback_strategy')
22
+ @classmethod
23
+ def validate_lookback_strategy(cls, v: str) -> str:
24
+ """Validate that lookback_strategy is one of the supported values."""
25
+ allowed_strategies = ['immediate_lookback', 'multi_window']
26
+ if v not in allowed_strategies:
27
+ raise ValueError(f'lookback_strategy must be one of {allowed_strategies}, got {v}')
28
+ return v
29
+
30
+ @field_validator('max_pre_imgs_per_burst')
31
+ @classmethod
32
+ def validate_max_pre_imgs_per_burst(
33
+ cls, v: int | list[int] | tuple[int, ...], info: ValidationInfo
34
+ ) -> int | tuple[int, ...]:
35
+ """Validate max_pre_imgs_per_burst based on lookback_strategy."""
36
+ lookback_strategy = info.data.get('lookback_strategy')
37
+
38
+ if lookback_strategy == 'immediate_lookback':
39
+ if isinstance(v, list | tuple):
40
+ raise ValueError('max_pre_imgs_per_burst must be a single integer for immediate lookback strategy.')
41
+
42
+ elif lookback_strategy == 'multi_window':
43
+ if isinstance(v, int):
44
+ v = (v,) * 3
45
+ elif isinstance(v, list):
46
+ v = tuple(v)
47
+
48
+ return v
49
+
50
+ @field_validator('delta_lookback_days')
51
+ @classmethod
52
+ def validate_delta_lookback_days(
53
+ cls, v: int | list[int] | tuple[int, ...], info: ValidationInfo
54
+ ) -> int | tuple[int, ...]:
55
+ """Validate delta_lookback_days based on lookback_strategy and max_pre_imgs_per_burst."""
56
+ lookback_strategy = info.data.get('lookback_strategy')
57
+ max_pre_imgs_per_burst = info.data.get('max_pre_imgs_per_burst')
58
+
59
+ if lookback_strategy == 'immediate_lookback':
60
+ if v != 0:
61
+ raise ValueError('delta_lookback_days must be 0 for immediate lookback strategy.')
62
+
63
+ elif lookback_strategy == 'multi_window':
64
+ if isinstance(v, int):
65
+ if isinstance(max_pre_imgs_per_burst, list | tuple):
66
+ v = tuple(v * i for i in range(1, len(max_pre_imgs_per_burst) + 1))
67
+ else:
68
+ v = tuple(v * i for i in range(1, 3 + 1)) # Default to 3 if max_pre_imgs_per_burst is still an int
69
+ elif isinstance(v, list):
70
+ v = tuple(v)
71
+
72
+ if isinstance(max_pre_imgs_per_burst, list | tuple) and len(v) != len(max_pre_imgs_per_burst):
73
+ raise ValueError(
74
+ 'max_pre_imgs_per_burst and delta_lookback_days must have the same length. '
75
+ 'If max_pre_imgs_per_burst is a single integer, this is interpreted as the maximum '
76
+ 'number of pre-images on 3 anniversary dates so ensure that `delta_lookback_days` '
77
+ 'is a tuple of length 3 or an integer.'
78
+ )
79
+
80
+ return v
81
+
82
+ @field_validator('min_pre_imgs_per_burst')
83
+ @classmethod
84
+ def validate_min_pre_imgs_per_burst(cls, v: int, info: ValidationInfo) -> int:
85
+ """Validate that all max_pre_imgs_per_burst values are greater than min_pre_imgs_per_burst."""
86
+ max_pre_imgs_per_burst = info.data.get('max_pre_imgs_per_burst')
87
+ lookback_strategy = info.data.get('lookback_strategy')
88
+
89
+ if lookback_strategy == 'immediate_lookback':
90
+ if isinstance(max_pre_imgs_per_burst, int) and max_pre_imgs_per_burst < v:
91
+ raise ValueError('max_pre_imgs_per_burst must be greater than min_pre_imgs_per_burst')
92
+
93
+ elif lookback_strategy == 'multi_window':
94
+ if isinstance(max_pre_imgs_per_burst, list | tuple):
95
+ if any(m < v for m in max_pre_imgs_per_burst):
96
+ raise ValueError('All values in max_pre_imgs_per_burst must be greater than min_pre_imgs_per_burst')
97
+ if isinstance(max_pre_imgs_per_burst, int) and max_pre_imgs_per_burst < v:
98
+ raise ValueError('max_pre_imgs_per_burst must be greater than min_pre_imgs_per_burst')
99
+
100
+ return v
File without changes
@@ -0,0 +1,142 @@
1
+ import concurrent.futures
2
+ from pathlib import Path
3
+
4
+ import geopandas as gpd
5
+ import requests
6
+ from pandera.pandas import check_input
7
+ from rasterio.errors import RasterioIOError
8
+ from requests.exceptions import HTTPError, RequestException, Timeout
9
+ from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
10
+ from tqdm.auto import tqdm
11
+
12
+ from dist_s1_enumerator.tabular_models import rtc_s1_schema
13
+
14
+
15
+ def generate_rtc_s1_local_paths(
16
+ urls: list[str], data_dir: Path | str, track_token: list, date_tokens: list[str], mgrs_tokens: list[str]
17
+ ) -> list[Path]:
18
+ data_dir = Path(data_dir)
19
+ data_dir.mkdir(parents=True, exist_ok=True)
20
+
21
+ n = len(urls)
22
+ bad_data = [
23
+ (input_name, len(data))
24
+ for (input_name, data) in zip(
25
+ ['urls', 'date_tokens', 'mgrs_tokens', 'track_token'], [urls, date_tokens, mgrs_tokens, track_token]
26
+ )
27
+ if len(data) != n
28
+ ]
29
+ if bad_data:
30
+ raise ValueError(f'Number of {bad_data[0][0]} (which is {bad_data[0][1]}) must match the number of URLs ({n}).')
31
+
32
+ dst_dirs = [
33
+ data_dir / mgrs_token / track_token / date_token
34
+ for (mgrs_token, track_token, date_token) in zip(mgrs_tokens, track_token, date_tokens)
35
+ ]
36
+ [dst_dir.mkdir(parents=True, exist_ok=True) for dst_dir in dst_dirs]
37
+
38
+ local_paths = [dst_dir / url.split('/')[-1] for (dst_dir, url) in zip(dst_dirs, urls)]
39
+ return local_paths
40
+
41
+
42
+ def append_local_paths(df_rtc_ts: gpd.GeoDataFrame, data_dir: Path | str) -> list[Path]:
43
+ copol_urls = df_rtc_ts['url_copol'].tolist()
44
+ crosspol_urls = df_rtc_ts['url_crosspol'].tolist()
45
+ track_tokens = df_rtc_ts['track_token'].tolist()
46
+ date_tokens = df_rtc_ts['acq_date_for_mgrs_pass'].tolist()
47
+ mgrs_tokens = df_rtc_ts['mgrs_tile_id'].tolist()
48
+
49
+ out_paths_copol = generate_rtc_s1_local_paths(copol_urls, data_dir, track_tokens, date_tokens, mgrs_tokens)
50
+ out_paths_crosspol = generate_rtc_s1_local_paths(crosspol_urls, data_dir, track_tokens, date_tokens, mgrs_tokens)
51
+ df_out = df_rtc_ts.copy()
52
+ df_out['loc_path_copol'] = out_paths_copol
53
+ df_out['loc_path_crosspol'] = out_paths_crosspol
54
+ return df_out
55
+
56
+
57
+ def create_download_session(max_workers: int = 5) -> requests.Session:
58
+ """Create a requests session with appropriate settings for downloads.
59
+
60
+ Args:
61
+ max_workers: Number of concurrent download threads (used to size connection pool)
62
+ """
63
+ session = requests.Session()
64
+ session.headers.update({'User-Agent': 'dist-s1-enumerator/1.0'})
65
+
66
+ # Size connection pool based on concurrent workers
67
+ pool_maxsize = max(max_workers * 2, 10)
68
+ pool_maxsize = min(pool_maxsize, 50)
69
+
70
+ adapter = requests.adapters.HTTPAdapter(
71
+ pool_connections=10,
72
+ pool_maxsize=pool_maxsize,
73
+ max_retries=0, # handle retries with tenacity
74
+ )
75
+ session.mount('http://', adapter)
76
+ session.mount('https://', adapter)
77
+ return session
78
+
79
+
80
+ @retry(
81
+ retry=retry_if_exception_type((ConnectionError, HTTPError, RasterioIOError, Timeout, RequestException)),
82
+ stop=stop_after_attempt(5),
83
+ wait=wait_exponential(multiplier=1, min=1, max=10),
84
+ reraise=True,
85
+ )
86
+ def localize_one_rtc(url: str, out_path: Path, session: requests.Session | None = None) -> Path:
87
+ """Download a single RTC file with retry logic."""
88
+ if out_path.exists():
89
+ return out_path
90
+
91
+ if session is None:
92
+ session = create_download_session()
93
+
94
+ try:
95
+ with session.get(url, stream=True, timeout=30) as r:
96
+ r.raise_for_status()
97
+ out_path.parent.mkdir(parents=True, exist_ok=True)
98
+ with out_path.open('wb') as f:
99
+ for chunk in r.iter_content(chunk_size=16384):
100
+ if chunk: # filter out keep-alive chunks
101
+ f.write(chunk)
102
+ except Exception:
103
+ # Clean up partial file on failure
104
+ if out_path.exists():
105
+ out_path.unlink()
106
+ raise
107
+ return out_path
108
+
109
+
110
+ @check_input(rtc_s1_schema, 0)
111
+ def localize_rtc_s1_ts(
112
+ df_rtc_ts: gpd.GeoDataFrame,
113
+ data_dir: Path | str,
114
+ max_workers: int = 5,
115
+ tqdm_enabled: bool = True,
116
+ ) -> gpd.GeoDataFrame:
117
+ df_out = append_local_paths(df_rtc_ts, data_dir)
118
+ urls = df_out['url_copol'].tolist() + df_out['url_crosspol'].tolist()
119
+ out_paths = df_out['loc_path_copol'].tolist() + df_out['loc_path_crosspol'].tolist()
120
+
121
+ # Create shared session for connection pooling, sized for concurrent workers
122
+ session = create_download_session(max_workers)
123
+
124
+ def localize_one_rtc_with_session(data: tuple) -> Path:
125
+ url, out_path = data
126
+ return localize_one_rtc(url, out_path, session)
127
+
128
+ disable_tqdm = not tqdm_enabled
129
+ with concurrent.futures.ThreadPoolExecutor(max_workers=max_workers) as executor:
130
+ _ = list(
131
+ tqdm(
132
+ executor.map(localize_one_rtc_with_session, zip(urls, out_paths)),
133
+ total=len(urls),
134
+ disable=disable_tqdm,
135
+ desc='Downloading RTC-S1 burst data',
136
+ dynamic_ncols=True,
137
+ )
138
+ )
139
+ # For serialization
140
+ df_out['loc_path_copol'] = df_out['loc_path_copol'].astype(str)
141
+ df_out['loc_path_crosspol'] = df_out['loc_path_crosspol'].astype(str)
142
+ return df_out
@@ -0,0 +1,91 @@
1
+ import geopandas as gpd
2
+ from pandera.engines.pandas_engine import DateTime
3
+ from pandera.pandas import Column, DataFrameSchema
4
+
5
+
6
+ burst_schema = DataFrameSchema(
7
+ {
8
+ 'jpl_burst_id': Column(str, required=True),
9
+ 'track_number': Column(int, required=False),
10
+ 'acq_group_id_within_mgrs_tile': Column(int, required=False),
11
+ 'mgrs_tile_id': Column(str, required=False),
12
+ 'geometry': Column('geometry', required=True),
13
+ }
14
+ )
15
+
16
+ mgrs_tile_schema = DataFrameSchema(
17
+ {
18
+ 'mgrs_tile_id': Column(str, required=True),
19
+ 'utm_epsg': Column(int, required=True),
20
+ 'utm_wkt': Column(str, required=True),
21
+ 'geometry': Column('geometry', required=True),
22
+ }
23
+ )
24
+
25
+ # Response schema from ASF DAAC API
26
+ rtc_s1_resp_schema = DataFrameSchema(
27
+ {
28
+ 'opera_id': Column(str, required=True),
29
+ 'jpl_burst_id': Column(str, required=True),
30
+ 'acq_dt': Column(DateTime(tz='UTC'), required=True),
31
+ 'acq_date_for_mgrs_pass': Column(str, required=False),
32
+ 'polarizations': Column(str, required=True),
33
+ 'track_number': Column(int, required=True),
34
+ # Integer number of 6 day periods since 2014-01-01
35
+ 'pass_id': Column(int, required=True),
36
+ 'url_crosspol': Column(str, required=True),
37
+ 'url_copol': Column(str, required=True),
38
+ 'geometry': Column('geometry', required=True),
39
+ }
40
+ )
41
+
42
+ # Schema for RTC-S1 metadata with MGRS tile and acq group id appended
43
+ # Note: a single burst product may be associated with multiple MGRS tiles and acq group_ids
44
+ rtc_s1_schema = rtc_s1_resp_schema.add_columns(
45
+ {
46
+ 'mgrs_tile_id': Column(str, required=True),
47
+ 'acq_group_id_within_mgrs_tile': Column(int, required=True),
48
+ 'track_token': Column(str, required=True),
49
+ 'geometry': Column('geometry', required=True),
50
+ }
51
+ )
52
+
53
+ # Schema for inputs to dist-s1 workflow
54
+ dist_s1_input_schema = rtc_s1_schema.add_columns(
55
+ {
56
+ 'input_category': Column(str, required=True),
57
+ 'product_id': Column(int, required=False),
58
+ 'geometry': Column('geometry', required=True),
59
+ }
60
+ )
61
+
62
+ # Schema for localized inputs
63
+ dist_s1_loc_input_schema = dist_s1_input_schema.add_columns(
64
+ {
65
+ 'loc_path_copol': Column(str, required=True),
66
+ 'loc_path_crosspol': Column(str, required=True),
67
+ 'geometry': Column('geometry', required=True),
68
+ }
69
+ )
70
+
71
+ burst_mgrs_lut_schema = DataFrameSchema(
72
+ {
73
+ 'jpl_burst_id': Column(str, required=True),
74
+ 'mgrs_tile_id': Column(str, required=True),
75
+ 'track_number': Column(int, required=True),
76
+ 'acq_group_id_within_mgrs_tile': Column(int, required=True),
77
+ 'orbit_pass': Column(str, required=True),
78
+ 'area_per_acq_group_km2': Column(int, required=True),
79
+ 'n_bursts_per_acq_group': Column(int, required=True),
80
+ }
81
+ )
82
+
83
+
84
+ def reorder_columns(df: gpd.GeoDataFrame, schema: DataFrameSchema) -> gpd.GeoDataFrame:
85
+ if not df.empty:
86
+ df = df[[col for col in schema.columns.keys() if col in df.columns]]
87
+ else:
88
+ df = gpd.GeoDataFrame(columns=schema.columns.keys())
89
+ if 'geometry' in schema.columns.keys():
90
+ df.set_crs(epsg=4326)
91
+ return df