rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. rslearn/config/dataset.py +22 -13
  2. rslearn/data_sources/__init__.py +8 -0
  3. rslearn/data_sources/aws_landsat.py +27 -18
  4. rslearn/data_sources/aws_open_data.py +41 -42
  5. rslearn/data_sources/copernicus.py +148 -2
  6. rslearn/data_sources/data_source.py +17 -10
  7. rslearn/data_sources/gcp_public_data.py +177 -100
  8. rslearn/data_sources/geotiff.py +1 -0
  9. rslearn/data_sources/google_earth_engine.py +17 -15
  10. rslearn/data_sources/local_files.py +59 -32
  11. rslearn/data_sources/openstreetmap.py +27 -23
  12. rslearn/data_sources/planet.py +10 -9
  13. rslearn/data_sources/planet_basemap.py +303 -0
  14. rslearn/data_sources/raster_source.py +23 -13
  15. rslearn/data_sources/usgs_landsat.py +56 -27
  16. rslearn/data_sources/utils.py +13 -6
  17. rslearn/data_sources/vector_source.py +1 -0
  18. rslearn/data_sources/xyz_tiles.py +8 -9
  19. rslearn/dataset/add_windows.py +1 -1
  20. rslearn/dataset/dataset.py +16 -5
  21. rslearn/dataset/manage.py +9 -4
  22. rslearn/dataset/materialize.py +26 -5
  23. rslearn/dataset/window.py +5 -0
  24. rslearn/log_utils.py +24 -0
  25. rslearn/main.py +123 -59
  26. rslearn/models/clip.py +62 -0
  27. rslearn/models/conv.py +56 -0
  28. rslearn/models/faster_rcnn.py +2 -19
  29. rslearn/models/fpn.py +1 -1
  30. rslearn/models/module_wrapper.py +43 -0
  31. rslearn/models/molmo.py +65 -0
  32. rslearn/models/multitask.py +1 -1
  33. rslearn/models/pooling_decoder.py +4 -2
  34. rslearn/models/satlaspretrain.py +4 -7
  35. rslearn/models/simple_time_series.py +61 -55
  36. rslearn/models/ssl4eo_s12.py +9 -9
  37. rslearn/models/swin.py +22 -21
  38. rslearn/models/unet.py +4 -2
  39. rslearn/models/upsample.py +35 -0
  40. rslearn/tile_stores/file.py +6 -3
  41. rslearn/tile_stores/tile_store.py +19 -7
  42. rslearn/train/callbacks/freeze_unfreeze.py +3 -3
  43. rslearn/train/data_module.py +5 -4
  44. rslearn/train/dataset.py +79 -36
  45. rslearn/train/lightning_module.py +15 -11
  46. rslearn/train/prediction_writer.py +22 -11
  47. rslearn/train/tasks/classification.py +9 -8
  48. rslearn/train/tasks/detection.py +94 -37
  49. rslearn/train/tasks/multi_task.py +1 -1
  50. rslearn/train/tasks/regression.py +8 -4
  51. rslearn/train/tasks/segmentation.py +23 -19
  52. rslearn/train/transforms/__init__.py +1 -1
  53. rslearn/train/transforms/concatenate.py +6 -2
  54. rslearn/train/transforms/crop.py +6 -2
  55. rslearn/train/transforms/flip.py +5 -1
  56. rslearn/train/transforms/normalize.py +9 -5
  57. rslearn/train/transforms/pad.py +1 -1
  58. rslearn/train/transforms/transform.py +3 -3
  59. rslearn/utils/__init__.py +4 -5
  60. rslearn/utils/array.py +2 -2
  61. rslearn/utils/feature.py +1 -1
  62. rslearn/utils/fsspec.py +70 -1
  63. rslearn/utils/geometry.py +155 -3
  64. rslearn/utils/grid_index.py +5 -5
  65. rslearn/utils/mp.py +4 -3
  66. rslearn/utils/raster_format.py +81 -73
  67. rslearn/utils/rtree_index.py +64 -17
  68. rslearn/utils/sqlite_index.py +7 -1
  69. rslearn/utils/utils.py +11 -3
  70. rslearn/utils/vector_format.py +113 -17
  71. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
  72. rslearn-0.0.2.dist-info/RECORD +94 -0
  73. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
  74. rslearn/utils/mgrs.py +0 -24
  75. rslearn-0.0.1.dist-info/RECORD +0 -88
  76. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
  77. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
  78. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
@@ -186,6 +186,8 @@ class XyzTiles(DataSource):
186
186
  @staticmethod
187
187
  def from_config(config: LayerConfig, ds_path: UPath) -> "XyzTiles":
