rslearn 0.0.15__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (49) hide show
  1. rslearn/config/__init__.py +2 -10
  2. rslearn/config/dataset.py +414 -420
  3. rslearn/data_sources/__init__.py +8 -31
  4. rslearn/data_sources/aws_landsat.py +13 -24
  5. rslearn/data_sources/aws_open_data.py +21 -46
  6. rslearn/data_sources/aws_sentinel1.py +3 -14
  7. rslearn/data_sources/climate_data_store.py +21 -40
  8. rslearn/data_sources/copernicus.py +30 -91
  9. rslearn/data_sources/data_source.py +26 -0
  10. rslearn/data_sources/earthdaily.py +13 -38
  11. rslearn/data_sources/earthdata_srtm.py +14 -32
  12. rslearn/data_sources/eurocrops.py +5 -9
  13. rslearn/data_sources/gcp_public_data.py +46 -43
  14. rslearn/data_sources/google_earth_engine.py +31 -44
  15. rslearn/data_sources/local_files.py +91 -100
  16. rslearn/data_sources/openstreetmap.py +21 -51
  17. rslearn/data_sources/planet.py +12 -30
  18. rslearn/data_sources/planet_basemap.py +4 -25
  19. rslearn/data_sources/planetary_computer.py +58 -141
  20. rslearn/data_sources/usda_cdl.py +15 -26
  21. rslearn/data_sources/usgs_landsat.py +4 -29
  22. rslearn/data_sources/utils.py +9 -0
  23. rslearn/data_sources/worldcereal.py +47 -54
  24. rslearn/data_sources/worldcover.py +16 -14
  25. rslearn/data_sources/worldpop.py +15 -18
  26. rslearn/data_sources/xyz_tiles.py +11 -30
  27. rslearn/dataset/dataset.py +6 -6
  28. rslearn/dataset/manage.py +14 -20
  29. rslearn/dataset/materialize.py +9 -45
  30. rslearn/lightning_cli.py +370 -1
  31. rslearn/main.py +3 -3
  32. rslearn/models/concatenate_features.py +93 -0
  33. rslearn/tile_stores/__init__.py +0 -11
  34. rslearn/train/dataset.py +4 -12
  35. rslearn/train/prediction_writer.py +16 -32
  36. rslearn/train/tasks/classification.py +2 -1
  37. rslearn/utils/fsspec.py +20 -0
  38. rslearn/utils/jsonargparse.py +79 -0
  39. rslearn/utils/raster_format.py +1 -41
  40. rslearn/utils/vector_format.py +1 -38
  41. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/METADATA +1 -1
  42. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/RECORD +47 -48
  43. rslearn/data_sources/geotiff.py +0 -1
  44. rslearn/data_sources/raster_source.py +0 -23
  45. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/WHEEL +0 -0
  46. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/entry_points.txt +0 -0
  47. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/licenses/LICENSE +0 -0
  48. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/licenses/NOTICE +0 -0
  49. {rslearn-0.0.15.dist-info → rslearn-0.0.16.dist-info}/top_level.txt +0 -0
@@ -2,7 +2,6 @@
2
2
 
3
3
  import functools
4
4
  import json
5
- from collections.abc import Callable
6
5
  from typing import Any, Generic, TypeVar
7
6
 
8
7
  import fiona
@@ -12,58 +11,43 @@ from rasterio.crs import CRS
12
11
  from upath import UPath
13
12
 
14
13
  import rslearn.data_sources.utils
15
- from rslearn.config import LayerConfig, LayerType, RasterLayerConfig, VectorLayerConfig
14
+ from rslearn.config import LayerType
16
15
  from rslearn.const import SHAPEFILE_AUX_EXTENSIONS
17
16
  from rslearn.log_utils import get_logger
18
17
  from rslearn.tile_stores import TileStoreWithLayer
19
18
  from rslearn.utils.feature import Feature
20
- from rslearn.utils.fsspec import get_upath_local, join_upath, open_rasterio_upath_reader
19
+ from rslearn.utils.fsspec import (
20
+ get_relative_suffix,
21
+ get_upath_local,
22
+ join_upath,
23
+ open_rasterio_upath_reader,
24
+ )
21
25
  from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
22
26
 
23
- from .data_source import DataSource, Item, QueryConfig
27
+ from .data_source import DataSource, DataSourceContext, Item, QueryConfig
24
28
 
25
29
  logger = get_logger("__name__")
