rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. rslearn/config/dataset.py +22 -13
  2. rslearn/data_sources/__init__.py +8 -0
  3. rslearn/data_sources/aws_landsat.py +27 -18
  4. rslearn/data_sources/aws_open_data.py +41 -42
  5. rslearn/data_sources/copernicus.py +148 -2
  6. rslearn/data_sources/data_source.py +17 -10
  7. rslearn/data_sources/gcp_public_data.py +177 -100
  8. rslearn/data_sources/geotiff.py +1 -0
  9. rslearn/data_sources/google_earth_engine.py +17 -15
  10. rslearn/data_sources/local_files.py +59 -32
  11. rslearn/data_sources/openstreetmap.py +27 -23
  12. rslearn/data_sources/planet.py +10 -9
  13. rslearn/data_sources/planet_basemap.py +303 -0
  14. rslearn/data_sources/raster_source.py +23 -13
  15. rslearn/data_sources/usgs_landsat.py +56 -27
  16. rslearn/data_sources/utils.py +13 -6
  17. rslearn/data_sources/vector_source.py +1 -0
  18. rslearn/data_sources/xyz_tiles.py +8 -9
  19. rslearn/dataset/add_windows.py +1 -1
  20. rslearn/dataset/dataset.py +16 -5
  21. rslearn/dataset/manage.py +9 -4
  22. rslearn/dataset/materialize.py +26 -5
  23. rslearn/dataset/window.py +5 -0
  24. rslearn/log_utils.py +24 -0
  25. rslearn/main.py +123 -59
  26. rslearn/models/clip.py +62 -0
  27. rslearn/models/conv.py +56 -0
  28. rslearn/models/faster_rcnn.py +2 -19
  29. rslearn/models/fpn.py +1 -1
  30. rslearn/models/module_wrapper.py +43 -0
  31. rslearn/models/molmo.py +65 -0
  32. rslearn/models/multitask.py +1 -1
  33. rslearn/models/pooling_decoder.py +4 -2
  34. rslearn/models/satlaspretrain.py +4 -7
  35. rslearn/models/simple_time_series.py +61 -55
  36. rslearn/models/ssl4eo_s12.py +9 -9
  37. rslearn/models/swin.py +22 -21
  38. rslearn/models/unet.py +4 -2
  39. rslearn/models/upsample.py +35 -0
  40. rslearn/tile_stores/file.py +6 -3
  41. rslearn/tile_stores/tile_store.py +19 -7
  42. rslearn/train/callbacks/freeze_unfreeze.py +3 -3
  43. rslearn/train/data_module.py +5 -4
  44. rslearn/train/dataset.py +79 -36
  45. rslearn/train/lightning_module.py +15 -11
  46. rslearn/train/prediction_writer.py +22 -11
  47. rslearn/train/tasks/classification.py +9 -8
  48. rslearn/train/tasks/detection.py +94 -37
  49. rslearn/train/tasks/multi_task.py +1 -1
  50. rslearn/train/tasks/regression.py +8 -4
  51. rslearn/train/tasks/segmentation.py +23 -19
  52. rslearn/train/transforms/__init__.py +1 -1
  53. rslearn/train/transforms/concatenate.py +6 -2
  54. rslearn/train/transforms/crop.py +6 -2
  55. rslearn/train/transforms/flip.py +5 -1
  56. rslearn/train/transforms/normalize.py +9 -5
  57. rslearn/train/transforms/pad.py +1 -1
  58. rslearn/train/transforms/transform.py +3 -3
  59. rslearn/utils/__init__.py +4 -5
  60. rslearn/utils/array.py +2 -2
  61. rslearn/utils/feature.py +1 -1
  62. rslearn/utils/fsspec.py +70 -1
  63. rslearn/utils/geometry.py +155 -3
  64. rslearn/utils/grid_index.py +5 -5
  65. rslearn/utils/mp.py +4 -3
  66. rslearn/utils/raster_format.py +81 -73
  67. rslearn/utils/rtree_index.py +64 -17
  68. rslearn/utils/sqlite_index.py +7 -1
  69. rslearn/utils/utils.py +11 -3
  70. rslearn/utils/vector_format.py +113 -17
  71. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
  72. rslearn-0.0.2.dist-info/RECORD +94 -0
  73. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
  74. rslearn/utils/mgrs.py +0 -24
  75. rslearn-0.0.1.dist-info/RECORD +0 -88
  76. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
  77. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
  78. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
1
1
  """Data source for raster or vector data in local files."""
