rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- """Data source for raster data on public Cloud Storage buckets."""
1
+ """Data source for OpenStreetMap vector features."""
2
2
 
3
3
  import json
4
4
  import shutil
@@ -7,17 +7,21 @@ from enum import Enum
7
7
  from typing import Any
8
8
 
9
9
  import osmium
10
+ import osmium.osm.types
10
11
  import shapely
11
12
  from upath import UPath
12
13
 
13
- from rslearn.config import LayerConfig, QueryConfig, VectorLayerConfig
14
+ from rslearn.config import QueryConfig
14
15
  from rslearn.const import WGS84_PROJECTION
15
- from rslearn.data_sources import DataSource, Item
16
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
16
17
  from rslearn.data_sources.utils import match_candidate_items_to_window
17
- from rslearn.tile_stores import LayerMetadata, TileStore
18
+ from rslearn.log_utils import get_logger
19
+ from rslearn.tile_stores import TileStoreWithLayer
18
20
  from rslearn.utils import Feature, GridIndex, STGeometry
19
21
  from rslearn.utils.fsspec import get_upath_local, join_upath
20
22
 
23
+ logger = get_logger(__name__)
24
+
21
25
 
22
26
  class FeatureType(Enum):
23
27
  """OpenStreetMap feature type."""
@@ -36,7 +40,7 @@ class Filter:
36
40
  tag_conditions: dict[str, list[str]] | None = None,
37
41
  tag_properties: dict[str, str] | None = None,
38
42
  to_geometry: str | None = None,
39
- ):
43
+ ) -> None:
40
44
  """Create a new Filter instance.
41
45
 
42
46
  Args:
@@ -54,27 +58,6 @@ class Filter:
54
58
  self.tag_properties = tag_properties
55
59
  self.to_geometry = to_geometry
56
60
 
57
- @staticmethod
58
- def from_config(d: dict[str, Any]) -> "Filter":
59
- """Creates a Filter from a config dict.
60
-
61
- Args:
62
- d: the config dict
63
-
64
- Returns:
65
- the Filter object
66
- """
67
- kwargs = {}
68
- if "feature_types" in d:
69
- kwargs["feature_types"] = [FeatureType(el) for el in d["feature_types"]]
70
- if "tag_conditions" in d:
71
- kwargs["tag_conditions"] = d["tag_conditions"]
72
- if "tag_properties" in d:
73
- kwargs["tag_properties"] = d["tag_properties"]
74
- if "to_geometry" in d:
75
- kwargs["to_geometry"] = d["to_geometry"]
76
- return Filter(**kwargs)
77
-
78
61
  def match_tags(self, tags: dict[str, str]) -> bool:
79
62
  """Returns whether this filter matches based on the tags."""
80
63
  if not self.tag_conditions:
@@ -104,12 +87,12 @@ class Filter:
104
87
  class BoundsHandler(osmium.SimpleHandler):
105
88
  """An osmium handler for computing the bounds of an input file."""
106
89
 
107
- def __init__(self):
90
+ def __init__(self) -> None:
108
91
  """Initialize a new BoundsHandler."""
109
92
  osmium.SimpleHandler.__init__(self)
110
- self.bounds = (180, 90, -180, -90)
93
+ self.bounds: tuple[float, float, float, float] = (180, 90, -180, -90)
111
94
 
112
- def node(self, n):
95
+ def node(self, n: osmium.osm.types.Node) -> None:
113
96
  """Handle nodes and update the computed bounds."""
114
97
  lon = n.location.lon
115
98
  lat = n.location.lat
@@ -130,7 +113,7 @@ class OsmHandler(osmium.SimpleHandler):
130
113
  geometries: list[STGeometry],
131
114
  grid_size: float = 0.03,
132
115
  padding: float = 0.03,
133
- ):
116
+ ) -> None:
134
117
  """Initialize a new OsmHandler.
135
118
 
136
119
  Args:
@@ -163,12 +146,12 @@ class OsmHandler(osmium.SimpleHandler):
163
146
  )
