rslearn 0.0.1__py3-none-any.whl → 0.0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (117) hide show
  1. rslearn/config/__init__.py +2 -2
  2. rslearn/config/dataset.py +164 -98
  3. rslearn/const.py +9 -15
  4. rslearn/data_sources/__init__.py +8 -0
  5. rslearn/data_sources/aws_landsat.py +235 -80
  6. rslearn/data_sources/aws_open_data.py +103 -118
  7. rslearn/data_sources/aws_sentinel1.py +142 -0
  8. rslearn/data_sources/climate_data_store.py +303 -0
  9. rslearn/data_sources/copernicus.py +943 -12
  10. rslearn/data_sources/data_source.py +17 -12
  11. rslearn/data_sources/earthdaily.py +489 -0
  12. rslearn/data_sources/earthdata_srtm.py +300 -0
  13. rslearn/data_sources/gcp_public_data.py +556 -203
  14. rslearn/data_sources/geotiff.py +1 -0
  15. rslearn/data_sources/google_earth_engine.py +454 -115
  16. rslearn/data_sources/local_files.py +153 -103
  17. rslearn/data_sources/openstreetmap.py +33 -39
  18. rslearn/data_sources/planet.py +17 -35
  19. rslearn/data_sources/planet_basemap.py +296 -0
  20. rslearn/data_sources/planetary_computer.py +764 -0
  21. rslearn/data_sources/raster_source.py +11 -297
  22. rslearn/data_sources/usda_cdl.py +206 -0
  23. rslearn/data_sources/usgs_landsat.py +130 -73
  24. rslearn/data_sources/utils.py +256 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +456 -0
  27. rslearn/data_sources/worldcover.py +142 -0
  28. rslearn/data_sources/worldpop.py +156 -0
  29. rslearn/data_sources/xyz_tiles.py +141 -79
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +1 -1
  32. rslearn/dataset/dataset.py +43 -7
  33. rslearn/dataset/index.py +173 -0
  34. rslearn/dataset/manage.py +137 -49
  35. rslearn/dataset/materialize.py +436 -95
  36. rslearn/dataset/window.py +225 -34
  37. rslearn/log_utils.py +24 -0
  38. rslearn/main.py +351 -130
  39. rslearn/models/clip.py +62 -0
  40. rslearn/models/conv.py +56 -0
  41. rslearn/models/croma.py +270 -0
  42. rslearn/models/detr/__init__.py +5 -0
  43. rslearn/models/detr/box_ops.py +103 -0
  44. rslearn/models/detr/detr.py +493 -0
  45. rslearn/models/detr/matcher.py +107 -0
  46. rslearn/models/detr/position_encoding.py +114 -0
  47. rslearn/models/detr/transformer.py +429 -0
  48. rslearn/models/detr/util.py +24 -0
  49. rslearn/models/faster_rcnn.py +10 -19
  50. rslearn/models/fpn.py +1 -1
  51. rslearn/models/module_wrapper.py +91 -0
  52. rslearn/models/moe/distributed.py +262 -0
  53. rslearn/models/moe/soft.py +676 -0
  54. rslearn/models/molmo.py +65 -0
  55. rslearn/models/multitask.py +351 -24
  56. rslearn/models/pick_features.py +15 -2
  57. rslearn/models/pooling_decoder.py +4 -2
  58. rslearn/models/satlaspretrain.py +4 -7
  59. rslearn/models/simple_time_series.py +75 -59
  60. rslearn/models/singletask.py +8 -4
  61. rslearn/models/ssl4eo_s12.py +10 -10
  62. rslearn/models/swin.py +22 -21
  63. rslearn/models/task_embedding.py +250 -0
  64. rslearn/models/terramind.py +219 -0
  65. rslearn/models/trunk.py +280 -0
  66. rslearn/models/unet.py +21 -5
  67. rslearn/models/upsample.py +35 -0
  68. rslearn/models/use_croma.py +508 -0
  69. rslearn/py.typed +0 -0
  70. rslearn/tile_stores/__init__.py +52 -18
  71. rslearn/tile_stores/default.py +382 -0
  72. rslearn/tile_stores/tile_store.py +236 -132
  73. rslearn/train/callbacks/freeze_unfreeze.py +32 -20
  74. rslearn/train/callbacks/gradients.py +109 -0
  75. rslearn/train/callbacks/peft.py +116 -0
  76. rslearn/train/data_module.py +407 -14
  77. rslearn/train/dataset.py +746 -200
  78. rslearn/train/lightning_module.py +164 -54
  79. rslearn/train/optimizer.py +31 -0
  80. rslearn/train/prediction_writer.py +235 -78
  81. rslearn/train/scheduler.py +62 -0
  82. rslearn/train/tasks/classification.py +13 -12
  83. rslearn/train/tasks/detection.py +101 -39
  84. rslearn/train/tasks/multi_task.py +24 -9
  85. rslearn/train/tasks/regression.py +113 -21
  86. rslearn/train/tasks/segmentation.py +353 -35
  87. rslearn/train/tasks/task.py +2 -2
  88. rslearn/train/transforms/__init__.py +1 -1
  89. rslearn/train/transforms/concatenate.py +9 -5
  90. rslearn/train/transforms/crop.py +8 -4
  91. rslearn/train/transforms/flip.py +5 -1
  92. rslearn/train/transforms/normalize.py +34 -10
  93. rslearn/train/transforms/pad.py +1 -1
  94. rslearn/train/transforms/transform.py +75 -73
  95. rslearn/utils/__init__.py +2 -6
  96. rslearn/utils/array.py +2 -2
  97. rslearn/utils/feature.py +2 -2
  98. rslearn/utils/fsspec.py +70 -1
  99. rslearn/utils/geometry.py +214 -7
  100. rslearn/utils/get_utm_ups_crs.py +2 -3
  101. rslearn/utils/grid_index.py +5 -5
  102. rslearn/utils/jsonargparse.py +33 -0
  103. rslearn/utils/mp.py +4 -3
  104. rslearn/utils/raster_format.py +211 -96
  105. rslearn/utils/rtree_index.py +64 -17
  106. rslearn/utils/sqlite_index.py +7 -1
  107. rslearn/utils/vector_format.py +235 -77
  108. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/METADATA +366 -284
  109. rslearn-0.0.3.dist-info/RECORD +123 -0
  110. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/WHEEL +1 -1
  111. rslearn/tile_stores/file.py +0 -242
  112. rslearn/utils/mgrs.py +0 -24
  113. rslearn/utils/utils.py +0 -22
  114. rslearn-0.0.1.dist-info/RECORD +0 -88
  115. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/entry_points.txt +0 -0
  116. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info/licenses}/LICENSE +0 -0
  117. {rslearn-0.0.1.dist-info → rslearn-0.0.3.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@
2
2
 
3
3
  from .dataset import (
4
4
  BandSetConfig,
5
+ CompositingMethod,
5
6
  DataSourceConfig,
6
7
  DType,
7
8
  LayerConfig,
@@ -10,7 +11,6 @@ from .dataset import (
10
11
  RasterFormatConfig,
11
12
  RasterLayerConfig,
12
13
  SpaceMode,
13
- TileStoreConfig,
14
14
  TimeMode,
15
15
  VectorFormatConfig,
16
16
  VectorLayerConfig,
@@ -19,6 +19,7 @@ from .dataset import (
19
19
 
20
20
  __all__ = [
21
21
  "BandSetConfig",
22
+ "CompositingMethod",
22
23
  "DataSourceConfig",
23
24
  "DType",
24
25
  "LayerConfig",
@@ -27,7 +28,6 @@ __all__ = [
27
28
  "RasterFormatConfig",
28
29
  "RasterLayerConfig",
29
30
  "SpaceMode",
30
- "TileStoreConfig",
31
31
  "TimeMode",
32
32
  "VectorFormatConfig",
33
33
  "VectorLayerConfig",
rslearn/config/dataset.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Classes for storing configuration of a dataset."""
2
2
 
3
+ import json
3
4
  from datetime import timedelta
4
5
  from enum import Enum
5
6
  from typing import Any
@@ -19,7 +20,11 @@ class DType(Enum):
19
20
  UINT8 = "uint8"
20
21
  UINT16 = "uint16"
21
22
  UINT32 = "uint32"
23
+ UINT64 = "uint64"
24
+ INT8 = "int8"
25
+ INT16 = "int16"
22
26
  INT32 = "int32"
27
+ INT64 = "int64"
23
28
  FLOAT32 = "float32"
24
29
 
25
30
  def get_numpy_dtype(self) -> npt.DTypeLike:
@@ -30,8 +35,16 @@ class DType(Enum):
30
35
  return np.uint16
31
36
  elif self == DType.UINT32:
32
37
  return np.uint32
38
+ elif self == DType.UINT64:
39
+ return np.uint64
40
+ elif self == DType.INT8:
41
+ return np.int8
42
+ elif self == DType.INT16:
43
+ return np.int16
33
44
  elif self == DType.INT32:
34
45
  return np.int32
46
+ elif self == DType.INT64:
47
+ return np.int64
35
48
  elif self == DType.FLOAT32:
36
49
  return np.float32
37
50
  raise ValueError(f"unable to handle numpy dtype {self}")
@@ -112,10 +125,12 @@ class BandSetConfig:
112
125
  self,
113
126
  config_dict: dict[str, Any],
114
127
  dtype: DType,
115
- bands: list[str] | None = None,
128
+ bands: list[str],
116
129
  format: dict[str, Any] | None = None,
117
130
  zoom_offset: int = 0,
118
131
  remap: dict[str, Any] | None = None,
132
+ class_names: list[list[str]] | None = None,
133
+ nodata_vals: list[float] | None = None,
119
134
  ) -> None:
120
135
  """Creates a new BandSetConfig instance.
121
136
 
@@ -124,29 +139,43 @@ class BandSetConfig:
124
139
  dtype: the pixel value type to store tiles in
125
140
  bands: list of band names in this BandSetConfig
126
141
  format: the format to store tiles in, defaults to geotiff
127
- zoom_offset: non-negative integer, store images at window resolution
128
- divided by 2^(zoom_offset).
142
+ zoom_offset: store images at a resolution higher or lower than the window
143
+ resolution. This enables keeping source data at its native resolution,
144
+ either to save storage space (for lower resolution data) or to retain
145
+ details (for higher resolution data). If positive, store data at the
146
+ window resolution divided by 2^(zoom_offset) (higher resolution). If
147
+ negative, store data at the window resolution multiplied by
148
+ 2^(-zoom_offset) (lower resolution).
129
149
  remap: config dict for Remapper to remap pixel values
150
+ class_names: optional list of names for the different possible values of
151
+ each band. The length of this list must equal the number of bands. For
152
+ example, [["forest", "desert"]] means that it is a single-band raster
153
+ where values can be 0 (forest) or 1 (desert).
154
+ nodata_vals: the nodata values for this band set. This is used during
155
+ materialization when creating mosaics, to determine which parts of the
156
+ source images should be copied.
130
157
  """
158
+ if class_names is not None and len(bands) != len(class_names):
159
+ raise ValueError(
160
+ f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
161
+ )
162
+
131
163
  self.config_dict = config_dict
132
164
  self.bands = bands
133
- self.format = format
134
165
  self.dtype = dtype
135
166
  self.zoom_offset = zoom_offset
136
167
  self.remap = remap
168
+ self.class_names = class_names
169
+ self.nodata_vals = nodata_vals
137
170
 
138
- if not self.format:
171
+ if format is None:
139
172
  self.format = {"name": "geotiff"}
173
+ else:
174
+ self.format = format
140
175
 
141
176
  def serialize(self) -> dict[str, Any]:
142
- """Serialize this BandSetConfig to a config dict, currently unused."""
143
- return {
144
- "bands": self.bands,
145
- "format": self.format,
146
- "dtype": self.dtype,
147
- "zoom_offset": self.zoom_offset,
148
- "remap": self.remap,
149
- }
177
+ """Serialize this BandSetConfig to a config dict."""
178
+ return self.config_dict
150
179
 
151
180
  @staticmethod
152
181
  def from_config(config: dict[str, Any]) -> "BandSetConfig":
@@ -158,15 +187,16 @@ class BandSetConfig:
158
187
  kwargs = dict(
159
188
  config_dict=config,
160
189
  dtype=DType(config["dtype"]),
190
+ bands=config["bands"],
161
191
  )
162
- for k in ["bands", "format", "zoom_offset", "remap"]:
192
+ for k in ["format", "zoom_offset", "remap", "class_names", "nodata_vals"]:
163
193
  if k in config:
164
194
  kwargs[k] = config[k]
165
- return BandSetConfig(**kwargs)
195
+ return BandSetConfig(**kwargs) # type: ignore
166
196
 
167
197
  def get_final_projection_and_bounds(
168
- self, projection: Projection, bounds: PixelBounds | None
169
- ) -> tuple[Projection, PixelBounds | None]:
198
+ self, projection: Projection, bounds: PixelBounds
199
+ ) -> tuple[Projection, PixelBounds]:
170
200
  """Gets the final projection/bounds based on band set config.
171
201
 
172
202
  The band set config may apply a non-zero zoom offset that modifies the window's
@@ -187,11 +217,14 @@ class BandSetConfig:
187
217
  projection.x_resolution / (2**self.zoom_offset),
188
218
  projection.y_resolution / (2**self.zoom_offset),
189
219
  )
190
- if bounds:
191
- if self.zoom_offset > 0:
192
- bounds = tuple(x * (2**self.zoom_offset) for x in bounds)
193
- else:
194
- bounds = tuple(x // (2 ** (-self.zoom_offset)) for x in bounds)
220
+ if self.zoom_offset > 0:
221
+ zoom_factor = 2**self.zoom_offset
222
+ bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
223
+ else:
224
+ bounds = tuple(
225
+ x // (2 ** (-self.zoom_offset))
226
+ for x in bounds # type: ignore
227
+ )
195
228
  return projection, bounds
196
229
 
197
230
 
@@ -211,6 +244,22 @@ class SpaceMode(Enum):
211
244
  dataset.
212
245
  """
213
246
 
247
+ PER_PERIOD_MOSAIC = 4
248
+ """Create one mosaic per sub-period of the time range.
249
+
250
+ The duration of the sub-periods is controlled by another option in QueryConfig.
251
+ """
252
+
253
+ COMPOSITE = 5
254
+ """Creates one composite covering the entire window.
255
+
256
+ During querying all items intersecting the window are placed in one group.
257
+ The compositing_method in the rasterlayer config specifies how these items are reduced
258
+ to a single item (e.g MEAN/MEDIAN/FIRST_VALID) during materialization.
259
+ """
260
+
261
+ # TODO add PER_PERIOD_COMPOSITE
262
+
214
263
 
215
264
  class TimeMode(Enum):
216
265
  """Temporal matching mode when looking up items corresponding to a window."""
@@ -222,10 +271,10 @@ class TimeMode(Enum):
222
271
  """Select items closest to the window time range, up to max_matches."""
223
272
 
224
273
  BEFORE = 3
225
- """Select items before the start of the window time range, up to max_matches."""
274
+ """Select items before the end of the window time range, up to max_matches."""
226
275
 
227
276
  AFTER = 4
228
- """Select items after the end of the window time range, up to max_matches."""
277
+ """Select items after the start of the window time range, up to max_matches."""
229
278
 
230
279
 
231
280
  class QueryConfig:
@@ -235,7 +284,9 @@ class QueryConfig:
235
284
  self,
236
285
  space_mode: SpaceMode = SpaceMode.MOSAIC,
237
286
  time_mode: TimeMode = TimeMode.WITHIN,
287
+ min_matches: int = 0,
238
288
  max_matches: int = 1,
289
+ period_duration: timedelta = timedelta(days=30),
239
290
  ):
240
291
  """Creates a new query configuration.
241
292
 
@@ -245,19 +296,29 @@ class QueryConfig:
245
296
  Args:
246
297
  space_mode: specifies how items should be matched with windows spatially
247
298
  time_mode: specifies how items should be matched with windows temporally
299
+ min_matches: the minimum number of item groups. If there are fewer than
300
+ this many matches, then no matches will be returned. This can be used
301
+ to prevent unnecessary data ingestion if the user plans to discard
302
+ windows that do not have a sufficient amount of data.
248
303
  max_matches: the maximum number of items (or groups of items, if space_mode
249
304
  is MOSAIC) to match
305
+ period_duration: the duration of the periods, if the space mode is
306
+ PER_PERIOD_MOSAIC.
250
307
  """
251
308
  self.space_mode = space_mode
252
309
  self.time_mode = time_mode
310
+ self.min_matches = min_matches
253
311
  self.max_matches = max_matches
312
+ self.period_duration = period_duration
254
313
 
255
314
  def serialize(self) -> dict[str, Any]:
256
- """Serialize this QueryConfig to a config dict, currently unused."""
315
+ """Serialize this QueryConfig to a config dict."""
257
316
  return {
258
317
  "space_mode": str(self.space_mode),
259
318
  "time_mode": str(self.time_mode),
319
+ "min_matches": self.min_matches,
260
320
  "max_matches": self.max_matches,
321
+ "period_duration": f"{self.period_duration.total_seconds()}s",
261
322
  }
262
323
 
263
324
  @staticmethod
@@ -267,11 +328,20 @@ class QueryConfig:
267
328
  Args:
268
329
  config: the config dict for this QueryConfig
269
330
  """
270
- return QueryConfig(
271
- space_mode=SpaceMode[config.get("space_mode", "MOSAIC")],
272
- time_mode=TimeMode[config.get("time_mode", "WITHIN")],
273
- max_matches=config.get("max_matches", 1),
274
- )
331
+ kwargs: dict[str, Any] = dict()
332
+ if "space_mode" in config:
333
+ kwargs["space_mode"] = SpaceMode[config["space_mode"]]
334
+ if "time_mode" in config:
335
+ kwargs["time_mode"] = TimeMode[config["time_mode"]]
336
+ if "period_duration" in config:
337
+ kwargs["period_duration"] = timedelta(
338
+ seconds=pytimeparse.parse(config["period_duration"])
339
+ )
340
+ for k in ["min_matches", "max_matches"]:
341
+ if k not in config:
342
+ continue
343
+ kwargs[k] = config[k]
344
+ return QueryConfig(**kwargs)
275
345
 
276
346
 
277
347
  class DataSourceConfig:
@@ -307,16 +377,8 @@ class DataSourceConfig:
307
377
  self.ingest = ingest
308
378
 
309
379
  def serialize(self) -> dict[str, Any]:
310
- """Serialize this DataSourceConfig to a config dict, currently unused."""
311
- config_dict = self.config_dict.copy()
312
- config_dict["name"] = self.name
313
- config_dict["query_config"] = self.query_config.serialize()
314
- config_dict["ingest"] = self.ingest
315
- if self.time_offset:
316
- config_dict["time_offset"] = str(self.time_offset)
317
- if self.duration:
318
- config_dict["duration"] = str(self.duration)
319
- return config_dict
380
+ """Serialize this DataSourceConfig to a config dict."""
381
+ return self.config_dict
320
382
 
321
383
  @staticmethod
322
384
  def from_config(config: dict[str, Any]) -> "DataSourceConfig":
@@ -371,13 +433,40 @@ class LayerConfig:
371
433
  self.alias = alias
372
434
 
373
435
  def serialize(self) -> dict[str, Any]:
374
- """Serialize this LayerConfig to a config dict, currently unused."""
436
+ """Serialize this LayerConfig to a config dict."""
375
437
  return {
376
438
  "layer_type": str(self.layer_type),
377
- "data_source": self.data_source,
439
+ "data_source": self.data_source.serialize() if self.data_source else None,
378
440
  "alias": self.alias,
379
441
  }
380
442
 
443
+ def __hash__(self) -> int:
444
+ """Return a hash of this LayerConfig."""
445
+ return hash(json.dumps(self.serialize(), sort_keys=True))
446
+
447
+ def __eq__(self, other: Any) -> bool:
448
+ """Returns whether other is the same as this LayerConfig.
449
+
450
+ Args:
451
+ other: the other object to compare.
452
+ """
453
+ if not isinstance(other, LayerConfig):
454
+ return False
455
+ return self.serialize() == other.serialize()
456
+
457
+
458
+ class CompositingMethod(Enum):
459
+ """Method how to select pixels for the composite from corresponding items of a window."""
460
+
461
+ FIRST_VALID = 1
462
+ """Select first valid pixel in order of corresponding items (might be sorted)"""
463
+
464
+ MEAN = 2
465
+ """Select per-pixel mean value of corresponding items of a window"""
466
+
467
+ MEDIAN = 3
468
+ """Select per-pixel median value of corresponding items of a window"""
469
+
381
470
 
382
471
  class RasterLayerConfig(LayerConfig):
383
472
  """Configuration of a raster layer."""
@@ -389,6 +478,7 @@ class RasterLayerConfig(LayerConfig):
389
478
  data_source: DataSourceConfig | None = None,
390
479
  resampling_method: Resampling = Resampling.bilinear,
391
480
  alias: str | None = None,
481
+ compositing_method: CompositingMethod = CompositingMethod.FIRST_VALID,
392
482
  ):
393
483
  """Initialize a new RasterLayerConfig.
394
484
 
@@ -398,10 +488,12 @@ class RasterLayerConfig(LayerConfig):
398
488
  data_source: optional DataSourceConfig if this layer is retrievable
399
489
  resampling_method: how to resample rasters (if needed), default bilinear resampling
400
490
  alias: alias for this layer to use in the tile store
491
+ compositing_method: how to compute pixel values in the composite of each windows items
401
492
  """
402
493
  super().__init__(layer_type, data_source, alias)
403
494
  self.band_sets = band_sets
404
495
  self.resampling_method = resampling_method
496
+ self.compositing_method = compositing_method
405
497
 
406
498
  @staticmethod
407
499
  def from_config(config: dict[str, Any]) -> "RasterLayerConfig":
@@ -422,7 +514,11 @@ class RasterLayerConfig(LayerConfig):
422
514
  ]
423
515
  if "alias" in config:
424
516
  kwargs["alias"] = config["alias"]
425
- return RasterLayerConfig(**kwargs)
517
+ if "compositing_method" in config:
518
+ kwargs["compositing_method"] = CompositingMethod[
519
+ config["compositing_method"]
520
+ ]
521
+ return RasterLayerConfig(**kwargs) # type: ignore
426
522
 
427
523
 
428
524
  class VectorLayerConfig(LayerConfig):
@@ -432,22 +528,28 @@ class VectorLayerConfig(LayerConfig):
432
528
  self,
433
529
  layer_type: LayerType,
434
530
  data_source: DataSourceConfig | None = None,
435
- zoom_offset: int = 0,
436
531
  format: VectorFormatConfig = VectorFormatConfig("geojson"),
437
532
  alias: str | None = None,
533
+ class_property_name: str | None = None,
534
+ class_names: list[str] | None = None,
438
535
  ):
439
536
  """Initialize a new VectorLayerConfig.
440
537
 
441
538
  Args:
442
539
  layer_type: the LayerType (must be vector)
443
540
  data_source: optional DataSourceConfig if this layer is retrievable
444
- zoom_offset: zoom offset at which to store the vector data
445
541
  format: the VectorFormatConfig, default storing as GeoJSON
446
542
  alias: alias for this layer to use in the tile store
543
+ class_property_name: optional metadata field indicating that the GeoJSON
544
+ features contain a property that corresponds to a class label, and this
545
+ is the name of that property.
546
+ class_names: the list of classes that the class_property_name property
547
+ could be set to.
447
548
  """
448
549
  super().__init__(layer_type, data_source, alias)
449
- self.zoom_offset = zoom_offset
450
550
  self.format = format
551
+ self.class_property_name = class_property_name
552
+ self.class_names = class_names
451
553
 
452
554
  @staticmethod
453
555
  def from_config(config: dict[str, Any]) -> "VectorLayerConfig":
@@ -456,42 +558,29 @@ class VectorLayerConfig(LayerConfig):
456
558
  Args:
457
559
  config: the config dict for this VectorLayerConfig
458
560
  """
459
- kwargs = {"layer_type": LayerType(config["type"])}
561
+ kwargs: dict[str, Any] = {"layer_type": LayerType(config["type"])}
460
562
  if "data_source" in config:
461
563
  kwargs["data_source"] = DataSourceConfig.from_config(config["data_source"])
462
- if "zoom_offset" in config:
463
- kwargs["zoom_offset"] = config["zoom_offset"]
464
564
  if "format" in config:
465
565
  kwargs["format"] = VectorFormatConfig.from_config(config["format"])
466
- if "alias" in config:
467
- kwargs["alias"] = config["alias"]
468
- return VectorLayerConfig(**kwargs)
469
566
 
470
- def get_final_projection_and_bounds(
471
- self, projection: Projection, bounds: PixelBounds | None
472
- ) -> tuple[Projection, PixelBounds | None]:
473
- """Gets the final projection/bounds based on zoom offset.
567
+ simple_optionals = [
568
+ "alias",
569
+ "class_property_name",
570
+ "class_names",
571
+ ]
572
+ for k in simple_optionals:
573
+ if k in config:
574
+ kwargs[k] = config[k]
474
575
 
475
- Args:
476
- projection: the window's projection
477
- bounds: the window's bounds (optional)
576
+ # The "zoom_offset" option was removed.
577
+ # We should change how we create configuration so we can error on all
578
+ # non-existing config options, but for now we make sure to raise error if
579
+ # zoom_offset is set since it is no longer supported.
580
+ if "zoom_offset" in config:
581
+ raise ValueError("unsupported zoom_offset option in vector layer config")
478
582
 
479
- Returns:
480
- tuple of updated projection and bounds with zoom offset applied
481
- """
482
- if self.zoom_offset == 0:
483
- return projection, bounds
484
- projection = Projection(
485
- projection.crs,
486
- projection.x_resolution / (2**self.zoom_offset),
487
- projection.y_resolution / (2**self.zoom_offset),
488
- )
489
- if bounds:
490
- if self.zoom_offset > 0:
491
- bounds = tuple(x * (2**self.zoom_offset) for x in bounds)
492
- else:
493
- bounds = tuple(x // (2 ** (-self.zoom_offset)) for x in bounds)
494
- return projection, bounds
583
+ return VectorLayerConfig(**kwargs) # type: ignore
495
584
 
496
585
 
497
586
  def load_layer_config(config: dict[str, Any]) -> LayerConfig:
@@ -502,26 +591,3 @@ def load_layer_config(config: dict[str, Any]) -> LayerConfig:
502
591
  elif layer_type == LayerType.VECTOR:
503
592
  return VectorLayerConfig.from_config(config)
504
593
  raise ValueError(f"Unknown layer type {layer_type}")
505
-
506
-
507
- class TileStoreConfig:
508
- """A configuration specifying a TileStore."""
509
-
510
- def __init__(self, name: str, config_dict: dict[str, Any]) -> None:
511
- """Create a new TileStoreConfig.
512
-
513
- Args:
514
- name: the tile store implementation name to use
515
- config_dict: configuration options
516
- """
517
- self.name = name
518
- self.config_dict = config_dict
519
-
520
- @staticmethod
521
- def from_config(config: dict[str, Any]) -> "TileStoreConfig":
522
- """Create a TileStoreConfig from config dict.
523
-
524
- Args:
525
- config: the config dict for this TileStoreConfig
526
- """
527
- return TileStoreConfig(name=config["name"], config_dict=config)
rslearn/const.py CHANGED
@@ -1,23 +1,17 @@
1
1
  """Constants."""
2
2
 
3
- from rasterio.crs import CRS
4
-
5
- from rslearn.utils import PixelBounds, Projection
6
-
7
- WGS84_EPSG = 4326
8
- """The EPSG code for WGS-84."""
9
-
10
- WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
11
- """The Projection for WGS-84 assuming 1 degree per pixel.
12
-
13
- This can be used to create STGeometry with shapes in longitude/latitude coordinates.
14
- """
15
-
16
- WGS84_BOUNDS: PixelBounds = (-180, -90, 180, 90)
17
- """The bounds of the WGS-84 projection."""
3
+ from rslearn.utils.geometry import WGS84_BOUNDS, WGS84_EPSG, WGS84_PROJECTION
18
4
 
19
5
  TILE_SIZE = 512
20
6
  """Default tile size. TODO: remove this or move it elsewhere."""
21
7
 
22
8
  SHAPEFILE_AUX_EXTENSIONS = [".cpg", ".dbf", ".prj", ".sbn", ".sbx", ".shx", ".txt"]
23
9
  """Extensions of potential auxiliary files to .shp file."""
10
+
11
+ __all__ = (
12
+ "WGS84_PROJECTION",
13
+ "WGS84_EPSG",
14
+ "WGS84_BOUNDS",
15
+ "TILE_SIZE",
16
+ "SHAPEFILE_AUX_EXTENSIONS",
17
+ )
@@ -10,15 +10,20 @@ Each source supports operations to lookup items that match with spatiotemporal
10
10
  geometries, and ingest those items.
11
11
  """
12
12
 
13
+ import functools
13
14
  import importlib
14
15
 
15
16
  from upath import UPath
16
17
 
17
18
  from rslearn.config import LayerConfig
19
+ from rslearn.log_utils import get_logger
18
20
 
19
21
  from .data_source import DataSource, Item, ItemLookupDataSource, RetrieveItemDataSource
20
22
 
23
+ logger = get_logger(__name__)
21
24
 
25
+
26
+ @functools.cache
22
27
  def data_source_from_config(config: LayerConfig, ds_path: UPath) -> DataSource:
23
28
  """Loads a data source from config dict.
24
29
 
@@ -26,6 +31,9 @@ def data_source_from_config(config: LayerConfig, ds_path: UPath) -> DataSource:
26
31
  config: the LayerConfig containing this data source.
27
32
  ds_path: the dataset root directory.
28
33
  """
34
+ logger.debug("getting a data source for dataset at %s", ds_path)
35
+ if config.data_source is None:
36
+ raise ValueError("No data source specified")
29
37
  name = config.data_source.name
30
38
  module_name = ".".join(name.split(".")[:-1])
31
39
  class_name = name.split(".")[-1]