2
2
 
3
- from typing import Any
3
+ from typing import Any, Generic, TypeVar
4
4
 
5
5
  import fiona
6
6
  import rasterio
@@ -11,8 +11,8 @@ from rasterio.crs import CRS
11
11
  from upath import UPath
12
12
 
13
13
  import rslearn.data_sources.utils
14
- from rslearn.config import LayerConfig, LayerType, VectorLayerConfig
15
- from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_PROJECTION
14
+ from rslearn.config import LayerConfig, LayerType, RasterLayerConfig, VectorLayerConfig
15
+ from rslearn.const import SHAPEFILE_AUX_EXTENSIONS
16
16
  from rslearn.tile_stores import LayerMetadata, PrefixedTileStore, TileStore
17
17
  from rslearn.utils import Feature, Projection, STGeometry
18
18
  from rslearn.utils.fsspec import get_upath_local, join_upath
@@ -22,11 +22,15 @@ from .raster_source import get_needed_projections, ingest_raster
22
22
 
23
23
  Importers = ClassRegistry()
24
24
 
25
+ ItemType = TypeVar("ItemType", bound=Item)
26
+ LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
27
+ ImporterType = TypeVar("ImporterType", bound="Importer")
25
28
 
26
- class Importer:
27
- """An abstract class for importing data from local files."""
28
29
 
29
- def list_items(self, config: LayerConfig, src_dir: UPath) -> list[Item]:
30
+ class Importer(Generic[ItemType, LayerConfigType]):
31
+ """An abstract base class for importing data from local files."""
32
+
33
+ def list_items(self, config: LayerConfigType, src_dir: UPath) -> list[ItemType]:
30
34
  """Extract a list of Items from the source directory.
31
35
 
32
36
  Args:
@@ -37,11 +41,11 @@ class Importer:
37
41
 
38
42
  def ingest_item(
39
43
  self,
40
- config: LayerConfig,
44
+ config: LayerConfigType,
41
45
  tile_store: TileStore,
42
- item: Item,
46
+ item: ItemType,
43
47
  cur_geometries: list[STGeometry],
44
- ):
48
+ ) -> None:
45
49
  """Ingest the specified local file item.
46
50
 
47
51
  Args:
@@ -131,7 +135,7 @@ class RasterItem(Item):
131
135
  return d
132
136
 
133
137
  @staticmethod
134
- def deserialize(d: dict) -> Item:
138
+ def deserialize(d: dict) -> "RasterItem":
135
139
  """Deserializes an item from a JSON-decoded dictionary."""
136
140
  item = super(RasterItem, RasterItem).deserialize(d)
137
141
  spec = RasterItemSpec.deserialize(d["spec"])
@@ -159,7 +163,7 @@ class VectorItem(Item):
159
163
  return d
160
164
 
161
165
  @staticmethod
162
- def deserialize(d: dict) -> Item:
166
+ def deserialize(d: dict) -> "VectorItem":
163
167
  """Deserializes an item from a JSON-decoded dictionary."""
164
168
  item = super(VectorItem, VectorItem).deserialize(d)
165
169
  return VectorItem(
@@ -171,7 +175,7 @@ class VectorItem(Item):
171
175
  class RasterImporter(Importer):
172
176
  """An Importer for raster data."""
173
177
 
174
- def list_items(self, config: LayerConfig, src_dir: UPath) -> list[Item]:
178
+ def list_items(self, config: LayerConfig, src_dir: UPath) -> list[RasterItem]:
175
179
  """Extract a list of Items from the source directory.
176
180
 
177
181
  Args:
@@ -179,6 +183,8 @@ class RasterImporter(Importer):
179
183
  src_dir: the source directory.
180
184
  """
181
185
  item_specs: list[RasterItemSpec] = []
186
+ if config.data_source is None:
187
+ raise ValueError("RasterImporter requires a data source config")
182
188
  # See if user has provided the item specs directly.
183
189
  if "item_specs" in config.data_source.config_dict:
184
190
  for spec_dict in config.data_source.config_dict["item_specs"]:
@@ -192,7 +198,7 @@ class RasterImporter(Importer):
192
198
  spec = RasterItemSpec(fnames=[path], bands=None)
193
199
  item_specs.append(spec)
194
200
 
195
- items: list[Item] = []
201
+ items = []
196
202
  for spec in item_specs:
197
203
  # Get geometry from the first raster file.
198
204
  # We assume files are readable with rasterio.
@@ -222,20 +228,19 @@ class RasterImporter(Importer):
222
228
 
223
229
  def ingest_item(
224
230
  self,
225
- config: LayerConfig,
231
+ config: RasterLayerConfig,
226
232
  tile_store: TileStore,
227
- item: Item,
233
+ item: RasterItem,
228
234
  cur_geometries: list[STGeometry],
229
- ):
235
+ ) -> None:
230
236
  """Ingest the specified local file item.
231
237
 
232
238
  Args:
233
239
  config: the configuration of the layer.
234
240
  tile_store: the TileStore to ingest the data into.
235
- item: the Item to ingest
241
+ item: the RasterItem to ingest
236
242
  cur_geometries: the geometries where the item is needed.
237
243
  """