164
147
  self.grid_index.insert(bounds, 1)
165
148
 
166
- self.cached_nodes = {}
167
- self.cached_ways = {}
149
+ self.cached_nodes: dict = {}
150
+ self.cached_ways: dict = {}
168
151
 
169
- self.features = []
152
+ self.features: list[Feature] = []
170
153
 
171
- def node(self, n):
154
+ def node(self, n: osmium.osm.types.Node) -> None:
172
155
  """Handle nodes."""
173
156
  # Check if node is relevant to our geometries.
174
157
  lon = n.location.lon
@@ -193,7 +176,7 @@ class OsmHandler(osmium.SimpleHandler):
193
176
  )
194
177
  self.features.append(feat)
195
178
 
196
- def _get_way_coords(self, node_ids):
179
+ def _get_way_coords(self, node_ids: list[int]) -> list:
197
180
  coords = []
198
181
  for id in node_ids:
199
182
  if id not in self.cached_nodes:
@@ -201,7 +184,7 @@ class OsmHandler(osmium.SimpleHandler):
201
184
  coords.append(self.cached_nodes[id])
202
185
  return coords
203
186
 
204
- def way(self, w):
187
+ def way(self, w: osmium.osm.types.Way) -> None:
205
188
  """Handle ways."""
206
189
  # Collect nodes, skip if too few.
207
190
  node_ids = [member.ref for member in w.nodes]
@@ -235,7 +218,7 @@ class OsmHandler(osmium.SimpleHandler):
235
218
  )
236
219
  self.features.append(feat)
237
220
 
238
- def match_relation(self, r):
221
+ def match_relation(self, r: osmium.osm.types.Relation) -> None:
239
222
  """Handle relations."""
240
223
  # Collect ways and distinguish exterior vs holes, skip if none found.
241
224
  exterior_ways = []
@@ -267,7 +250,7 @@ class OsmHandler(osmium.SimpleHandler):
267
250
  # Merge the ways in case some exterior/interior polygons are split into
268
251
  # multiple ways.
269
252
  # And convert them from node IDs to coordinates.
270
- def get_polygons(ways):
253
+ def get_polygons(ways: list) -> list:
271
254
  polygons: list[list[int]] = []
272
255
  for way in ways:
273
256
  # Attempt to match the way to an existing polygon.
@@ -366,13 +349,13 @@ class OsmItem(Item):
366
349
  return d
367
350
 
368
351
  @staticmethod
369
- def deserialize(d: dict) -> Item:
352
+ def deserialize(d: dict) -> "OsmItem":
370
353
  """Deserializes an item from a JSON-decoded dictionary."""
371
354
  item = super(OsmItem, OsmItem).deserialize(d)
372
355
  return OsmItem(name=item.name, geometry=item.geometry, path_uri=d["path_uri"])
373
356
 
374
357
 
375
- class OpenStreetMap(DataSource):
358
+ class OpenStreetMap(DataSource[OsmItem]):
376
359
  """A data source for OpenStreetMap data from PBF file.
377
360
 
378
361
  An existing local PBF file can be used, or if the provided path doesn't exist, then
@@ -386,12 +369,12 @@ class OpenStreetMap(DataSource):
386
369
 
387
370
  def __init__(
388
371
  self,
389
- config: VectorLayerConfig,
390
- pbf_fnames: list[UPath],
391
- bounds_fname: UPath,
372
+ pbf_fnames: list[str],
373
+ bounds_fname: str,
392
374
  categories: dict[str, Filter],
375
+ context: DataSourceContext = DataSourceContext(),
393
376
  ):
394
- """Initialize a new Sentinel2 instance.
377
+ """Initialize a new OpenStreetMap instance.
395
378
 
396
379
  Args:
397
380
  config: the configuration of this layer.
@@ -401,14 +384,21 @@ class OpenStreetMap(DataSource):
401
384
  bounds_fname: filename where the bounds of the PBF are cached.
402
385
  categories: dictionary of (category name, filter). Features that match the
403
386
  filter will be emitted under the corresponding category.
387
+ context: the data source context.
404
388
  """
