rslearn 0.0.3__tar.gz → 0.0.4__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {rslearn-0.0.3/rslearn.egg-info → rslearn-0.0.4}/PKG-INFO +3 -3
- {rslearn-0.0.3 → rslearn-0.0.4}/pyproject.toml +3 -7
- rslearn-0.0.4/rslearn/arg_parser.py +59 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/copernicus.py +4 -4
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/earthdaily.py +21 -1
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/gcp_public_data.py +3 -3
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/utils.py +1 -17
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/main.py +10 -1
- rslearn-0.0.4/rslearn/models/trunk.py +136 -0
- rslearn-0.0.4/rslearn/train/callbacks/adapters.py +53 -0
- rslearn-0.0.4/rslearn/train/callbacks/freeze_unfreeze.py +410 -0
- rslearn-0.0.4/rslearn/train/callbacks/gradients.py +129 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/data_module.py +70 -41
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/dataset.py +232 -54
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/lightning_module.py +4 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/prediction_writer.py +7 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/scheduler.py +15 -0
- rslearn-0.0.4/rslearn/train/tasks/per_pixel_regression.py +259 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/regression.py +6 -4
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/segmentation.py +44 -14
- rslearn-0.0.4/rslearn/train/transforms/mask.py +69 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/geometry.py +8 -8
- {rslearn-0.0.3 → rslearn-0.0.4/rslearn.egg-info}/PKG-INFO +3 -3
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/SOURCES.txt +4 -2
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/requires.txt +1 -1
- rslearn-0.0.3/rslearn/models/moe/distributed.py +0 -262
- rslearn-0.0.3/rslearn/models/moe/soft.py +0 -676
- rslearn-0.0.3/rslearn/models/trunk.py +0 -280
- rslearn-0.0.3/rslearn/train/callbacks/freeze_unfreeze.py +0 -91
- rslearn-0.0.3/rslearn/train/callbacks/gradients.py +0 -109
- {rslearn-0.0.3 → rslearn-0.0.4}/LICENSE +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/README.md +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/config/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/config/dataset.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/const.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/aws_landsat.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/aws_open_data.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/aws_sentinel1.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/climate_data_store.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/data_source.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/earthdata_srtm.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/geotiff.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/google_earth_engine.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/local_files.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/openstreetmap.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/planet.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/planet_basemap.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/planetary_computer.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/raster_source.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/usda_cdl.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/usgs_landsat.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/vector_source.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/worldcereal.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/worldcover.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/worldpop.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/data_sources/xyz_tiles.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/add_windows.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/dataset.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/index.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/manage.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/materialize.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/remap.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/dataset/window.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/log_utils.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/clip.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/conv.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/croma.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/box_ops.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/detr.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/matcher.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/position_encoding.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/transformer.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/detr/util.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/faster_rcnn.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/fpn.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/module_wrapper.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/molmo.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/multitask.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/pick_features.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/pooling_decoder.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/registry.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/sam2_enc.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/satlaspretrain.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/simple_time_series.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/singletask.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/ssl4eo_s12.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/swin.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/task_embedding.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/terramind.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/unet.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/upsample.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/models/use_croma.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/py.typed +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/tile_stores/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/tile_stores/default.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/tile_stores/tile_store.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/callbacks/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/callbacks/peft.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/optimizer.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/classification.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/detection.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/multi_task.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/tasks/task.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/concatenate.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/crop.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/flip.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/normalize.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/pad.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/train/transforms/transform.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/__init__.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/array.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/feature.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/fsspec.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/get_utm_ups_crs.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/grid_index.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/jsonargparse.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/mp.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/raster_format.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/rtree_index.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/spatial_index.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/sqlite_index.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/time.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn/utils/vector_format.py +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/dependency_links.txt +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/entry_points.txt +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/rslearn.egg-info/top_level.txt +0 -0
- {rslearn-0.0.3 → rslearn-0.0.4}/setup.cfg +0 -0
|
@@ -1,8 +1,8 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.4
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
|
-
Author
|
|
5
|
+
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
7
7
|
Version 2.0, January 2004
|
|
8
8
|
http://www.apache.org/licenses/
|
|
@@ -227,7 +227,7 @@ Requires-Dist: universal_pathlib>=0.2.6
|
|
|
227
227
|
Provides-Extra: extra
|
|
228
228
|
Requires-Dist: accelerate>=1.10; extra == "extra"
|
|
229
229
|
Requires-Dist: cdsapi>=0.7.6; extra == "extra"
|
|
230
|
-
Requires-Dist: earthdaily[platform]>=1.0.
|
|
230
|
+
Requires-Dist: earthdaily[platform]>=1.0.7; extra == "extra"
|
|
231
231
|
Requires-Dist: earthengine-api>=1.6.3; extra == "extra"
|
|
232
232
|
Requires-Dist: einops>=0.8; extra == "extra"
|
|
233
233
|
Requires-Dist: gcsfs==2025.3.0; extra == "extra"
|
|
@@ -1,13 +1,9 @@
|
|
|
1
1
|
[project]
|
|
2
2
|
name = "rslearn"
|
|
3
|
-
version = "0.0.
|
|
3
|
+
version = "0.0.4"
|
|
4
4
|
description = "A library for developing remote sensing datasets and models"
|
|
5
5
|
authors = [
|
|
6
|
-
{name = "
|
|
7
|
-
{name = "Yawen Zhang", email = "yawenz@allenai.org"},
|
|
8
|
-
{name = "Patrick Beukema", email = "patrickb@allenai.org"},
|
|
9
|
-
{name = "Henry Herzog", email = "henryh@allenai.org"},
|
|
10
|
-
{name = "Piper Wolters", email = "piperw@allenai.org"},
|
|
6
|
+
{ name = "OlmoEarth Team" },
|
|
11
7
|
]
|
|
12
8
|
readme = "README.md"
|
|
13
9
|
license = {file = "LICENSE"}
|
|
@@ -39,7 +35,7 @@ dependencies = [
|
|
|
39
35
|
extra = [
|
|
40
36
|
"accelerate>=1.10",
|
|
41
37
|
"cdsapi>=0.7.6",
|
|
42
|
-
"earthdaily[platform]>=1.0.
|
|
38
|
+
"earthdaily[platform]>=1.0.7",
|
|
43
39
|
"earthengine-api>=1.6.3",
|
|
44
40
|
"einops>=0.8",
|
|
45
41
|
"gcsfs==2025.3.0",
|
|
@@ -0,0 +1,59 @@
|
|
|
1
|
+
"""Custom Lightning ArgumentParser with environment variable substitution support."""
|
|
2
|
+
|
|
3
|
+
import os
|
|
4
|
+
import re
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from jsonargparse import Namespace
|
|
8
|
+
from lightning.pytorch.cli import LightningArgumentParser
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def substitute_env_vars_in_string(content: str) -> str:
|
|
12
|
+
"""Substitute environment variables in a string.
|
|
13
|
+
|
|
14
|
+
Replaces ${VAR_NAME} patterns with os.getenv(VAR_NAME, "") values.
|
|
15
|
+
This works on raw string content before YAML parsing.
|
|
16
|
+
|
|
17
|
+
Args:
|
|
18
|
+
content: The string content containing template variables
|
|
19
|
+
|
|
20
|
+
Returns:
|
|
21
|
+
The string with environment variables substituted
|
|
22
|
+
"""
|
|
23
|
+
pattern = r"\$\{([^}]+)\}"
|
|
24
|
+
|
|
25
|
+
def replace_variable(match_obj: re.Match[str]) -> str:
|
|
26
|
+
var_name = match_obj.group(1)
|
|
27
|
+
env_value = os.getenv(var_name, "")
|
|
28
|
+
return env_value if env_value is not None else ""
|
|
29
|
+
|
|
30
|
+
return re.sub(pattern, replace_variable, content)
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class RslearnArgumentParser(LightningArgumentParser):
|
|
34
|
+
"""Custom ArgumentParser that substitutes environment variables in config files.
|
|
35
|
+
|
|
36
|
+
This parser extends LightningArgumentParser to automatically substitute
|
|
37
|
+
${VAR_NAME} patterns with environment variable values before parsing
|
|
38
|
+
configuration content. This allows config files to use environment
|
|
39
|
+
variables while still passing Lightning's validation.
|
|
40
|
+
"""
|
|
41
|
+
|
|
42
|
+
def parse_string(
|
|
43
|
+
self,
|
|
44
|
+
cfg_str: str,
|
|
45
|
+
cfg_path: str | os.PathLike = "",
|
|
46
|
+
ext_vars: dict | None = None,
|
|
47
|
+
env: bool | None = None,
|
|
48
|
+
defaults: bool = True,
|
|
49
|
+
with_meta: bool | None = None,
|
|
50
|
+
**kwargs: Any,
|
|
51
|
+
) -> Namespace:
|
|
52
|
+
"""Pre-processes string for environment variable substitution before parsing."""
|
|
53
|
+
# Substitute environment variables in the config string before parsing
|
|
54
|
+
substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
|
|
55
|
+
|
|
56
|
+
# Call the parent method with the substituted config
|
|
57
|
+
return super().parse_string(
|
|
58
|
+
substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
|
|
59
|
+
)
|
|
@@ -34,7 +34,7 @@ from rslearn.utils.geometry import (
|
|
|
34
34
|
FloatBounds,
|
|
35
35
|
STGeometry,
|
|
36
36
|
flatten_shape,
|
|
37
|
-
|
|
37
|
+
split_shape_at_antimeridian,
|
|
38
38
|
)
|
|
39
39
|
from rslearn.utils.grid_index import GridIndex
|
|
40
40
|
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
@@ -160,7 +160,7 @@ def get_sentinel2_tile_index() -> dict[str, list[FloatBounds]]:
|
|
|
160
160
|
# issues where the tile bounds go from -180 to 180 longitude and thus match
|
|
161
161
|
# with anything at the same latitude.
|
|
162
162
|
union_shp = shapely.unary_union(shapes)
|
|
163
|
-
split_shapes = flatten_shape(
|
|
163
|
+
split_shapes = flatten_shape(split_shape_at_antimeridian(union_shp))
|
|
164
164
|
bounds_list: list[FloatBounds] = []
|
|
165
165
|
for shp in split_shapes:
|
|
166
166
|
bounds_list.append(shp.bounds)
|
|
@@ -222,10 +222,10 @@ def get_sentinel2_tiles(geometry: STGeometry, cache_dir: UPath) -> list[str]:
|
|
|
222
222
|
"""
|
|
223
223
|
tile_index = load_sentinel2_tile_index(cache_dir)
|
|
224
224
|
wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
|
|
225
|
-
# If the shape is a collection, it could be cutting across
|
|
225
|
+
# If the shape is a collection, it could be cutting across antimeridian.
|
|
226
226
|
# So we query each component shape separately and collect the results to avoid
|
|
227
227
|
# issues.
|
|
228
|
-
# We assume the caller has already applied
|
|
228
|
+
# We assume the caller has already applied split_at_antimeridian.
|
|
229
229
|
results = set()
|
|
230
230
|
for shp in flatten_shape(wgs84_geometry.shp):
|
|
231
231
|
for result in tile_index.query(shp.bounds):
|
|
@@ -82,6 +82,8 @@ class EarthDaily(DataSource, TileStore):
|
|
|
82
82
|
timeout: timedelta = timedelta(seconds=10),
|
|
83
83
|
skip_items_missing_assets: bool = False,
|
|
84
84
|
cache_dir: UPath | None = None,
|
|
85
|
+
max_retries: int = 3,
|
|
86
|
+
retry_backoff_factor: float = 5.0,
|
|
85
87
|
service_name: Literal["platform"] = "platform",
|
|
86
88
|
):
|
|
87
89
|
"""Initialize a new EarthDaily instance.
|
|
@@ -99,6 +101,11 @@ class EarthDaily(DataSource, TileStore):
|
|
|
99
101
|
cache_dir: optional directory to cache items by name, including asset URLs.
|
|
100
102
|
If not set, there will be no cache and instead STAC requests will be
|
|
101
103
|
needed each time.
|
|
104
|
+
max_retries: the maximum number of retry attempts for HTTP requests that fail
|
|
105
|
+
due to transient errors (e.g., 429, 500, 502, 503, 504 status codes).
|
|
106
|
+
retry_backoff_factor: backoff factor for exponential retry delays between HTTP
|
|
107
|
+
request attempts. The delay between retries is calculated using the formula:
|
|
108
|
+
`(retry_backoff_factor * (2 ** (retry_count - 1)))` seconds.
|
|
102
109
|
service_name: the service name, only "platform" is supported, the other
|
|
103
110
|
services "legacy" and "internal" are not supported.
|
|
104
111
|
"""
|
|
@@ -110,6 +117,8 @@ class EarthDaily(DataSource, TileStore):
|
|
|
110
117
|
self.timeout = timeout
|
|
111
118
|
self.skip_items_missing_assets = skip_items_missing_assets
|
|
112
119
|
self.cache_dir = cache_dir
|
|
120
|
+
self.max_retries = max_retries
|
|
121
|
+
self.retry_backoff_factor = retry_backoff_factor
|
|
113
122
|
self.service_name = service_name
|
|
114
123
|
|
|
115
124
|
if cache_dir is not None:
|
|
@@ -139,6 +148,12 @@ class EarthDaily(DataSource, TileStore):
|
|
|
139
148
|
if "cache_dir" in d:
|
|
140
149
|
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
141
150
|
|
|
151
|
+
if "max_retries" in d:
|
|
152
|
+
kwargs["max_retries"] = d["max_retries"]
|
|
153
|
+
|
|
154
|
+
if "retry_backoff_factor" in d:
|
|
155
|
+
kwargs["retry_backoff_factor"] = d["retry_backoff_factor"]
|
|
156
|
+
|
|
142
157
|
simple_optionals = ["query", "sort_by", "sort_ascending"]
|
|
143
158
|
for k in simple_optionals:
|
|
144
159
|
if k in d:
|
|
@@ -159,7 +174,12 @@ class EarthDaily(DataSource, TileStore):
|
|
|
159
174
|
if self.eds_client is not None:
|
|
160
175
|
return self.eds_client, self.client, self.collection
|
|
161
176
|
|
|
162
|
-
self.eds_client = EDSClient(
|
|
177
|
+
self.eds_client = EDSClient(
|
|
178
|
+
EDSConfig(
|
|
179
|
+
max_retries=self.max_retries,
|
|
180
|
+
retry_backoff_factor=self.retry_backoff_factor,
|
|
181
|
+
)
|
|
182
|
+
)
|
|
163
183
|
|
|
164
184
|
if self.service_name == "platform":
|
|
165
185
|
self.client = self.eds_client.platform.pystac_client
|
|
@@ -26,7 +26,7 @@ from rslearn.data_sources.utils import match_candidate_items_to_window
|
|
|
26
26
|
from rslearn.log_utils import get_logger
|
|
27
27
|
from rslearn.tile_stores import TileStoreWithLayer
|
|
28
28
|
from rslearn.utils.fsspec import join_upath, open_atomic
|
|
29
|
-
from rslearn.utils.geometry import STGeometry, flatten_shape,
|
|
29
|
+
from rslearn.utils.geometry import STGeometry, flatten_shape, split_at_antimeridian
|
|
30
30
|
from rslearn.utils.raster_format import get_raster_projection_and_bounds
|
|
31
31
|
|
|
32
32
|
from .copernicus import get_harmonize_callback, get_sentinel2_tiles
|
|
@@ -358,7 +358,7 @@ class Sentinel2(DataSource):
|
|
|
358
358
|
shp = shapely.box(*bounds)
|
|
359
359
|
sensing_time = row["sensing_time"]
|
|
360
360
|
geometry = STGeometry(WGS84_PROJECTION, shp, (sensing_time, sensing_time))
|
|
361
|
-
geometry =
|
|
361
|
+
geometry = split_at_antimeridian(geometry)
|
|
362
362
|
|
|
363
363
|
cloud_cover = float(row["cloud_cover"])
|
|
364
364
|
|
|
@@ -511,7 +511,7 @@ class Sentinel2(DataSource):
|
|
|
511
511
|
|
|
512
512
|
time_range = (product_xml.start_time, product_xml.start_time)
|
|
513
513
|
geometry = STGeometry(WGS84_PROJECTION, product_xml.shp, time_range)
|
|
514
|
-
geometry =
|
|
514
|
+
geometry = split_at_antimeridian(geometry)
|
|
515
515
|
|
|
516
516
|
# Sometimes the geometry is not valid.
|
|
517
517
|
# We just apply make_valid on it to correct issues.
|
|
@@ -256,23 +256,7 @@ def match_candidate_items_to_window(
|
|
|
256
256
|
if item_geom.is_global():
|
|
257
257
|
item_geom = geometry
|
|
258
258
|
else:
|
|
259
|
-
|
|
260
|
-
# So we first clip the item to the window bounds in the item's
|
|
261
|
-
# projection, then re-project the item to the window's projection.
|
|
262
|
-
buffered_window_geom = STGeometry(
|
|
263
|
-
geometry.projection,
|
|
264
|
-
geometry.shp.buffer(1),
|
|
265
|
-
geometry.time_range,
|
|
266
|
-
)
|
|
267
|
-
window_shp_in_item_proj = buffered_window_geom.to_projection(
|
|
268
|
-
item_geom.projection
|
|
269
|
-
).shp
|
|
270
|
-
clipped_item_geom = STGeometry(
|
|
271
|
-
item_geom.projection,
|
|
272
|
-
item_geom.shp.intersection(window_shp_in_item_proj),
|
|
273
|
-
item_geom.time_range,
|
|
274
|
-
)
|
|
275
|
-
item_geom = clipped_item_geom.to_projection(geometry.projection)
|
|
259
|
+
item_geom = item_geom.to_projection(geometry.projection)
|
|
276
260
|
item_shps.append(item_geom.shp)
|
|
277
261
|
|
|
278
262
|
if query_config.space_mode == SpaceMode.CONTAINS:
|
|
@@ -13,6 +13,7 @@ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
|
|
|
13
13
|
from rasterio.crs import CRS
|
|
14
14
|
from upath import UPath
|
|
15
15
|
|
|
16
|
+
from rslearn.arg_parser import RslearnArgumentParser
|
|
16
17
|
from rslearn.config import LayerConfig
|
|
17
18
|
from rslearn.const import WGS84_EPSG
|
|
18
19
|
from rslearn.data_sources import Item, data_source_from_config
|
|
@@ -779,7 +780,7 @@ def dataset_build_index() -> None:
|
|
|
779
780
|
|
|
780
781
|
|
|
781
782
|
class RslearnLightningCLI(LightningCLI):
|
|
782
|
-
"""LightningCLI that links data.tasks to model.tasks."""
|
|
783
|
+
"""LightningCLI that links data.tasks to model.tasks and supports environment variables."""
|
|
783
784
|
|
|
784
785
|
def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
|
|
785
786
|
"""Link data.tasks to model.tasks.
|
|
@@ -787,6 +788,7 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
787
788
|
Args:
|
|
788
789
|
parser: the argument parser
|
|
789
790
|
"""
|
|
791
|
+
# Link data.tasks to model.tasks
|
|
790
792
|
parser.link_arguments(
|
|
791
793
|
"data.init_args.task", "model.init_args.task", apply_on="instantiate"
|
|
792
794
|
)
|
|
@@ -815,6 +817,12 @@ class RslearnLightningCLI(LightningCLI):
|
|
|
815
817
|
# sampler as needed.
|
|
816
818
|
c.trainer.use_distributed_sampler = False
|
|
817
819
|
|
|
820
|
+
# For predict, make sure that return_predictions is False.
|
|
821
|
+
# Otherwise all the predictions would be stored in memory which can lead to
|
|
822
|
+
# high memory consumption.
|
|
823
|
+
if subcommand == "predict":
|
|
824
|
+
c.return_predictions = False
|
|
825
|
+
|
|
818
826
|
|
|
819
827
|
def model_handler() -> None:
|
|
820
828
|
"""Handler for any rslearn model X commands."""
|
|
@@ -825,6 +833,7 @@ def model_handler() -> None:
|
|
|
825
833
|
subclass_mode_model=True,
|
|
826
834
|
subclass_mode_data=True,
|
|
827
835
|
save_config_kwargs={"overwrite": True},
|
|
836
|
+
parser_class=RslearnArgumentParser,
|
|
828
837
|
)
|
|
829
838
|
|
|
830
839
|
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
"""Trunk module for decoder."""
|
|
2
|
+
|
|
3
|
+
from abc import ABC, abstractmethod
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
from rslearn.log_utils import get_logger
|
|
9
|
+
from rslearn.models.task_embedding import BaseTaskEmbedding
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DecoderTrunkLayer(torch.nn.Module, ABC):
|
|
15
|
+
"""Trunk layer for decoder."""
|
|
16
|
+
|
|
17
|
+
def __init__(self) -> None:
|
|
18
|
+
"""Initialize the DecoderTrunkLayer module."""
|
|
19
|
+
super().__init__()
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def forward(
|
|
23
|
+
self, x: torch.Tensor, task_embedding: torch.Tensor | None = None
|
|
24
|
+
) -> dict[str, torch.Tensor]:
|
|
25
|
+
"""Forward pass.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
x: input tensor of shape (batch_size, seq_len, dim)
|
|
29
|
+
task_embedding: task embedding tensor of shape (batch_size, dim), or None
|
|
30
|
+
|
|
31
|
+
Returns:
|
|
32
|
+
dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
|
|
33
|
+
and optionally other keys.
|
|
34
|
+
"""
|
|
35
|
+
|
|
36
|
+
@abstractmethod
|
|
37
|
+
def apply_auxiliary_losses(
|
|
38
|
+
self, trunk_out: dict[str, Any], outs: dict[str, Any]
|
|
39
|
+
) -> None:
|
|
40
|
+
"""Apply auxiliary losses in-place.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
trunk_out: The output of the trunk.
|
|
44
|
+
outs: The output of the decoders, with key "loss_dict" containing the losses.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class DecoderTrunk(torch.nn.Module):
|
|
49
|
+
"""Trunk module for decoder, including arbitrary layers plus an optional task embedding."""
|
|
50
|
+
|
|
51
|
+
def __init__(
|
|
52
|
+
self,
|
|
53
|
+
task_embedding: BaseTaskEmbedding | None = None,
|
|
54
|
+
layers: list[DecoderTrunkLayer] | None = None,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""Initialize the DecoderTrunk module.
|
|
57
|
+
|
|
58
|
+
Args:
|
|
59
|
+
task_embedding: Task-specific embedding module, or None if not using task embedding.
|
|
60
|
+
layers: List of other shared layers. The first one should expect a
|
|
61
|
+
B x T x C tensor, and the last should output a B x T x C tensor.
|
|
62
|
+
All layers must output a dict with key "outputs" (output tensor of shape
|
|
63
|
+
(B, T, C)) and optionally other keys.
|
|
64
|
+
"""
|
|
65
|
+
super().__init__()
|
|
66
|
+
self.layers = torch.nn.ModuleList(layers or [])
|
|
67
|
+
self.task_embedding = task_embedding
|
|
68
|
+
|
|
69
|
+
# If we have multiple instances of the same layer class, output keys will get overwritten
|
|
70
|
+
if layers is not None:
|
|
71
|
+
types = [type(layer) for layer in layers]
|
|
72
|
+
if len(set(types)) != len(types):
|
|
73
|
+
logger.warning(
|
|
74
|
+
"Multiple instances of the same layer class found in trunk. "
|
|
75
|
+
"Only the keys from the last instance will be used"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
def register_tasks(self, task_names: list[str]) -> None:
|
|
79
|
+
"""Register tasks.
|
|
80
|
+
|
|
81
|
+
Args:
|
|
82
|
+
task_names: list of task names
|
|
83
|
+
"""
|
|
84
|
+
if self.task_embedding is not None:
|
|
85
|
+
self.task_embedding.register_tasks(task_names)
|
|
86
|
+
|
|
87
|
+
def forward(
|
|
88
|
+
self,
|
|
89
|
+
features: list[torch.tensor],
|
|
90
|
+
inputs: list[dict[str, Any]],
|
|
91
|
+
) -> dict[str, Any]:
|
|
92
|
+
"""Forward pass.
|
|
93
|
+
|
|
94
|
+
Args:
|
|
95
|
+
features: The encoder features, a 1-list of B x C x H x W features.
|
|
96
|
+
inputs: The original inputs to the encoder.
|
|
97
|
+
|
|
98
|
+
Returns:
|
|
99
|
+
dict with key "outputs" (output tensor of shape (batch_size, seq_len, dim))
|
|
100
|
+
and optionally other keys from the other layers.
|
|
101
|
+
"""
|
|
102
|
+
embeds = None
|
|
103
|
+
if self.task_embedding is not None:
|
|
104
|
+
embeds = self.task_embedding.compute_embeds(features, inputs)
|
|
105
|
+
features = self.task_embedding(features, inputs, embeds=embeds)
|
|
106
|
+
|
|
107
|
+
if not self.layers:
|
|
108
|
+
return {"outputs": features}
|
|
109
|
+
|
|
110
|
+
assert len(features) == 1, "DecoderTrunk only supports one feature map"
|
|
111
|
+
x = torch.einsum("bchw->bhwc", features[0])
|
|
112
|
+
x = torch.flatten(x, start_dim=1, end_dim=2) # B x T x C, T = HW
|
|
113
|
+
out = {}
|
|
114
|
+
for layer in self.layers:
|
|
115
|
+
layer_out = layer(x, task_embedding=embeds)
|
|
116
|
+
x = layer_out.pop("outputs") # unspecified shape
|
|
117
|
+
out.update(layer_out)
|
|
118
|
+
x = torch.einsum("btc->bct", x) # B x C x T
|
|
119
|
+
x = x.view(*features[0].shape) # B x C x H x W
|
|
120
|
+
|
|
121
|
+
out["outputs"] = [x]
|
|
122
|
+
return out
|
|
123
|
+
|
|
124
|
+
def apply_auxiliary_losses(
|
|
125
|
+
self, trunk_out: dict[str, Any], outs: dict[str, Any]
|
|
126
|
+
) -> None:
|
|
127
|
+
"""Apply auxiliary losses in-place.
|
|
128
|
+
|
|
129
|
+
Each layer handles its own auxiliary losses, assuming the loss key is `loss_dict`.
|
|
130
|
+
|
|
131
|
+
Args:
|
|
132
|
+
trunk_out: The output of the trunk.
|
|
133
|
+
outs: The output of the decoders, with key "loss_dict" containing the losses.
|
|
134
|
+
"""
|
|
135
|
+
for layer in self.layers:
|
|
136
|
+
layer.apply_auxiliary_losses(trunk_out, outs)
|
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
"""Callback to activate/deactivate adapter layers."""
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
from lightning.pytorch import LightningModule
|
|
6
|
+
from lightning.pytorch.callbacks import Callback
|
|
7
|
+
from lightning.pytorch.trainer import Trainer
|
|
8
|
+
|
|
9
|
+
from rslearn.log_utils import get_logger
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class ActivateLayers(Callback):
|
|
15
|
+
"""Activates adapter layers on a given epoch.
|
|
16
|
+
|
|
17
|
+
By default, at every epoch, every adapter layer is deactivated.
|
|
18
|
+
To activate an adapter layer, add a selector with the name of the adapter layer
|
|
19
|
+
and the epoch at which to activate it. Once an adapter layer is activated, it
|
|
20
|
+
remains active until the end of training.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, selectors: list[dict[str, Any]]) -> None:
|
|
24
|
+
"""Initialize the callback.
|
|
25
|
+
|
|
26
|
+
Args:
|
|
27
|
+
selectors: List of selectors to activate.
|
|
28
|
+
Each selector is a dictionary with the following keys:
|
|
29
|
+
- "name": Substring selector of modules to activate (str).
|
|
30
|
+
- "at_epoch": The epoch at which to activate (int).
|
|
31
|
+
"""
|
|
32
|
+
self.selectors = selectors
|
|
33
|
+
|
|
34
|
+
def on_train_epoch_start(
|
|
35
|
+
self,
|
|
36
|
+
trainer: Trainer,
|
|
37
|
+
pl_module: LightningModule,
|
|
38
|
+
) -> None:
|
|
39
|
+
"""Activate adapter layers on a given epoch.
|
|
40
|
+
|
|
41
|
+
Adapter layers are activated/deactivated by setting the `active` attribute.
|
|
42
|
+
|
|
43
|
+
Args:
|
|
44
|
+
trainer: The trainer object.
|
|
45
|
+
pl_module: The LightningModule object.
|
|
46
|
+
"""
|
|
47
|
+
status = {}
|
|
48
|
+
for name, module in pl_module.named_modules():
|
|
49
|
+
for selector in self.selectors:
|
|
50
|
+
if selector["name"] in name:
|
|
51
|
+
module.active = trainer.current_epoch >= selector["at_epoch"]
|
|
52
|
+
status[selector["name"]] = "active" if module.active else "inactive"
|
|
53
|
+
logger.info(f"Updated adapter status: {status}")
|