rslearn 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. rslearn/config/__init__.py +2 -10
  2. rslearn/config/dataset.py +414 -420
  3. rslearn/data_sources/__init__.py +8 -31
  4. rslearn/data_sources/aws_landsat.py +13 -24
  5. rslearn/data_sources/aws_open_data.py +21 -46
  6. rslearn/data_sources/aws_sentinel1.py +3 -14
  7. rslearn/data_sources/climate_data_store.py +21 -40
  8. rslearn/data_sources/copernicus.py +30 -91
  9. rslearn/data_sources/data_source.py +26 -0
  10. rslearn/data_sources/earthdaily.py +13 -38
  11. rslearn/data_sources/earthdata_srtm.py +14 -32
  12. rslearn/data_sources/eurocrops.py +5 -9
  13. rslearn/data_sources/gcp_public_data.py +46 -43
  14. rslearn/data_sources/google_earth_engine.py +31 -44
  15. rslearn/data_sources/local_files.py +91 -100
  16. rslearn/data_sources/openstreetmap.py +21 -51
  17. rslearn/data_sources/planet.py +12 -30
  18. rslearn/data_sources/planet_basemap.py +4 -25
  19. rslearn/data_sources/planetary_computer.py +58 -141
  20. rslearn/data_sources/usda_cdl.py +15 -26
  21. rslearn/data_sources/usgs_landsat.py +4 -29
  22. rslearn/data_sources/utils.py +9 -0
  23. rslearn/data_sources/worldcereal.py +47 -54
  24. rslearn/data_sources/worldcover.py +16 -14
  25. rslearn/data_sources/worldpop.py +15 -18
  26. rslearn/data_sources/xyz_tiles.py +11 -30
  27. rslearn/dataset/dataset.py +6 -6
  28. rslearn/dataset/manage.py +28 -26
  29. rslearn/dataset/materialize.py +9 -45
  30. rslearn/lightning_cli.py +370 -1
  31. rslearn/main.py +3 -3
  32. rslearn/models/clay/clay.py +14 -1
  33. rslearn/models/concatenate_features.py +93 -0
  34. rslearn/models/croma.py +26 -3
  35. rslearn/models/satlaspretrain.py +18 -4
  36. rslearn/models/terramind.py +19 -0
  37. rslearn/tile_stores/__init__.py +0 -11
  38. rslearn/train/dataset.py +4 -12
  39. rslearn/train/prediction_writer.py +16 -32
  40. rslearn/train/tasks/classification.py +2 -1
  41. rslearn/utils/fsspec.py +20 -0
  42. rslearn/utils/jsonargparse.py +79 -0
  43. rslearn/utils/raster_format.py +1 -41
  44. rslearn/utils/vector_format.py +1 -38
  45. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
  46. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/RECORD +51 -52
  47. rslearn/data_sources/geotiff.py +0 -1
  48. rslearn/data_sources/raster_source.py +0 -23
  49. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
  50. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
  51. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
  52. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
  53. {rslearn-0.0.14.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
@@ -4,15 +4,14 @@ from typing import Any
4
4
 
5
5
  import satlaspretrain_models
6
6
  import torch
7
+ import torch.nn.functional as F
7
8
 
8
9
 
9
10
  class SatlasPretrain(torch.nn.Module):
10
11
  """SatlasPretrain backbones."""
11
12
 
12
13
  def __init__(
13
- self,
14
- model_identifier: str,
15
- fpn: bool = False,
14
+ self, model_identifier: str, fpn: bool = False, resize_to_pretrain: bool = False
16
15
  ) -> None:
17
16
  """Instantiate a new SatlasPretrain instance.
18
17
 
@@ -21,6 +20,8 @@ class SatlasPretrain(torch.nn.Module):
21
20
  https://github.com/allenai/satlaspretrain_models
22
21
  fpn: whether to include the feature pyramid network, otherwise only the
23
22
  Swin-v2-Transformer is used.
23
+ resize_to_pretrain: whether to resize inputs to the pretraining input
24
+ size (512 x 512)
24
25
  """
25
26
  super().__init__()
26
27
  weights_manager = satlaspretrain_models.Weights()
@@ -49,6 +50,19 @@ class SatlasPretrain(torch.nn.Module):
49
50
  [16, 1024],
50
51
  [32, 2048],
51
52
  ]