405
- self.config = config
406
- self.pbf_fnames = pbf_fnames
407
- self.bounds_fname = bounds_fname
408
389
  self.categories = categories
409
390
 
391
+ if context.ds_path is not None:
392
+ self.pbf_fnames = [
393
+ join_upath(context.ds_path, pbf_fname) for pbf_fname in pbf_fnames
394
+ ]
395
+ self.bounds_fname = join_upath(context.ds_path, bounds_fname)
396
+ else:
397
+ self.pbf_fnames = [UPath(pbf_fname) for pbf_fname in pbf_fnames]
398
+ self.bounds_fname = UPath(bounds_fname)
399
+
410
400
  if len(self.pbf_fnames) == 1 and not self.pbf_fnames[0].exists():
411
- print(
401
+ logger.info(
412
402
  "Downloading planet.osm.pbf from "
413
403
  + f"{self.planet_pbf_url} to {self.pbf_fnames[0]}"
414
404
  )
@@ -419,29 +409,13 @@ class OpenStreetMap(DataSource):
419
409
  # Detect bounds of each pbf file if needed.
420
410
  self.pbf_bounds = self._get_pbf_bounds()
421
411
 
422
- @staticmethod
423
- def from_config(config: LayerConfig, ds_path: UPath) -> "OpenStreetMap":
424
- """Creates a new OpenStreetMap instance from a configuration dictionary."""
425
- assert isinstance(config, VectorLayerConfig)
426
- d = config.data_source.config_dict
427
- categories = {
428
- category_name: Filter.from_config(filter_config_dict)
429
- for category_name, filter_config_dict in d["categories"].items()
430
- }
431
- pbf_fnames = [join_upath(ds_path, pbf_fname) for pbf_fname in d["pbf_fnames"]]
432
- bounds_fname = join_upath(ds_path, d["bounds_fname"])
433
- return OpenStreetMap(
434
- config=config,
435
- pbf_fnames=pbf_fnames,
436
- bounds_fname=bounds_fname,
437
- categories=categories,
438
- )
439
-
440
- def _get_pbf_bounds(self):
412
+ def _get_pbf_bounds(self) -> list[tuple[float, float, float, float]]:
413
+ # Determine WGS84 bounds of each PBF file by processing them through
414
+ # BoundsHandler.
441
415
  if not self.bounds_fname.exists():
442
416
  pbf_bounds = []
443
417
  for pbf_fname in self.pbf_fnames:
444
- print(f"detecting bounds of {pbf_fname}")
418
+ logger.info(f"detecting bounds of {pbf_fname}")
445
419
  handler = BoundsHandler()
446
420
  with get_upath_local(pbf_fname) as local_fname:
447
421
  handler.apply_file(local_fname)
@@ -458,7 +432,7 @@ class OpenStreetMap(DataSource):
458
432
 
459
433
  def get_items(
460
434
  self, geometries: list[STGeometry], query_config: QueryConfig
461
- ) -> list[list[list[Item]]]:
435
+ ) -> list[list[list[OsmItem]]]:
462
436
  """Get a list of items in the data source intersecting the given geometries.
463
437
 
464
438
  Args:
@@ -487,14 +461,14 @@ class OpenStreetMap(DataSource):
487
461
  groups.append(cur_groups)
488
462
  return groups
489
463
 
490
- def deserialize_item(self, serialized_item: Any) -> Item:
464
+ def deserialize_item(self, serialized_item: Any) -> OsmItem:
491
465
  """Deserializes an item from JSON-decoded data."""
492
466
  return OsmItem.deserialize(serialized_item)
493
467
 
494
468
  def ingest(
495
469
  self,
496
- tile_store: TileStore,
497
- items: list[Item],
470
+ tile_store: TileStoreWithLayer,
471
+ items: list[OsmItem],
498
472
  geometries: list[list[STGeometry]],
499
473
  ) -> None:
500
474
  """Ingest items into the given tile store.