188
188
  """Creates a new XyzTiles instance from a configuration dictionary."""
189
+ if config.data_source is None:
190
+ raise ValueError("data_source is required")
189
191
  d = config.data_source.config_dict
190
192
  time_ranges = []
191
193
  for str1, str2 in d["time_ranges"]:
@@ -207,7 +209,7 @@ class XyzTiles(DataSource):
207
209
 
208
210
  def get_items(
209
211
  self, geometries: list[STGeometry], query_config: QueryConfig
210
- ) -> list[list[list[Item]]]:
212
+ ) -> list[list[list[XyzItem]]]:
211
213
  """Get a list of items in the data source intersecting the given geometries.
212
214
 
213
215
  In XyzTiles we treat the data source as containing a single item, i.e., the
@@ -278,7 +280,7 @@ class XyzTiles(DataSource):
278
280
  def materialize(
279
281
  self,
280
282
  window: Window,
281
- item_groups: list[list[Item]],
283
+ item_groups: list[list[XyzItem]],
282
284
  layer_name: str,
283
285
  layer_cfg: LayerConfig,
284
286
  ) -> None:
@@ -305,13 +307,10 @@ class XyzTiles(DataSource):
305
307
  window_projection, shapely.box(*window_bounds), None
306
308
  )
307
309
  projected_geometry = window_geometry.to_projection(self.projection)
308
- projected_bounds = [
309
- math.floor(projected_geometry.shp.bounds[0]),
310
- math.floor(projected_geometry.shp.bounds[1]),
311
- math.ceil(projected_geometry.shp.bounds[2]),
312
- math.ceil(projected_geometry.shp.bounds[3]),
313
- ]
314
- projected_raster = self.read_bounds(item.url_template, projected_bounds)
310
+ projected_bounds = tuple(
311
+ math.floor(projected_geometry.shp.bounds[i]) for i in range(4)
312
+ )
313
+ projected_raster = self.read_bounds(item.url_template, projected_bounds) # type: ignore
315
314
 
316
315
  # Attach the transform to the raster.
317
316
  src_transform = rasterio.transform.Affine(
@@ -25,7 +25,7 @@ def add_windows_from_geometries(
25
25
  window_size: int | None = None,
26
26
  time_range: tuple[datetime, datetime] | None = None,
27
27
  use_utm: bool = False,
28
- ):
28
+ ) -> list[Window]:
29
29
  """Create windows based on a list of STGeometry.
30
30
 
31
31
  Args:
@@ -7,10 +7,13 @@ import tqdm
7
7
  from upath import UPath
8
8
 
9
9
  from rslearn.config import TileStoreConfig, load_layer_config
10
+ from rslearn.log_utils import get_logger
10
11
  from rslearn.tile_stores import TileStore, load_tile_store
11
12
 
12
13
  from .window import Window
13
14
 
15
+ logger = get_logger(__name__)
16
+
14
17
 
15
18
  class Dataset:
16
19
  """A rslearn dataset.
@@ -37,21 +40,29 @@ class Dataset:
37
40
  materialize.
38
41
  """
39
42
 
40
- def __init__(self, path: UPath) -> None:
43
+ def __init__(self, path: UPath, disabled_layers: list[str] = []) -> None:
41
44
  """Initializes a new Dataset.
42
45
 
43
46
  Args:
44
47
  path: the root directory of the dataset
48
+ disabled_layers: list of layers to disable
45
49
  """
46
50
  self.path = path
47
51
 
48
52
  # Load dataset configuration.
53
+
49
54
  with (self.path / "config.json").open("r") as f:
50
55
  config = json.load(f)
51
- self.layers = {
52
- layer_name: load_layer_config(d)
53
- for layer_name, d in config["layers"].items()
54
- }
56
+ self.layers = {}
57
+ for layer_name, d in config["layers"].items():
58
+ # Layer names must not contain period, since we use period to
59
+ # distinguish different materialized groups within a layer.
60
+ assert "." not in layer_name, "layer names must not contain periods"
61
+ if layer_name in disabled_layers:
62
+ logger.warning(f"Layer {layer_name} is disabled")
63
+ continue
64
+ self.layers[layer_name] = load_layer_config(d)
65
+
55
66
  self.tile_store_config = TileStoreConfig.from_config(config["tile_store"])
56
67
  self.materializer_name = config.get("materialize")
57
68
 
rslearn/dataset/manage.py CHANGED
@@ -1,15 +1,17 @@
1
1
  """Functions to manage datasets."""
2
2
 
3
3
  import rslearn.data_sources