238
- assert isinstance(item, RasterItem)
239
244
  for file_idx, fname in enumerate(item.spec.fnames):
240
245
  with fname.open("rb") as f:
241
246
  with rasterio.open(f) as src:
@@ -264,7 +269,10 @@ class RasterImporter(Importer):
264
269
  class VectorImporter(Importer):
265
270
  """An Importer for vector data."""
266
271
 
267
- def list_items(self, config: LayerConfig, src_dir: UPath) -> list[Item]:
272
+ # We need some buffer around GeoJSON bounds in case it just contains one point.
273
+ item_buffer_epsilon = 1e-4
274
+
275
+ def list_items(self, config: LayerConfig, src_dir: UPath) -> list[VectorItem]:
268
276
  """Extract a list of Items from the source directory.
269
277
 
270
278
  Args:
@@ -272,7 +280,7 @@ class VectorImporter(Importer):
272
280
  src_dir: the source directory.
273
281
  """
274
282
  file_paths = src_dir.glob("**/*.*")
275
- items: list[Item] = []
283
+ items: list[VectorItem] = []
276
284
 
277
285
  for path in file_paths:
278
286
  # Get the bounds of the features in the vector file, which we assume fiona can
@@ -299,8 +307,14 @@ class VectorImporter(Importer):
299
307
  bounds[2] = max(bounds[2], cur_bounds[2])
300
308
  bounds[3] = max(bounds[3], cur_bounds[3])
301
309
 
310
+ # Normal GeoJSON should have coordinates in CRS coordinates, i.e. it
311
+ # should be 1 projection unit/pixel.
302
312
  projection = Projection(crs, 1, 1)
303
- geometry = STGeometry(projection, shapely.box(*bounds), None)
313
+ geometry = STGeometry(
314
+ projection,
315
+ shapely.box(*bounds).buffer(self.item_buffer_epsilon),
316
+ None,
317
+ )
304
318
 
305
319
  items.append(
306
320
  VectorItem(path.name.split(".")[0], geometry, path.absolute().as_uri())
@@ -310,11 +324,11 @@ class VectorImporter(Importer):
310
324
 
311
325
  def ingest_item(
312
326
  self,
313
- config: LayerConfig,
327
+ config: VectorLayerConfig,
314
328
  tile_store: TileStore,
315
- item: Item,
329
+ item: VectorItem,
316
330
  cur_geometries: list[STGeometry],
317
- ):
331
+ ) -> None:
318
332
  """Ingest the specified local file item.
319
333
 
320
334
  Args:
@@ -323,7 +337,8 @@ class VectorImporter(Importer):
323
337
  item: the Item to ingest
324
338
  cur_geometries: the geometries where the item is needed.
325
339
  """
326
- assert isinstance(config, VectorLayerConfig)
340
+ if not isinstance(config, VectorLayerConfig):
341
+ raise ValueError("VectorImporter requires a VectorLayerConfig")
327
342
 
328
343
  needed_projections = set()
329
344
  for geometry in cur_geometries:
@@ -347,14 +362,18 @@ class VectorImporter(Importer):
347
362
  aux_files.append(path.parent / (prefix + ext))
348
363
 
349
364
  # TODO: move converting fiona file to list[Feature] to utility function.
350
- # TODO: don't assume WGS-84 projection here.
351
365
  with get_upath_local(path, extra_paths=aux_files) as local_fname:
352
366
  with fiona.open(local_fname) as src:
367
+ crs = CRS.from_wkt(src.crs.to_wkt())
368
+ # Normal GeoJSON should have coordinates in CRS coordinates, i.e. it
369
+ # should be 1 projection unit/pixel.
370
+ projection = Projection(crs, 1, 1)
371
+
353
372
  features = []
354
373
  for feat in src:
355
374
  features.append(
356
375
  Feature.from_geojson(
357
- WGS84_PROJECTION,
376
+ projection,
358
377
  {
359
378
  "type": "Feature",
360
379
  "geometry": dict(feat.geometry),
@@ -376,7 +395,11 @@ class VectorImporter(Importer):
376
395
  class LocalFiles(DataSource):
377
396
  """A data source for ingesting data from local files."""
378
397
 
379
- def __init__(self, config: LayerConfig, src_dir: UPath) -> None:
398
+ def __init__(
399
+ self,
400
+ config: LayerConfig,
401
+ src_dir: UPath,
402
+ ) -> None:
380
403
  """Initialize a new LocalFiles instance.
381
404
 
382
405
  Args:
@@ -384,14 +407,17 @@ class LocalFiles(DataSource):
384
407
  src_dir: source directory to ingest
385
408
  """
386
409
  self.config = config
410
+
411
+ self.importer = Importers[config.layer_type.value]
387
412
  self.src_dir = src_dir
388
413
 
389
- self.importer: Importer = Importers[config.layer_type.value]
390
- self.items = self.importer.list_items(config, src_dir)
414
+ self.items = self.importer.list_items(self.config, src_dir)
391
415
 
392
416
  @staticmethod
393
417
  def from_config(config: LayerConfig, ds_path: UPath) -> "LocalFiles":
394
418
  """Creates a new LocalFiles instance from a configuration dictionary."""
419
+ if config.data_source is None:
420
+ raise ValueError("LocalFiles data source requires a data source config")
395
421
  d = config.data_source.config_dict
396
422
  return LocalFiles(config=config, src_dir=join_upath(ds_path, d["src_dir"]))
397
423
 
@@ -421,13 +447,14 @@ class LocalFiles(DataSource):
421
447
  groups.append(cur_groups)
422
448
  return groups
423
449
 
424
- def deserialize_item(self, serialized_item: Any) -> Item:
450
+ def deserialize_item(self, serialized_item: Any) -> RasterItem | VectorItem:
425
451
  """Deserializes an item from JSON-decoded data."""
426
- assert isinstance(serialized_item, dict)
427
452
  if self.config.layer_type == LayerType.RASTER:
428
453
  return RasterItem.deserialize(serialized_item)
429
454
  elif self.config.layer_type == LayerType.VECTOR:
430
455
  return VectorItem.deserialize(serialized_item)
456
+ else:
457
+ raise ValueError(f"Unknown layer type: {self.config.layer_type}")
431
458
 
432
459
  def ingest(
433
460
  self,
@@ -7,10 +7,11 @@ from enum import Enum
7
7
  from typing import Any
8
8
 
9
9
  import osmium
10
+ import osmium.osm.types
10
11
  import shapely
11
12
  from upath import UPath
12
13
 
13
- from rslearn.config import LayerConfig, QueryConfig, VectorLayerConfig
14
+ from rslearn.config import QueryConfig, VectorLayerConfig
14
15
  from rslearn.const import WGS84_PROJECTION
15
16
  from rslearn.data_sources import DataSource, Item
16
17
  from rslearn.data_sources.utils import match_candidate_items_to_window
@@ -36,7 +37,7 @@ class Filter:
36
37
  tag_conditions: dict[str, list[str]] | None = None,
37
38
  tag_properties: dict[str, str] | None = None,
38
39
  to_geometry: str | None = None,
39
- ):
40
+ ) -> None:
40
41
  """Create a new Filter instance.
41
42
 
42
43
  Args:
@@ -64,7 +65,7 @@ class Filter:
64
65
  Returns:
65
66
  the Filter object
66
67
  """
67
- kwargs = {}
68
+ kwargs: dict[str, Any] = {}
68
69
  if "feature_types" in d:
69
70
  kwargs["feature_types"] = [FeatureType(el) for el in d["feature_types"]]
70
71
  if "tag_conditions" in d:
@@ -104,12 +105,12 @@ class Filter:
104
105
  class BoundsHandler(osmium.SimpleHandler):
105
106
  """An osmium handler for computing the bounds of an input file."""
106
107
 
107
- def __init__(self):
108
+ def __init__(self) -> None:
108
109
  """Initialize a new BoundsHandler."""
109
110
  osmium.SimpleHandler.__init__(self)
110
- self.bounds = (180, 90, -180, -90)
111
+ self.bounds: tuple[float, float, float, float] = (180, 90, -180, -90)
111
112
 
112
- def node(self, n):
113
+ def node(self, n: osmium.osm.types.Node) -> None:
113
114
  """Handle nodes and update the computed bounds."""
114
115
  lon = n.location.lon
115
116
  lat = n.location.lat
@@ -130,7 +131,7 @@ class OsmHandler(osmium.SimpleHandler):
130
131
  geometries: list[STGeometry],
131
132
  grid_size: float = 0.03,
132
133
  padding: float = 0.03,
133
- ):
134
+ ) -> None:
134
135
  """Initialize a new OsmHandler.
135
136
 
136
137
  Args:
@@ -163,12 +164,12 @@ class OsmHandler(osmium.SimpleHandler):
163
164
  )
164
165
  self.grid_index.insert(bounds, 1)
165
166
 
166
- self.cached_nodes = {}
167
- self.cached_ways = {}
167
+ self.cached_nodes: dict = {}
168
+ self.cached_ways: dict = {}
168
169
 
169
- self.features = []
170
+ self.features: list[Feature] = []
170
171
 
171
- def node(self, n):
172
+ def node(self, n: osmium.osm.types.Node) -> None:
172
173
  """Handle nodes."""
173
174
  # Check if node is relevant to our geometries.
174
175
  lon = n.location.lon
@@ -193,7 +194,7 @@ class OsmHandler(osmium.SimpleHandler):
193
194
  )
194
195
  self.features.append(feat)
195
196
 
196
- def _get_way_coords(self, node_ids):
197
+ def _get_way_coords(self, node_ids: list[int]) -> list:
197
198
  coords = []
198
199
  for id in node_ids:
199
200
  if id not in self.cached_nodes:
@@ -201,7 +202,7 @@ class OsmHandler(osmium.SimpleHandler):
201
202
  coords.append(self.cached_nodes[id])
202
203
  return coords
203
204
 
204
- def way(self, w):
205
+ def way(self, w: osmium.osm.types.Way) -> None:
205
206
  """Handle ways."""
206
207
  # Collect nodes, skip if too few.
207
208
  node_ids = [member.ref for member in w.nodes]
@@ -235,7 +236,7 @@ class OsmHandler(osmium.SimpleHandler):
235
236
  )
236
237
  self.features.append(feat)
237
238
 
238
- def match_relation(self, r):
239
+ def match_relation(self, r: osmium.osm.types.Relation) -> None:
239
240
  """Handle relations."""
240
241
  # Collect ways and distinguish exterior vs holes, skip if none found.
241
242
  exterior_ways = []
@@ -267,7 +268,7 @@ class OsmHandler(osmium.SimpleHandler):
267
268
  # Merge the ways in case some exterior/interior polygons are split into
268
269
  # multiple ways.
269
270
  # And convert them from node IDs to coordinates.
270
- def get_polygons(ways):
271
+ def get_polygons(ways: list) -> list:
271
272
  polygons: list[list[int]] = []
272
273
  for way in ways:
273
274
  # Attempt to match the way to an existing polygon.
@@ -366,13 +367,13 @@ class OsmItem(Item):
366
367
  return d
367
368
 
368
369
  @staticmethod
369
- def deserialize(d: dict) -> Item:
370
+ def deserialize(d: dict) -> "OsmItem":
370
371
  """Deserializes an item from a JSON-decoded dictionary."""
371
372
  item = super(OsmItem, OsmItem).deserialize(d)
372
373
  return OsmItem(name=item.name, geometry=item.geometry, path_uri=d["path_uri"])
373
374
 
374
375
 
375
- class OpenStreetMap(DataSource):
376
+ class OpenStreetMap(DataSource[OsmItem]):
376
377
  """A data source for OpenStreetMap data from PBF file.