26
- _ImporterT = TypeVar("_ImporterT", bound="Importer")
27
-
28
-
29
- class _ImporterRegistry(dict[str, type["Importer"]]):
30
- """Registry for Importer classes."""
31
-
32
- def register(self, name: str) -> Callable[[type[_ImporterT]], type[_ImporterT]]:
33
- """Decorator to register an importer class."""
34
-
35
- def decorator(cls: type[_ImporterT]) -> type[_ImporterT]:
36
- self[name] = cls
37
- return cls
38
-
39
- return decorator
40
-
41
-
42
- Importers = _ImporterRegistry()
43
30
 
44
31
 
45
32
  ItemType = TypeVar("ItemType", bound=Item)
46
- LayerConfigType = TypeVar("LayerConfigType", bound=LayerConfig)
47
33
  ImporterType = TypeVar("ImporterType", bound="Importer")
48
34
 
49
35
  SOURCE_NAME = "rslearn.data_sources.local_files.LocalFiles"
50
36
 
51
37
 
52
- class Importer(Generic[ItemType, LayerConfigType]):
38
+ class Importer(Generic[ItemType]):
53
39
  """An abstract base class for importing data from local files."""
54
40
 
55
- def list_items(self, config: LayerConfigType, src_dir: UPath) -> list[ItemType]:
41
+ def list_items(self, src_dir: UPath) -> list[ItemType]:
56
42
  """Extract a list of Items from the source directory.
57
43
 
58
44
  Args:
59
- config: the configuration of the layer.
60
45
  src_dir: the source directory.
61
46
  """
62
47
  raise NotImplementedError
63
48
 
64
49
  def ingest_item(
65
50
  self,
66
- config: LayerConfigType,
67
51
  tile_store: TileStoreWithLayer,
68
52
  item: ItemType,
69
53
  cur_geometries: list[STGeometry],
@@ -71,7 +55,6 @@ class Importer(Generic[ItemType, LayerConfigType]):
71
55
  """Ingest the specified local file item.
72
56
 
73
57
  Args:
74
- config: the configuration of the layer.
75
58
  tile_store: the tile store to ingest the data into.
76
59
  item: the Item to ingest
77
60
  cur_geometries: the geometries where the item is needed.
@@ -84,7 +67,7 @@ class RasterItemSpec:
84
67
 
85
68
  def __init__(
86
69
  self,
87
- fnames: list[UPath],
70
+ fnames: list[str],
88
71
  bands: list[list[str]] | None = None,
89
72
  name: str | None = None,
90
73
  ):
@@ -99,25 +82,6 @@ class RasterItemSpec:
99
82
  self.bands = bands
100
83
  self.name = name
101
84
 
102
- @staticmethod
103
- def from_config(src_dir: UPath, d: dict[str, Any]) -> "RasterItemSpec":
104
- """Decode a dict into a RasterItemSpec.
105
-
106
- Args:
107
- src_dir: the source directory.
108
- d: the configuration dict.
109
-
110
- Returns:
111
- the RasterItemSpec.
112
- """
113
- kwargs = dict(
114
- fnames=[join_upath(src_dir, suffix) for suffix in d["fnames"]],
115
- bands=d["bands"],
116
- )
117
- if "name" in d:
118
- kwargs["name"] = d["name"]
119
- return RasterItemSpec(**kwargs)
120
-
121
85
  def serialize(self) -> dict[str, Any]:
122
86
  """Serializes the RasterItemSpec to a JSON-encodable dictionary."""
123
87
  return {
@@ -130,7 +94,7 @@ class RasterItemSpec:
130
94
  def deserialize(d: dict[str, Any]) -> "RasterItemSpec":
131
95
  """Deserializes a RasterItemSpec from a JSON-decoded dictionary."""
132
96
  return RasterItemSpec(
133
- fnames=[UPath(s) for s in d["fnames"]],
97
+ fnames=[s for s in d["fnames"]],
134
98
  bands=d["bands"],
135
99
  name=d["name"],
136
100
  )
@@ -139,20 +103,25 @@ class RasterItemSpec:
139
103
  class RasterItem(Item):
140
104
  """An item corresponding to a local file."""
141
105
 
142
- def __init__(self, name: str, geometry: STGeometry, spec: RasterItemSpec):
143
- """Creates a new LocalFileItem.
106
+ def __init__(
107
+ self, name: str, geometry: STGeometry, src_dir: str, spec: RasterItemSpec
108
+ ):
109
+ """Creates a new RasterItem.
144
110
 
145
111
  Args:
146
112
  name: unique name of the item
147
113
  geometry: the spatial and temporal extent of the item
114
+ src_dir: the source directory.
148
115
  spec: the RasterItemSpec that specifies the filename(s) and bands.
149
116
  """
150
117
  super().__init__(name, geometry)
118
+ self.src_dir = src_dir
151
119
  self.spec = spec
152
120
 
153
121
  def serialize(self) -> dict:
154
122
  """Serializes the item to a JSON-encodable dictionary."""
155
123
  d = super().serialize()
124
+ d["src_dir"] = str(self.src_dir)
156
125
  d["spec"] = self.spec.serialize()
157
126
  return d
158
127
 
@@ -160,8 +129,11 @@ class RasterItem(Item):
160
129
  def deserialize(d: dict) -> "RasterItem":
161
130
  """Deserializes an item from a JSON-decoded dictionary."""
162
131
  item = super(RasterItem, RasterItem).deserialize(d)
132
+ src_dir = d["src_dir"]
163
133
  spec = RasterItemSpec.deserialize(d["spec"])
164
- return RasterItem(name=item.name, geometry=item.geometry, spec=spec)
134
+ return RasterItem(
135
+ name=item.name, geometry=item.geometry, src_dir=src_dir, spec=spec
136
+ )
165
137
 
166
138
 
167
139
  class VectorItem(Item):
@@ -193,29 +165,34 @@ class VectorItem(Item):
193
165
  )