53
+ self.resize_to_pretrain = resize_to_pretrain
54
+
55
+ def maybe_resize(self, data: torch.Tensor) -> list[torch.Tensor]:
56
+ """Resize to pretraining sizes if resize_to_pretrain == True."""
57
+ if self.resize_to_pretrain:
58
+ return F.interpolate(
59
+ data,
60
+ size=(512, 512),
61
+ mode="bilinear",
62
+ align_corners=False,
63
+ )
64
+ else:
65
+ return data
52
66
 
53
67
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
54
68
  """Compute feature maps from the SatlasPretrain backbone.
@@ -58,7 +72,7 @@ class SatlasPretrain(torch.nn.Module):
58
72
  process.
59
73
  """
60
74
  images = torch.stack([inp["image"] for inp in inputs], dim=0)
61
- return self.model(images)
75
+ return self.model(self.maybe_resize(images))
62
76
 
63
77
  def get_backbone_channels(self) -> list:
64
78
  """Returns the output channels of this model when used as a backbone.
@@ -4,6 +4,7 @@ from enum import Enum
4
4
  from typing import Any
5
5
 
6
6
  import torch
7
+ import torch.nn.functional as F
7
8
  from einops import rearrange
8
9
  from terratorch.registry import BACKBONE_REGISTRY
9
10
 
@@ -18,6 +19,8 @@ class TerramindSize(str, Enum):
18
19
  LARGE = "large"
19
20
 
20
21
 
22
+ # Pretraining image size for Terramind
23
+ IMAGE_SIZE = 224
21
24
  # Default patch size for Terramind
22
25
  PATCH_SIZE = 16
23
26
 
@@ -89,12 +92,14 @@ class Terramind(torch.nn.Module):
89
92
  self,
90
93
  model_size: TerramindSize,
91
94
  modalities: list[str] = ["S2L2A"],
95
+ do_resizing: bool = False,
92
96
  ) -> None:
93
97
  """Initialize the Terramind model.
94
98
 
95
99
  Args:
96
100
  model_size: The size of the Terramind model.
97
101
  modalities: The modalities to use.
102
+ do_resizing: Whether to resize the input images to the pretraining resolution.
98
103
  """
99
104
  super().__init__()
100
105
 
@@ -116,6 +121,7 @@ class Terramind(torch.nn.Module):
116
121
 
117
122
  self.model_size = model_size
118
123
  self.modalities = modalities
124
+ self.do_resizing = do_resizing
119
125
 
120
126
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
121
127
  """Forward pass for the Terramind model.
@@ -132,6 +138,19 @@ class Terramind(torch.nn.Module):
132
138
  if modality not in inputs[0]:
133
139
  continue
134
140
  cur = torch.stack([inp[modality] for inp in inputs], dim=0) # (B, C, H, W)
141
+ if self.do_resizing and (
142
+ cur.shape[2] != IMAGE_SIZE or cur.shape[3] != IMAGE_SIZE
143
+ ):
144
+ if cur.shape[2] == 1 and cur.shape[3] == 1:
145
+ new_height, new_width = PATCH_SIZE, PATCH_SIZE
146
+ else:
147
+ new_height, new_width = IMAGE_SIZE, IMAGE_SIZE
148
+ cur = F.interpolate(
149
+ cur,
150
+ size=(new_height, new_width),
151
+ mode="bilinear",
152
+ align_corners=False,
153
+ )
135
154
  model_inputs[modality] = cur
136
155
 
137
156
  # By default, the patch embeddings are averaged over all modalities to reduce output tokens
@@ -22,17 +22,6 @@ def load_tile_store(config: dict[str, Any], ds_path: UPath) -> TileStore:
22
22
  Returns:
23
23
  the TileStore