377
378
 
378
379
  An existing local PBF file can be used, or if the provided path doesn't exist, then
@@ -420,9 +421,10 @@ class OpenStreetMap(DataSource):
420
421
  self.pbf_bounds = self._get_pbf_bounds()
421
422
 
422
423
  @staticmethod
423
- def from_config(config: LayerConfig, ds_path: UPath) -> "OpenStreetMap":
424
+ def from_config(config: VectorLayerConfig, ds_path: UPath) -> "OpenStreetMap":
424
425
  """Creates a new OpenStreetMap instance from a configuration dictionary."""
425
- assert isinstance(config, VectorLayerConfig)
426
+ if config.data_source is None:
427
+ raise ValueError("data_source is required")
426
428
  d = config.data_source.config_dict
427
429
  categories = {
428
430
  category_name: Filter.from_config(filter_config_dict)
@@ -437,7 +439,9 @@ class OpenStreetMap(DataSource):
437
439
  categories=categories,
438
440
  )
439
441
 
440
- def _get_pbf_bounds(self):
442
+ def _get_pbf_bounds(self) -> list[tuple[float, float, float, float]]:
443
+ # Determine WGS84 bounds of each PBF file by processing them through
444
+ # BoundsHandler.
441
445
  if not self.bounds_fname.exists():