194
166
 
195
167
 
196
- @Importers.register("raster")
197
168
  class RasterImporter(Importer):
198
169
  """An Importer for raster data."""
199
170
 
200
- def list_items(self, config: LayerConfig, src_dir: UPath) -> list[RasterItem]:
171
+ def __init__(self, item_specs: list[RasterItemSpec] | None = None):
172
+ """Create a new RasterImporter.
173
+
174
+ Args:
175
+ item_specs: the specs to specify the raster items directly. If None, the
176
+ raster items are automatically detected from the files in the source
177
+ directory.
178
+ """
179
+ self.item_specs = item_specs
180
+
181
+ def list_items(self, src_dir: UPath) -> list[Item]:
201
182
  """Extract a list of Items from the source directory.
202
183
 
203
184
  Args:
204
- config: the configuration of the layer.
205
185
  src_dir: the source directory.
206
186
  """
207
- item_specs: list[RasterItemSpec] = []
187
+ item_specs: list[RasterItemSpec]
188
+
208
189
  # See if user has provided the item specs directly.
209
- if (
210
- config.data_source is not None
211
- and "item_specs" in config.data_source.config_dict
212
- ):
213
- for spec_dict in config.data_source.config_dict["item_specs"]:
214
- spec = RasterItemSpec.from_config(src_dir, spec_dict)
215
- item_specs.append(spec)
190
+ if self.item_specs is not None:
191
+ item_specs = self.item_specs
216
192
  else:
217
193
  # Otherwise we need to list files and assume each one is separate.
218
194
  # And we'll need to autodetect the bands later.
195
+ item_specs = []
219
196
  file_paths = src_dir.glob("**/*.*")
220
197
  for path in file_paths:
221
198
  # Ignore JSON files.
@@ -228,14 +205,17 @@ class RasterImporter(Importer):
228
205
  if len(parts) >= 4 and parts[-2] == "tmp" and parts[-1].isdigit():
229
206
  continue
230
207
 
231
- spec = RasterItemSpec(fnames=[path], bands=None)
208
+ spec = RasterItemSpec(
209
+ fnames=[get_relative_suffix(src_dir, path)], bands=None
210
+ )
232
211
  item_specs.append(spec)
233
212
 
234
- items = []
213
+ items: list[Item] = []
235
214
  for spec in item_specs:
236
215
  # Get geometry from the first raster file.
237
216
  # We assume files are readable with rasterio.
238
- with open_rasterio_upath_reader(spec.fnames[0]) as src:
217
+ fname = join_upath(src_dir, spec.fnames[0])
218
+ with open_rasterio_upath_reader(fname) as src:
239
219
  crs = src.crs
240
220
  left = src.transform.c
241
221
  top = src.transform.f
@@ -263,33 +243,33 @@ class RasterImporter(Importer):
263
243
  if spec.name:
264
244
  item_name = spec.name