24
24
  """
25
- if config is None:
26
- tile_store = DefaultTileStore()
27
- tile_store.set_dataset_path(ds_path)
28
- return tile_store
29
-
30
- # Backwards compatability.
31
- if "name" in config and "root_dir" in config and config["name"] == "file":
32
- tile_store = DefaultTileStore(config["root_dir"])
33
- tile_store.set_dataset_path(ds_path)
34
- return tile_store
35
-
36
25
  init_jsonargparse()
37
26
  parser = jsonargparse.ArgumentParser()
38
27
  parser.add_argument("--tile_store", type=TileStore)
rslearn/train/dataset.py CHANGED
@@ -17,9 +17,7 @@ from rasterio.warp import Resampling
17
17
  import rslearn.train.transforms.transform
18
18
  from rslearn.config import (
19
19
  DType,
20
- RasterFormatConfig,
21
- RasterLayerConfig,
22
- VectorLayerConfig,
20
+ LayerConfig,
23
21
  )
24
22
  from rslearn.dataset.dataset import Dataset
25
23
  from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
@@ -28,8 +26,6 @@ from rslearn.train.tasks import Task
28
26
  from rslearn.utils.feature import Feature
29
27
  from rslearn.utils.geometry import PixelBounds
30
28
  from rslearn.utils.mp import star_imap_unordered
31
- from rslearn.utils.raster_format import load_raster_format
32
- from rslearn.utils.vector_format import load_vector_format
33
29
 
34
30
  from .transforms import Sequential
35
31
 
@@ -185,7 +181,7 @@ def read_raster_layer_for_data_input(
185
181
  bounds: PixelBounds,
186
182
  layer_name: str,
187
183
  group_idx: int,
188
- layer_config: RasterLayerConfig,
184
+ layer_config: LayerConfig,
189
185
  data_input: DataInput,
190
186
  ) -> torch.Tensor:
191
187
  """Read a raster layer for a DataInput.
@@ -246,9 +242,7 @@ def read_raster_layer_for_data_input(
246
242
  )
247
243
  if band_set.format is None:
248
244
  raise ValueError(f"No format specified for {layer_name}")
249
- raster_format = load_raster_format(
250
- RasterFormatConfig(band_set.format["name"], band_set.format)
251
- )
245
+ raster_format = band_set.instantiate_raster_format()
252
246
  raster_dir = window.get_raster_dir(
253
247
  layer_name, band_set.bands, group_idx=group_idx
254
248
  )