@@ -504,10 +478,11 @@ class OpenStreetMap(DataSource):
504
478
  items: the items to ingest
505
479
  geometries: a list of geometries needed for each item
506
480
  """
507
- item_names = [item.name for item in items]
508
- item_names.sort()
509
481
  for cur_item, cur_geometries in zip(items, geometries):
510
- print(
482
+ if tile_store.is_vector_ready(cur_item.name):
483
+ continue
484
+
485
+ logger.info(
511
486
  f"ingesting osm item {cur_item.name} "
512
487
  + f"with {len(cur_geometries)} geometries"
513
488
  )
@@ -515,17 +490,4 @@ class OpenStreetMap(DataSource):
515
490
  with get_upath_local(UPath(cur_item.path_uri)) as local_fname:
516
491
  handler.apply_file(local_fname)
517
492
 
518
- projections = set()
519
- for geometry in cur_geometries:
520
- projection, _ = self.config.get_final_projection_and_bounds(
521
- geometry.projection, None
522
- )
523
- projections.add(projection)
524
-
525
- for projection in projections:
526
- features = [feat.to_projection(projection) for feat in handler.features]
527
- layer = tile_store.create_layer(
528
- (cur_item.name, str(projection)),
529
- LayerMetadata(projection, None, {}),
530
- )
531
- layer.write_vector(features)
493
+ tile_store.write_vector(cur_item.name, handler.features)
@@ -6,24 +6,22 @@ import pathlib
6
6
  import shutil
7
7
  import tempfile
8
8
  from datetime import datetime
9
+ from pathlib import Path
9
10
  from typing import Any
10
11
 
11
12
  import planet
12
- import rasterio
13
13
  import shapely
14
14
  from fsspec.implementations.local import LocalFileSystem
15
15
  from upath import UPath
16
16
 
17
- from rslearn.config import LayerConfig, QueryConfig, RasterLayerConfig
17
+ from rslearn.config import QueryConfig
18
18
  from rslearn.const import WGS84_PROJECTION
19
- from rslearn.data_sources import DataSource, Item
19
+ from rslearn.data_sources import DataSource, DataSourceContext, Item
20
20
  from rslearn.data_sources.utils import match_candidate_items_to_window
21
- from rslearn.tile_stores import PrefixedTileStore, TileStore
21
+ from rslearn.tile_stores import TileStoreWithLayer
22
22
  from rslearn.utils import STGeometry
23
23
  from rslearn.utils.fsspec import join_upath
24
24
 
25
- from .raster_source import get_needed_projections, ingest_raster
26
-
27
25
 
28
26
  class Planet(DataSource):
29
27
  """A data source for Planet Labs API.
@@ -33,19 +31,18 @@ class Planet(DataSource):
33
31
 
34
32
  def __init__(
35
33
  self,
36
- config: LayerConfig,
37
34
  item_type_id: str,
38
- cache_dir: UPath | None = None,
35
+ cache_dir: str | None = None,
39
36
  asset_type_id: str = "ortho_analytic_sr",
40
37
  range_filters: dict[str, dict[str, Any]] = {},
41
38
  use_permission_filter: bool = True,
42
39
  sort_by: str | None = None,
43
40
  bands: list[str] = ["b01", "b02", "b03", "b04"],
41
+ context: DataSourceContext = DataSourceContext(),
44
42
  ):
45
43
  """Initialize a new Planet instance.
46
44
 
47
45
  Args:
48
- config: the LayerConfig of the layer containing this data source
49
46
  item_type_id: the item type ID, like "PSScene" or "SkySatCollect".
50
47
  cache_dir: where to store downloaded assets, or None to just store it in
51
48
  temporary directory before putting into tile store.
@@ -62,38 +59,22 @@ class Planet(DataSource):
62
59
  "-clear_percent" or "cloud_cover" (if it starts with minus sign then we
63
60
  sort descending.)
64
61
  bands: what to call the bands in the asset.
62
+ context: the data source context.
65
63
  """
66
- self.config = config
67
64
  self.item_type_id = item_type_id