265
245
  else:
266
- item_name = spec.fnames[0].name.split(".")[0]
246
+ item_name = fname.name.split(".")[0]
267
247
 
268
248
  logger.debug(
269
249
  "RasterImporter.list_items: got bounds of %s: %s", item_name, geometry
270
250
  )
271
- items.append(RasterItem(item_name, geometry, spec))
251
+ items.append(RasterItem(item_name, geometry, str(src_dir), spec))
272
252
 
273
253
  logger.debug("RasterImporter.list_items: discovered %d items", len(items))
274
254
  return items
275
255
 
276
256
  def ingest_item(
277
257
  self,
278
- config: RasterLayerConfig,
279
258
  tile_store: TileStoreWithLayer,
280
- item: RasterItem,
259
+ item: Item,
281
260
  cur_geometries: list[STGeometry],
282
261
  ) -> None:
283
262
  """Ingest the specified local file item.
284
263
 
285
264
  Args:
286
- config: the configuration of the layer.
287
265
  tile_store: the tile store to ingest the data into.
288
266
  item: the RasterItem to ingest
289
267
  cur_geometries: the geometries where the item is needed.
290
268
  """
269
+ assert isinstance(item, RasterItem)
291
270
  for file_idx, fname in enumerate(item.spec.fnames):
292
- with open_rasterio_upath_reader(fname) as src:
271
+ fname_upath = join_upath(UPath(item.src_dir), fname)
272
+ with open_rasterio_upath_reader(fname_upath) as src:
293
273
  if item.spec.bands:
294
274
  bands = item.spec.bands[file_idx]
295
275
  else:
@@ -297,25 +277,23 @@ class RasterImporter(Importer):
297
277
 
298
278
  if tile_store.is_raster_ready(item.name, bands):
299
279
  continue
300
- tile_store.write_raster_file(item.name, bands, fname)
280
+ tile_store.write_raster_file(item.name, bands, fname_upath)
301
281
 
302
282
 
303
- @Importers.register("vector")
304
283
  class VectorImporter(Importer):
305
284
  """An Importer for vector data."""
306
285
 
307
286
  # We need some buffer around GeoJSON bounds in case it just contains one point.
308
287
  item_buffer_epsilon = 1e-4
309
288
 
310
- def list_items(self, config: LayerConfig, src_dir: UPath) -> list[VectorItem]:
289
+ def list_items(self, src_dir: UPath) -> list[Item]:
311
290
  """Extract a list of Items from the source directory.
312
291
 
313
292
  Args:
314
- config: the configuration of the layer.
315
293
  src_dir: the source directory.
316
294
  """
317
295
  file_paths = src_dir.glob("**/*.*")
318
- items: list[VectorItem] = []
296
+ items: list[Item] = []
319
297
 
320
298
  for path in file_paths:
321
299
  # Ignore JSON files.
@@ -375,25 +353,21 @@ class VectorImporter(Importer):
375
353
 
376
354
  def ingest_item(
377
355
  self,
378
- config: VectorLayerConfig,
379
356
  tile_store: TileStoreWithLayer,
380
- item: VectorItem,
357
+ item: Item,
381
358
  cur_geometries: list[STGeometry],
382
359
  ) -> None:
383
360
  """Ingest the specified local file item.
384
361
 
385
362
  Args:
386
- config: the configuration of the layer.
387
363
  tile_store: the TileStore to ingest the data into.
388
364
  item: the Item to ingest
389
365
  cur_geometries: the geometries where the item is needed.
390
366
  """
391
- if not isinstance(config, VectorLayerConfig):
392
- raise ValueError("VectorImporter requires a VectorLayerConfig")
393
-
394
367
  if tile_store.is_vector_ready(item.name):
395
368
  return
396
369
 
370
+ assert isinstance(item, VectorItem)
397
371
  path = UPath(item.path_uri)
398
372
 
399
373
  aux_files: list[UPath] = []
@@ -431,27 +405,44 @@ class LocalFiles(DataSource):
431
405
 
432
406
  def __init__(
433
407
  self,
434
- config: LayerConfig,
435
- src_dir: UPath,
408
+ src_dir: str,
409
+ raster_item_specs: list[RasterItemSpec] | None = None,
410
+ layer_type: LayerType | None = None,
411
+ context: DataSourceContext = DataSourceContext(),
436
412
  ) -> None:
437
413
  """Initialize a new LocalFiles instance.
438
414
 
439
415
  Args:
440
- config: configuration for this layer.
441
416
  src_dir: source directory to ingest
417
+ raster_item_specs: the specs to specify the raster items directly. If None,
418
+ the raster items are automatically detected from the files in the
419
+ source directory.
420
+ layer_type: the layer type. It only needs to be set if the layer_config is
421
+ missing from the context.
422
+ context: the data source context. The layer config must be in the context.
442
423
  """
443
- self.config = config
424
+ if context.ds_path is not None:
425
+ self.src_dir = join_upath(context.ds_path, src_dir)
426
+ else:
427
+ self.src_dir = UPath(src_dir)
444
428
 
445
- self.importer = Importers[config.layer_type.value]()
446
- self.src_dir = src_dir
429
+ # Determine layer type.
430
+ if context.layer_config is not None:
431
+ self.layer_type = context.layer_config.type
432
+ elif layer_type is not None:
433
+ self.layer_type = layer_type
434
+ else:
435
+ raise ValueError(
436
+ "layer type must be specified if the layer config is not in the context"
437
+ )
447
438
 
448
- @staticmethod
449
- def from_config(config: LayerConfig, ds_path: UPath) -> "LocalFiles":
450
- """Creates a new LocalFiles instance from a configuration dictionary."""
451
- if config.data_source is None:
452
- raise ValueError("LocalFiles data source requires a data source config")
453
- d = config.data_source.config_dict
454
- return LocalFiles(config=config, src_dir=join_upath(ds_path, d["src_dir"]))
439
+ self.importer: Importer
440
+ if self.layer_type == LayerType.RASTER:
441
+ self.importer = RasterImporter(item_specs=raster_item_specs)
442
+ elif self.layer_type == LayerType.VECTOR:
443
+ self.importer = VectorImporter()
444
+ else:
445
+ raise ValueError(f"unknown layer type {self.layer_type}")
455
446
 
456
447
  @functools.cache
457
448
  def list_items(self) -> list[Item]:
@@ -459,7 +450,7 @@ class LocalFiles(DataSource):
459
450
  cache_fname = self.src_dir / "summary.json"
460
451
  if not cache_fname.exists():
461
452
  logger.debug("cache at %s does not exist, listing items", cache_fname)
462
- items = self.importer.list_items(self.config, self.src_dir)
453
+ items = self.importer.list_items(self.src_dir)
463
454
  serialized_items = [item.serialize() for item in items]
464
455
  with cache_fname.open("w") as f:
465
456
  json.dump(serialized_items, f)
@@ -501,12 +492,12 @@ class LocalFiles(DataSource):
501
492
 
502
493
  def deserialize_item(self, serialized_item: Any) -> RasterItem | VectorItem:
503
494
  """Deserializes an item from JSON-decoded data."""
504
- if self.config.layer_type == LayerType.RASTER:
495
+ if self.layer_type == LayerType.RASTER:
505
496
  return RasterItem.deserialize(serialized_item)
506
- elif self.config.layer_type == LayerType.VECTOR:
497
+ elif self.layer_type == LayerType.VECTOR:
507
498
  return VectorItem.deserialize(serialized_item)
508
499
  else:
509
- raise ValueError(f"Unknown layer type: {self.config.layer_type}")
500
+ raise ValueError(f"Unknown layer type: {self.layer_type}")
510
501
 