@@ -349,7 +343,6 @@ def read_data_input(
349
343
  images: list[torch.Tensor] = []
350
344
  for layer_name, group_idx in layers_to_read:
351
345
  layer_config = dataset.layers[layer_name]
352
- assert isinstance(layer_config, RasterLayerConfig)
353
346
  images.append(
354
347
  read_raster_layer_for_data_input(
355
348
  window, bounds, layer_name, group_idx, layer_config, data_input
@@ -363,8 +356,7 @@ def read_data_input(
363
356
  features: list[Feature] = []
364
357
  for layer_name, group_idx in layers_to_read:
365
358
  layer_config = dataset.layers[layer_name]
366
- assert isinstance(layer_config, VectorLayerConfig)
367
- vector_format = load_vector_format(layer_config.format)
359
+ vector_format = layer_config.instantiate_vector_format()
368
360
  layer_dir = window.get_layer_dir(layer_name, group_idx=group_idx)
369
361
  cur_features = vector_format.decode_vector(
370
362
  layer_dir, window.projection, window.bounds
@@ -12,10 +12,8 @@ from lightning.pytorch.callbacks import BasePredictionWriter
12
12
  from upath import UPath
13
13
 
14
14
  from rslearn.config import (
15
+ LayerConfig,
15
16
  LayerType,
16
- RasterFormatConfig,
17
- RasterLayerConfig,
18
- VectorLayerConfig,
19
17
  )
20
18
  from rslearn.dataset import Dataset, Window
21
19
  from rslearn.log_utils import get_logger
@@ -25,9 +23,8 @@ from rslearn.utils.geometry import PixelBounds
25
23
  from rslearn.utils.raster_format import (
26
24
  RasterFormat,
27
25
  adjust_projection_and_bounds_for_array,
28
- load_raster_format,
29
26
  )
30
- from rslearn.utils.vector_format import VectorFormat, load_vector_format
27
+ from rslearn.utils.vector_format import VectorFormat
31
28
 
32
29
  from .lightning_module import RslearnLightningModule
33
30
  from .tasks.task import Task
@@ -150,7 +147,7 @@ class RslearnWriter(BasePredictionWriter):
150
147
  selector: list[str] | None = None,
151
148
  merger: PatchPredictionMerger | None = None,
152
149
  output_path: str | Path | None = None,
153
- layer_config: RasterLayerConfig | VectorLayerConfig | None = None,
150
+ layer_config: LayerConfig | None = None,
154
151
  ):
155
152
  """Create a new RslearnWriter.
156
153
 
@@ -178,43 +175,31 @@ class RslearnWriter(BasePredictionWriter):
178
175
  )
179
176
 
180
177
  # Handle dataset and layer config
181
- self.layer_config: RasterLayerConfig | VectorLayerConfig
178
+ self.layer_config: LayerConfig
182
179
  if layer_config:
183
180
  self.layer_config = layer_config
184
- self.dataset = None if self.output_path else Dataset(self.path)
185
181
  else:
186
- self.dataset = Dataset(self.path)
187
- if self.output_layer not in self.dataset.layers:
182
+ dataset = Dataset(self.path)
183
+ if self.output_layer not in dataset.layers:
188
184
  raise KeyError(
189
185
  f"Output layer '{self.output_layer}' not found in dataset layers."
190
186
  )
191
- raw_layer_config = self.dataset.layers[self.output_layer]
192
- # Type narrowing to ensure compatibility
193
- if isinstance(raw_layer_config, (RasterLayerConfig | VectorLayerConfig)):
194
- self.layer_config = raw_layer_config
195
- else:
196
- raise ValueError(
197
- f"Layer config must be RasterLayerConfig or VectorLayerConfig, got {type(raw_layer_config)}"
198
- )
187
+ self.layer_config = dataset.layers[self.output_layer]
199
188
 
200
189
  self.format: RasterFormat | VectorFormat
201
- if self.layer_config.layer_type == LayerType.RASTER:
202
- assert isinstance(self.layer_config, RasterLayerConfig)
190
+ if self.layer_config.type == LayerType.RASTER:
203
191
  band_cfg = self.layer_config.band_sets[0]
204
- self.format = load_raster_format(
205
- RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
206
- )
207
- elif self.layer_config.layer_type == LayerType.VECTOR:
208
- assert isinstance(self.layer_config, VectorLayerConfig)
209
- self.format = load_vector_format(self.layer_config.format)
192
+ self.format = band_cfg.instantiate_raster_format()
193
+ elif self.layer_config.type == LayerType.VECTOR:
194
+ self.format = self.layer_config.instantiate_vector_format()
210
195
  else:
211
- raise ValueError(f"invalid layer type {self.layer_config.layer_type}")
196
+ raise ValueError(f"invalid layer type {self.layer_config.type}")
212
197
 
213
198
  if merger is not None:
214
199
  self.merger = merger
215
- elif self.layer_config.layer_type == LayerType.RASTER:
200
+ elif self.layer_config.type == LayerType.RASTER:
216
201
  self.merger = RasterMerger()
217
- elif self.layer_config.layer_type == LayerType.VECTOR:
202
+ elif self.layer_config.type == LayerType.VECTOR:
218
203
  self.merger = VectorMerger()
219
204
 
220
205
  # Map from window name to pending data to write.
@@ -337,8 +322,7 @@ class RslearnWriter(BasePredictionWriter):
337
322
  logger.debug(f"Merging and writing for window {window.name}")
338
323
  merged_output = self.merger.merge(window, pending_output)
339
324
 
340
- if self.layer_config.layer_type == LayerType.RASTER:
341
- assert isinstance(self.layer_config, RasterLayerConfig)
325
+ if self.layer_config.type == LayerType.RASTER:
342
326
  raster_dir = window.get_raster_dir(
343
327
  self.output_layer, self.layer_config.band_sets[0].bands
344
328
  )
@@ -351,7 +335,7 @@ class RslearnWriter(BasePredictionWriter):
351
335
  )
352
336
  self.format.encode_raster(raster_dir, projection, bounds, merged_output)
353
337
 
354
- elif self.layer_config.layer_type == LayerType.VECTOR:
338
+ elif self.layer_config.type == LayerType.VECTOR:
355
339
  layer_dir = window.get_layer_dir(self.output_layer)
356
340
  assert isinstance(self.format, VectorFormat)
357
341
  self.format.encode_vector(layer_dir, merged_output)
@@ -26,7 +26,7 @@ class ClassificationTask(BasicTask):
26
26
  def __init__(
27
27
  self,
28
28
  property_name: str,
29
- classes: list, # TODO: Should this be a list of str or int or can it be both?
29
+ classes: list[str],
30
30
  filters: list[tuple[str, str]] = [],
31
31
  read_class_id: bool = False,
32
32
  allow_invalid: bool = False,
@@ -176,6 +176,7 @@ class ClassificationTask(BasicTask):
176
176
  # For multiclass classification or when using the default threshold
177
177
  class_idx = probs.argmax().item()
178
178
 
179
+ value: str | int
179
180
  if not self.read_class_id:
180
181
  value = self.classes[class_idx] # type: ignore
181
182
  else:
rslearn/utils/fsspec.py CHANGED
@@ -156,3 +156,23 @@ def open_rasterio_upath_writer(
156
156
  with path.open("wb") as f:
157
157
  with rasterio.open(f, "w", **kwargs) as raster:
158
158
  yield raster
159
+
160
+
161
+ def get_relative_suffix(base_dir: UPath, fname: UPath) -> str:
162
+ """Get the suffix of fname relative to base_dir.
163
+
164
+ Args:
165
+ base_dir: the base directory.
166
+ fname: a filename within the base directory.
167
+
168
+ Returns:
169
+ the suffix on base_dir that would yield the given filename.
170
+ """
171
+ if not fname.path.startswith(base_dir.path):
172
+ raise ValueError(
173
+ f"filename {fname.path} must start with base directory {base_dir.path}"
174
+ )
175
+ suffix = fname.path[len(base_dir.path) :]
176
+ if suffix.startswith("/"):
177
+ suffix = suffix[1:]
178
+ return suffix
@@ -1,7 +1,18 @@
1
1
  """Custom serialization for jsonargparse."""
2
2
 
3
+ from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any
5
+
3
6
  import jsonargparse
4
7
  from rasterio.crs import CRS
8
+ from upath import UPath
9
+
10
+ from rslearn.config.dataset import LayerConfig
11
+
12
+ if TYPE_CHECKING:
13
+ from rslearn.data_sources.data_source import DataSourceContext
14
+
15
+ INITIALIZED = False
5
16
 
6
17
 
7
18
  def crs_serializer(v: CRS) -> str:
@@ -28,6 +39,74 @@ def crs_deserializer(v: str) -> CRS:
28
39
  return CRS.from_string(v)
29
40
 
30
41
 
42
+ def datetime_serializer(v: datetime) -> str:
43
+ """Serialize datetime for jsonargparse.
44
+
45
+ Args:
46
+ v: the datetime object.
47
+
48
+ Returns:
49
+ the datetime encoded to string
50
+ """
51
+ return v.isoformat()
52
+
53
+
54
+ def datetime_deserializer(v: str) -> datetime:
55
+ """Deserialize datetime for jsonargparse.
56
+
57
+ Args:
58
+ v: the encoded datetime.
59
+
60
+ Returns:
61
+ the decoded datetime object
62
+ """
63
+ return datetime.fromisoformat(v)
64
+
65
+
66
+ def data_source_context_serializer(v: "DataSourceContext") -> dict[str, Any]:
67
+ """Serialize DataSourceContext for jsonargparse."""
68
+ x = {
69
+ "ds_path": (str(v.ds_path) if v.ds_path is not None else None),
70
+ "layer_config": (
71
+ v.layer_config.model_dump(mode="json")
72
+ if v.layer_config is not None
73
+ else None
74
+ ),
75
+ }
76
+ return x
77
+
78
+
79
+ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
80
+ """Deserialize DataSourceContext for jsonargparse."""
81
+ # We lazily import these to avoid cyclic dependency.
82
+ from rslearn.data_sources.data_source import DataSourceContext
83
+
84
+ return DataSourceContext(
85
+ ds_path=(UPath(v["ds_path"]) if v["ds_path"] is not None else None),
86
+ layer_config=(
87
+ LayerConfig.model_validate(v["layer_config"])
88
+ if v["layer_config"] is not None
89
+ else None
90
+ ),
91
+ )
92
+
93
+
31
94
  def init_jsonargparse() -> None:
32
95
  """Initialize custom jsonargparse serializers."""
96
+ global INITIALIZED
97
+ if INITIALIZED:
98
+ return
33
99
  jsonargparse.typing.register_type(CRS, crs_serializer, crs_deserializer)
100
+ jsonargparse.typing.register_type(
101
+ datetime, datetime_serializer, datetime_deserializer
102
+ )
103
+
104
+ from rslearn.data_sources.data_source import DataSourceContext
105
+
106
+ jsonargparse.typing.register_type(
107
+ DataSourceContext,
108
+ data_source_context_serializer,
109
+ data_source_context_deserializer,
110
+ )
111
+
112
+ INITIALIZED = True
@@ -2,8 +2,7 @@
2
2
 
3
3
  import hashlib
4
4
  import json
5
- from collections.abc import Callable
6
- from typing import Any, BinaryIO, TypeVar
5
+ from typing import Any, BinaryIO
7
6
 
8
7
  import affine
9
8
  import numpy as np
@@ -14,34 +13,12 @@ from rasterio.crs import CRS
14
13
  from rasterio.enums import Resampling
15
14
  from upath import UPath
16
15
 
17
- from rslearn.config import RasterFormatConfig
18
16
  from rslearn.const import TILE_SIZE
19
17
  from rslearn.log_utils import get_logger
20
18
  from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath_writer
21
19
 
22
20
  from .geometry import PixelBounds, Projection
23
21
 
24
- _RasterFormatT = TypeVar("_RasterFormatT", bound="RasterFormat")
25
-
26
-
27
- class _RasterFormatRegistry(dict[str, type["RasterFormat"]]):
28
- """Registry for RasterFormat classes."""
29
-
30
- def register(
31
- self, name: str
32
- ) -> Callable[[type[_RasterFormatT]], type[_RasterFormatT]]:
33
- """Decorator to register a raster format class."""
34
-
35
- def decorator(cls: type[_RasterFormatT]) -> type[_RasterFormatT]:
36
- self[name] = cls
37
- return cls
38
-
39
- return decorator
40
-
41
-
42
- RasterFormats = _RasterFormatRegistry()
43
-
44
-
45
22
  logger = get_logger(__name__)
46
23
 
47
24
 
@@ -219,7 +196,6 @@ class RasterFormat:
219
196
  raise NotImplementedError
220
197
 
221
198
 
222
- @RasterFormats.register("image_tile")
223
199
  class ImageTileRasterFormat(RasterFormat):
224
200
  """A RasterFormat that stores data in image tiles corresponding to grid cells.
225
201
 
@@ -468,7 +444,6 @@ class ImageTileRasterFormat(RasterFormat):
468
444
  )
469
445
 
470
446
 
471
- @RasterFormats.register("geotiff")
472
447
  class GeotiffRasterFormat(RasterFormat):
473
448
  """A raster format that uses one big, tiled GeoTIFF with small block size."""
474
449
 
@@ -623,7 +598,6 @@ class GeotiffRasterFormat(RasterFormat):
623
598
  return GeotiffRasterFormat(**kwargs)
624
599
 
625
600
 
626
- @RasterFormats.register("single_image")
627
601
  class SingleImageRasterFormat(RasterFormat):
628
602
  """A raster format that produces a single image called image.png/jpg.
629
603
 
@@ -775,17 +749,3 @@ class SingleImageRasterFormat(RasterFormat):
775
749
  if "format" in config:
776
750
  kwargs["format"] = config["format"]
777
751
  return SingleImageRasterFormat(**kwargs)
778
-
779
-
780
- def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
781
- """Loads a RasterFormat from a RasterFormatConfig.
782
-
783
- Args:
784
- config: the RasterFormatConfig configuration object specifying the
785
- RasterFormat.
786
-
787
- Returns:
788
- the loaded RasterFormat implementation
789
- """
790
- cls = RasterFormats[config.name]
791
- return cls.from_config(config.name, config.config_dict)
@@ -1,15 +1,13 @@
1
1
  """Classes for writing vector data to a UPath."""
2
2
 
3
3
  import json
4
- from collections.abc import Callable
5
4
  from enum import Enum
6
- from typing import Any, TypeVar
5
+ from typing import Any
7
6
 
8
7
  import shapely
9
8
  from rasterio.crs import CRS
10
9
  from upath import UPath
11
10
 
12
- from rslearn.config import VectorFormatConfig
13
11
  from rslearn.const import WGS84_PROJECTION
14
12
  from rslearn.log_utils import get_logger
15
13
  from rslearn.utils.fsspec import open_atomic
@@ -18,25 +16,6 @@ from .feature import Feature
18
16
  from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
19
17
 
20
18
  logger = get_logger(__name__)
21
- _VectorFormatT = TypeVar("_VectorFormatT", bound="VectorFormat")
22
-
23
-
24
- class _VectorFormatRegistry(dict[str, type["VectorFormat"]]):
25
- """Registry for VectorFormat classes."""
26
-
27
- def register(
28
- self, name: str
29
- ) -> Callable[[type[_VectorFormatT]], type[_VectorFormatT]]:
30
- """Decorator to register a vector format class."""
31
-
32
- def decorator(cls: type[_VectorFormatT]) -> type[_VectorFormatT]:
33
- self[name] = cls
34
- return cls
35
-
36
- return decorator
37
-
38
-
39
- VectorFormats = _VectorFormatRegistry()
40
19
 
41
20
 
42
21
  class VectorFormat:
@@ -85,7 +64,6 @@ class VectorFormat:
85
64
  raise NotImplementedError
86
65
 
87
66
 
88
- @VectorFormats.register("tile")
89
67
  class TileVectorFormat(VectorFormat):
90
68
  """TileVectorFormat stores data in GeoJSON files corresponding to grid cells.
91
69
 
@@ -275,7 +253,6 @@ class GeojsonCoordinateMode(Enum):
275
253
  WGS84 = "wgs84"
276
254
 
277
255
 
278
- @VectorFormats.register("geojson")
279
256
  class GeojsonVectorFormat(VectorFormat):
280
257
  """A vector format that uses one big GeoJSON."""
281
258
 
@@ -429,17 +406,3 @@ class GeojsonVectorFormat(VectorFormat):
429
406
  if "coordinate_mode" in config:
430
407
  kwargs["coordinate_mode"] = GeojsonCoordinateMode(config["coordinate_mode"])
431
408
  return GeojsonVectorFormat(**kwargs)
432
-
433
-
434
- def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
435
- """Loads a VectorFormat from a VectorFormatConfig.
436
-
437
- Args:
438
- config: the VectorFormatConfig configuration object specifying the
439
- VectorFormat.
440
-
441
- Returns:
442
- the loaded VectorFormat implementation
443
- """
444
- cls = VectorFormats[config.name]
445
- return cls.from_config(config.name, config.config_dict)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.14
3
+ Version: 0.0.16
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License