rslearn 0.0.3__py3-none-any.whl → 0.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +59 -0
- rslearn/data_sources/copernicus.py +10 -8
- rslearn/data_sources/earthdaily.py +21 -1
- rslearn/data_sources/eurocrops.py +246 -0
- rslearn/data_sources/gcp_public_data.py +3 -3
- rslearn/data_sources/local_files.py +11 -0
- rslearn/data_sources/openstreetmap.py +2 -4
- rslearn/data_sources/utils.py +1 -17
- rslearn/main.py +10 -1
- rslearn/models/copernicusfm.py +216 -0
- rslearn/models/copernicusfm_src/__init__.py +1 -0
- rslearn/models/copernicusfm_src/aurora/area.py +50 -0
- rslearn/models/copernicusfm_src/aurora/fourier.py +134 -0
- rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +523 -0
- rslearn/models/copernicusfm_src/flexivit/patch_embed.py +260 -0
- rslearn/models/copernicusfm_src/flexivit/utils.py +69 -0
- rslearn/models/copernicusfm_src/model_vit.py +348 -0
- rslearn/models/copernicusfm_src/util/pos_embed.py +216 -0
- rslearn/models/panopticon.py +167 -0
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +247 -0
- rslearn/models/presto/single_file_presto.py +932 -0
- rslearn/models/trunk.py +0 -144
- rslearn/models/unet.py +15 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +319 -0
- rslearn/train/callbacks/gradients.py +54 -34
- rslearn/train/data_module.py +70 -41
- rslearn/train/dataset.py +232 -54
- rslearn/train/lightning_module.py +4 -0
- rslearn/train/prediction_writer.py +7 -0
- rslearn/train/scheduler.py +15 -0
- rslearn/train/tasks/per_pixel_regression.py +259 -0
- rslearn/train/tasks/regression.py +6 -4
- rslearn/train/tasks/segmentation.py +44 -14
- rslearn/train/transforms/mask.py +69 -0
- rslearn/utils/geometry.py +8 -8
- {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/METADATA +6 -3
- {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/RECORD +43 -27
- rslearn/models/moe/distributed.py +0 -262
- rslearn/models/moe/soft.py +0 -676
- {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/WHEEL +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.3.dist-info → rslearn-0.0.5.dist-info}/top_level.txt +0 -0
rslearn/arg_parser.py
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Custom Lightning ArgumentParser with environment variable substitution support."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from jsonargparse import Namespace
|
|
8
|
+
from lightning.pytorch.cli import LightningArgumentParser
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def substitute_env_vars_in_string(content: str) -> str:
|
|
12
|
+
"""Substitute environment variables in a string.
|
|
13
|
+
|
|
14
|
+
Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
|
|
15
|
+
This works on raw string content before YAML parsing.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
content: The string content containing template variables
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The string with environment variables substituted
|
|
22
|
+
"""
|
|
23
|
+
pattern = r"\$\{([^}]+)\}"
|
|
24
|
+
|
|
25
|
+
def replace_variable(match_obj: re.Match[str]) -> str:
|
|
26
|
+
var_name = match_obj.group(1)
|
|
27
|
+
env_value = os.getenv(var_name, "")
|
|
28
|
+
return env_value if env_value is not None else ""
|
|
29
|
+
|
|
30
|
+
return re.sub(pattern, replace_variable, content)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RslearnArgumentParser(LightningArgumentParser):
|
|
34
|
+
"""Custom ArgumentParser that substitutes environment variables in config files.
|
|
35
|
+
|
|
36
|
+
This parser extends LightningArgumentParser to automatically substitute
|
|
37
|
+
${VAR_NAME} patterns with environment variable values before parsing
|
|
38
|
+
configuration content. This allows config files to use environment
|
|
39
|
+
variables while still passing Lightning's validation.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def parse_string(
|
|
43
|
+
self,
|
|
44
|
+
cfg_str: str,
|
|
45
|
+
cfg_path: str | os.PathLike = "",
|
|
46
|
+
ext_vars: dict | None = None,
|
|
47
|
+
env: bool | None = None,
|
|
48
|
+
defaults: bool = True,
|
|
49
|
+
with_meta: bool | None = None,
|
|
50
|
+
**kwargs: Any,
|
|
51
|
+
) -> Namespace:
|
|
52
|
+
"""Pre-processes string for environment variable substitution before parsing."""
|
|
53
|
+
# Substitute environment variables in the config string before parsing
|
|
54
|
+
substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
|
|
55
|
+
|
|
56
|
+
# Call the parent method with the substituted config
|
|
57
|
+
return super().parse_string(
|
|
58
|
+
substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
|
|
59
|
+
)
|
|
@@ -34,7 +34,7 @@ from rslearn.utils.geometry import (
|
|
|
34
34
|
FloatBounds,
|
|
35
35
|
STGeometry,
|
|
36
36
|
flatten_shape,
|
|
37
|
-
|
|
37
|
+
split_shape_at_antimeridian,
|
|
38
38
|
)
|
|
39
39
|
from rslearn.utils.grid_index import GridIndex
|
|
40
40
|
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
@@ -160,7 +160,7 @@ def get_sentinel2_tile_index() -> dict[str, list[FloatBounds]]:
|
|
|
160
160
|
# issues where the tile bounds go from -180 to 180 longitude and thus match
|
|
161
161
|
# with anything at the same latitude.
|
|
162
162
|
union_shp = shapely.unary_union(shapes)
|
|
163
|
-
split_shapes = flatten_shape(
|
|
163
|
+
split_shapes = flatten_shape(split_shape_at_antimeridian(union_shp))
|
|
164
164
|
bounds_list: list[FloatBounds] = []
|
|
165
165
|
for shp in split_shapes:
|
|
166
166
|
bounds_list.append(shp.bounds)
|
|
@@ -222,10 +222,10 @@ def get_sentinel2_tiles(geometry: STGeometry, cache_dir: UPath) -> list[str]:
|
|
|
222
222
|
"""
|
|
223
223
|
tile_index = load_sentinel2_tile_index(cache_dir)
|
|
224
224
|
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
225
|
-
# If the shape is a collection, it could be cutting across
|
|
225
|
+
# If the shape is a collection, it could be cutting across antimeridian.
|
|
226
226
|
# So we query each component shape separately and collect the results to avoid
|
|
227
227
|
# issues.
|
|
228
|
-
# We assume the caller has already applied
|
|
228
|
+
# We assume the caller has already applied split_at_antimeridian.
|
|
229
229
|
results = set()
|
|
230
230
|
for shp in flatten_shape(wgs84_geometry.shp):
|
|
231
231
|
for result in tile_index.query(shp.bounds):
|
|
@@ -319,7 +319,6 @@ class Copernicus(DataSource):
|
|
|
319
319
|
then we attempt to read the username/password from COPERNICUS_USERNAME
|
|
320
320
|
and COPERNICUS_PASSWORD (this is useful since access tokens are only
|
|
321
321
|
valid for an hour).
|
|
322
|
-
password: set API username/password instead of access token.
|
|
323
322
|
query_filter: filter string to include when searching for items. This will
|
|
324
323
|
be appended to other name, geographic, and sensing time filters where
|
|
325
324
|
applicable. For example, "Collection/Name eq 'SENTINEL-2'". See the API
|
|
@@ -368,6 +367,7 @@ class Copernicus(DataSource):
|
|
|
368
367
|
"order_by",
|
|
369
368
|
"sort_by",
|
|
370
369
|
"sort_desc",
|
|
370
|
+
"timeout",
|
|
371
371
|
]
|
|
372
372
|
for k in simple_optionals:
|
|
373
373
|
if k in d:
|
|
@@ -709,6 +709,8 @@ class Sentinel2(Copernicus):
|
|
|
709
709
|
"B12": ["B12"],
|
|
710
710
|
"B8A": ["B8A"],
|
|
711
711
|
"TCI": ["R", "G", "B"],
|
|
712
|
+
# L1C-only products.
|
|
713
|
+
"B10": ["B10"],
|
|
712
714
|
# L2A-only products.
|
|
713
715
|
"AOT": ["AOT"],
|
|
714
716
|
"WVP": ["WVP"],
|
|
@@ -809,17 +811,16 @@ class Sentinel2(Copernicus):
|
|
|
809
811
|
|
|
810
812
|
kwargs: dict[str, Any] = dict(
|
|
811
813
|
assets=list(needed_assets),
|
|
814
|
+
product_type=Sentinel2ProductType[d["product_type"]],
|
|
812
815
|
)
|
|
813
816
|
|
|
814
|
-
if "product_type" in d:
|
|
815
|
-
kwargs["product_type"] = Sentinel2ProductType(d["product_type"])
|
|
816
|
-
|
|
817
817
|
simple_optionals = [
|
|
818
818
|
"harmonize",
|
|
819
819
|
"access_token",
|
|
820
820
|
"order_by",
|
|
821
821
|
"sort_by",
|
|
822
822
|
"sort_desc",
|
|
823
|
+
"timeout",
|
|
823
824
|
]
|
|
824
825
|
for k in simple_optionals:
|
|
825
826
|
if k in d:
|
|
@@ -965,6 +966,7 @@ class Sentinel1(Copernicus):
|
|
|
965
966
|
"order_by",
|
|
966
967
|
"sort_by",
|
|
967
968
|
"sort_desc",
|
|
969
|
+
"timeout",
|
|
968
970
|
]
|
|
969
971
|
for k in simple_optionals:
|
|
970
972
|
if k in d:
|
|
@@ -82,6 +82,8 @@ class EarthDaily(DataSource, TileStore):
|
|
|
82
82
|
timeout: timedelta = timedelta(seconds=10),
|
|
83
83
|
skip_items_missing_assets: bool = False,
|
|
84
84
|
cache_dir: UPath | None = None,
|
|
85
|
+
max_retries: int = 3,
|
|
86
|
+
retry_backoff_factor: float = 5.0,
|
|
85
87
|
service_name: Literal["platform"] = "platform",
|
|
86
88
|
):
|
|
87
89
|
"""Initialize a new EarthDaily instance.
|
|
@@ -99,6 +101,11 @@ class EarthDaily(DataSource, TileStore):
|
|
|
99
101
|
cache_dir: optional directory to cache items by name, including asset URLs.
|
|
100
102
|
If not set, there will be no cache and instead STAC requests will be
|
|
101
103
|
needed each time.
|
|
104
|
+
max_retries: the maximum number of retry attempts for HTTP requests that fail
|
|
105
|
+
due to transient errors (e.g., 429, 500, 502, 503, 504 status codes).
|
|
106
|
+
retry_backoff_factor: backoff factor for exponential retry delays between HTTP
|
|
107
|
+
request attempts. The delay between retries is calculated using the formula:
|
|
108
|
+
`(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
|
|
102
109
|
service_name: the service name, only "platform" is supported, the other
|
|
103
110
|
services "legacy" and "internal" are not supported.
|
|
104
111
|
"""
|
|
@@ -110,6 +117,8 @@ class EarthDaily(DataSource, TileStore):
|
|
|
110
117
|
self.timeout = timeout
|
|
111
118
|
self.skip_items_missing_assets = skip_items_missing_assets
|
|
112
119
|
self.cache_dir = cache_dir
|
|
120
|
+
self.max_retries = max_retries
|
|
121
|
+
self.retry_backoff_factor = retry_backoff_factor
|
|
113
122
|
self.service_name = service_name
|
|
114
123
|
|
|
115
124
|
if cache_dir is not None:
|
|
@@ -139,6 +148,12 @@ class EarthDaily(DataSource, TileStore):
|
|
|
139
148
|
if "cache_dir" in d:
|
|
140
149
|
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
141
150
|
|
|
151
|
+
if "max_retries" in d:
|
|
152
|
+
kwargs["max_retries"] = d["max_retries"]
|
|
153
|
+
|
|
154
|
+
if "retry_backoff_factor" in d:
|
|
155
|
+
kwargs["retry_backoff_factor"] = d["retry_backoff_factor"]
|
|
156
|
+
|
|
142
157
|
simple_optionals = ["query", "sort_by", "sort_ascending"]
|
|
143
158
|
for k in simple_optionals:
|
|
144
159
|
if k in d:
|
|
@@ -159,7 +174,12 @@ class EarthDaily(DataSource, TileStore):
|
|
|
159
174
|
if self.eds_client is not None:
|
|
160
175
|
return self.eds_client, self.client, self.collection
|
|
161
176
|
|
|
162
|
-
self.eds_client = EDSClient(
|
|
177
|
+
self.eds_client = EDSClient(
|
|
178
|
+
EDSConfig(
|
|
179
|
+
max_retries=self.max_retries,
|
|
180
|
+
retry_backoff_factor=self.retry_backoff_factor,
|
|
181
|
+
)
|
|
182
|
+
)
|
|
163
183
|
|
|
164
184
|
if self.service_name == "platform":
|
|
165
185
|
self.client = self.eds_client.platform.pystac_client
|
|
@@ -0,0 +1,246 @@
|
|
|
1
|
+
"""Data source for vector EuroCrops crop type data."""
|
|
2
|
+
|
|
3
|
+
import glob
|
|
4
|
+
import os
|
|
5
|
+
import tempfile
|
|
6
|
+
import zipfile
|
|
7
|
+
from datetime import UTC, datetime, timedelta
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import fiona
|
|
11
|
+
import requests
|
|
12
|
+
from rasterio.crs import CRS
|
|
13
|
+
from upath import UPath
|
|
14
|
+
|
|
15
|
+
from rslearn.config import QueryConfig, VectorLayerConfig
|
|
16
|
+
from rslearn.const import WGS84_PROJECTION
|
|
17
|
+
from rslearn.data_sources import DataSource, Item
|
|
18
|
+
from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
19
|
+
from rslearn.log_utils import get_logger
|
|
20
|
+
from rslearn.tile_stores import TileStoreWithLayer
|
|
21
|
+
from rslearn.utils.feature import Feature
|
|
22
|
+
from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
|
|
23
|
+
|
|
24
|
+
logger = get_logger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class EuroCropsItem(Item):
|
|
28
|
+
"""An item in the EuroCrops data source.
|
|
29
|
+
|
|
30
|
+
For simplicity, we have just one item per year, so each item combines all of the
|
|
31
|
+
country-level files for that year.
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, name: str, geometry: STGeometry, zip_fnames: list[str]):
|
|
35
|
+
"""Creates a new EuroCropsItem.
|
|
36
|
+
|
|
37
|
+
Args:
|
|
38
|
+
name: unique name of the item. It is just the year that this item
|
|
39
|
+
corresponds to.
|
|
40
|
+
geometry: the spatial and temporal extent of the item
|
|
41
|
+
zip_fnames: the filenames of the zip files that contain country-level crop
|
|
42
|
+
type data for this year.
|
|
43
|
+
"""
|
|
44
|
+
super().__init__(name, geometry)
|
|
45
|
+
self.zip_fnames = zip_fnames
|
|
46
|
+
|
|
47
|
+
def serialize(self) -> dict:
|
|
48
|
+
"""Serializes the item to a JSON-encodable dictionary."""
|
|
49
|
+
d = super().serialize()
|
|
50
|
+
d["zip_fnames"] = self.zip_fnames
|
|
51
|
+
return d
|
|
52
|
+
|
|
53
|
+
@staticmethod
|
|
54
|
+
def deserialize(d: dict) -> "EuroCropsItem":
|
|
55
|
+
"""Deserializes an item from a JSON-decoded dictionary."""
|
|
56
|
+
item = super(EuroCropsItem, EuroCropsItem).deserialize(d)
|
|
57
|
+
return EuroCropsItem(
|
|
58
|
+
name=item.name, geometry=item.geometry, zip_fnames=d["zip_fnames"]
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class EuroCrops(DataSource[EuroCropsItem]):
|
|
63
|
+
"""A data source for EuroCrops vector data (v11).
|
|
64
|
+
|
|
65
|
+
See https://zenodo.org/records/14094196 for details.
|
|
66
|
+
|
|
67
|
+
While the source data is split into country-level files, this data source uses one
|
|
68
|
+
item per year for simplicity. So each item corresponds to all of the country-level
|
|
69
|
+
files for that year.
|
|
70
|
+
|
|
71
|
+
Note that the RO_ny.zip file is not used.
|
|
72
|
+
"""
|
|
73
|
+
|
|
74
|
+
BASE_URL = "https://zenodo.org/records/14094196/files/"
|
|
75
|
+
FILENAMES_BY_YEAR = {
|
|
76
|
+
2018: [
|
|
77
|
+
"FR_2018.zip",
|
|
78
|
+
],
|
|
79
|
+
2019: [
|
|
80
|
+
"DK_2019.zip",
|
|
81
|
+
],
|
|
82
|
+
2020: [
|
|
83
|
+
"ES_NA_2020.zip",
|
|
84
|
+
"FI_2020.zip",
|
|
85
|
+
"HR_2020.zip",
|
|
86
|
+
"NL_2020.zip",
|
|
87
|
+
],
|
|
88
|
+
2021: [
|
|
89
|
+
"AT_2021.zip",
|
|
90
|
+
"BE_VLG_2021.zip",
|
|
91
|
+
"BE_WAL_2021.zip",
|
|
92
|
+
"EE_2021.zip",
|
|
93
|
+
"LT_2021.zip",
|
|
94
|
+
"LV_2021.zip",
|
|
95
|
+
"PT_2021.zip",
|
|
96
|
+
"SE_2021.zip",
|
|
97
|
+
"SI_2021.zip",
|
|
98
|
+
"SK_2021.zip",
|
|
99
|
+
],
|
|
100
|
+
2023: [
|
|
101
|
+
"CZ_2023.zip",
|
|
102
|
+
"DE_BB_2023.zip",
|
|
103
|
+
"DE_LS_2021.zip",
|
|
104
|
+
"DE_NRW_2021.zip",
|
|
105
|
+
"ES_2023.zip",
|
|
106
|
+
"IE_2023.zip",
|
|
107
|
+
],
|
|
108
|
+
}
|
|
109
|
+
TIMEOUT = timedelta(seconds=10)
|
|
110
|
+
|
|
111
|
+
@staticmethod
|
|
112
|
+
def from_config(config: VectorLayerConfig, ds_path: UPath) -> "EuroCrops":
|
|
113
|
+
"""Creates a new EuroCrops instance from a configuration dictionary."""
|
|
114
|
+
if config.data_source is None:
|
|
115
|
+
raise ValueError("data_source is required")
|
|
116
|
+
return EuroCrops()
|
|
117
|
+
|
|
118
|
+
def _get_all_items(self) -> list[EuroCropsItem]:
|
|
119
|
+
"""Get a list of all available items in the data source."""
|
|
120
|
+
items: list[EuroCropsItem] = []
|
|
121
|
+
for year, fnames in self.FILENAMES_BY_YEAR.items():
|
|
122
|
+
items.append(
|
|
123
|
+
EuroCropsItem(
|
|
124
|
+
str(year),
|
|
125
|
+
get_global_geometry(
|
|
126
|
+
time_range=(
|
|
127
|
+
datetime(year, 1, 1, tzinfo=UTC),
|
|
128
|
+
datetime(year + 1, 1, 1, tzinfo=UTC),
|
|
129
|
+
),
|
|
130
|
+
),
|
|
131
|
+
fnames,
|
|
132
|
+
)
|
|
133
|
+
)
|
|
134
|
+
return items
|
|
135
|
+
|
|
136
|
+
def get_items(
|
|
137
|
+
self, geometries: list[STGeometry], query_config: QueryConfig
|
|
138
|
+
) -> list[list[list[EuroCropsItem]]]:
|
|
139
|
+
"""Get a list of items in the data source intersecting the given geometries.
|
|
140
|
+
|
|
141
|
+
Args:
|
|
142
|
+
geometries: the spatiotemporal geometries
|
|
143
|
+
query_config: the query configuration
|
|
144
|
+
|
|
145
|
+
Returns:
|
|
146
|
+
List of groups of items that should be retrieved for each geometry.
|
|
147
|
+
"""
|
|
148
|
+
wgs84_geometries = [
|
|
149
|
+
geometry.to_projection(WGS84_PROJECTION) for geometry in geometries
|
|
150
|
+
]
|
|
151
|
+
all_items = self._get_all_items()
|
|
152
|
+
groups = []
|
|
153
|
+
for geometry in wgs84_geometries:
|
|
154
|
+
cur_groups = match_candidate_items_to_window(
|
|
155
|
+
geometry, all_items, query_config
|
|
156
|
+
)
|
|
157
|
+
groups.append(cur_groups)
|
|
158
|
+
return groups
|
|
159
|
+
|
|
160
|
+
def deserialize_item(self, serialized_item: Any) -> EuroCropsItem:
|
|
161
|
+
"""Deserializes an item from JSON-decoded data."""
|
|
162
|
+
return EuroCropsItem.deserialize(serialized_item)
|
|
163
|
+
|
|
164
|
+
def _extract_features(self, fname: str) -> list[Feature]:
|
|
165
|
+
"""Download the given zip file, extract shapefile, and return list of features."""
|
|
166
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
167
|
+
# Download the zip file.
|
|
168
|
+
url = self.BASE_URL + fname
|
|
169
|
+
logger.debug(f"Downloading zip file from {url}")
|
|
170
|
+
response = requests.get(
|
|
171
|
+
url,
|
|
172
|
+
stream=True,
|
|
173
|
+
timeout=self.TIMEOUT.total_seconds(),
|
|
174
|
+
allow_redirects=False,
|
|
175
|
+
)
|
|
176
|
+
response.raise_for_status()
|
|
177
|
+
zip_fname = os.path.join(tmp_dir, "data.zip")
|
|
178
|
+
with open(zip_fname, "wb") as f:
|
|
179
|
+
for chunk in response.iter_content(chunk_size=8192):
|
|
180
|
+
f.write(chunk)
|
|
181
|
+
|
|
182
|
+
# Extract all of the files and look for shapefile filename.
|
|
183
|
+
logger.debug(f"Extracting zip file {fname}")
|
|
184
|
+
with zipfile.ZipFile(zip_fname) as zip_f:
|
|
185
|
+
zip_f.extractall(path=tmp_dir)
|
|
186
|
+
|
|
187
|
+
# The shapefiles or geopackage files can appear at any level in the hierarchy.
|
|
188
|
+
# Most zip files contain one but some contain multiple (one per region).
|
|
189
|
+
shp_fnames = glob.glob(
|
|
190
|
+
"**/*.shp", root_dir=tmp_dir, recursive=True
|
|
191
|
+
) + glob.glob("**/*.gpkg", root_dir=tmp_dir, recursive=True)
|
|
192
|
+
if len(shp_fnames) == 0:
|
|
193
|
+
tmp_dir_fnames = os.listdir(tmp_dir)
|
|
194
|
+
raise ValueError(
|
|
195
|
+
f"expected {fname} to contain .shp file but none found (matches={shp_fnames}, ls={tmp_dir_fnames})"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
# Load the features from the shapefile(s).
|
|
199
|
+
features = []
|
|
200
|
+
for shp_fname in shp_fnames:
|
|
201
|
+
logger.debug(f"Loading feature list from {shp_fname}")
|
|
202
|
+
with fiona.open(os.path.join(tmp_dir, shp_fname)) as src:
|
|
203
|
+
crs = CRS.from_wkt(src.crs.to_wkt())
|
|
204
|
+
# Normal GeoJSON should have coordinates in CRS coordinates, i.e. it
|
|
205
|
+
# should be 1 projection unit/pixel.
|
|
206
|
+
projection = Projection(crs, 1, 1)
|
|
207
|
+
|
|
208
|
+
for feat in src:
|
|
209
|
+
features.append(
|
|
210
|
+
Feature.from_geojson(
|
|
211
|
+
projection,
|
|
212
|
+
{
|
|
213
|
+
"type": "Feature",
|
|
214
|
+
"geometry": dict(feat.geometry),
|
|
215
|
+
"properties": dict(feat.properties),
|
|
216
|
+
},
|
|
217
|
+
)
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
return features
|
|
221
|
+
|
|
222
|
+
def ingest(
|
|
223
|
+
self,
|
|
224
|
+
tile_store: TileStoreWithLayer,
|
|
225
|
+
items: list[EuroCropsItem],
|
|
226
|
+
geometries: list[list[STGeometry]],
|
|
227
|
+
) -> None:
|
|
228
|
+
"""Ingest items into the given tile store.
|
|
229
|
+
|
|
230
|
+
Args:
|
|
231
|
+
tile_store: the tile store to ingest into
|
|
232
|
+
items: the items to ingest
|
|
233
|
+
geometries: a list of geometries needed for each item
|
|
234
|
+
"""
|
|
235
|
+
for item in items:
|
|
236
|
+
if tile_store.is_vector_ready(item.name):
|
|
237
|
+
continue
|
|
238
|
+
|
|
239
|
+
# Get features across all shapefiles.
|
|
240
|
+
features: list[Feature] = []
|
|
241
|
+
for fname in item.zip_fnames:
|
|
242
|
+
logger.debug(f"Getting features from {fname} for item {item.name}")
|
|
243
|
+
features.extend(self._extract_features(fname))
|
|
244
|
+
|
|
245
|
+
logger.debug(f"Writing features for {item.name} to the tile store")
|
|
246
|
+
tile_store.write_vector(item.name, features)
|
|
@@ -26,7 +26,7 @@ from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
|
26
26
|
from rslearn.log_utils import get_logger
|
|
27
27
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
28
28
|
from rslearn.utils.fsspec import join_upath, open_atomic
|
|
29
|
-
from rslearn.utils.geometry import STGeometry, flatten_shape,
|
|
29
|
+
from rslearn.utils.geometry import STGeometry, flatten_shape, split_at_antimeridian
|
|
30
30
|
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
31
31
|
|
|
32
32
|
from .copernicus import get_harmonize_callback, get_sentinel2_tiles
|
|
@@ -358,7 +358,7 @@ class Sentinel2(DataSource):
|
|
|
358
358
|
shp = shapely.box(*bounds)
|
|
359
359
|
sensing_time = row["sensing_time"]
|
|
360
360
|
geometry = STGeometry(WGS84_PROJECTION, shp, (sensing_time, sensing_time))
|
|
361
|
-
geometry =
|
|
361
|
+
geometry = split_at_antimeridian(geometry)
|
|
362
362
|
|
|
363
363
|
cloud_cover = float(row["cloud_cover"])
|
|
364
364
|
|
|
@@ -511,7 +511,7 @@ class Sentinel2(DataSource):
|
|
|
511
511
|
|
|
512
512
|
time_range = (product_xml.start_time, product_xml.start_time)
|
|
513
513
|
geometry = STGeometry(WGS84_PROJECTION, product_xml.shp, time_range)
|
|
514
|
-
geometry =
|
|
514
|
+
geometry = split_at_antimeridian(geometry)
|
|
515
515
|
|
|
516
516
|
# Sometimes the geometry is not valid.
|
|
517
517
|
# We just apply make_valid on it to correct issues.
|
|
@@ -232,6 +232,17 @@ class RasterImporter(Importer):
|
|
|
232
232
|
projection = Projection(crs, x_resolution, y_resolution)
|
|
233
233
|
geometry = STGeometry(projection, shp, None)
|
|
234
234
|
|
|
235
|
+
if geometry.is_too_large():
|
|
236
|
+
geometry = get_global_geometry(time_range=None)
|
|
237
|
+
logger.warning(
|
|
238
|
+
"Global geometry detected: this geometry will be matched against all "
|
|
239
|
+
"windows in the rslearn dataset. When using settings like "
|
|
240
|
+
"max_matches=1 and space_mode=MOSAIC, this may cause windows outside "
|
|
241
|
+
"the geometry’s valid bounds to be materialized from the global raster "
|
|
242
|
+
"instead of a more appropriate source. Consider using COMPOSITE mode, "
|
|
243
|
+
"or increasing max_matches if this behavior is unintended."
|
|
244
|
+
)
|
|
245
|
+
|
|
235
246
|
if spec.name:
|
|
236
247
|
item_name = spec.name
|
|
237
248
|
else:
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
"""Data source for
|
|
1
|
+
"""Data source for OpenStreetMap vector features."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
4
|
import shutil
|
|
@@ -392,7 +392,7 @@ class OpenStreetMap(DataSource[OsmItem]):
|
|
|
392
392
|
bounds_fname: UPath,
|
|
393
393
|
categories: dict[str, Filter],
|
|
394
394
|
):
|
|
395
|
-
"""Initialize a new
|
|
395
|
+
"""Initialize a new OpenStreetMap instance.
|
|
396
396
|
|
|
397
397
|
Args:
|
|
398
398
|
config: the configuration of this layer.
|
|
@@ -508,8 +508,6 @@ class OpenStreetMap(DataSource[OsmItem]):
|
|
|
508
508
|
items: the items to ingest
|
|
509
509
|
geometries: a list of geometries needed for each item
|
|
510
510
|
"""
|
|
511
|
-
item_names = [item.name for item in items]
|
|
512
|
-
item_names.sort()
|
|
513
511
|
for cur_item, cur_geometries in zip(items, geometries):
|
|
514
512
|
if tile_store.is_vector_ready(cur_item.name):
|
|
515
513
|
continue
|
rslearn/data_sources/utils.py
CHANGED
|
@@ -256,23 +256,7 @@ def match_candidate_items_to_window(
|
|
|
256
256
|
if item_geom.is_global():
|
|
257
257
|
item_geom = geometry
|
|
258
258
|
else:
|
|
259
|
-
|
|
260
|
-
# So we first clip the item to the window bounds in the item's
|
|
261
|
-
# projection, then re-project the item to the window's projection.
|
|
262
|
-
buffered_window_geom = STGeometry(
|
|
263
|
-
geometry.projection,
|
|
264
|
-
geometry.shp.buffer(1),
|
|
265
|
-
geometry.time_range,
|
|
266
|
-
)
|
|
267
|
-
window_shp_in_item_proj = buffered_window_geom.to_projection(
|
|
268
|
-
item_geom.projection
|
|
269
|
-
).shp
|
|
270
|
-
clipped_item_geom = STGeometry(
|
|
271
|
-
item_geom.projection,
|
|
272
|
-
item_geom.shp.intersection(window_shp_in_item_proj),
|
|
273
|
-
item_geom.time_range,
|
|
274
|
-
)
|
|
275
|
-
item_geom = clipped_item_geom.to_projection(geometry.projection)
|
|
259
|
+
item_geom = item_geom.to_projection(geometry.projection)
|
|
276
260
|
item_shps.append(item_geom.shp)
|
|
277
261
|
|
|
278
262
|
if query_config.space_mode == SpaceMode.CONTAINS:
|
rslearn/main.py
CHANGED
|
@@ -13,6 +13,7 @@ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
|
13
13
|
from rasterio.crs import CRS
|
|
14
14
|
from upath import UPath
|
|
15
15
|
|
|
16
|
+
from rslearn.arg_parser import RslearnArgumentParser
|
|
16
17
|
from rslearn.config import LayerConfig
|
|
17
18
|
from rslearn.const import WGS84_EPSG
|
|
18
19
|
from rslearn.data_sources import Item, data_source_from_config
|
|
@@ -779,7 +780,7 @@ def dataset_build_index() -> None:
|
|
|
779
780
|
|
|
780
781
|
|
|
781
782
|
class RslearnLightningCLI(LightningCLI):
|
|
782
|
-
"""LightningCLI that links data.tasks to model.tasks."""
|
|
783
|
+
"""LightningCLI that links data.tasks to model.tasks and supports environment variables."""
|
|
783
784
|
|
|
784
785
|
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
785
786
|
"""Link data.tasks to model.tasks.
|
|
@@ -787,6 +788,7 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
787
788
|
Args:
|
|
788
789
|
parser: the argument parser
|
|
789
790
|
"""
|
|
791
|
+
# Link data.tasks to model.tasks
|
|
790
792
|
parser.link_arguments(
|
|
791
793
|
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
792
794
|
)
|
|
@@ -815,6 +817,12 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
815
817
|
# sampler as needed.
|
|
816
818
|
c.trainer.use_distributed_sampler = False
|
|
817
819
|
|
|
820
|
+
# For predict, make sure that return_predictions is False.
|
|
821
|
+
# Otherwise all the predictions would be stored in memory which can lead to
|
|
822
|
+
# high memory consumption.
|
|
823
|
+
if subcommand == "predict":
|
|
824
|
+
c.return_predictions = False
|
|
825
|
+
|
|
818
826
|
|
|
819
827
|
def model_handler() -> None:
|
|
820
828
|
"""Handler for any rslearn model X commands."""
|
|
@@ -825,6 +833,7 @@ def model_handler() -> None:
|
|
|
825
833
|
subclass_mode_model=True,
|
|
826
834
|
subclass_mode_data=True,
|
|
827
835
|
save_config_kwargs={"overwrite": True},
|
|
836
|
+
parser_class=RslearnArgumentParser,
|
|
828
837
|
)
|
|
829
838
|
|
|
830
839
|
|