511
502
  def ingest(
512
503
  self,
@@ -522,4 +513,4 @@ class LocalFiles(DataSource):
522
513
  geometries: a list of geometries needed for each item
523
514
  """
524
515
  for item, cur_geometries in zip(items, geometries):
525
- self.importer.ingest_item(self.config, tile_store, item, cur_geometries)
516
+ self.importer.ingest_item(tile_store, item, cur_geometries)
@@ -11,14 +11,17 @@ import osmium.osm.types
11
11
  import shapely
12
12
  from upath import UPath
13
13
 
14
- from rslearn.config import QueryConfig, VectorLayerConfig
14
+ from rslearn.config import QueryConfig
15
15
  from rslearn.const import WGS84_PROJECTION
16
- from rslearn.data_sources import DataSource, Item
16
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
17
17
  from rslearn.data_sources.utils import match_candidate_items_to_window
18
+ from rslearn.log_utils import get_logger
18
19
  from rslearn.tile_stores import TileStoreWithLayer
19
20
  from rslearn.utils import Feature, GridIndex, STGeometry
20
21
  from rslearn.utils.fsspec import get_upath_local, join_upath
21
22
 
23
+ logger = get_logger(__name__)
24
+
22
25
 
23
26
  class FeatureType(Enum):
24
27
  """OpenStreetMap feature type."""
@@ -55,27 +58,6 @@ class Filter:
55
58
  self.tag_properties = tag_properties
56
59
  self.to_geometry = to_geometry
57
60
 
58
- @staticmethod
59
- def from_config(d: dict[str, Any]) -> "Filter":
60
- """Creates a Filter from a config dict.
61
-
62
- Args:
63
- d: the config dict
64
-
65
- Returns:
66
- the Filter object
67
- """
68
- kwargs: dict[str, Any] = {}
69
- if "feature_types" in d:
70
- kwargs["feature_types"] = [FeatureType(el) for el in d["feature_types"]]
71
- if "tag_conditions" in d:
72
- kwargs["tag_conditions"] = d["tag_conditions"]
73
- if "tag_properties" in d:
74
- kwargs["tag_properties"] = d["tag_properties"]
75
- if "to_geometry" in d:
76
- kwargs["to_geometry"] = d["to_geometry"]
77
- return Filter(**kwargs)
78
-
79
61
  def match_tags(self, tags: dict[str, str]) -> bool:
80
62
  """Returns whether this filter matches based on the tags."""
81
63
  if not self.tag_conditions:
@@ -387,10 +369,10 @@ class OpenStreetMap(DataSource[OsmItem]):
387
369
 
388
370
  def __init__(
389
371
  self,
390
- config: VectorLayerConfig,
391
- pbf_fnames: list[UPath],
392
- bounds_fname: UPath,
372
+ pbf_fnames: list[str],
373
+ bounds_fname: str,
393
374
  categories: dict[str, Filter],
375
+ context: DataSourceContext = DataSourceContext(),
394
376
  ):
395
377
  """Initialize a new OpenStreetMap instance.
396
378
 
@@ -402,14 +384,21 @@ class OpenStreetMap(DataSource[OsmItem]):
402
384
  bounds_fname: filename where the bounds of the PBF are cached.
403
385
  categories: dictionary of (category name, filter). Features that match the
404
386
  filter will be emitted under the corresponding category.
387
+ context: the data source context.
405
388
  """
406
- self.config = config
407
- self.pbf_fnames = pbf_fnames
408
- self.bounds_fname = bounds_fname
409
389
  self.categories = categories
410
390
 
391
+ if context.ds_path is not None:
392
+ self.pbf_fnames = [
393
+ join_upath(context.ds_path, pbf_fname) for pbf_fname in pbf_fnames
394
+ ]
395
+ self.bounds_fname = join_upath(context.ds_path, bounds_fname)
396
+ else:
397
+ self.pbf_fnames = [UPath(pbf_fname) for pbf_fname in pbf_fnames]
398
+ self.bounds_fname = UPath(bounds_fname)
399
+
411
400
  if len(self.pbf_fnames) == 1 and not self.pbf_fnames[0].exists():
412
- print(
401
+ logger.info(
413
402
  "Downloading planet.osm.pbf from "
414
403
  + f"{self.planet_pbf_url} to {self.pbf_fnames[0]}"
415
404
  )
@@ -420,32 +409,13 @@ class OpenStreetMap(DataSource[OsmItem]):
420
409
  # Detect bounds of each pbf file if needed.
421
410
  self.pbf_bounds = self._get_pbf_bounds()
422
411
 
423
- @staticmethod
424
- def from_config(config: VectorLayerConfig, ds_path: UPath) -> "OpenStreetMap":
425
- """Creates a new OpenStreetMap instance from a configuration dictionary."""
426
- if config.data_source is None:
427
- raise ValueError("data_source is required")
428
- d = config.data_source.config_dict
429
- categories = {
430
- category_name: Filter.from_config(filter_config_dict)
431
- for category_name, filter_config_dict in d["categories"].items()
432
- }
433
- pbf_fnames = [join_upath(ds_path, pbf_fname) for pbf_fname in d["pbf_fnames"]]
434
- bounds_fname = join_upath(ds_path, d["bounds_fname"])
435
- return OpenStreetMap(
436
- config=config,
437
- pbf_fnames=pbf_fnames,
438
- bounds_fname=bounds_fname,
439
- categories=categories,
440
- )
441
-
442
412
  def _get_pbf_bounds(self) -> list[tuple[float, float, float, float]]:
443
413
  # Determine WGS84 bounds of each PBF file by processing them through
444
414
  # BoundsHandler.
445
415
  if not self.bounds_fname.exists():
446
416
  pbf_bounds = []
447
417
  for pbf_fname in self.pbf_fnames:
448
- print(f"detecting bounds of {pbf_fname}")
418
+ logger.info(f"detecting bounds of {pbf_fname}")
449
419
  handler = BoundsHandler()
450
420
  with get_upath_local(pbf_fname) as local_fname:
451
421
  handler.apply_file(local_fname)
@@ -512,7 +482,7 @@ class OpenStreetMap(DataSource[OsmItem]):
512
482
  if tile_store.is_vector_ready(cur_item.name):
513
483
  continue
514
484
 
515
- print(
485
+ logger.info(
516
486
  f"ingesting osm item {cur_item.name} "
517
487
  + f"with {len(cur_geometries)} geometries"
518
488
  )
@@ -14,9 +14,9 @@ import shapely
14
14
  from fsspec.implementations.local import LocalFileSystem
15
15
  from upath import UPath
16
16
 
17
- from rslearn.config import QueryConfig, RasterLayerConfig
17
+ from rslearn.config import QueryConfig
18
18
  from rslearn.const import WGS84_PROJECTION
19
- from rslearn.data_sources import DataSource, Item
19
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
20
20
  from rslearn.data_sources.utils import match_candidate_items_to_window
21
21
  from rslearn.tile_stores import TileStoreWithLayer
22
22
  from rslearn.utils import STGeometry
@@ -31,19 +31,18 @@ class Planet(DataSource):
31
31
 
32
32
  def __init__(
33
33
  self,
34
- config: RasterLayerConfig,
35
34
  item_type_id: str,
36
- cache_dir: UPath | None = None,
35
+ cache_dir: str | None = None,
37
36
  asset_type_id: str = "ortho_analytic_sr",
38
37
  range_filters: dict[str, dict[str, Any]] = {},
39
38
  use_permission_filter: bool = True,
40
39
  sort_by: str | None = None,
41
40
  bands: list[str] = ["b01", "b02", "b03", "b04"],
41
+ context: DataSourceContext = DataSourceContext(),
42
42
  ):
43
43
  """Initialize a new Planet instance.
44
44
 
45
45
  Args:
46
- config: the LayerConfig of the layer containing this data source
47
46
  item_type_id: the item type ID, like "PSScene" or "SkySatCollect".
48
47
  cache_dir: where to store downloaded assets, or None to just store it in
49
48
  temporary directory before putting into tile store.
@@ -60,39 +59,22 @@ class Planet(DataSource):
60
59
  "-clear_percent" or "cloud_cover" (if it starts with minus sign then we
61
60
  sort descending.)
62
61
  bands: what to call the bands in the asset.
62
+ context: the data source context.
63
63
  """
64
- self.config = config
65
64
  self.item_type_id = item_type_id
66
- self.cache_dir = cache_dir
67
65
  self.asset_type_id = asset_type_id
68
66
  self.range_filters = range_filters
69
67
  self.use_permission_filter = use_permission_filter
70
68
  self.sort_by = sort_by
71
69
  self.bands = bands
72
70
 
73
- @staticmethod
74
- def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Planet":
75
- """Creates a new Planet instance from a configuration dictionary."""
76
- if config.data_source is None:
77
- raise ValueError("data_source is required")
78
- d = config.data_source.config_dict
79
- kwargs = dict(
80
- config=config,
81
- item_type_id=d["item_type_id"],
82
- )
83
- optional_keys = [
84
- "asset_type_id",
85
- "range_filters",
86
- "use_permission_filter",
87
- "sort_by",
88
- "bands",
89
- ]
90
- for optional_key in optional_keys:
91
- if optional_key in d:
92
- kwargs[optional_key] = d[optional_key]
93
- if "cache_dir" in d:
94
- kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
95
- return Planet(**kwargs)
71
+ if cache_dir is None:
72
+ self.cache_dir = None
73
+ else:
74
+ if context.ds_path is not None:
75
+ self.cache_dir = join_upath(context.ds_path, cache_dir)
76
+ else:
77
+ self.cache_dir = UPath(cache_dir)
96
78
 
97
79
  async def _search_items(self, geometry: STGeometry) -> list[dict[str, Any]]:
98
80
  wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)