rslearn 0.0.8__py3-none-any.whl → 0.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/data_sources/local_files.py +20 -3
- rslearn/data_sources/planetary_computer.py +79 -14
- rslearn/dataset/manage.py +2 -2
- rslearn/dataset/materialize.py +21 -2
- rslearn/dataset/remap.py +29 -4
- rslearn/models/dinov3.py +12 -11
- rslearn/models/galileo/galileo.py +58 -12
- rslearn/models/galileo/single_file_galileo.py +7 -1
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +203 -0
- rslearn/models/olmoearth_pretrain/norm.py +84 -0
- rslearn/models/pooling_decoder.py +43 -0
- rslearn/models/presto/presto.py +11 -0
- rslearn/models/prithvi.py +11 -0
- rslearn/models/registry.py +19 -2
- rslearn/tile_stores/default.py +3 -1
- rslearn/train/transforms/transform.py +23 -6
- rslearn/utils/raster_format.py +37 -4
- rslearn/utils/vector_format.py +35 -4
- {rslearn-0.0.8.dist-info → rslearn-0.0.11.dist-info}/METADATA +3 -2
- {rslearn-0.0.8.dist-info → rslearn-0.0.11.dist-info}/RECORD +25 -22
- {rslearn-0.0.8.dist-info → rslearn-0.0.11.dist-info}/WHEEL +0 -0
- {rslearn-0.0.8.dist-info → rslearn-0.0.11.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.8.dist-info → rslearn-0.0.11.dist-info}/licenses/LICENSE +0 -0
- {rslearn-0.0.8.dist-info → rslearn-0.0.11.dist-info}/top_level.txt +0 -0
|
@@ -2,12 +2,12 @@
|
|
|
2
2
|
|
|
3
3
|
import functools
|
|
4
4
|
import json
|
|
5
|
+
from collections.abc import Callable
|
|
5
6
|
from typing import Any, Generic, TypeVar
|
|
6
7
|
|
|
7
8
|
import fiona
|
|
8
9
|
import shapely
|
|
9
10
|
import shapely.geometry
|
|
10
|
-
from class_registry import ClassRegistry
|
|
11
11
|
from rasterio.crs import CRS
|
|
12
12
|
from upath import UPath
|
|
13
13
|
|
|
@@ -23,7 +23,24 @@ from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
|
|
|
23
23
|
from .data_source import DataSource, Item, QueryConfig
|
|
24
24
|
|
|
25
25
|
logger = get_logger("__name__")
|
|
26
|
-
|
|
26
|
+
_ImporterT = TypeVar("_ImporterT", bound="Importer")
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class _ImporterRegistry(dict[str, type["Importer"]]):
|
|
30
|
+
"""Registry for Importer classes."""
|
|
31
|
+
|
|
32
|
+
def register(self, name: str) -> Callable[[type[_ImporterT]], type[_ImporterT]]:
|
|
33
|
+
"""Decorator to register an importer class."""
|
|
34
|
+
|
|
35
|
+
def decorator(cls: type[_ImporterT]) -> type[_ImporterT]:
|
|
36
|
+
self[name] = cls
|
|
37
|
+
return cls
|
|
38
|
+
|
|
39
|
+
return decorator
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
Importers = _ImporterRegistry()
|
|
43
|
+
|
|
27
44
|
|
|
28
45
|
ItemType = TypeVar("ItemType", bound=Item)
|
|
29
46
|
LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
|
|
@@ -425,7 +442,7 @@ class LocalFiles(DataSource):
|
|
|
425
442
|
"""
|
|
426
443
|
self.config = config
|
|
427
444
|
|
|
428
|
-
self.importer = Importers[config.layer_type.value]
|
|
445
|
+
self.importer = Importers[config.layer_type.value]()
|
|
429
446
|
self.src_dir = src_dir
|
|
430
447
|
|
|
431
448
|
@staticmethod
|
|
@@ -83,6 +83,10 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
83
83
|
|
|
84
84
|
STAC_ENDPOINT = "https://planetarycomputer.microsoft.com/api/stac/v1"
|
|
85
85
|
|
|
86
|
+
# Default threshold for recreating the STAC client to prevent memory leaks
|
|
87
|
+
# from the pystac Catalog's resolved objects cache growing unbounded
|
|
88
|
+
DEFAULT_MAX_ITEMS_PER_CLIENT = 1000
|
|
89
|
+
|
|
86
90
|
def __init__(
|
|
87
91
|
self,
|
|
88
92
|
collection_name: str,
|
|
@@ -93,6 +97,7 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
93
97
|
timeout: timedelta = timedelta(seconds=10),
|
|
94
98
|
skip_items_missing_assets: bool = False,
|
|
95
99
|
cache_dir: UPath | None = None,
|
|
100
|
+
max_items_per_client: int | None = None,
|
|
96
101
|
):
|
|
97
102
|
"""Initialize a new PlanetaryComputer instance.
|
|
98
103
|
|
|
@@ -109,6 +114,9 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
109
114
|
cache_dir: optional directory to cache items by name, including asset URLs.
|
|
110
115
|
If not set, there will be no cache and instead STAC requests will be
|
|
111
116
|
needed each time.
|
|
117
|
+
max_items_per_client: number of STAC items to process before recreating
|
|
118
|
+
the client to prevent memory leaks from the resolved objects cache.
|
|
119
|
+
Defaults to DEFAULT_MAX_ITEMS_PER_CLIENT.
|
|
112
120
|
"""
|
|
113
121
|
self.collection_name = collection_name
|
|
114
122
|
self.asset_bands = asset_bands
|
|
@@ -118,12 +126,15 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
118
126
|
self.timeout = timeout
|
|
119
127
|
self.skip_items_missing_assets = skip_items_missing_assets
|
|
120
128
|
self.cache_dir = cache_dir
|
|
129
|
+
self.max_items_per_client = (
|
|
130
|
+
max_items_per_client or self.DEFAULT_MAX_ITEMS_PER_CLIENT
|
|
131
|
+
)
|
|
121
132
|
|
|
122
133
|
if self.cache_dir is not None:
|
|
123
134
|
self.cache_dir.mkdir(parents=True, exist_ok=True)
|
|
124
135
|
|
|
125
136
|
self.client: pystac_client.Client | None = None
|
|
126
|
-
self.
|
|
137
|
+
self._client_item_count = 0
|
|
127
138
|
|
|
128
139
|
@staticmethod
|
|
129
140
|
def from_config(config: RasterLayerConfig, ds_path: UPath) -> "PlanetaryComputer":
|
|
@@ -142,7 +153,12 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
142
153
|
if "cache_dir" in d:
|
|
143
154
|
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
144
155
|
|
|
145
|
-
simple_optionals = [
|
|
156
|
+
simple_optionals = [
|
|
157
|
+
"query",
|
|
158
|
+
"sort_by",
|
|
159
|
+
"sort_ascending",
|
|
160
|
+
"max_items_per_client",
|
|
161
|
+
]
|
|
146
162
|
for k in simple_optionals:
|
|
147
163
|
if k in d:
|
|
148
164
|
kwargs[k] = d[k]
|
|
@@ -151,20 +167,40 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
151
167
|
|
|
152
168
|
def _load_client(
|
|
153
169
|
self,
|
|
154
|
-
) ->
|
|
170
|
+
) -> pystac_client.Client:
|
|
155
171
|
"""Lazily load pystac client.
|
|
156
172
|
|
|
157
173
|
We don't load it when creating the data source because it takes time and caller
|
|
158
174
|
may not be calling get_items. Additionally, loading it during the get_items
|
|
159
175
|
call enables leveraging the retry loop functionality in
|
|
160
176
|
prepare_dataset_windows.
|
|
161
|
-
"""
|
|
162
|
-
if self.client is not None:
|
|
163
|
-
return self.client, self.collection
|
|
164
177
|
|
|
178
|
+
Note: We periodically recreate the client to prevent memory leaks from the
|
|
179
|
+
pystac Catalog's resolved objects cache, which grows unbounded as STAC items
|
|
180
|
+
are deserialized and cached. The cache cannot be cleared or disabled.
|
|
181
|
+
"""
|
|
182
|
+
if self.client is None:
|
|
183
|
+
logger.info("Creating initial STAC client")
|
|
184
|
+
self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
|
|
185
|
+
return self.client
|
|
186
|
+
|
|
187
|
+
if self._client_item_count < self.max_items_per_client:
|
|
188
|
+
return self.client
|
|
189
|
+
|
|
190
|
+
# Recreate client to clear the resolved objects cache
|
|
191
|
+
current_client = self.client
|
|
192
|
+
logger.debug(
|
|
193
|
+
"Recreating STAC client after processing %d items (threshold: %d)",
|
|
194
|
+
self._client_item_count,
|
|
195
|
+
self.max_items_per_client,
|
|
196
|
+
)
|
|
197
|
+
client_root = current_client.get_root()
|
|
198
|
+
client_root.clear_links()
|
|
199
|
+
client_root.clear_items()
|
|
200
|
+
client_root.clear_children()
|
|
201
|
+
self._client_item_count = 0
|
|
165
202
|
self.client = pystac_client.Client.open(self.STAC_ENDPOINT)
|
|
166
|
-
|
|
167
|
-
return self.client, self.collection
|
|
203
|
+
return self.client
|
|
168
204
|
|
|
169
205
|
def _stac_item_to_item(self, stac_item: pystac.Item) -> PlanetaryComputerItem:
|
|
170
206
|
shp = shapely.geometry.shape(stac_item.geometry)
|
|
@@ -210,10 +246,26 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
210
246
|
|
|
211
247
|
# No cache or not in cache, so we need to make the STAC request.
|
|
212
248
|
logger.debug("Getting STAC item {name}")
|
|
213
|
-
|
|
214
|
-
|
|
249
|
+
client = self._load_client()
|
|
250
|
+
|
|
251
|
+
search_result = client.search(ids=[name], collections=[self.collection_name])
|
|
252
|
+
stac_items = list(search_result.items())
|
|
253
|
+
|
|
254
|
+
if not stac_items:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Item {name} not found in collection {self.collection_name}"
|
|
257
|
+
)
|
|
258
|
+
if len(stac_items) > 1:
|
|
259
|
+
raise ValueError(
|
|
260
|
+
f"Multiple items found for ID {name} in collection {self.collection_name}"
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
stac_item = stac_items[0]
|
|
215
264
|
item = self._stac_item_to_item(stac_item)
|
|
216
265
|
|
|
266
|
+
# Track items processed for client recreation threshold (after deserialization)
|
|
267
|
+
self._client_item_count += 1
|
|
268
|
+
|
|
217
269
|
# Finally we cache it if cache_dir is set.
|
|
218
270
|
if cache_fname is not None:
|
|
219
271
|
with cache_fname.open("w") as f:
|
|
@@ -233,7 +285,7 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
233
285
|
Returns:
|
|
234
286
|
List of groups of items that should be retrieved for each geometry.
|
|
235
287
|
"""
|
|
236
|
-
client
|
|
288
|
+
client = self._load_client()
|
|
237
289
|
|
|
238
290
|
groups = []
|
|
239
291
|
for geometry in geometries:
|
|
@@ -247,7 +299,9 @@ class PlanetaryComputer(DataSource, TileStore):
|
|
|
247
299
|
datetime=wgs84_geometry.time_range,
|
|
248
300
|
query=self.query,
|
|
249
301
|
)
|
|
250
|
-
stac_items = [item for item in result.
|
|
302
|
+
stac_items = [item for item in result.items()]
|
|
303
|
+
# Track items processed for client recreation threshold (after deserialization)
|
|
304
|
+
self._client_item_count += len(stac_items)
|
|
251
305
|
logger.debug("STAC search yielded %d items", len(stac_items))
|
|
252
306
|
|
|
253
307
|
if self.skip_items_missing_assets:
|
|
@@ -580,7 +634,13 @@ class Sentinel2(PlanetaryComputer):
|
|
|
580
634
|
if "cache_dir" in d:
|
|
581
635
|
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
582
636
|
|
|
583
|
-
simple_optionals = [
|
|
637
|
+
simple_optionals = [
|
|
638
|
+
"harmonize",
|
|
639
|
+
"query",
|
|
640
|
+
"sort_by",
|
|
641
|
+
"sort_ascending",
|
|
642
|
+
"max_items_per_client",
|
|
643
|
+
]
|
|
584
644
|
for k in simple_optionals:
|
|
585
645
|
if k in d:
|
|
586
646
|
kwargs[k] = d[k]
|
|
@@ -756,7 +816,12 @@ class Sentinel1(PlanetaryComputer):
|
|
|
756
816
|
if "cache_dir" in d:
|
|
757
817
|
kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
|
|
758
818
|
|
|
759
|
-
simple_optionals = [
|
|
819
|
+
simple_optionals = [
|
|
820
|
+
"query",
|
|
821
|
+
"sort_by",
|
|
822
|
+
"sort_ascending",
|
|
823
|
+
"max_items_per_client",
|
|
824
|
+
]
|
|
760
825
|
for k in simple_optionals:
|
|
761
826
|
if k in d:
|
|
762
827
|
kwargs[k] = d[k]
|
rslearn/dataset/manage.py
CHANGED
|
@@ -396,9 +396,9 @@ def materialize_window(
|
|
|
396
396
|
)
|
|
397
397
|
|
|
398
398
|
if dataset.materializer_name:
|
|
399
|
-
materializer = Materializers[dataset.materializer_name]
|
|
399
|
+
materializer = Materializers[dataset.materializer_name]()
|
|
400
400
|
else:
|
|
401
|
-
materializer = Materializers[layer_cfg.layer_type.value]
|
|
401
|
+
materializer = Materializers[layer_cfg.layer_type.value]()
|
|
402
402
|
|
|
403
403
|
retry(
|
|
404
404
|
fn=lambda: materializer.materialize(
|
rslearn/dataset/materialize.py
CHANGED
|
@@ -1,10 +1,10 @@
|
|
|
1
1
|
"""Classes to implement dataset materialization."""
|
|
2
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
3
4
|
from typing import Any, Generic, TypeVar
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
|
-
from class_registry import ClassRegistry
|
|
8
8
|
from rasterio.enums import Resampling
|
|
9
9
|
|
|
10
10
|
from rslearn.config import (
|
|
@@ -25,7 +25,26 @@ from rslearn.utils.vector_format import load_vector_format
|
|
|
25
25
|
from .remap import Remapper, load_remapper
|
|
26
26
|
from .window import Window
|
|
27
27
|
|
|
28
|
-
|
|
28
|
+
_MaterializerT = TypeVar("_MaterializerT", bound="Materializer")
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class _MaterializerRegistry(dict[str, type["Materializer"]]):
|
|
32
|
+
"""Registry for Materializer classes."""
|
|
33
|
+
|
|
34
|
+
def register(
|
|
35
|
+
self, name: str
|
|
36
|
+
) -> Callable[[type[_MaterializerT]], type[_MaterializerT]]:
|
|
37
|
+
"""Decorator to register a materializer class."""
|
|
38
|
+
|
|
39
|
+
def decorator(cls: type[_MaterializerT]) -> type[_MaterializerT]:
|
|
40
|
+
self[name] = cls
|
|
41
|
+
return cls
|
|
42
|
+
|
|
43
|
+
return decorator
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
Materializers = _MaterializerRegistry()
|
|
47
|
+
|
|
29
48
|
|
|
30
49
|
LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
|
|
31
50
|
|
rslearn/dataset/remap.py
CHANGED
|
@@ -1,18 +1,42 @@
|
|
|
1
1
|
"""Classes to remap raster values."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, TypeVar
|
|
4
5
|
|
|
5
6
|
import numpy as np
|
|
6
7
|
import numpy.typing as npt
|
|
7
|
-
from class_registry import ClassRegistry
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
_RemapperT = TypeVar("_RemapperT", bound="Remapper")
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class _RemapperRegistry(dict[str, type["Remapper"]]):
|
|
13
|
+
"""Registry for Remapper classes."""
|
|
14
|
+
|
|
15
|
+
def register(self, name: str) -> Callable[[type[_RemapperT]], type[_RemapperT]]:
|
|
16
|
+
"""Decorator to register a remapper class."""
|
|
17
|
+
|
|
18
|
+
def decorator(cls: type[_RemapperT]) -> type[_RemapperT]:
|
|
19
|
+
self[name] = cls
|
|
20
|
+
return cls
|
|
21
|
+
|
|
22
|
+
return decorator
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
Remappers = _RemapperRegistry()
|
|
10
26
|
"""Registry of Remapper implementations."""
|
|
11
27
|
|
|
12
28
|
|
|
13
29
|
class Remapper:
|
|
14
30
|
"""An abstract class that remaps pixel values based on layer configuration."""
|
|
15
31
|
|
|
32
|
+
def __init__(self, config: dict[str, Any]) -> None:
|
|
33
|
+
"""Initialize a Remapper.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
config: the config dict for this remapper.
|
|
37
|
+
"""
|
|
38
|
+
pass
|
|
39
|
+
|
|
16
40
|
def __call__(
|
|
17
41
|
self, array: npt.NDArray[Any], dtype: npt.DTypeLike
|
|
18
42
|
) -> npt.NDArray[Any]:
|
|
@@ -67,4 +91,5 @@ class LinearRemapper(Remapper):
|
|
|
67
91
|
|
|
68
92
|
def load_remapper(config: dict[str, Any]) -> Remapper:
|
|
69
93
|
"""Load a remapper from a configuration dictionary."""
|
|
70
|
-
|
|
94
|
+
cls = Remappers[config["name"]]
|
|
95
|
+
return cls(config)
|
rslearn/models/dinov3.py
CHANGED
|
@@ -7,8 +7,8 @@ from typing import Any
|
|
|
7
7
|
import torch
|
|
8
8
|
import torchvision
|
|
9
9
|
from einops import rearrange
|
|
10
|
-
from torchvision.transforms import v2
|
|
11
10
|
|
|
11
|
+
from rslearn.train.transforms.normalize import Normalize
|
|
12
12
|
from rslearn.train.transforms.transform import Transform
|
|
13
13
|
|
|
14
14
|
|
|
@@ -139,15 +139,17 @@ class DinoV3Normalize(Transform):
|
|
|
139
139
|
super().__init__()
|
|
140
140
|
self.satellite = satellite
|
|
141
141
|
if satellite:
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
std=(0.213, 0.156, 0.143),
|
|
145
|
-
)
|
|
142
|
+
mean = [0.430, 0.411, 0.296]
|
|
143
|
+
std = [0.213, 0.156, 0.143]
|
|
146
144
|
else:
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
145
|
+
mean = [0.485, 0.456, 0.406]
|
|
146
|
+
std = [0.229, 0.224, 0.225]
|
|
147
|
+
|
|
148
|
+
self.normalize = Normalize(
|
|
149
|
+
[value * 255 for value in mean],
|
|
150
|
+
[value * 255 for value in std],
|
|
151
|
+
num_bands=3,
|
|
152
|
+
)
|
|
151
153
|
|
|
152
154
|
def forward(
|
|
153
155
|
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
@@ -161,5 +163,4 @@ class DinoV3Normalize(Transform):
|
|
|
161
163
|
Returns:
|
|
162
164
|
normalized (input_dicts, target_dicts) tuple
|
|
163
165
|
"""
|
|
164
|
-
|
|
165
|
-
return input_dict, target_dict
|
|
166
|
+
return self.normalize(input_dict, target_dict)
|
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
import math
|
|
4
4
|
import tempfile
|
|
5
|
+
from contextlib import nullcontext
|
|
5
6
|
from enum import StrEnum
|
|
6
7
|
from typing import Any, cast
|
|
7
8
|
|
|
@@ -63,6 +64,11 @@ pretrained_weights: dict[GalileoSize, str] = {
|
|
|
63
64
|
|
|
64
65
|
DEFAULT_NORMALIZER = Normalizer()
|
|
65
66
|
|
|
67
|
+
AUTOCAST_DTYPE_MAP = {
|
|
68
|
+
"bfloat16": torch.bfloat16,
|
|
69
|
+
"float32": torch.float32,
|
|
70
|
+
}
|
|
71
|
+
|
|
66
72
|
|
|
67
73
|
class GalileoModel(nn.Module):
|
|
68
74
|
"""Galileo backbones."""
|
|
@@ -85,6 +91,7 @@ class GalileoModel(nn.Module):
|
|
|
85
91
|
size: GalileoSize,
|
|
86
92
|
patch_size: int = 4,
|
|
87
93
|
pretrained_path: str | UPath | None = None,
|
|
94
|
+
autocast_dtype: str | None = "bfloat16",
|
|
88
95
|
) -> None:
|
|
89
96
|
"""Initialize the Galileo model.
|
|
90
97
|
|
|
@@ -93,6 +100,7 @@ class GalileoModel(nn.Module):
|
|
|
93
100
|
patch_size: The patch size to use.
|
|
94
101
|
pretrained_path: the local path to the pretrained weights. Otherwise it is
|
|
95
102
|
downloaded and cached in temp directory.
|
|
103
|
+
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
96
104
|
"""
|
|
97
105
|
super().__init__()
|
|
98
106
|
if pretrained_path is None:
|
|
@@ -128,8 +136,14 @@ class GalileoModel(nn.Module):
|
|
|
128
136
|
idx for idx, key in enumerate(SPACE_TIME_BANDS_GROUPS_IDX) if "S1" in key
|
|
129
137
|
]
|
|
130
138
|
|
|
139
|
+
self.size = size
|
|
131
140
|
self.patch_size = patch_size
|
|
132
141
|
|
|
142
|
+
if autocast_dtype is not None:
|
|
143
|
+
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
144
|
+
else:
|
|
145
|
+
self.autocast_dtype = None
|
|
146
|
+
|
|
133
147
|
@staticmethod
|
|
134
148
|
def to_cartesian(
|
|
135
149
|
lat: float | np.ndarray | torch.Tensor, lon: float | np.ndarray | torch.Tensor
|
|
@@ -484,18 +498,31 @@ class GalileoModel(nn.Module):
|
|
|
484
498
|
patch_size = h
|
|
485
499
|
else:
|
|
486
500
|
patch_size = self.patch_size
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
495
|
-
|
|
496
|
-
|
|
497
|
-
|
|
498
|
-
|
|
501
|
+
|
|
502
|
+
# Decide context based on self.autocast_dtype.
|
|
503
|
+
device = galileo_input.s_t_x.device
|
|
504
|
+
if self.autocast_dtype is None:
|
|
505
|
+
context = nullcontext()
|
|
506
|
+
else:
|
|
507
|
+
assert device is not None
|
|
508
|
+
context = torch.amp.autocast(
|
|
509
|
+
device_type=device.type, dtype=self.autocast_dtype
|
|
510
|
+
)
|
|
511
|
+
|
|
512
|
+
with context:
|
|
513
|
+
outputs = self.model(
|
|
514
|
+
s_t_x=galileo_input.s_t_x,
|
|
515
|
+
s_t_m=galileo_input.s_t_m,
|
|
516
|
+
sp_x=galileo_input.sp_x,
|
|
517
|
+
sp_m=galileo_input.sp_m,
|
|
518
|
+
t_x=galileo_input.t_x,
|
|
519
|
+
t_m=galileo_input.t_m,
|
|
520
|
+
st_x=galileo_input.st_x,
|
|
521
|
+
st_m=galileo_input.st_m,
|
|
522
|
+
months=galileo_input.months,
|
|
523
|
+
patch_size=patch_size,
|
|
524
|
+
)
|
|
525
|
+
|
|
499
526
|
if h == patch_size:
|
|
500
527
|
# only one spatial patch, so we can just take an average
|
|
501
528
|
# of all the tokens to output b c_g 1 1
|
|
@@ -515,3 +542,22 @@ class GalileoModel(nn.Module):
|
|
|
515
542
|
"b h w c_g d -> b c_g d h w",
|
|
516
543
|
).mean(dim=1)
|
|
517
544
|
]
|
|
545
|
+
|
|
546
|
+
def get_backbone_channels(self) -> list:
|
|
547
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
548
|
+
|
|
549
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
550
|
+
to the feature maps that the backbone returns.
|
|
551
|
+
|
|
552
|
+
Returns:
|
|
553
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
554
|
+
"""
|
|
555
|
+
if self.size == GalileoSize.BASE:
|
|
556
|
+
depth = 768
|
|
557
|
+
elif self.model_size == GalileoSize.TINY:
|
|
558
|
+
depth = 192
|
|
559
|
+
elif self.model_size == GalileoSize.NANO:
|
|
560
|
+
depth = 128
|
|
561
|
+
else:
|
|
562
|
+
raise ValueError(f"Invalid model size: {self.size}")
|
|
563
|
+
return [(self.patch_size, depth)]
|
|
@@ -1469,7 +1469,13 @@ class Encoder(GalileoBase):
|
|
|
1469
1469
|
# we take the inverse of the mask because a value
|
|
1470
1470
|
# of True indicates the value *should* take part in
|
|
1471
1471
|
# attention
|
|
1472
|
-
|
|
1472
|
+
temp_mask = ~new_m.bool()
|
|
1473
|
+
if temp_mask.all():
|
|
1474
|
+
# if all the tokens are used in attention we can pass a None mask
|
|
1475
|
+
# to the attention block
|
|
1476
|
+
temp_mask = None
|
|
1477
|
+
|
|
1478
|
+
x = blk(x=x, y=None, attn_mask=temp_mask)
|
|
1473
1479
|
|
|
1474
1480
|
if exit_ids_seq is not None:
|
|
1475
1481
|
assert exited_tokens is not None
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
"""OlmoEarth model architecture."""
|
|
@@ -0,0 +1,203 @@
|
|
|
1
|
+
"""OlmoEarth model wrapper for fine-tuning in rslearn."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from contextlib import nullcontext
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
from einops import rearrange
|
|
9
|
+
from olmo_core.config import Config
|
|
10
|
+
from olmo_core.distributed.checkpoint import load_model_and_optim_state
|
|
11
|
+
from olmoearth_pretrain.data.constants import Modality
|
|
12
|
+
from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
|
|
13
|
+
from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
|
|
14
|
+
from upath import UPath
|
|
15
|
+
|
|
16
|
+
from rslearn.log_utils import get_logger
|
|
17
|
+
|
|
18
|
+
logger = get_logger(__name__)
|
|
19
|
+
|
|
20
|
+
MODALITY_NAMES = [
|
|
21
|
+
"sentinel2_l2a",
|
|
22
|
+
"sentinel1",
|
|
23
|
+
"worldcover",
|
|
24
|
+
"openstreetmap_raster",
|
|
25
|
+
"landsat",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
AUTOCAST_DTYPE_MAP = {
|
|
29
|
+
"bfloat16": torch.bfloat16,
|
|
30
|
+
"float16": torch.float16,
|
|
31
|
+
"float32": torch.float32,
|
|
32
|
+
}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
class OlmoEarth(torch.nn.Module):
|
|
36
|
+
"""A wrapper to support the OlmoEarth model."""
|
|
37
|
+
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
# TODO: we should accept model ID instead of checkpoint_path once we are closer
|
|
41
|
+
# to being ready for release.
|
|
42
|
+
checkpoint_path: str,
|
|
43
|
+
selector: list[str | int] = [],
|
|
44
|
+
forward_kwargs: dict[str, Any] = {},
|
|
45
|
+
random_initialization: bool = False,
|
|
46
|
+
embedding_size: int | None = None,
|
|
47
|
+
patch_size: int | None = None,
|
|
48
|
+
autocast_dtype: str | None = "bfloat16",
|
|
49
|
+
):
|
|
50
|
+
"""Create a new OlmoEarth model.
|
|
51
|
+
|
|
52
|
+
Args:
|
|
53
|
+
checkpoint_path: the checkpoint directory to load. It should contain
|
|
54
|
+
config.json file as well as model_and_optim folder.
|
|
55
|
+
selector: an optional sequence of attribute names or list indices to select
|
|
56
|
+
the sub-module that should be applied on the input images.
|
|
57
|
+
forward_kwargs: additional arguments to pass to forward pass besides the
|
|
58
|
+
MaskedOlmoEarthSample.
|
|
59
|
+
random_initialization: whether to skip loading the checkpoint so the
|
|
60
|
+
weights are randomly initialized. In this case, the checkpoint is only
|
|
61
|
+
used to define the model architecture.
|
|
62
|
+
embedding_size: optional embedding size to report via
|
|
63
|
+
get_backbone_channels.
|
|
64
|
+
patch_size: optional patch size to report via get_backbone_channels.
|
|
65
|
+
autocast_dtype: which dtype to use for autocasting, or set None to disable.
|
|
66
|
+
"""
|
|
67
|
+
super().__init__()
|
|
68
|
+
_checkpoint_path = UPath(checkpoint_path)
|
|
69
|
+
self.forward_kwargs = forward_kwargs
|
|
70
|
+
self.embedding_size = embedding_size
|
|
71
|
+
self.patch_size = patch_size
|
|
72
|
+
|
|
73
|
+
if autocast_dtype is not None:
|
|
74
|
+
self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
|
|
75
|
+
else:
|
|
76
|
+
self.autocast_dtype = None
|
|
77
|
+
|
|
78
|
+
# Load the model config and initialize it.
|
|
79
|
+
# We avoid loading the train module here because it depends on running within
|
|
80
|
+
# olmo_core.
|
|
81
|
+
with (_checkpoint_path / "config.json").open() as f:
|
|
82
|
+
config_dict = json.load(f)
|
|
83
|
+
model_config = Config.from_dict(config_dict["model"])
|
|
84
|
+
|
|
85
|
+
model = model_config.build()
|
|
86
|
+
|
|
87
|
+
# Load the checkpoint.
|
|
88
|
+
if not random_initialization:
|
|
89
|
+
train_module_dir = _checkpoint_path / "model_and_optim"
|
|
90
|
+
if train_module_dir.exists():
|
|
91
|
+
load_model_and_optim_state(str(train_module_dir), model)
|
|
92
|
+
logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
|
|
93
|
+
else:
|
|
94
|
+
logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
|
|
95
|
+
else:
|
|
96
|
+
logger.info("skipping loading OlmoEarth encoder")
|
|
97
|
+
|
|
98
|
+
# Select just the portion of the model that we actually want to use.
|
|
99
|
+
for part in selector:
|
|
100
|
+
if isinstance(part, str):
|
|
101
|
+
model = getattr(model, part)
|
|
102
|
+
else:
|
|
103
|
+
model = model[part]
|
|
104
|
+
self.model = model
|
|
105
|
+
|
|
106
|
+
def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
|
|
107
|
+
"""Compute feature maps from the OlmoEarth backbone.
|
|
108
|
+
|
|
109
|
+
Inputs:
|
|
110
|
+
inputs: input dicts. It should include keys corresponding to the modalities
|
|
111
|
+
that should be passed to the OlmoEarth model.
|
|
112
|
+
"""
|
|
113
|
+
kwargs = {}
|
|
114
|
+
present_modalities = []
|
|
115
|
+
device = None
|
|
116
|
+
# Handle the case where some modalities are multitemporal and some are not.
|
|
117
|
+
# We assume all multitemporal modalities have the same number of timesteps.
|
|
118
|
+
max_timesteps = 1
|
|
119
|
+
for modality in MODALITY_NAMES:
|
|
120
|
+
if modality not in inputs[0]:
|
|
121
|
+
continue
|
|
122
|
+
present_modalities.append(modality)
|
|
123
|
+
cur = torch.stack([inp[modality] for inp in inputs], dim=0)
|
|
124
|
+
device = cur.device
|
|
125
|
+
# Check if it's single or multitemporal, and reshape accordingly
|
|
126
|
+
num_bands = Modality.get(modality).num_bands
|
|
127
|
+
num_timesteps = cur.shape[1] // num_bands
|
|
128
|
+
max_timesteps = max(max_timesteps, num_timesteps)
|
|
129
|
+
cur = rearrange(cur, "b (t c) h w -> b h w t c", t=num_timesteps)
|
|
130
|
+
kwargs[modality] = cur
|
|
131
|
+
# Create mask array which is BHWTS (without channels but with band sets).
|
|
132
|
+
num_band_sets = len(Modality.get(modality).band_sets)
|
|
133
|
+
mask_shape = cur.shape[0:4] + (num_band_sets,)
|
|
134
|
+
mask = (
|
|
135
|
+
torch.ones(mask_shape, dtype=torch.int32, device=device)
|
|
136
|
+
* MaskValue.ONLINE_ENCODER.value
|
|
137
|
+
)
|
|
138
|
+
kwargs[f"{modality}_mask"] = mask
|
|
139
|
+
|
|
140
|
+
# Timestamps is required.
|
|
141
|
+
# Note that only months (0 to 11) are used in OlmoEarth position encoding.
|
|
142
|
+
# For now, we assign same timestamps to all inputs, but later we should handle varying timestamps per input.
|
|
143
|
+
timestamps = torch.zeros(
|
|
144
|
+
(len(inputs), max_timesteps, 3), dtype=torch.int32, device=device
|
|
145
|
+
)
|
|
146
|
+
timestamps[:, :, 0] = 1 # day
|
|
147
|
+
timestamps[:, :, 1] = torch.arange(max_timesteps, device=device)[
|
|
148
|
+
None, :
|
|
149
|
+
] # month
|
|
150
|
+
timestamps[:, :, 2] = 2024 # year
|
|
151
|
+
kwargs["timestamps"] = timestamps
|
|
152
|
+
|
|
153
|
+
sample = MaskedOlmoEarthSample(**kwargs)
|
|
154
|
+
|
|
155
|
+
# Decide context based on self.autocast_dtype.
|
|
156
|
+
if self.autocast_dtype is None:
|
|
157
|
+
context = nullcontext()
|
|
158
|
+
else:
|
|
159
|
+
assert device is not None
|
|
160
|
+
context = torch.amp.autocast(
|
|
161
|
+
device_type=device.type, dtype=self.autocast_dtype
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
with context:
|
|
165
|
+
# Currently we assume the provided model always returns a TokensAndMasks object.
|
|
166
|
+
tokens_and_masks: TokensAndMasks
|
|
167
|
+
if isinstance(self.model, Encoder):
|
|
168
|
+
# Encoder has a fast_pass argument to indicate mask is not needed.
|
|
169
|
+
tokens_and_masks = self.model(
|
|
170
|
+
sample, fast_pass=True, **self.forward_kwargs
|
|
171
|
+
)["tokens_and_masks"]
|
|
172
|
+
else:
|
|
173
|
+
# Other models like STEncoder do not have this option supported.
|
|
174
|
+
tokens_and_masks = self.model(sample, **self.forward_kwargs)[
|
|
175
|
+
"tokens_and_masks"
|
|
176
|
+
]
|
|
177
|
+
|
|
178
|
+
# Apply temporal/modality pooling so we just have one feature per patch.
|
|
179
|
+
features = []
|
|
180
|
+
for modality in present_modalities:
|
|
181
|
+
modality_features = getattr(tokens_and_masks, modality)
|
|
182
|
+
# Pool over band sets and timesteps (BHWTSC -> BHWC).
|
|
183
|
+
pooled = modality_features.mean(dim=[3, 4])
|
|
184
|
+
# We want BHWC -> BCHW.
|
|
185
|
+
pooled = rearrange(pooled, "b h w c -> b c h w")
|
|
186
|
+
features.append(pooled)
|
|
187
|
+
# Pool over the modalities, so we get one BCHW feature map.
|
|
188
|
+
pooled = torch.stack(features, dim=0).mean(dim=0)
|
|
189
|
+
return [pooled]
|
|
190
|
+
|
|
191
|
+
def get_backbone_channels(self) -> list:
|
|
192
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
193
|
+
|
|
194
|
+
The output channels is a list of (downsample_factor, depth) that corresponds
|
|
195
|
+
to the feature maps that the backbone returns. For example, an element [2, 32]
|
|
196
|
+
indicates that the corresponding feature map is 1/2 the input resolution and
|
|
197
|
+
has 32 channels.
|
|
198
|
+
|
|
199
|
+
Returns:
|
|
200
|
+
the output channels of the backbone as a list of (downsample_factor, depth)
|
|
201
|
+
tuples.
|
|
202
|
+
"""
|
|
203
|
+
return [(self.patch_size, self.embedding_size)]
|
|
@@ -0,0 +1,84 @@
|
|
|
1
|
+
"""Normalization transforms."""
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
from olmoearth_pretrain.data.normalize import load_computed_config
|
|
7
|
+
|
|
8
|
+
from rslearn.log_utils import get_logger
|
|
9
|
+
from rslearn.train.transforms.transform import Transform
|
|
10
|
+
|
|
11
|
+
logger = get_logger(__file__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OlmoEarthNormalize(Transform):
|
|
15
|
+
"""Normalize using OlmoEarth JSON config.
|
|
16
|
+
|
|
17
|
+
For Sentinel-1 data, the values should be converted to decibels before being passed
|
|
18
|
+
to this transform.
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
band_names: dict[str, list[str]],
|
|
24
|
+
std_multiplier: float | None = 2,
|
|
25
|
+
config_fname: str | None = None,
|
|
26
|
+
) -> None:
|
|
27
|
+
"""Initialize a new OlmoEarthNormalize.
|
|
28
|
+
|
|
29
|
+
Args:
|
|
30
|
+
band_names: map from modality name to the list of bands in that modality in
|
|
31
|
+
the order they are being loaded. Note that this order must match the
|
|
32
|
+
expected order for the OlmoEarth model.
|
|
33
|
+
std_multiplier: the std multiplier matching the one used for the model
|
|
34
|
+
training in OlmoEarth.
|
|
35
|
+
config_fname: load the normalization configuration from this file, instead
|
|
36
|
+
of getting it from OlmoEarth.
|
|
37
|
+
"""
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.band_names = band_names
|
|
40
|
+
self.std_multiplier = std_multiplier
|
|
41
|
+
|
|
42
|
+
if config_fname is None:
|
|
43
|
+
self.norm_config = load_computed_config()
|
|
44
|
+
else:
|
|
45
|
+
logger.warning(
|
|
46
|
+
f"Loading normalization config from {config_fname}. This argument is deprecated and will be removed in a future version."
|
|
47
|
+
)
|
|
48
|
+
with open(config_fname) as f:
|
|
49
|
+
self.norm_config = json.load(f)
|
|
50
|
+
|
|
51
|
+
def forward(
|
|
52
|
+
self, input_dict: dict[str, Any], target_dict: dict[str, Any]
|
|
53
|
+
) -> tuple[dict[str, Any], dict[str, Any]]:
|
|
54
|
+
"""Apply normalization over the inputs and targets.
|
|
55
|
+
|
|
56
|
+
Args:
|
|
57
|
+
input_dict: the input
|
|
58
|
+
target_dict: the target
|
|
59
|
+
|
|
60
|
+
Returns:
|
|
61
|
+
normalized (input_dicts, target_dicts) tuple
|
|
62
|
+
"""
|
|
63
|
+
for modality_name, cur_band_names in self.band_names.items():
|
|
64
|
+
band_norms = self.norm_config[modality_name]
|
|
65
|
+
image = input_dict[modality_name]
|
|
66
|
+
# Keep a set of indices to make sure that we normalize all of them.
|
|
67
|
+
needed_band_indices = set(range(image.shape[0]))
|
|
68
|
+
num_timesteps = image.shape[0] // len(cur_band_names)
|
|
69
|
+
|
|
70
|
+
for band, norm_dict in band_norms.items():
|
|
71
|
+
# If multitemporal, normalize each timestep separately.
|
|
72
|
+
for t in range(num_timesteps):
|
|
73
|
+
band_idx = cur_band_names.index(band) + t * len(cur_band_names)
|
|
74
|
+
min_val = norm_dict["mean"] - self.std_multiplier * norm_dict["std"]
|
|
75
|
+
max_val = norm_dict["mean"] + self.std_multiplier * norm_dict["std"]
|
|
76
|
+
image[band_idx] = (image[band_idx] - min_val) / (max_val - min_val)
|
|
77
|
+
needed_band_indices.remove(band_idx)
|
|
78
|
+
|
|
79
|
+
if len(needed_band_indices) > 0:
|
|
80
|
+
raise ValueError(
|
|
81
|
+
f"for modality {modality_name}, bands {needed_band_indices} were unexpectedly not normalized"
|
|
82
|
+
)
|
|
83
|
+
|
|
84
|
+
return input_dict, target_dict
|
|
@@ -76,3 +76,46 @@ class PoolingDecoder(torch.nn.Module):
|
|
|
76
76
|
features = torch.amax(features, dim=(2, 3))
|
|
77
77
|
features = self.fc_layers(features)
|
|
78
78
|
return self.output_layer(features)
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
class SegmentationPoolingDecoder(PoolingDecoder):
|
|
82
|
+
"""Like PoolingDecoder, but copy output to all pixels.
|
|
83
|
+
|
|
84
|
+
This allows for the model to produce a global output while still being compatible
|
|
85
|
+
with SegmentationTask. This only makes sense for very small windows, since the
|
|
86
|
+
output probabilities will be the same at all pixels. The main use case is to train
|
|
87
|
+
for a classification-like task on small windows, but still produce a raster during
|
|
88
|
+
inference on large windows.
|
|
89
|
+
"""
|
|
90
|
+
|
|
91
|
+
def __init__(
|
|
92
|
+
self,
|
|
93
|
+
in_channels: int,
|
|
94
|
+
out_channels: int,
|
|
95
|
+
image_key: str = "image",
|
|
96
|
+
**kwargs: Any,
|
|
97
|
+
):
|
|
98
|
+
"""Create a new SegmentationPoolingDecoder.
|
|
99
|
+
|
|
100
|
+
Args:
|
|
101
|
+
in_channels: input channels (channels in the last feature map passed to
|
|
102
|
+
this module)
|
|
103
|
+
out_channels: channels for the output flat feature vector
|
|
104
|
+
image_key: the key in inputs for the image from which the expected width
|
|
105
|
+
and height is derived.
|
|
106
|
+
kwargs: other arguments to pass to PoolingDecoder.
|
|
107
|
+
"""
|
|
108
|
+
super().__init__(in_channels=in_channels, out_channels=out_channels, **kwargs)
|
|
109
|
+
self.image_key = image_key
|
|
110
|
+
|
|
111
|
+
def forward(
|
|
112
|
+
self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
|
|
113
|
+
) -> torch.Tensor:
|
|
114
|
+
"""Extend PoolingDecoder forward to upsample the output to a segmentation mask.
|
|
115
|
+
|
|
116
|
+
This only works when all of the pixels have the same segmentation target.
|
|
117
|
+
"""
|
|
118
|
+
output_probs = super().forward(features, inputs)
|
|
119
|
+
# BC -> BCHW
|
|
120
|
+
h, w = inputs[0][self.image_key].shape[1:3]
|
|
121
|
+
return output_probs[:, :, None, None].repeat([1, 1, h, w])
|
rslearn/models/presto/presto.py
CHANGED
|
@@ -248,3 +248,14 @@ class Presto(nn.Module):
|
|
|
248
248
|
output_features[batch_idx : batch_idx + self.pixel_batch_size] = output_b
|
|
249
249
|
|
|
250
250
|
return [rearrange(output_features, "(b h w) d -> b d h w", h=h, w=w, b=b)]
|
|
251
|
+
|
|
252
|
+
def get_backbone_channels(self) -> list:
|
|
253
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
254
|
+
|
|
255
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
256
|
+
to the feature maps that the backbone returns.
|
|
257
|
+
|
|
258
|
+
Returns:
|
|
259
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
260
|
+
"""
|
|
261
|
+
return [(1, 128)]
|
rslearn/models/prithvi.py
CHANGED
|
@@ -173,6 +173,17 @@ class PrithviV2(nn.Module):
|
|
|
173
173
|
features, num_timesteps
|
|
174
174
|
)
|
|
175
175
|
|
|
176
|
+
def get_backbone_channels(self) -> list:
|
|
177
|
+
"""Returns the output channels of this model when used as a backbone.
|
|
178
|
+
|
|
179
|
+
The output channels is a list of (patch_size, depth) that corresponds
|
|
180
|
+
to the feature maps that the backbone returns.
|
|
181
|
+
|
|
182
|
+
Returns:
|
|
183
|
+
the output channels of the backbone as a list of (patch_size, depth) tuples.
|
|
184
|
+
"""
|
|
185
|
+
return [(1, 1024)]
|
|
186
|
+
|
|
176
187
|
|
|
177
188
|
class PrithviNormalize(Transform):
|
|
178
189
|
"""Normalize inputs using Prithvi normalization.
|
rslearn/models/registry.py
CHANGED
|
@@ -1,5 +1,22 @@
|
|
|
1
1
|
"""Model registry."""
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Any, TypeVar
|
|
4
5
|
|
|
5
|
-
|
|
6
|
+
_ModelT = TypeVar("_ModelT")
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class _ModelRegistry(dict[str, type[Any]]):
|
|
10
|
+
"""Registry for Model classes."""
|
|
11
|
+
|
|
12
|
+
def register(self, name: str) -> Callable[[type[_ModelT]], type[_ModelT]]:
|
|
13
|
+
"""Decorator to register a model class."""
|
|
14
|
+
|
|
15
|
+
def decorator(cls: type[_ModelT]) -> type[_ModelT]:
|
|
16
|
+
self[name] = cls
|
|
17
|
+
return cls
|
|
18
|
+
|
|
19
|
+
return decorator
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
Models = _ModelRegistry()
|
rslearn/tile_stores/default.py
CHANGED
|
@@ -130,10 +130,12 @@ class DefaultTileStore(TileStore):
|
|
|
130
130
|
"""
|
|
131
131
|
raster_dir = self._get_raster_dir(layer_name, item_name, bands)
|
|
132
132
|
for fname in raster_dir.iterdir():
|
|
133
|
-
# Ignore completed sentinel files as well as temporary files created by
|
|
133
|
+
# Ignore completed sentinel files, bands files, as well as temporary files created by
|
|
134
134
|
# open_atomic (in case this tile store is on local filesystem).
|
|
135
135
|
if fname.name == COMPLETED_FNAME:
|
|
136
136
|
continue
|
|
137
|
+
if fname.name == BANDS_FNAME:
|
|
138
|
+
continue
|
|
137
139
|
if ".tmp." in fname.name:
|
|
138
140
|
continue
|
|
139
141
|
return fname
|
|
@@ -54,7 +54,7 @@ def read_selector(
|
|
|
54
54
|
the item specified by the selector
|
|
55
55
|
"""
|
|
56
56
|
d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
|
|
57
|
-
parts = selector.split("/")
|
|
57
|
+
parts = selector.split("/") if selector else []
|
|
58
58
|
cur = d
|
|
59
59
|
for part in parts:
|
|
60
60
|
cur = cur[part]
|
|
@@ -76,11 +76,28 @@ def write_selector(
|
|
|
76
76
|
v: the value to write
|
|
77
77
|
"""
|
|
78
78
|
d, selector = get_dict_and_subselector(input_dict, target_dict, selector)
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
79
|
+
if selector:
|
|
80
|
+
parts = selector.split("/")
|
|
81
|
+
cur = d
|
|
82
|
+
for part in parts[:-1]:
|
|
83
|
+
cur = cur[part]
|
|
84
|
+
cur[parts[-1]] = v
|
|
85
|
+
else:
|
|
86
|
+
# If the selector references the input or target dictionary directly, then we
|
|
87
|
+
# have a special case where instead of overwriting with v, we replace the keys
|
|
88
|
+
# with those in v. v must be a dictionary here, not a tensor, since otherwise
|
|
89
|
+
# it wouldn't match the type of the input or target dictionary.
|
|
90
|
+
if not isinstance(v, dict):
|
|
91
|
+
raise ValueError(
|
|
92
|
+
"when directly specifying the input or target dict, expected the value to be a dict"
|
|
93
|
+
)
|
|
94
|
+
if d == v:
|
|
95
|
+
# This may happen if the writer did not make a copy of the dictionary. In
|
|
96
|
+
# this case the code below would not update d correctly since it would also
|
|
97
|
+
# clear v.
|
|
98
|
+
return
|
|
99
|
+
d.clear()
|
|
100
|
+
d.update(v)
|
|
84
101
|
|
|
85
102
|
|
|
86
103
|
class Transform(torch.nn.Module):
|
rslearn/utils/raster_format.py
CHANGED
|
@@ -2,13 +2,13 @@
|
|
|
2
2
|
|
|
3
3
|
import hashlib
|
|
4
4
|
import json
|
|
5
|
-
from
|
|
5
|
+
from collections.abc import Callable
|
|
6
|
+
from typing import Any, BinaryIO, TypeVar
|
|
6
7
|
|
|
7
8
|
import affine
|
|
8
9
|
import numpy as np
|
|
9
10
|
import numpy.typing as npt
|
|
10
11
|
import rasterio
|
|
11
|
-
from class_registry import ClassRegistry
|
|
12
12
|
from PIL import Image
|
|
13
13
|
from rasterio.crs import CRS
|
|
14
14
|
from rasterio.enums import Resampling
|
|
@@ -21,7 +21,27 @@ from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath
|
|
|
21
21
|
|
|
22
22
|
from .geometry import PixelBounds, Projection
|
|
23
23
|
|
|
24
|
-
|
|
24
|
+
_RasterFormatT = TypeVar("_RasterFormatT", bound="RasterFormat")
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
class _RasterFormatRegistry(dict[str, type["RasterFormat"]]):
|
|
28
|
+
"""Registry for RasterFormat classes."""
|
|
29
|
+
|
|
30
|
+
def register(
|
|
31
|
+
self, name: str
|
|
32
|
+
) -> Callable[[type[_RasterFormatT]], type[_RasterFormatT]]:
|
|
33
|
+
"""Decorator to register a raster format class."""
|
|
34
|
+
|
|
35
|
+
def decorator(cls: type[_RasterFormatT]) -> type[_RasterFormatT]:
|
|
36
|
+
self[name] = cls
|
|
37
|
+
return cls
|
|
38
|
+
|
|
39
|
+
return decorator
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
RasterFormats = _RasterFormatRegistry()
|
|
43
|
+
|
|
44
|
+
|
|
25
45
|
logger = get_logger(__name__)
|
|
26
46
|
|
|
27
47
|
|
|
@@ -147,6 +167,19 @@ class RasterFormat:
|
|
|
147
167
|
"""
|
|
148
168
|
raise NotImplementedError
|
|
149
169
|
|
|
170
|
+
@staticmethod
|
|
171
|
+
def from_config(name: str, config: dict[str, Any]) -> "RasterFormat":
|
|
172
|
+
"""Create a RasterFormat from a config dict.
|
|
173
|
+
|
|
174
|
+
Args:
|
|
175
|
+
name: the name of this format
|
|
176
|
+
config: the config dict
|
|
177
|
+
|
|
178
|
+
Returns:
|
|
179
|
+
the RasterFormat instance
|
|
180
|
+
"""
|
|
181
|
+
raise NotImplementedError
|
|
182
|
+
|
|
150
183
|
|
|
151
184
|
@RasterFormats.register("image_tile")
|
|
152
185
|
class ImageTileRasterFormat(RasterFormat):
|
|
@@ -716,5 +749,5 @@ def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
|
|
|
716
749
|
Returns:
|
|
717
750
|
the loaded RasterFormat implementation
|
|
718
751
|
"""
|
|
719
|
-
cls = RasterFormats
|
|
752
|
+
cls = RasterFormats[config.name]
|
|
720
753
|
return cls.from_config(config.name, config.config_dict)
|
rslearn/utils/vector_format.py
CHANGED
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
"""Classes for writing vector data to a UPath."""
|
|
2
2
|
|
|
3
3
|
import json
|
|
4
|
+
from collections.abc import Callable
|
|
4
5
|
from enum import Enum
|
|
5
|
-
from typing import Any
|
|
6
|
+
from typing import Any, TypeVar
|
|
6
7
|
|
|
7
8
|
import shapely
|
|
8
|
-
from class_registry import ClassRegistry
|
|
9
9
|
from rasterio.crs import CRS
|
|
10
10
|
from upath import UPath
|
|
11
11
|
|
|
@@ -18,7 +18,25 @@ from .feature import Feature
|
|
|
18
18
|
from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
|
|
19
19
|
|
|
20
20
|
logger = get_logger(__name__)
|
|
21
|
-
|
|
21
|
+
_VectorFormatT = TypeVar("_VectorFormatT", bound="VectorFormat")
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class _VectorFormatRegistry(dict[str, type["VectorFormat"]]):
|
|
25
|
+
"""Registry for VectorFormat classes."""
|
|
26
|
+
|
|
27
|
+
def register(
|
|
28
|
+
self, name: str
|
|
29
|
+
) -> Callable[[type[_VectorFormatT]], type[_VectorFormatT]]:
|
|
30
|
+
"""Decorator to register a vector format class."""
|
|
31
|
+
|
|
32
|
+
def decorator(cls: type[_VectorFormatT]) -> type[_VectorFormatT]:
|
|
33
|
+
self[name] = cls
|
|
34
|
+
return cls
|
|
35
|
+
|
|
36
|
+
return decorator
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
VectorFormats = _VectorFormatRegistry()
|
|
22
40
|
|
|
23
41
|
|
|
24
42
|
class VectorFormat:
|
|
@@ -53,6 +71,19 @@ class VectorFormat:
|
|
|
53
71
|
"""
|
|
54
72
|
raise NotImplementedError
|
|
55
73
|
|
|
74
|
+
@staticmethod
|
|
75
|
+
def from_config(name: str, config: dict[str, Any]) -> "VectorFormat":
|
|
76
|
+
"""Create a VectorFormat from a config dict.
|
|
77
|
+
|
|
78
|
+
Args:
|
|
79
|
+
name: the name of this format
|
|
80
|
+
config: the config dict
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
the VectorFormat instance
|
|
84
|
+
"""
|
|
85
|
+
raise NotImplementedError
|
|
86
|
+
|
|
56
87
|
|
|
57
88
|
@VectorFormats.register("tile")
|
|
58
89
|
class TileVectorFormat(VectorFormat):
|
|
@@ -410,5 +441,5 @@ def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
|
|
|
410
441
|
Returns:
|
|
411
442
|
the loaded VectorFormat implementation
|
|
412
443
|
"""
|
|
413
|
-
cls = VectorFormats
|
|
444
|
+
cls = VectorFormats[config.name]
|
|
414
445
|
return cls.from_config(config.name, config.config_dict)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.11
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
5
|
Author: OlmoEarth Team
|
|
6
6
|
License: Apache License
|
|
@@ -212,7 +212,6 @@ Requires-Python: >=3.11
|
|
|
212
212
|
Description-Content-Type: text/markdown
|
|
213
213
|
License-File: LICENSE
|
|
214
214
|
Requires-Dist: boto3>=1.39
|
|
215
|
-
Requires-Dist: class_registry>=2.1
|
|
216
215
|
Requires-Dist: fiona>=1.10
|
|
217
216
|
Requires-Dist: fsspec>=2025.9.0
|
|
218
217
|
Requires-Dist: jsonargparse>=4.35.0
|
|
@@ -244,6 +243,7 @@ Requires-Dist: planetary_computer>=1.0; extra == "extra"
|
|
|
244
243
|
Requires-Dist: pycocotools>=2.0; extra == "extra"
|
|
245
244
|
Requires-Dist: pystac_client>=0.9; extra == "extra"
|
|
246
245
|
Requires-Dist: rtree>=1.4; extra == "extra"
|
|
246
|
+
Requires-Dist: termcolor>=3.0; extra == "extra"
|
|
247
247
|
Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
|
|
248
248
|
Requires-Dist: scipy>=1.16; extra == "extra"
|
|
249
249
|
Requires-Dist: terratorch>=1.0.2; extra == "extra"
|
|
@@ -284,6 +284,7 @@ Quick links:
|
|
|
284
284
|
- [Examples](docs/Examples.md) contains more examples, including customizing different
|
|
285
285
|
stages of rslearn with additional code.
|
|
286
286
|
- [DatasetConfig](docs/DatasetConfig.md) documents the dataset configuration file.
|
|
287
|
+
- [ModelConfig](docs/ModelConfig.md) documents the model configuration file.
|
|
287
288
|
|
|
288
289
|
|
|
289
290
|
Setup
|
|
@@ -20,11 +20,11 @@ rslearn/data_sources/eurocrops.py,sha256=bH_ul45RvNLLvVKU9JjZH1XMenAmgn7Lakq0pXj
|
|
|
20
20
|
rslearn/data_sources/gcp_public_data.py,sha256=kr9stYo7ZCvz8s4E3wmoY-yAGZoLa_9RCwjS-Q5k9dM,36128
|
|
21
21
|
rslearn/data_sources/geotiff.py,sha256=sFUp919chaX4j6lQytNp__xnMLlDI3Ac3rfB6F8sgZ0,45
|
|
22
22
|
rslearn/data_sources/google_earth_engine.py,sha256=hpkt74ly2lEwjRrDp8FBmGvB3MEw_mQ38Av4rQOR3_w,24246
|
|
23
|
-
rslearn/data_sources/local_files.py,sha256=
|
|
23
|
+
rslearn/data_sources/local_files.py,sha256=d08m6IzrUN_80VfvgpHahMJrv-n6_CI6EIocp6kyDRs,19490
|
|
24
24
|
rslearn/data_sources/openstreetmap.py,sha256=qUSMFiIA_laJkO3meBXf9TmSI7OBD-o3i4JxqllUv3Q,19232
|
|
25
25
|
rslearn/data_sources/planet.py,sha256=F2JoLaQ5Cb3k1cTm0hwSWTL2TPfbaAUMXZ8q4Dy7UlA,10109
|
|
26
26
|
rslearn/data_sources/planet_basemap.py,sha256=wuWM9dHSJMdINfyWb78Zk9i-KvJHTrf9J0Q2gyEyiiA,10450
|
|
27
|
-
rslearn/data_sources/planetary_computer.py,sha256=
|
|
27
|
+
rslearn/data_sources/planetary_computer.py,sha256=uHNYxvnMkmo8zbqIiDRdnkz8LQ7TSs6K39Y1AXjboDI,30392
|
|
28
28
|
rslearn/data_sources/raster_source.py,sha256=b8wo55GhVLxXwx1WYLzeRAlzD_ZkE_P9tnvUOdnsfQE,689
|
|
29
29
|
rslearn/data_sources/usda_cdl.py,sha256=2_V11AhPRgLEGd4U5Pmx3UvE2HWBPbsFXhUIQVRVFeE,7138
|
|
30
30
|
rslearn/data_sources/usgs_landsat.py,sha256=31GmOUfmxwTE6MTiVI4psb-ciVmunuA8cfvqDuvTHPE,19312
|
|
@@ -39,9 +39,9 @@ rslearn/dataset/add_windows.py,sha256=pwCEvwLE1jQCoqQxw6CJ-sP46ayWppFa2hGYIB6VVk
|
|
|
39
39
|
rslearn/dataset/dataset.py,sha256=bjf9nI55j-MF0bIQWSNPjNbpfqnLK4jy-96TAcwO0MM,5214
|
|
40
40
|
rslearn/dataset/handler_summaries.py,sha256=wGnbBpjLWTxVn3UT7j7nPoHlYsaWb9_MVJ5DhU0qWXY,2581
|
|
41
41
|
rslearn/dataset/index.py,sha256=Wni5m6h4gisRB54fPLnCfUrRTEsJ5EvwS0fs9sYc2wg,6025
|
|
42
|
-
rslearn/dataset/manage.py,sha256=
|
|
43
|
-
rslearn/dataset/materialize.py,sha256
|
|
44
|
-
rslearn/dataset/remap.py,sha256=
|
|
42
|
+
rslearn/dataset/manage.py,sha256=mkdBHo1RFGxMx8f9zBT_VmRO_6y8Qb2KfWPPziKWYkg,18062
|
|
43
|
+
rslearn/dataset/materialize.py,sha256=-z47svc_JqGhzkp8kq5Hd9fykWNqFEUCQezo887TWBw,22056
|
|
44
|
+
rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
|
|
45
45
|
rslearn/dataset/window.py,sha256=I5RqZ12jlIXhohw4qews1x_I4tSDpml709DZRtLiN24,12546
|
|
46
46
|
rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
|
|
47
47
|
rslearn/models/anysat.py,sha256=3BnaiS1sYB4SnV6qRjHksiz_r9vUuZeGPUO2XUziFA0,7810
|
|
@@ -49,7 +49,7 @@ rslearn/models/clip.py,sha256=u5aqYnVB4Jag7o1h8EzPDAc1t2BAHeALA9FcUwP5tfo,2238
|
|
|
49
49
|
rslearn/models/conv.py,sha256=fWyByeswIOKKzyPmP3erYUlZaKEV0huWHA4CyKTBbfY,1703
|
|
50
50
|
rslearn/models/copernicusfm.py,sha256=3AiORuUre9sZYwydbrDgShwKtxeTLmExp7WQmJtBylg,7842
|
|
51
51
|
rslearn/models/croma.py,sha256=cOazTp3l2PNJltKrmPqD5Gy4pi3CI03-X9G4T10cX2k,9529
|
|
52
|
-
rslearn/models/dinov3.py,sha256=
|
|
52
|
+
rslearn/models/dinov3.py,sha256=GKk5qXZPCEporATJdjaSWsDTfWDlAGRWBplFUJN5nRM,6146
|
|
53
53
|
rslearn/models/faster_rcnn.py,sha256=uaxX6-E1f0BibaA9sorEg3be83C7kTdTc39pC5jRqwE,8286
|
|
54
54
|
rslearn/models/fpn.py,sha256=s3cz29I14FaSuvBvLOcwCrqVsaRBxG5GjLlqap4WgPc,1603
|
|
55
55
|
rslearn/models/module_wrapper.py,sha256=H2zb-8Au4t31kawW_4JEKHsaXFjpYDawb31ZEauKcxU,2728
|
|
@@ -57,9 +57,9 @@ rslearn/models/molmo.py,sha256=mVrARBhZciMzOgOOjGB5AHlPIf2iO9IBSJmdyKSl1L8,2061
|
|
|
57
57
|
rslearn/models/multitask.py,sha256=j2Kiwj_dUiUp_CIUr25bS8HiyeoFlr1PGqjTfpgIGLc,14672
|
|
58
58
|
rslearn/models/panopticon.py,sha256=woNEs53wVc5D-NxbSDEPRZ_mYe8vllnuldmADjvhfDQ,5806
|
|
59
59
|
rslearn/models/pick_features.py,sha256=y8e4tJFhyG7ZuVSElWhQ5-Aer4ZKJCEH9wLGJU7WqGI,1551
|
|
60
|
-
rslearn/models/pooling_decoder.py,sha256=
|
|
61
|
-
rslearn/models/prithvi.py,sha256=
|
|
62
|
-
rslearn/models/registry.py,sha256=
|
|
60
|
+
rslearn/models/pooling_decoder.py,sha256=unr2fSE_QmJHPi3dKtopqMtb1Kn-2h94LgwwAVP9vZg,4437
|
|
61
|
+
rslearn/models/prithvi.py,sha256=SVM3ypJlVTkXQ69pPhB4UeJr87VnmADTCuyV365dbkU,39961
|
|
62
|
+
rslearn/models/registry.py,sha256=yCcrOvLkbn07Xtln1j7hAB_kmGw0MGsiR2TloJq9Bmk,504
|
|
63
63
|
rslearn/models/resize_features.py,sha256=asKXWrLHIBrU6GaAV0Ory9YuK7IK104XjhkB4ljzI3A,1289
|
|
64
64
|
rslearn/models/sam2_enc.py,sha256=gNlPokr7eNxO2KvnzDMXNxYM2WRO0YkQPjR4110n6cw,3508
|
|
65
65
|
rslearn/models/satlaspretrain.py,sha256=YpjXl-uClhTZMDmyhN64Fg3AszzT-ymZgJB0fO9RyoY,2419
|
|
@@ -91,8 +91,11 @@ rslearn/models/detr/position_encoding.py,sha256=8FFoBT-Jtgqk7D4qDBTbVLOeAdmjdjtJ
|
|
|
91
91
|
rslearn/models/detr/transformer.py,sha256=aK4HO7AkCZn7xGHP3Iq91w2iFPVshugOILYAjVjroCw,13971
|
|
92
92
|
rslearn/models/detr/util.py,sha256=NMHhHbkIo7PoBUVbDqa2ZknJBTswmaxFCGHrPtFKnGg,676
|
|
93
93
|
rslearn/models/galileo/__init__.py,sha256=QQa0C29nuPRva0KtGiMHQ2ZB02n9SSwj_wqTKPz18NM,112
|
|
94
|
-
rslearn/models/galileo/galileo.py,sha256=
|
|
95
|
-
rslearn/models/galileo/single_file_galileo.py,sha256=
|
|
94
|
+
rslearn/models/galileo/galileo.py,sha256=jUHA64YvVC3Fz5fevc_9dFJfZaINODRDrhSGLIiOZcw,21115
|
|
95
|
+
rslearn/models/galileo/single_file_galileo.py,sha256=l5tlmmdr2eieHNH-M7rVIvcptkv0Fuk3vKXFW691ezA,56143
|
|
96
|
+
rslearn/models/olmoearth_pretrain/__init__.py,sha256=AjRvbjBdadCdPh-EdvySH76sVAQ8NGQaJt11Tsn1D5I,36
|
|
97
|
+
rslearn/models/olmoearth_pretrain/model.py,sha256=F-B1ym9UZuTPJ0OY15Jwb1TkNtr_EtAUlqI-tr_Z2uo,8352
|
|
98
|
+
rslearn/models/olmoearth_pretrain/norm.py,sha256=rHjFyWkpNLYMx9Ow7TsU-jGm9Sjx7FVf0p4R__ohx2c,3266
|
|
96
99
|
rslearn/models/panopticon_data/sensors/drone.yaml,sha256=xqWS-_QMtJyRoWXJm-igoSur9hAmCFdqkPin8DT5qpw,431
|
|
97
100
|
rslearn/models/panopticon_data/sensors/enmap.yaml,sha256=b2j6bSgYR2yKR9DRm3SPIzSVYlHf51ny_p-1B4B9sB4,13431
|
|
98
101
|
rslearn/models/panopticon_data/sensors/goes.yaml,sha256=o00aoWCYqam0aB1rPmXq1MKe8hsKak_qyBG7BPL27Sc,152
|
|
@@ -106,10 +109,10 @@ rslearn/models/panopticon_data/sensors/sentinel2.yaml,sha256=qYJ92x-GHO0ZdCrTtCj
|
|
|
106
109
|
rslearn/models/panopticon_data/sensors/superdove.yaml,sha256=QpIRyopdV4hAez_EIsDwhGFT4VtTk7UgzQveyc8t8fc,795
|
|
107
110
|
rslearn/models/panopticon_data/sensors/wv23.yaml,sha256=SWYSlkka6UViKAz6YI8aqwQ-Ayo-S5kmNa9rO3iGW6o,1172
|
|
108
111
|
rslearn/models/presto/__init__.py,sha256=eZrB-XKi_vYqZhpyAOwppJi4dRuMtYVAdbq7KRygze0,64
|
|
109
|
-
rslearn/models/presto/presto.py,sha256=
|
|
112
|
+
rslearn/models/presto/presto.py,sha256=8mZnc0jk_r_JikybHQNyyHg6t7JNPmoPmgoivyNf-U8,9177
|
|
110
113
|
rslearn/models/presto/single_file_presto.py,sha256=Kbwp8V7pO8HHM2vlCPpjekQiFiDryW8zQkWmt1g05BY,30381
|
|
111
114
|
rslearn/tile_stores/__init__.py,sha256=o_tWVKu6UwFzZbO9jn_3cmIDqc_Q3qDd6tA9If0T_Qk,2050
|
|
112
|
-
rslearn/tile_stores/default.py,sha256=
|
|
115
|
+
rslearn/tile_stores/default.py,sha256=PYaDNvBxhJTDKJGw0EjDTSE1OKajR7_iJpMbOjj-mE8,15054
|
|
113
116
|
rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
|
|
114
117
|
rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
|
|
115
118
|
rslearn/train/data_module.py,sha256=K-nQgnOZn-KGq_G2pVOQFtWRrlWih0212i_bkXZ2bEE,23515
|
|
@@ -140,7 +143,7 @@ rslearn/train/transforms/normalize.py,sha256=uyv2hE5hw5B2kCRHa4JIx0tfowm-C7bgumw
|
|
|
140
143
|
rslearn/train/transforms/pad.py,sha256=EDswS9KYRSloM3DQlbCz6S0WYqFQJvI433qMqTtqrZw,4686
|
|
141
144
|
rslearn/train/transforms/select_bands.py,sha256=uDfD9G8Z4VTt88QZsjj1FB20QEmzSefhKf7uDXYn77M,2441
|
|
142
145
|
rslearn/train/transforms/sentinel1.py,sha256=FrLaYZs2AjqWQCun8DTFtgo1l0xLxqaFKtDNIehtpDg,1913
|
|
143
|
-
rslearn/train/transforms/transform.py,sha256=
|
|
146
|
+
rslearn/train/transforms/transform.py,sha256=n1Qzqix2dVvej-Q7iPzHeOQbqH79IBlvqPoymxhNVpE,4446
|
|
144
147
|
rslearn/utils/__init__.py,sha256=GNvdTUmXakiEMnLdje7k1fe5aC7SFVqP757kbpN6Fzw,558
|
|
145
148
|
rslearn/utils/array.py,sha256=JwZi7o0uj-dftREzJmqrRVR2joIwBikm3Er9KeHVIZU,2402
|
|
146
149
|
rslearn/utils/feature.py,sha256=lsg0WThZDJzo1mrbaL04dXYI5G3x-n5FG9aEjj7uUaI,1649
|
|
@@ -150,15 +153,15 @@ rslearn/utils/get_utm_ups_crs.py,sha256=kUrcyjCK7KWvuP1XR-nURPeRqYeRO-3L8QUJ1QTF
|
|
|
150
153
|
rslearn/utils/grid_index.py,sha256=hRmrtgpqN1pLa-djnZtgSXqKJlbgGyttGnCEmPLD0zo,2347
|
|
151
154
|
rslearn/utils/jsonargparse.py,sha256=JcTKQoZ6jgwag-kSeTIEVBO9AsRj0X1oEJBsoaCazH4,658
|
|
152
155
|
rslearn/utils/mp.py,sha256=XYmVckI5TOQuCKc49NJyirDJyFgvb4AI-gGypG2j680,1399
|
|
153
|
-
rslearn/utils/raster_format.py,sha256=
|
|
156
|
+
rslearn/utils/raster_format.py,sha256=dBTSa8l6Ms9Ndbx9Krgqm9z4RU7j2hwLBkw2w-KibU4,26009
|
|
154
157
|
rslearn/utils/rtree_index.py,sha256=j0Zwrq3pXuAJ-hKpiRFQ7VNtvO3fZYk-Em2uBPAqfx4,6460
|
|
155
158
|
rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfsU,762
|
|
156
159
|
rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
|
|
157
160
|
rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
|
|
158
|
-
rslearn/utils/vector_format.py,sha256=
|
|
159
|
-
rslearn-0.0.
|
|
160
|
-
rslearn-0.0.
|
|
161
|
-
rslearn-0.0.
|
|
162
|
-
rslearn-0.0.
|
|
163
|
-
rslearn-0.0.
|
|
164
|
-
rslearn-0.0.
|
|
161
|
+
rslearn/utils/vector_format.py,sha256=EIChYCL6GLOILS2TO2JBkca1TuaWsSubWv6iRS3P2ds,16139
|
|
162
|
+
rslearn-0.0.11.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
|
|
163
|
+
rslearn-0.0.11.dist-info/METADATA,sha256=jwB0ZZ-oLa1Y_1iuZRKCQoB4i3kOFDJ0xSeMTJP7zww,36297
|
|
164
|
+
rslearn-0.0.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
165
|
+
rslearn-0.0.11.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
|
|
166
|
+
rslearn-0.0.11.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
|
|
167
|
+
rslearn-0.0.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|