rslearn 0.0.2__py3-none-any.whl → 0.0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/config/__init__.py +2 -2
- rslearn/config/dataset.py +156 -99
- rslearn/const.py +9 -15
- rslearn/data_sources/aws_landsat.py +216 -70
- rslearn/data_sources/aws_open_data.py +64 -78
- rslearn/data_sources/aws_sentinel1.py +142 -0
- rslearn/data_sources/climate_data_store.py +303 -0
- rslearn/data_sources/copernicus.py +830 -45
- rslearn/data_sources/data_source.py +2 -4
- rslearn/data_sources/earthdaily.py +489 -0
- rslearn/data_sources/earthdata_srtm.py +300 -0
- rslearn/data_sources/gcp_public_data.py +435 -159
- rslearn/data_sources/google_earth_engine.py +443 -106
- rslearn/data_sources/local_files.py +97 -74
- rslearn/data_sources/openstreetmap.py +6 -16
- rslearn/data_sources/planet.py +7 -26
- rslearn/data_sources/planet_basemap.py +52 -59
- rslearn/data_sources/planetary_computer.py +764 -0
- rslearn/data_sources/raster_source.py +9 -305
- rslearn/data_sources/usda_cdl.py +206 -0
- rslearn/data_sources/usgs_landsat.py +84 -56
- rslearn/data_sources/utils.py +250 -62
- rslearn/data_sources/worldcereal.py +456 -0
- rslearn/data_sources/worldcover.py +142 -0
- rslearn/data_sources/worldpop.py +156 -0
- rslearn/data_sources/xyz_tiles.py +141 -78
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/dataset.py +28 -3
- rslearn/dataset/index.py +173 -0
- rslearn/dataset/manage.py +129 -46
- rslearn/dataset/materialize.py +429 -109
- rslearn/dataset/window.py +221 -35
- rslearn/main.py +235 -78
- rslearn/models/clip.py +2 -2
- rslearn/models/conv.py +7 -7
- rslearn/models/croma.py +270 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +493 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/faster_rcnn.py +8 -0
- rslearn/models/module_wrapper.py +48 -0
- rslearn/models/moe/distributed.py +262 -0
- rslearn/models/moe/soft.py +676 -0
- rslearn/models/molmo.py +2 -2
- rslearn/models/multitask.py +351 -24
- rslearn/models/pick_features.py +15 -2
- rslearn/models/simple_time_series.py +14 -4
- rslearn/models/singletask.py +8 -4
- rslearn/models/ssl4eo_s12.py +1 -1
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +219 -0
- rslearn/models/trunk.py +280 -0
- rslearn/models/unet.py +17 -3
- rslearn/models/use_croma.py +508 -0
- rslearn/py.typed +0 -0
- rslearn/tile_stores/__init__.py +52 -18
- rslearn/tile_stores/default.py +382 -0
- rslearn/tile_stores/tile_store.py +241 -149
- rslearn/train/callbacks/freeze_unfreeze.py +29 -17
- rslearn/train/callbacks/gradients.py +109 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +402 -10
- rslearn/train/dataset.py +695 -192
- rslearn/train/lightning_module.py +151 -45
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +224 -78
- rslearn/train/scheduler.py +62 -0
- rslearn/train/tasks/classification.py +4 -4
- rslearn/train/tasks/detection.py +27 -22
- rslearn/train/tasks/multi_task.py +23 -8
- rslearn/train/tasks/regression.py +105 -17
- rslearn/train/tasks/segmentation.py +330 -16
- rslearn/train/tasks/task.py +2 -2
- rslearn/train/transforms/concatenate.py +3 -3
- rslearn/train/transforms/crop.py +2 -2
- rslearn/train/transforms/normalize.py +25 -5
- rslearn/train/transforms/transform.py +73 -71
- rslearn/utils/__init__.py +0 -3
- rslearn/utils/feature.py +1 -1
- rslearn/utils/geometry.py +60 -5
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/jsonargparse.py +33 -0
- rslearn/utils/raster_format.py +207 -100
- rslearn/utils/vector_format.py +142 -80
- {rslearn-0.0.2.dist-info → rslearn-0.0.3.dist-info}/METADATA +366 -289
- rslearn-0.0.3.dist-info/RECORD +123 -0
- {rslearn-0.0.2.dist-info → rslearn-0.0.3.dist-info}/WHEEL +1 -1
- rslearn/tile_stores/file.py +0 -245
- rslearn/utils/utils.py +0 -30
- rslearn-0.0.2.dist-info/RECORD +0 -94
- {rslearn-0.0.2.dist-info → rslearn-0.0.3.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.2.dist-info → rslearn-0.0.3.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.2.dist-info → rslearn-0.0.3.dist-info}/top_level.txt +0 -0
rslearn/config/__init__.py
CHANGED
|
@@ -2,6 +2,7 @@
|
|
|
2
2
|
|
|
3
3
|
from .dataset import (
|
|
4
4
|
BandSetConfig,
|
|
5
|
+
CompositingMethod,
|
|
5
6
|
DataSourceConfig,
|
|
6
7
|
DType,
|
|
7
8
|
LayerConfig,
|
|
@@ -10,7 +11,6 @@ from .dataset import (
|
|
|
10
11
|
RasterFormatConfig,
|
|
11
12
|
RasterLayerConfig,
|
|
12
13
|
SpaceMode,
|
|
13
|
-
TileStoreConfig,
|
|
14
14
|
TimeMode,
|
|
15
15
|
VectorFormatConfig,
|
|
16
16
|
VectorLayerConfig,
|
|
@@ -19,6 +19,7 @@ from .dataset import (
|
|
|
19
19
|
|
|
20
20
|
__all__ = [
|
|
21
21
|
"BandSetConfig",
|
|
22
|
+
"CompositingMethod",
|
|
22
23
|
"DataSourceConfig",
|
|
23
24
|
"DType",
|
|
24
25
|
"LayerConfig",
|
|
@@ -27,7 +28,6 @@ __all__ = [
|
|
|
27
28
|
"RasterFormatConfig",
|
|
28
29
|
"RasterLayerConfig",
|
|
29
30
|
"SpaceMode",
|
|
30
|
-
"TileStoreConfig",
|
|
31
31
|
"TimeMode",
|
|
32
32
|
"VectorFormatConfig",
|
|
33
33
|
"VectorLayerConfig",
|
rslearn/config/dataset.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
"""Classes for storing configuration of a dataset."""
|
|
2
2
|
|
|
3
|
+
import json
|
|
3
4
|
from datetime import timedelta
|
|
4
5
|
from enum import Enum
|
|
5
6
|
from typing import Any
|
|
@@ -19,7 +20,11 @@ class DType(Enum):
|
|
|
19
20
|
UINT8 = "uint8"
|
|
20
21
|
UINT16 = "uint16"
|
|
21
22
|
UINT32 = "uint32"
|
|
23
|
+
UINT64 = "uint64"
|
|
24
|
+
INT8 = "int8"
|
|
25
|
+
INT16 = "int16"
|
|
22
26
|
INT32 = "int32"
|
|
27
|
+
INT64 = "int64"
|
|
23
28
|
FLOAT32 = "float32"
|
|
24
29
|
|
|
25
30
|
def get_numpy_dtype(self) -> npt.DTypeLike:
|
|
@@ -30,8 +35,16 @@ class DType(Enum):
|
|
|
30
35
|
return np.uint16
|
|
31
36
|
elif self == DType.UINT32:
|
|
32
37
|
return np.uint32
|
|
38
|
+
elif self == DType.UINT64:
|
|
39
|
+
return np.uint64
|
|
40
|
+
elif self == DType.INT8:
|
|
41
|
+
return np.int8
|
|
42
|
+
elif self == DType.INT16:
|
|
43
|
+
return np.int16
|
|
33
44
|
elif self == DType.INT32:
|
|
34
45
|
return np.int32
|
|
46
|
+
elif self == DType.INT64:
|
|
47
|
+
return np.int64
|
|
35
48
|
elif self == DType.FLOAT32:
|
|
36
49
|
return np.float32
|
|
37
50
|
raise ValueError(f"unable to handle numpy dtype {self}")
|
|
@@ -116,6 +129,8 @@ class BandSetConfig:
|
|
|
116
129
|
format: dict[str, Any] | None = None,
|
|
117
130
|
zoom_offset: int = 0,
|
|
118
131
|
remap: dict[str, Any] | None = None,
|
|
132
|
+
class_names: list[list[str]] | None = None,
|
|
133
|
+
nodata_vals: list[float] | None = None,
|
|
119
134
|
) -> None:
|
|
120
135
|
"""Creates a new BandSetConfig instance.
|
|
121
136
|
|
|
@@ -124,15 +139,34 @@ class BandSetConfig:
|
|
|
124
139
|
dtype: the pixel value type to store tiles in
|
|
125
140
|
bands: list of band names in this BandSetConfig
|
|
126
141
|
format: the format to store tiles in, defaults to geotiff
|
|
127
|
-
zoom_offset:
|
|
128
|
-
|
|
142
|
+
zoom_offset: store images at a resolution higher or lower than the window
|
|
143
|
+
resolution. This enables keeping source data at its native resolution,
|
|
144
|
+
either to save storage space (for lower resolution data) or to retain
|
|
145
|
+
details (for higher resolution data). If positive, store data at the
|
|
146
|
+
window resolution divided by 2^(zoom_offset) (higher resolution). If
|
|
147
|
+
negative, store data at the window resolution multiplied by
|
|
148
|
+
2^(-zoom_offset) (lower resolution).
|
|
129
149
|
remap: config dict for Remapper to remap pixel values
|
|
150
|
+
class_names: optional list of names for the different possible values of
|
|
151
|
+
each band. The length of this list must equal the number of bands. For
|
|
152
|
+
example, [["forest", "desert"]] means that it is a single-band raster
|
|
153
|
+
where values can be 0 (forest) or 1 (desert).
|
|
154
|
+
nodata_vals: the nodata values for this band set. This is used during
|
|
155
|
+
materialization when creating mosaics, to determine which parts of the
|
|
156
|
+
source images should be copied.
|
|
130
157
|
"""
|
|
158
|
+
if class_names is not None and len(bands) != len(class_names):
|
|
159
|
+
raise ValueError(
|
|
160
|
+
f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
|
|
161
|
+
)
|
|
162
|
+
|
|
131
163
|
self.config_dict = config_dict
|
|
132
164
|
self.bands = bands
|
|
133
165
|
self.dtype = dtype
|
|
134
166
|
self.zoom_offset = zoom_offset
|
|
135
167
|
self.remap = remap
|
|
168
|
+
self.class_names = class_names
|
|
169
|
+
self.nodata_vals = nodata_vals
|
|
136
170
|
|
|
137
171
|
if format is None:
|
|
138
172
|
self.format = {"name": "geotiff"}
|
|
@@ -140,14 +174,8 @@ class BandSetConfig:
|
|
|
140
174
|
self.format = format
|
|
141
175
|
|
|
142
176
|
def serialize(self) -> dict[str, Any]:
|
|
143
|
-
"""Serialize this BandSetConfig to a config dict
|
|
144
|
-
return
|
|
145
|
-
"bands": self.bands,
|
|
146
|
-
"format": self.format,
|
|
147
|
-
"dtype": self.dtype,
|
|
148
|
-
"zoom_offset": self.zoom_offset,
|
|
149
|
-
"remap": self.remap,
|
|
150
|
-
}
|
|
177
|
+
"""Serialize this BandSetConfig to a config dict."""
|
|
178
|
+
return self.config_dict
|
|
151
179
|
|
|
152
180
|
@staticmethod
|
|
153
181
|
def from_config(config: dict[str, Any]) -> "BandSetConfig":
|
|
@@ -161,14 +189,14 @@ class BandSetConfig:
|
|
|
161
189
|
dtype=DType(config["dtype"]),
|
|
162
190
|
bands=config["bands"],
|
|
163
191
|
)
|
|
164
|
-
for k in ["format", "zoom_offset", "remap"]:
|
|
192
|
+
for k in ["format", "zoom_offset", "remap", "class_names", "nodata_vals"]:
|
|
165
193
|
if k in config:
|
|
166
194
|
kwargs[k] = config[k]
|
|
167
195
|
return BandSetConfig(**kwargs) # type: ignore
|
|
168
196
|
|
|
169
197
|
def get_final_projection_and_bounds(
|
|
170
|
-
self, projection: Projection, bounds: PixelBounds
|
|
171
|
-
) -> tuple[Projection, PixelBounds
|
|
198
|
+
self, projection: Projection, bounds: PixelBounds
|
|
199
|
+
) -> tuple[Projection, PixelBounds]:
|
|
172
200
|
"""Gets the final projection/bounds based on band set config.
|
|
173
201
|
|
|
174
202
|
The band set config may apply a non-zero zoom offset that modifies the window's
|
|
@@ -189,15 +217,14 @@ class BandSetConfig:
|
|
|
189
217
|
projection.x_resolution / (2**self.zoom_offset),
|
|
190
218
|
projection.y_resolution / (2**self.zoom_offset),
|
|
191
219
|
)
|
|
192
|
-
if
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
)
|
|
220
|
+
if self.zoom_offset > 0:
|
|
221
|
+
zoom_factor = 2**self.zoom_offset
|
|
222
|
+
bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
|
|
223
|
+
else:
|
|
224
|
+
bounds = tuple(
|
|
225
|
+
x // (2 ** (-self.zoom_offset))
|
|
226
|
+
for x in bounds # type: ignore
|
|
227
|
+
)
|
|
201
228
|
return projection, bounds
|
|
202
229
|
|
|
203
230
|
|
|
@@ -217,6 +244,22 @@ class SpaceMode(Enum):
|
|
|
217
244
|
dataset.
|
|
218
245
|
"""
|
|
219
246
|
|
|
247
|
+
PER_PERIOD_MOSAIC = 4
|
|
248
|
+
"""Create one mosaic per sub-period of the time range.
|
|
249
|
+
|
|
250
|
+
The duration of the sub-periods is controlled by another option in QueryConfig.
|
|
251
|
+
"""
|
|
252
|
+
|
|
253
|
+
COMPOSITE = 5
|
|
254
|
+
"""Creates one composite covering the entire window.
|
|
255
|
+
|
|
256
|
+
During querying all items intersecting the window are placed in one group.
|
|
257
|
+
The compositing_method in the rasterlayer config specifies how these items are reduced
|
|
258
|
+
to a single item (e.g MEAN/MEDIAN/FIRST_VALID) during materialization.
|
|
259
|
+
"""
|
|
260
|
+
|
|
261
|
+
# TODO add PER_PERIOD_COMPOSITE
|
|
262
|
+
|
|
220
263
|
|
|
221
264
|
class TimeMode(Enum):
|
|
222
265
|
"""Temporal matching mode when looking up items corresponding to a window."""
|
|
@@ -228,10 +271,10 @@ class TimeMode(Enum):
|
|
|
228
271
|
"""Select items closest to the window time range, up to max_matches."""
|
|
229
272
|
|
|
230
273
|
BEFORE = 3
|
|
231
|
-
"""Select items before the
|
|
274
|
+
"""Select items before the end of the window time range, up to max_matches."""
|
|
232
275
|
|
|
233
276
|
AFTER = 4
|
|
234
|
-
"""Select items after the
|
|
277
|
+
"""Select items after the start of the window time range, up to max_matches."""
|
|
235
278
|
|
|
236
279
|
|
|
237
280
|
class QueryConfig:
|
|
@@ -241,7 +284,9 @@ class QueryConfig:
|
|
|
241
284
|
self,
|
|
242
285
|
space_mode: SpaceMode = SpaceMode.MOSAIC,
|
|
243
286
|
time_mode: TimeMode = TimeMode.WITHIN,
|
|
287
|
+
min_matches: int = 0,
|
|
244
288
|
max_matches: int = 1,
|
|
289
|
+
period_duration: timedelta = timedelta(days=30),
|
|
245
290
|
):
|
|
246
291
|
"""Creates a new query configuration.
|
|
247
292
|
|
|
@@ -251,19 +296,29 @@ class QueryConfig:
|
|
|
251
296
|
Args:
|
|
252
297
|
space_mode: specifies how items should be matched with windows spatially
|
|
253
298
|
time_mode: specifies how items should be matched with windows temporally
|
|
299
|
+
min_matches: the minimum number of item groups. If there are fewer than
|
|
300
|
+
this many matches, then no matches will be returned. This can be used
|
|
301
|
+
to prevent unnecessary data ingestion if the user plans to discard
|
|
302
|
+
windows that do not have a sufficient amount of data.
|
|
254
303
|
max_matches: the maximum number of items (or groups of items, if space_mode
|
|
255
304
|
is MOSAIC) to match
|
|
305
|
+
period_duration: the duration of the periods, if the space mode is
|
|
306
|
+
PER_PERIOD_MOSAIC.
|
|
256
307
|
"""
|
|
257
308
|
self.space_mode = space_mode
|
|
258
309
|
self.time_mode = time_mode
|
|
310
|
+
self.min_matches = min_matches
|
|
259
311
|
self.max_matches = max_matches
|
|
312
|
+
self.period_duration = period_duration
|
|
260
313
|
|
|
261
314
|
def serialize(self) -> dict[str, Any]:
|
|
262
|
-
"""Serialize this QueryConfig to a config dict
|
|
315
|
+
"""Serialize this QueryConfig to a config dict."""
|
|
263
316
|
return {
|
|
264
317
|
"space_mode": str(self.space_mode),
|
|
265
318
|
"time_mode": str(self.time_mode),
|
|
319
|
+
"min_matches": self.min_matches,
|
|
266
320
|
"max_matches": self.max_matches,
|
|
321
|
+
"period_duration": f"{self.period_duration.total_seconds()}s",
|
|
267
322
|
}
|
|
268
323
|
|
|
269
324
|
@staticmethod
|
|
@@ -273,11 +328,20 @@ class QueryConfig:
|
|
|
273
328
|
Args:
|
|
274
329
|
config: the config dict for this QueryConfig
|
|
275
330
|
"""
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
331
|
+
kwargs: dict[str, Any] = dict()
|
|
332
|
+
if "space_mode" in config:
|
|
333
|
+
kwargs["space_mode"] = SpaceMode[config["space_mode"]]
|
|
334
|
+
if "time_mode" in config:
|
|
335
|
+
kwargs["time_mode"] = TimeMode[config["time_mode"]]
|
|
336
|
+
if "period_duration" in config:
|
|
337
|
+
kwargs["period_duration"] = timedelta(
|
|
338
|
+
seconds=pytimeparse.parse(config["period_duration"])
|
|
339
|
+
)
|
|
340
|
+
for k in ["min_matches", "max_matches"]:
|
|
341
|
+
if k not in config:
|
|
342
|
+
continue
|
|
343
|
+
kwargs[k] = config[k]
|
|
344
|
+
return QueryConfig(**kwargs)
|
|
281
345
|
|
|
282
346
|
|
|
283
347
|
class DataSourceConfig:
|
|
@@ -313,16 +377,8 @@ class DataSourceConfig:
|
|
|
313
377
|
self.ingest = ingest
|
|
314
378
|
|
|
315
379
|
def serialize(self) -> dict[str, Any]:
|
|
316
|
-
"""Serialize this DataSourceConfig to a config dict
|
|
317
|
-
|
|
318
|
-
config_dict["name"] = self.name
|
|
319
|
-
config_dict["query_config"] = self.query_config.serialize()
|
|
320
|
-
config_dict["ingest"] = self.ingest
|
|
321
|
-
if self.time_offset:
|
|
322
|
-
config_dict["time_offset"] = str(self.time_offset)
|
|
323
|
-
if self.duration:
|
|
324
|
-
config_dict["duration"] = str(self.duration)
|
|
325
|
-
return config_dict
|
|
380
|
+
"""Serialize this DataSourceConfig to a config dict."""
|
|
381
|
+
return self.config_dict
|
|
326
382
|
|
|
327
383
|
@staticmethod
|
|
328
384
|
def from_config(config: dict[str, Any]) -> "DataSourceConfig":
|
|
@@ -377,13 +433,40 @@ class LayerConfig:
|
|
|
377
433
|
self.alias = alias
|
|
378
434
|
|
|
379
435
|
def serialize(self) -> dict[str, Any]:
|
|
380
|
-
"""Serialize this LayerConfig to a config dict
|
|
436
|
+
"""Serialize this LayerConfig to a config dict."""
|
|
381
437
|
return {
|
|
382
438
|
"layer_type": str(self.layer_type),
|
|
383
|
-
"data_source": self.data_source,
|
|
439
|
+
"data_source": self.data_source.serialize() if self.data_source else None,
|
|
384
440
|
"alias": self.alias,
|
|
385
441
|
}
|
|
386
442
|
|
|
443
|
+
def __hash__(self) -> int:
|
|
444
|
+
"""Return a hash of this LayerConfig."""
|
|
445
|
+
return hash(json.dumps(self.serialize(), sort_keys=True))
|
|
446
|
+
|
|
447
|
+
def __eq__(self, other: Any) -> bool:
|
|
448
|
+
"""Returns whether other is the same as this LayerConfig.
|
|
449
|
+
|
|
450
|
+
Args:
|
|
451
|
+
other: the other object to compare.
|
|
452
|
+
"""
|
|
453
|
+
if not isinstance(other, LayerConfig):
|
|
454
|
+
return False
|
|
455
|
+
return self.serialize() == other.serialize()
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
class CompositingMethod(Enum):
|
|
459
|
+
"""Method how to select pixels for the composite from corresponding items of a window."""
|
|
460
|
+
|
|
461
|
+
FIRST_VALID = 1
|
|
462
|
+
"""Select first valid pixel in order of corresponding items (might be sorted)"""
|
|
463
|
+
|
|
464
|
+
MEAN = 2
|
|
465
|
+
"""Select per-pixel mean value of corresponding items of a window"""
|
|
466
|
+
|
|
467
|
+
MEDIAN = 3
|
|
468
|
+
"""Select per-pixel median value of corresponding items of a window"""
|
|
469
|
+
|
|
387
470
|
|
|
388
471
|
class RasterLayerConfig(LayerConfig):
|
|
389
472
|
"""Configuration of a raster layer."""
|
|
@@ -395,6 +478,7 @@ class RasterLayerConfig(LayerConfig):
|
|
|
395
478
|
data_source: DataSourceConfig | None = None,
|
|
396
479
|
resampling_method: Resampling = Resampling.bilinear,
|
|
397
480
|
alias: str | None = None,
|
|
481
|
+
compositing_method: CompositingMethod = CompositingMethod.FIRST_VALID,
|
|
398
482
|
):
|
|
399
483
|
"""Initialize a new RasterLayerConfig.
|
|
400
484
|
|
|
@@ -404,10 +488,12 @@ class RasterLayerConfig(LayerConfig):
|
|
|
404
488
|
data_source: optional DataSourceConfig if this layer is retrievable
|
|
405
489
|
resampling_method: how to resample rasters (if needed), default bilinear resampling
|
|
406
490
|
alias: alias for this layer to use in the tile store
|
|
491
|
+
compositing_method: how to compute pixel values in the composite of each windows items
|
|
407
492
|
"""
|
|
408
493
|
super().__init__(layer_type, data_source, alias)
|
|
409
494
|
self.band_sets = band_sets
|
|
410
495
|
self.resampling_method = resampling_method
|
|
496
|
+
self.compositing_method = compositing_method
|
|
411
497
|
|
|
412
498
|
@staticmethod
|
|
413
499
|
def from_config(config: dict[str, Any]) -> "RasterLayerConfig":
|
|
@@ -428,6 +514,10 @@ class RasterLayerConfig(LayerConfig):
|
|
|
428
514
|
]
|
|
429
515
|
if "alias" in config:
|
|
430
516
|
kwargs["alias"] = config["alias"]
|
|
517
|
+
if "compositing_method" in config:
|
|
518
|
+
kwargs["compositing_method"] = CompositingMethod[
|
|
519
|
+
config["compositing_method"]
|
|
520
|
+
]
|
|
431
521
|
return RasterLayerConfig(**kwargs) # type: ignore
|
|
432
522
|
|
|
433
523
|
|
|
@@ -438,22 +528,28 @@ class VectorLayerConfig(LayerConfig):
|
|
|
438
528
|
self,
|
|
439
529
|
layer_type: LayerType,
|
|
440
530
|
data_source: DataSourceConfig | None = None,
|
|
441
|
-
zoom_offset: int = 0,
|
|
442
531
|
format: VectorFormatConfig = VectorFormatConfig("geojson"),
|
|
443
532
|
alias: str | None = None,
|
|
533
|
+
class_property_name: str | None = None,
|
|
534
|
+
class_names: list[str] | None = None,
|
|
444
535
|
):
|
|
445
536
|
"""Initialize a new VectorLayerConfig.
|
|
446
537
|
|
|
447
538
|
Args:
|
|
448
539
|
layer_type: the LayerType (must be vector)
|
|
449
540
|
data_source: optional DataSourceConfig if this layer is retrievable
|
|
450
|
-
zoom_offset: zoom offset at which to store the vector data
|
|
451
541
|
format: the VectorFormatConfig, default storing as GeoJSON
|
|
452
542
|
alias: alias for this layer to use in the tile store
|
|
543
|
+
class_property_name: optional metadata field indicating that the GeoJSON
|
|
544
|
+
features contain a property that corresponds to a class label, and this
|
|
545
|
+
is the name of that property.
|
|
546
|
+
class_names: the list of classes that the class_property_name property
|
|
547
|
+
could be set to.
|
|
453
548
|
"""
|
|
454
549
|
super().__init__(layer_type, data_source, alias)
|
|
455
|
-
self.zoom_offset = zoom_offset
|
|
456
550
|
self.format = format
|
|
551
|
+
self.class_property_name = class_property_name
|
|
552
|
+
self.class_names = class_names
|
|
457
553
|
|
|
458
554
|
@staticmethod
|
|
459
555
|
def from_config(config: dict[str, Any]) -> "VectorLayerConfig":
|
|
@@ -465,42 +561,26 @@ class VectorLayerConfig(LayerConfig):
|
|
|
465
561
|
kwargs: dict[str, Any] = {"layer_type": LayerType(config["type"])}
|
|
466
562
|
if "data_source" in config:
|
|
467
563
|
kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
|
|
468
|
-
if "zoom_offset" in config:
|
|
469
|
-
kwargs["zoom_offset"] = config["zoom_offset"]
|
|
470
564
|
if "format" in config:
|
|
471
565
|
kwargs["format"] = VectorFormatConfig.from_config(config["format"])
|
|
472
|
-
if "alias" in config:
|
|
473
|
-
kwargs["alias"] = config["alias"]
|
|
474
|
-
return VectorLayerConfig(**kwargs) # type: ignore
|
|
475
566
|
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
567
|
+
simple_optionals = [
|
|
568
|
+
"alias",
|
|
569
|
+
"class_property_name",
|
|
570
|
+
"class_names",
|
|
571
|
+
]
|
|
572
|
+
for k in simple_optionals:
|
|
573
|
+
if k in config:
|
|
574
|
+
kwargs[k] = config[k]
|
|
480
575
|
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
576
|
+
# The "zoom_offset" option was removed.
|
|
577
|
+
# We should change how we create configuration so we can error on all
|
|
578
|
+
# non-existing config options, but for now we make sure to raise error if
|
|
579
|
+
# zoom_offset is set since it is no longer supported.
|
|
580
|
+
if "zoom_offset" in config:
|
|
581
|
+
raise ValueError("unsupported zoom_offset option in vector layer config")
|
|
484
582
|
|
|
485
|
-
|
|
486
|
-
tuple of updated projection and bounds with zoom offset applied
|
|
487
|
-
"""
|
|
488
|
-
if self.zoom_offset == 0:
|
|
489
|
-
return projection, bounds
|
|
490
|
-
projection = Projection(
|
|
491
|
-
projection.crs,
|
|
492
|
-
projection.x_resolution / (2**self.zoom_offset),
|
|
493
|
-
projection.y_resolution / (2**self.zoom_offset),
|
|
494
|
-
)
|
|
495
|
-
if bounds:
|
|
496
|
-
if self.zoom_offset > 0:
|
|
497
|
-
bounds = tuple(x * (2**self.zoom_offset) for x in bounds) # type: ignore
|
|
498
|
-
else:
|
|
499
|
-
bounds = tuple(
|
|
500
|
-
x // (2 ** (-self.zoom_offset))
|
|
501
|
-
for x in bounds # type: ignore
|
|
502
|
-
)
|
|
503
|
-
return projection, bounds
|
|
583
|
+
return VectorLayerConfig(**kwargs) # type: ignore
|
|
504
584
|
|
|
505
585
|
|
|
506
586
|
def load_layer_config(config: dict[str, Any]) -> LayerConfig:
|
|
@@ -511,26 +591,3 @@ def load_layer_config(config: dict[str, Any]) -> LayerConfig:
|
|
|
511
591
|
elif layer_type == LayerType.VECTOR:
|
|
512
592
|
return VectorLayerConfig.from_config(config)
|
|
513
593
|
raise ValueError(f"Unknown layer type {layer_type}")
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
class TileStoreConfig:
|
|
517
|
-
"""A configuration specifying a TileStore."""
|
|
518
|
-
|
|
519
|
-
def __init__(self, name: str, config_dict: dict[str, Any]) -> None:
|
|
520
|
-
"""Create a new TileStoreConfig.
|
|
521
|
-
|
|
522
|
-
Args:
|
|
523
|
-
name: the tile store implementation name to use
|
|
524
|
-
config_dict: configuration options
|
|
525
|
-
"""
|
|
526
|
-
self.name = name
|
|
527
|
-
self.config_dict = config_dict
|
|
528
|
-
|
|
529
|
-
@staticmethod
|
|
530
|
-
def from_config(config: dict[str, Any]) -> "TileStoreConfig":
|
|
531
|
-
"""Create a TileStoreConfig from config dict.
|
|
532
|
-
|
|
533
|
-
Args:
|
|
534
|
-
config: the config dict for this TileStoreConfig
|
|
535
|
-
"""
|
|
536
|
-
return TileStoreConfig(name=config["name"], config_dict=config)
|
rslearn/const.py
CHANGED
|
@@ -1,23 +1,17 @@
|
|
|
1
1
|
"""Constants."""
|
|
2
2
|
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
from rslearn.utils import PixelBounds, Projection
|
|
6
|
-
|
|
7
|
-
WGS84_EPSG = 4326
|
|
8
|
-
"""The EPSG code for WGS-84."""
|
|
9
|
-
|
|
10
|
-
WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
|
|
11
|
-
"""The Projection for WGS-84 assuming 1 degree per pixel.
|
|
12
|
-
|
|
13
|
-
This can be used to create STGeometry with shapes in longitude/latitude coordinates.
|
|
14
|
-
"""
|
|
15
|
-
|
|
16
|
-
WGS84_BOUNDS: PixelBounds = (-180, -90, 180, 90)
|
|
17
|
-
"""The bounds of the WGS-84 projection."""
|
|
3
|
+
from rslearn.utils.geometry import WGS84_BOUNDS, WGS84_EPSG, WGS84_PROJECTION
|
|
18
4
|
|
|
19
5
|
TILE_SIZE = 512
|
|
20
6
|
"""Default tile size. TODO: remove this or move it elsewhere."""
|
|
21
7
|
|
|
22
8
|
SHAPEFILE_AUX_EXTENSIONS = [".cpg", ".dbf", ".prj", ".sbn", ".sbx", ".shx", ".txt"]
|
|
23
9
|
"""Extensions of potential auxiliary files to .shp file."""
|
|
10
|
+
|
|
11
|
+
__all__ = (
|
|
12
|
+
"WGS84_PROJECTION",
|
|
13
|
+
"WGS84_EPSG",
|
|
14
|
+
"WGS84_BOUNDS",
|
|
15
|
+
"TILE_SIZE",
|
|
16
|
+
"SHAPEFILE_AUX_EXTENSIONS",
|
|
17
|
+
)
|