68
- self.cache_dir = cache_dir
69
65
  self.asset_type_id = asset_type_id
70
66
  self.range_filters = range_filters
71
67
  self.use_permission_filter = use_permission_filter
72
68
  self.sort_by = sort_by
73
69
  self.bands = bands
74
70
 
75
- @staticmethod
76
- def from_config(config: LayerConfig, ds_path: UPath) -> "Planet":
77
- """Creates a new Planet instance from a configuration dictionary."""
78
- assert isinstance(config, RasterLayerConfig)
79
- d = config.data_source.config_dict
80
- kwargs = dict(
81
- config=config,
82
- item_type_id=d["item_type_id"],
83
- )
84
- optional_keys = [
85
- "asset_type_id",
86
- "range_filters",
87
- "use_permission_filter",
88
- "sort_by",
89
- "bands",
90
- ]
91
- for optional_key in optional_keys:
92
- if optional_key in d:
93
- kwargs[optional_key] = d[optional_key]
94
- if "cache_dir" in d:
95
- kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
96
- return Planet(**kwargs)
71
+ if cache_dir is None:
72
+ self.cache_dir = None
73
+ else:
74
+ if context.ds_path is not None:
75
+ self.cache_dir = join_upath(context.ds_path, cache_dir)
76
+ else:
77
+ self.cache_dir = UPath(cache_dir)
97
78
 
98
79
  async def _search_items(self, geometry: STGeometry) -> list[dict[str, Any]]:
99
80
  wgs84_geometry = geometry.to_projection(WGS84_PROJECTION)
@@ -101,11 +82,10 @@ class Planet(DataSource):
101
82
 
102
83
  async with planet.Session() as session:
103
84
  client = session.client("data")
104
-
85
+ gte = geometry.time_range[0] if geometry.time_range is not None else None
86
+ lte = geometry.time_range[1] if geometry.time_range is not None else None
105
87
  filter_list = [
106
- planet.data_filter.date_range_filter(
107
- "acquired", gte=geometry.time_range[0], lte=geometry.time_range[1]
108
- ),
88
+ planet.data_filter.date_range_filter("acquired", gte=gte, lte=lte),
109
89
  planet.data_filter.geometry_filter(geojson_data),
110
90
  planet.data_filter.asset_filter([self.asset_type_id]),
111
91
  ]
@@ -242,7 +222,7 @@ class Planet(DataSource):
242
222
 
243
223
  def ingest(
244
224
  self,
245
- tile_store: TileStore,
225
+ tile_store: TileStoreWithLayer,
246
226
  items: list[Item],
247
227
  geometries: list[list[STGeometry]],
248
228
  ) -> None:
@@ -253,26 +233,10 @@ class Planet(DataSource):
253
233
  items: the items to ingest
254
234
  geometries: a list of geometries needed for each item
255
235
  """
256
- for item, cur_geometries in zip(items, geometries):
236
+ for item in items:
237
+ if tile_store.is_raster_ready(item.name, self.bands):
238
+ continue
239
+
257
240
  with tempfile.TemporaryDirectory() as tmp_dir:
258
- band_names = self.bands
259
- cur_tile_store = PrefixedTileStore(
260
- tile_store, (item.name, "_".join(band_names))
261
- )
262
- needed_projections = get_needed_projections(
263
- cur_tile_store, band_names, self.config.band_sets, cur_geometries
264
- )
265
- if not needed_projections:
266
- continue
267
-
268
- asset_path = asyncio.run(self._download_asset(item, tmp_dir))
269
- with asset_path.open("rb") as f:
270
- with rasterio.open(f) as raster:
271
- for projection in needed_projections:
272
- ingest_raster(
273
- tile_store=cur_tile_store,
274
- raster=raster,
275
- projection=projection,
276
- time_range=item.geometry.time_range,
277
- layer_config=self.config,
278
- )
241
+ asset_path = asyncio.run(self._download_asset(item, Path(tmp_dir)))
242
+ tile_store.write_raster_file(item.name, self.bands, asset_path)