4
- from rslearn.config import LayerConfig, LayerType
4
+ from rslearn.config import LayerConfig, LayerType, RasterLayerConfig
5
5
  from rslearn.data_sources import DataSource, Item
6
+ from rslearn.log_utils import get_logger
6
7
  from rslearn.tile_stores import TileStore, get_tile_store_for_layer
7
- from rslearn.utils import logger
8
8
 
9
9
  from .dataset import Dataset
10
10
  from .materialize import Materializers
11
11
  from .window import Window, WindowLayerData
12
12
 
13
+ logger = get_logger(__name__)
14
+
13
15
 
14
16
  def prepare_dataset_windows(
15
17
  dataset: Dataset, windows: list[Window], force: bool = False
@@ -37,7 +39,7 @@ def prepare_dataset_windows(
37
39
  if layer_name in layer_datas and not force:
38
40
  continue
39
41
  needed_windows.append(window)
40
- print(f"Preparing {len(needed_windows)} windows for layer {layer_name}")
42
+ logger.info(f"Preparing {len(needed_windows)} windows for layer {layer_name}")
41
43
  if len(needed_windows) == 0:
42
44
  continue
43
45
 
@@ -101,7 +103,7 @@ def ingest_dataset_windows(dataset: Dataset, windows: list[Window]) -> None:
101
103
  layer_cfg, dataset.path
102
104
  )
103
105
 
104
- geometries_by_item = {}
106
+ geometries_by_item: dict = {}
105
107
  for window in windows:
106
108
  layer_datas = window.load_layer_datas()
107
109
  if layer_name not in layer_datas:
@@ -151,6 +153,7 @@ def is_window_ingested(
151
153
  item = Item.deserialize(serialized_item)
152
154
 
153
155
  if layer_cfg.layer_type == LayerType.RASTER:
156
+ assert isinstance(layer_cfg, RasterLayerConfig)
154
157
  for band_set in layer_cfg.band_sets:
155
158
  projection, _ = band_set.get_final_projection_and_bounds(
156
159
  window.projection, window.bounds
@@ -229,6 +232,8 @@ def materialize_window(
229
232
  item_group.append(item)
230
233
  item_groups.append(item_group)
231
234
 
235
+ if layer_cfg.data_source is None:
236
+ raise ValueError("data_source is required")
232
237
  if layer_cfg.data_source.ingest:
233
238
  if not is_window_ingested(dataset, window, check_layer_name=layer_name):
234
239
  logger.info(
@@ -1,6 +1,6 @@
1
1
  """Classes to implement dataset materialization."""
2
2
 
3
- from typing import Any
3
+ from typing import Any, Generic, TypeVar
4
4
 
5
5
  import numpy as np
6
6
  import numpy.typing as npt
@@ -24,8 +24,10 @@ from .window import Window
24
24
 
25
25
  Materializers = ClassRegistry()
26
26
 
27
+ LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
27
28
 
28
- class Materializer:
29
+
30
+ class Materializer(Generic[LayerConfigType]):
29
31
  """An abstract class that materializes data from a tile store."""
30
32
 
31
33
  def materialize(
@@ -33,7 +35,7 @@ class Materializer:
33
35
  tile_store: TileStore,
34
36
  window: Window,
35
37
  layer_name: str,
36
- layer_cfg: LayerConfig,
38
+ layer_cfg: LayerConfigType,
37
39
  item_groups: list[list[Item]],
38
40
  ) -> None:
39
41
  """Materialize portions of items corresponding to this window into the dataset.
@@ -82,6 +84,8 @@ def read_raster_window_from_tiles(
82
84
  dst_row_offset = intersection[1] - bounds[1]
83
85
 
84
86
  src = ts_layer.read_raster(intersection)
87
+ if src is None:
88
+ raise ValueError(f"No raster data found for bounds {intersection}")
85
89
  src = src[src_indexes, :, :]
86
90
  if remapper:
87
91
  src = remapper(src, dst.dtype)
@@ -97,7 +101,7 @@ def read_raster_window_from_tiles(
97
101
 
98
102
 
99
103
  @Materializers.register("raster")
100
- class RasterMaterializer(Materializer):
104
+ class RasterMaterializer(Materializer[RasterLayerConfig]):
101
105
  """A Materializer for raster data."""
102
106
 
103
107
  def materialize(
@@ -105,7 +109,7 @@ class RasterMaterializer(Materializer):
105
109
  tile_store: TileStore,
106
110
  window: Window,
107
111
  layer_name: str,
108
- layer_cfg: LayerConfig,
112
+ layer_cfg: RasterLayerConfig,
109
113
  item_groups: list[list[Item]],
110
114
  ) -> None:
111
115
  """Materialize portions of items corresponding to this window into the dataset.
@@ -142,6 +146,12 @@ class RasterMaterializer(Materializer):
142
146
  if band_cfg.remap:
143
147
  remapper = load_remapper(band_cfg.remap)
144
148
 
149
+ if band_cfg.format is None or band_cfg.bands is None or bounds is None:
150
+ raise ValueError(
151
+ f"No raster format or bands specified for {layer_name} \
152
+ with {band_cfg}"
153
+ )
154
+
145
155
  raster_format = load_raster_format(
146
156
  RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
147
157
  )
@@ -182,6 +192,11 @@ class RasterMaterializer(Materializer):
182
192
  ts_layer = layer_tile_store.get_layer(
183
193
  (item.name, suffix, str(projection))
184
194
  )
195
+ if ts_layer is None:
196
+ raise ValueError(
197
+ f"No tile store layer found for {item.name} {suffix} \
198
+ {projection}"
199
+ )
185
200
  read_raster_window_from_tiles(
186
201
  dst, ts_layer, bounds, src_indexes, dst_indexes, remapper
187
202
  )
@@ -223,6 +238,8 @@ class VectorMaterializer(Materializer):
223
238
  projection, bounds = layer_cfg.get_final_projection_and_bounds(
224
239
  window.projection, window.bounds
225
240
  )
241
+ if bounds is None:
242
+ raise ValueError(f"No bounds specified for {layer_name}")
226
243
  vector_format = load_vector_format(layer_cfg.format)
227
244
 
228
245
  out_layer_dirs: list[UPath] = []
@@ -241,6 +258,10 @@ class VectorMaterializer(Materializer):
241
258
  ts_layer = get_tile_store_for_layer(
242
259
  tile_store, layer_name, layer_cfg
243
260
  ).get_layer((item.name, str(projection)))
261
+ if ts_layer is None:
262
+ raise ValueError(
263
+ f"No tile store layer found for {item.name} {projection}"
264
+ )
244
265
  cur_features = ts_layer.read_vector(bounds)
245
266
  features.extend(cur_features)
246
267
 
rslearn/dataset/window.py CHANGED
@@ -7,9 +7,12 @@ from typing import Any
7
7
  import shapely
8
8
  from upath import UPath
9
9
 
10
+ from rslearn.log_utils import get_logger
10
11
  from rslearn.utils import Projection, STGeometry
11
12
  from rslearn.utils.fsspec import open_atomic
12
13
 
14
+ logger = get_logger(__name__)
15
+
13
16
 
14
17
  class WindowLayerData:
15
18
  """Layer data for retrieved layers specifying relevant items in the data source.
@@ -115,6 +118,7 @@ class Window:
115
118
  "options": self.options,
116
119
  }
117
120
  metadata_path = self.path / "metadata.json"
121
+ logger.info(f"Saving window metadata to {metadata_path}")
118
122
  with open_atomic(metadata_path, "w") as f:
119
123
  json.dump(metadata, f)
120
124
 
@@ -141,6 +145,7 @@ class Window:
141
145
  """Save layer datas to items.json."""
142
146
  json_data = [layer_data.serialize() for layer_data in layer_datas.values()]
143
147
  items_fname = self.path / "items.json"
148
+ logger.info(f"Saving window items to {items_fname}")
144
149
  with open_atomic(items_fname, "w") as f:
145
150
  json.dump(json_data, f)
146
151
 
rslearn/log_utils.py ADDED
@@ -0,0 +1,24 @@
1
+ """Logging utilities."""
2
+
3
+ import logging
4
+ import os
5
+ import sys
6
+
7
+ LOG_FORMAT = "format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() L%(lineno)-4d %(message)s"
8
+ # DETAILED_LOG_FORMAT = "format=%(asctime)s loglevel=%(levelname)-6s logger=%(name)s %(funcName)s() L%(lineno)-4d %(message)s call_trace=%(pathname)s L%(lineno)-4d" # noqa
9
+
10
+
11
+ def get_logger(name: str) -> logging.Logger:
12
+ """Get a logger with a console handler."""
13
+ this_logger = logging.getLogger(name)
14
+ log_level = os.environ.get("RSLEARN_LOGLEVEL", "INFO")
15
+ if not this_logger.handlers:
16
+ console_handler = logging.StreamHandler(sys.stdout)
17
+ console_handler.setLevel(log_level)
18
+ console_formatter = logging.Formatter(LOG_FORMAT)
19
+ console_handler.setFormatter(console_formatter)
20
+ this_logger.addHandler(console_handler)
21
+
22
+ this_logger.setLevel(log_level)
23
+ this_logger.propagate = True
24
+ return this_logger