442
446
  pbf_bounds = []
443
447
  for pbf_fname in self.pbf_fnames:
@@ -458,7 +462,7 @@ class OpenStreetMap(DataSource):
458
462
 
459
463
  def get_items(
460
464
  self, geometries: list[STGeometry], query_config: QueryConfig
461
- ) -> list[list[list[Item]]]:
465
+ ) -> list[list[list[OsmItem]]]:
462
466
  """Get a list of items in the data source intersecting the given geometries.
463
467
 
464
468
  Args:
@@ -487,14 +491,14 @@ class OpenStreetMap(DataSource):
487
491
  groups.append(cur_groups)
488
492
  return groups
489
493
 
490
- def deserialize_item(self, serialized_item: Any) -> Item:
494
+ def deserialize_item(self, serialized_item: Any) -> OsmItem:
491
495
  """Deserializes an item from JSON-decoded data."""
492
496
  return OsmItem.deserialize(serialized_item)
493
497
 
494
498
  def ingest(
495
499
  self,
496
500
  tile_store: TileStore,
497
- items: list[Item],
501
+ items: list[OsmItem],
498
502
  geometries: list[list[STGeometry]],
499
503
  ) -> None:
500
504
  """Ingest items into the given tile store.
@@ -6,6 +6,7 @@ import pathlib
6
6
  import shutil
7
7
  import tempfile
8
8
  from datetime import datetime
9
+ from pathlib import Path
9
10
  from typing import Any
10
11
 
11
12
  import planet
@@ -14,7 +15,7 @@ import shapely
14
15
  from fsspec.implementations.local import LocalFileSystem
15
16
  from upath import UPath
16
17
 
17
- from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
18
+ from rslearn.config import QueryConfig, RasterLayerConfig
18
19
  from rslearn.const import WGS84_PROJECTION
19
20
  from rslearn.data_sources import DataSource, Item
20
21
  from rslearn.data_sources.utils import match_candidate_items_to_window
@@ -33,7 +34,7 @@ class Planet(DataSource):
33
34
 
34
35
  def __init__(
35
36
  self,
36
- config: LayerConfig,
37
+ config: RasterLayerConfig,
37
38
  item_type_id: str,
38
39
  cache_dir: UPath | None = None,
39
40
  asset_type_id: str = "ortho_analytic_sr",
@@ -73,9 +74,10 @@ class Planet(DataSource):
73
74
  self.bands = bands
74
75
 
75
76
  @staticmethod
76
- def from_config(config: LayerConfig, ds_path: UPath) -> "Planet":
77
+ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Planet":
77
78
  """Creates a new Planet instance from a configuration dictionary."""
78
- assert isinstance(config, RasterLayerConfig)
79
+ if config.data_source is None:
80
+ raise ValueError("data_source is required")
79
81
  d = config.data_source.config_dict
80
82
  kwargs = dict(
81
83
  config=config,
@@ -101,11 +103,10 @@ class Planet(DataSource):
101
103
 
102
104
  async with planet.Session() as session:
103
105
  client = session.client("data")
104
-
106
+ gte = geometry.time_range[0] if geometry.time_range is not None else None
107
+ lte = geometry.time_range[1] if geometry.time_range is not None else None
105
108
  filter_list = [
106
- planet.data_filter.date_range_filter(
107
- "acquired", gte=geometry.time_range[0], lte=geometry.time_range[1]
108
- ),
109
+ planet.data_filter.date_range_filter("acquired", gte=gte, lte=lte),
109
110
  planet.data_filter.geometry_filter(geojson_data),
110
111
  planet.data_filter.asset_filter([self.asset_type_id]),
111
112
  ]
@@ -265,7 +266,7 @@ class Planet(DataSource):
265
266
  if not needed_projections:
266
267
  continue
267
268
 
268
- asset_path = asyncio.run(self._download_asset(item, tmp_dir))
269
+ asset_path = asyncio.run(self._download_asset(item, Path(tmp_dir)))
269
270
  with asset_path.open("rb") as f:
270
271
  with rasterio.open(f) as raster:
271
272
  for projection in needed_projections: