rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,52 +1,61 @@
1
1
  """Data source for raster or vector data in local files."""
2
2
 
3
- from typing import Any
3
+ import functools
4
+ import json
5
+ from typing import Any, Generic, TypeVar
4
6
 
5
7
  import fiona
6
- import rasterio
7
8
  import shapely
8
9
  import shapely.geometry
9
- from class_registry import ClassRegistry
10
10
  from rasterio.crs import CRS
11
11
  from upath import UPath
12
12
 
13
13
  import rslearn.data_sources.utils
14
- from rslearn.config import LayerConfig, LayerType, VectorLayerConfig
15
- from rslearn.const import SHAPEFILE_AUX_EXTENSIONS, WGS84_PROJECTION
16
- from rslearn.tile_stores import LayerMetadata, PrefixedTileStore, TileStore
17
- from rslearn.utils import Feature, Projection, STGeometry
18
- from rslearn.utils.fsspec import get_upath_local, join_upath
14
+ from rslearn.config import LayerType
15
+ from rslearn.const import SHAPEFILE_AUX_EXTENSIONS
16
+ from rslearn.log_utils import get_logger
17
+ from rslearn.tile_stores import TileStoreWithLayer
18
+ from rslearn.utils.feature import Feature
19
+ from rslearn.utils.fsspec import (
20
+ get_relative_suffix,
21
+ get_upath_local,
22
+ join_upath,
23
+ open_rasterio_upath_reader,
24
+ )
25
+ from rslearn.utils.geometry import Projection, STGeometry, get_global_geometry
19
26
 
20
- from .data_source import DataSource, Item, QueryConfig
21
- from .raster_source import get_needed_projections, ingest_raster
27
+ from .data_source import DataSource, DataSourceContext, Item, QueryConfig
22
28
 
23
- Importers = ClassRegistry()
29
+ logger = get_logger("__name__")
24
30
 
25
31
 
26
- class Importer:
27
- """An abstract class for importing data from local files."""
32
+ ItemType = TypeVar("ItemType", bound=Item)
33
+ ImporterType = TypeVar("ImporterType", bound="Importer")
28
34
 
29
- def list_items(self, config: LayerConfig, src_dir: UPath) -> list[Item]:
35
+ SOURCE_NAME = "rslearn.data_sources.local_files.LocalFiles"
36
+
37
+
38
+ class Importer(Generic[ItemType]):
39
+ """An abstract base class for importing data from local files."""
40
+
41
+ def list_items(self, src_dir: UPath) -> list[ItemType]:
30
42
  """Extract a list of Items from the source directory.
31
43
 
32
44
  Args:
33
- config: the configuration of the layer.
34
45
  src_dir: the source directory.
35
46
  """
36
47
  raise NotImplementedError
37
48
 
38
49
  def ingest_item(
39
50
  self,
40
- config: LayerConfig,
41
- tile_store: TileStore,
42
- item: Item,
51
+ tile_store: TileStoreWithLayer,
52
+ item: ItemType,
43
53
  cur_geometries: list[STGeometry],
44
- ):
54
+ ) -> None:
45
55
  """Ingest the specified local file item.
46
56
 
47
57
  Args:
48
- config: the configuration of the layer.
49
- tile_store: the TileStore to ingest the data into.
58
+ tile_store: the tile store to ingest the data into.
50
59
  item: the Item to ingest
51
60
  cur_geometries: the geometries where the item is needed.
52
61
  """
@@ -58,7 +67,7 @@ class RasterItemSpec:
58
67
 
59
68
  def __init__(
60
69
  self,
61
- fnames: list[UPath],
70
+ fnames: list[str],
62
71
  bands: list[list[str]] | None = None,
63
72
  name: str | None = None,
64
73
  ):
@@ -73,25 +82,6 @@ class RasterItemSpec:
73
82
  self.bands = bands
74
83
  self.name = name
75
84
 
76
- @staticmethod
77
- def from_config(src_dir: UPath, d: dict[str, Any]) -> "RasterItemSpec":
78
- """Decode a dict into a RasterItemSpec.
79
-
80
- Args:
81
- src_dir: the source directory.
82
- d: the configuration dict.
83
-
84
- Returns:
85
- the RasterItemSpec.
86
- """
87
- kwargs = dict(
88
- fnames=[join_upath(src_dir, suffix) for suffix in d["fnames"]],
89
- bands=d["bands"],
90
- )
91
- if "name" in d:
92
- kwargs["name"] = d["name"]
93
- return RasterItemSpec(**kwargs)
94
-
95
85
  def serialize(self) -> dict[str, Any]:
96
86
  """Serializes the RasterItemSpec to a JSON-encodable dictionary."""
97
87
  return {
@@ -104,7 +94,7 @@ class RasterItemSpec:
104
94
  def deserialize(d: dict[str, Any]) -> "RasterItemSpec":
105
95
  """Deserializes a RasterItemSpec from a JSON-decoded dictionary."""
106
96
  return RasterItemSpec(
107
- fnames=[UPath(s) for s in d["fnames"]],
97
+ fnames=[s for s in d["fnames"]],
108
98
  bands=d["bands"],
109
99
  name=d["name"],
110
100
  )
@@ -113,29 +103,37 @@ class RasterItemSpec:
113
103
  class RasterItem(Item):
114
104
  """An item corresponding to a local file."""
115
105
 
116
- def __init__(self, name: str, geometry: STGeometry, spec: RasterItemSpec):
117
- """Creates a new LocalFileItem.
106
+ def __init__(
107
+ self, name: str, geometry: STGeometry, src_dir: str, spec: RasterItemSpec
108
+ ):
109
+ """Creates a new RasterItem.
118
110
 
119
111
  Args:
120
112
  name: unique name of the item
121
113
  geometry: the spatial and temporal extent of the item
114
+ src_dir: the source directory.
122
115
  spec: the RasterItemSpec that specifies the filename(s) and bands.
123
116
  """
124
117
  super().__init__(name, geometry)
118
+ self.src_dir = src_dir
125
119
  self.spec = spec
126
120
 
127
121
  def serialize(self) -> dict:
128
122
  """Serializes the item to a JSON-encodable dictionary."""
129
123
  d = super().serialize()
124
+ d["src_dir"] = str(self.src_dir)
130
125
  d["spec"] = self.spec.serialize()
131
126
  return d
132
127
 
133
128
  @staticmethod
134
- def deserialize(d: dict) -> Item:
129
+ def deserialize(d: dict) -> "RasterItem":
135
130
  """Deserializes an item from a JSON-decoded dictionary."""
136
131
  item = super(RasterItem, RasterItem).deserialize(d)
132
+ src_dir = d["src_dir"]
137
133
  spec = RasterItemSpec.deserialize(d["spec"])
138
- return RasterItem(name=item.name, geometry=item.geometry, spec=spec)
134
+ return RasterItem(
135
+ name=item.name, geometry=item.geometry, src_dir=src_dir, spec=spec
136
+ )
139
137
 
140
138
 
141
139
  class VectorItem(Item):
@@ -159,7 +157,7 @@ class VectorItem(Item):
159
157
  return d
160
158
 
161
159
  @staticmethod
162
- def deserialize(d: dict) -> Item:
160
+ def deserialize(d: dict) -> "VectorItem":
163
161
  """Deserializes an item from a JSON-decoded dictionary."""
164
162
  item = super(VectorItem, VectorItem).deserialize(d)
165
163
  return VectorItem(
@@ -167,119 +165,146 @@ class VectorItem(Item):
167
165
  )
168
166
 
169
167
 
170
- @Importers.register("raster")
171
168
  class RasterImporter(Importer):
172
169
  """An Importer for raster data."""
173
170
 
174
- def list_items(self, config: LayerConfig, src_dir: UPath) -> list[Item]:
171
+ def __init__(self, item_specs: list[RasterItemSpec] | None = None):
172
+ """Create a new RasterImporter.
173
+
174
+ Args:
175
+ item_specs: the specs to specify the raster items directly. If None, the
176
+ raster items are automatically detected from the files in the source
177
+ directory.
178
+ """
179
+ self.item_specs = item_specs
180
+
181
+ def list_items(self, src_dir: UPath) -> list[Item]:
175
182
  """Extract a list of Items from the source directory.
176
183
 
177
184
  Args:
178
- config: the configuration of the layer.
179
185
  src_dir: the source directory.
180
186
  """
181
- item_specs: list[RasterItemSpec] = []
187
+ item_specs: list[RasterItemSpec]
188
+
182
189
  # See if user has provided the item specs directly.
183
- if "item_specs" in config.data_source.config_dict:
184
- for spec_dict in config.data_source.config_dict["item_specs"]:
185
- spec = RasterItemSpec.from_config(src_dir, spec_dict)
186
- item_specs.append(spec)
190
+ if self.item_specs is not None:
191
+ item_specs = self.item_specs
187
192
  else:
188
193
  # Otherwise we need to list files and assume each one is separate.
189
194
  # And we'll need to autodetect the bands later.
195
+ item_specs = []
190
196
  file_paths = src_dir.glob("**/*.*")
191
197
  for path in file_paths:
192
- spec = RasterItemSpec(fnames=[path], bands=None)
198
+ # Ignore JSON files.
199
+ if path.name.endswith(".json"):
200
+ continue
201
+
202
+ # Ignore temporary files that may be created by open_atomic.
203
+ # The suffix should be like "X.tif.tmp.1234".
204
+ parts = path.name.split(".")
205
+ if len(parts) >= 4 and parts[-2] == "tmp" and parts[-1].isdigit():
206
+ continue
207
+
208
+ spec = RasterItemSpec(
209
+ fnames=[get_relative_suffix(src_dir, path)], bands=None
210
+ )
193
211
  item_specs.append(spec)
194
212
 
195
213
  items: list[Item] = []
196
214
  for spec in item_specs:
197
215
  # Get geometry from the first raster file.
198
216
  # We assume files are readable with rasterio.
199
- with spec.fnames[0].open("rb") as f:
200
- with rasterio.open(f) as src:
201
- crs = src.crs
202
- left = src.transform.c
203
- top = src.transform.f
204
- # Resolutions in projection units per pixel.
205
- x_resolution = src.transform.a
206
- y_resolution = src.transform.e
207
- start = (int(left / x_resolution), int(top / y_resolution))
208
- shp = shapely.box(
209
- start[0], start[1], start[0] + src.width, start[1] + src.height
210
- )
211
- projection = Projection(crs, x_resolution, y_resolution)
212
- geometry = STGeometry(projection, shp, None)
217
+ fname = join_upath(src_dir, spec.fnames[0])
218
+ with open_rasterio_upath_reader(fname) as src:
219
+ crs = src.crs
220
+ left = src.transform.c
221
+ top = src.transform.f
222
+ # Resolutions in projection units per pixel.
223
+ x_resolution = src.transform.a
224
+ y_resolution = src.transform.e
225
+ start = (int(left / x_resolution), int(top / y_resolution))
226
+ shp = shapely.box(
227
+ start[0], start[1], start[0] + src.width, start[1] + src.height
228
+ )
229
+ projection = Projection(crs, x_resolution, y_resolution)
230
+ geometry = STGeometry(projection, shp, None)
231
+
232
+ if geometry.is_too_large():
233
+ geometry = get_global_geometry(time_range=None)
234
+ logger.warning(
235
+ "Global geometry detected: this geometry will be matched against all "
236
+ "windows in the rslearn dataset. When using settings like "
237
+ "max_matches=1 and space_mode=MOSAIC, this may cause windows outside "
238
+ "the geometry’s valid bounds to be materialized from the global raster "
239
+ "instead of a more appropriate source. Consider using COMPOSITE mode, "
240
+ "or increasing max_matches if this behavior is unintended."
241
+ )
213
242
 
214
243
  if spec.name:
215
244
  item_name = spec.name
216
245
  else:
217
- item_name = spec.fnames[0].name.split(".")[0]
246
+ item_name = fname.name.split(".")[0]
218
247
 
219
- items.append(RasterItem(item_name, geometry, spec))
248
+ logger.debug(
249
+ "RasterImporter.list_items: got bounds of %s: %s", item_name, geometry
250
+ )
251
+ items.append(RasterItem(item_name, geometry, str(src_dir), spec))
220
252
 
253
+ logger.debug("RasterImporter.list_items: discovered %d items", len(items))
221
254
  return items
222
255
 
223
256
  def ingest_item(
224
257
  self,
225
- config: LayerConfig,
226
- tile_store: TileStore,
258
+ tile_store: TileStoreWithLayer,
227
259
  item: Item,
228
260
  cur_geometries: list[STGeometry],
229
- ):
261
+ ) -> None:
230
262
  """Ingest the specified local file item.
231
263
 
232
264
  Args:
233
- config: the configuration of the layer.
234
- tile_store: the TileStore to ingest the data into.
235
- item: the Item to ingest
265
+ tile_store: the tile store to ingest the data into.
266
+ item: the RasterItem to ingest
236
267
  cur_geometries: the geometries where the item is needed.
237
268
  """
238
269
  assert isinstance(item, RasterItem)
239
270
  for file_idx, fname in enumerate(item.spec.fnames):
240
- with fname.open("rb") as f:
241
- with rasterio.open(f) as src:
242
- if item.spec.bands:
243
- bands = item.spec.bands[file_idx]
244
- else:
245
- bands = [f"B{band_idx+1}" for band_idx in range(src.count)]
246
- cur_tile_store = PrefixedTileStore(tile_store, ("_".join(bands),))
247
- needed_projections = get_needed_projections(
248
- cur_tile_store, bands, config.band_sets, cur_geometries
249
- )
250
- if not needed_projections:
251
- return
252
-
253
- for projection in needed_projections:
254
- ingest_raster(
255
- tile_store=cur_tile_store,
256
- raster=src,
257
- projection=projection,
258
- time_range=item.geometry.time_range,
259
- layer_config=config,
260
- )
271
+ fname_upath = join_upath(UPath(item.src_dir), fname)
272
+ with open_rasterio_upath_reader(fname_upath) as src:
273
+ if item.spec.bands:
274
+ bands = item.spec.bands[file_idx]
275
+ else:
276
+ bands = [f"B{band_idx + 1}" for band_idx in range(src.count)]
277
+
278
+ if tile_store.is_raster_ready(item.name, bands):
279
+ continue
280
+ tile_store.write_raster_file(item.name, bands, fname_upath)
261
281
 
262
282
 
263
- @Importers.register("vector")
264
283
  class VectorImporter(Importer):
265
284
  """An Importer for vector data."""
266
285
 
267
- def list_items(self, config: LayerConfig, src_dir: UPath) -> list[Item]:
286
+ # We need some buffer around GeoJSON bounds in case it just contains one point.
287
+ item_buffer_epsilon = 1e-4
288
+
289
+ def list_items(self, src_dir: UPath) -> list[Item]:
268
290
  """Extract a list of Items from the source directory.
269
291
 
270
292
  Args:
271
- config: the configuration of the layer.
272
293
  src_dir: the source directory.
273
294
  """
274
295
  file_paths = src_dir.glob("**/*.*")
275
296
  items: list[Item] = []
276
297
 
277
298
  for path in file_paths:
299
+ # Ignore JSON files.
300
+ if path.name.endswith(".json"):
301
+ continue
302
+
278
303
  # Get the bounds of the features in the vector file, which we assume fiona can
279
304
  # read.
280
305
  # For shapefile, to open it we need to copy all the aux files.
281
306
  aux_files: list[UPath] = []
282
- if path.name.split(".")[-1] == "shp":
307
+ if path.name.endswith(".shp"):
283
308
  prefix = ".".join(path.name.split(".")[:-1])
284
309
  for ext in SHAPEFILE_AUX_EXTENSIONS:
285
310
  aux_files.append(path.parent / (prefix + ext))
@@ -299,45 +324,50 @@ class VectorImporter(Importer):
299
324
  bounds[2] = max(bounds[2], cur_bounds[2])
300
325
  bounds[3] = max(bounds[3], cur_bounds[3])
301
326
 
327
+ # Normal GeoJSON should have coordinates in CRS coordinates, i.e. it
328
+ # should be 1 projection unit/pixel.
302
329
  projection = Projection(crs, 1, 1)
303
- geometry = STGeometry(projection, shapely.box(*bounds), None)
330
+ geometry = STGeometry(
331
+ projection,
332
+ shapely.box(*bounds).buffer(self.item_buffer_epsilon),
333
+ None,
334
+ )
304
335
 
336
+ # There can be problems with GeoJSON files that have large spatial
337
+ # coverage, since the bounds may not re-project correctly to match
338
+ # windows that are using projections with limited validity.
339
+ # We check if there is a large spatial coverage here, and mark the
340
+ # item's geometry as having global coverage if so.
341
+ if geometry.is_too_large():
342
+ geometry = get_global_geometry(time_range=None)
343
+
344
+ logger.debug(
345
+ "VectorImporter.list_items: got bounds of %s: %s", path, geometry
346
+ )
305
347
  items.append(
306
348
  VectorItem(path.name.split(".")[0], geometry, path.absolute().as_uri())
307
349
  )
308
350
 
351
+ logger.debug("VectorImporter.list_items: discovered %d items", len(items))
309
352
  return items
310
353
 
311
354
  def ingest_item(
312
355
  self,
313
- config: LayerConfig,
314
- tile_store: TileStore,
356
+ tile_store: TileStoreWithLayer,
315
357
  item: Item,
316
358
  cur_geometries: list[STGeometry],
317
- ):
359
+ ) -> None:
318
360
  """Ingest the specified local file item.
319
361
 
320
362
  Args:
321
- config: the configuration of the layer.
322
363
  tile_store: the TileStore to ingest the data into.
323
364
  item: the Item to ingest
324
365
  cur_geometries: the geometries where the item is needed.
325
366
  """
326
- assert isinstance(config, VectorLayerConfig)
327
-
328
- needed_projections = set()
329
- for geometry in cur_geometries:
330
- projection, _ = config.get_final_projection_and_bounds(
331
- geometry.projection, None
332
- )
333
- ts_layer = tile_store.get_layer((str(projection),))
334
- if ts_layer and ts_layer.get_metadata().properties.get("completed"):
335
- continue
336
- needed_projections.add(projection)
337
-
338
- if not needed_projections:
367
+ if tile_store.is_vector_ready(item.name):
339
368
  return
340
369
 
370
+ assert isinstance(item, VectorItem)
341
371
  path = UPath(item.path_uri)
342
372
 
343
373
  aux_files: list[UPath] = []
@@ -347,14 +377,18 @@ class VectorImporter(Importer):
347
377
  aux_files.append(path.parent / (prefix + ext))
348
378
 
349
379
  # TODO: move converting fiona file to list[Feature] to utility function.
350
- # TODO: don't assume WGS-84 projection here.
351
380
  with get_upath_local(path, extra_paths=aux_files) as local_fname:
352
381
  with fiona.open(local_fname) as src:
382
+ crs = CRS.from_wkt(src.crs.to_wkt())
383
+ # Normal GeoJSON should have coordinates in CRS coordinates, i.e. it
384
+ # should be 1 projection unit/pixel.
385
+ projection = Projection(crs, 1, 1)
386
+
353
387
  features = []
354
388
  for feat in src:
355
389
  features.append(
356
390
  Feature.from_geojson(
357
- WGS84_PROJECTION,
391
+ projection,
358
392
  {
359
393
  "type": "Feature",
360
394
  "geometry": dict(feat.geometry),
@@ -363,37 +397,72 @@ class VectorImporter(Importer):
363
397
  )
364
398
  )
365
399
 
366
- for projection in needed_projections:
367
- cur_features = [feat.to_projection(projection) for feat in features]
368
- layer = tile_store.create_layer(
369
- (str(projection),),
370
- LayerMetadata(projection, None, {}),
371
- )
372
- layer.write_vector(cur_features)
373
- layer.set_property("completed", True)
400
+ tile_store.write_vector(item.name, features)
374
401
 
375
402
 
376
403
  class LocalFiles(DataSource):
377
404
  """A data source for ingesting data from local files."""
378
405
 
379
- def __init__(self, config: LayerConfig, src_dir: UPath) -> None:
406
+ def __init__(
407
+ self,
408
+ src_dir: str,
409
+ raster_item_specs: list[RasterItemSpec] | None = None,
410
+ layer_type: LayerType | None = None,
411
+ context: DataSourceContext = DataSourceContext(),
412
+ ) -> None:
380
413
  """Initialize a new LocalFiles instance.
381
414
 
382
415
  Args:
383
- config: configuration for this layer.
384
416
  src_dir: source directory to ingest
417
+ raster_item_specs: the specs to specify the raster items directly. If None,
418
+ the raster items are automatically detected from the files in the
419
+ source directory.
420
+ layer_type: the layer type. It only needs to be set if the layer_config is
421
+ missing from the context.
422
+ context: the data source context. The layer config must be in the context.
385
423
  """
386
- self.config = config
387
- self.src_dir = src_dir
424
+ if context.ds_path is not None:
425
+ self.src_dir = join_upath(context.ds_path, src_dir)
426
+ else:
427
+ self.src_dir = UPath(src_dir)
388
428
 
389
- self.importer: Importer = Importers[config.layer_type.value]
390
- self.items = self.importer.list_items(config, src_dir)
429
+ # Determine layer type.
430
+ if context.layer_config is not None:
431
+ self.layer_type = context.layer_config.type
432
+ elif layer_type is not None:
433
+ self.layer_type = layer_type
434
+ else:
435
+ raise ValueError(
436
+ "layer type must be specified if the layer config is not in the context"
437
+ )
391
438
 
392
- @staticmethod
393
- def from_config(config: LayerConfig, ds_path: UPath) -> "LocalFiles":
394
- """Creates a new LocalFiles instance from a configuration dictionary."""
395
- d = config.data_source.config_dict
396
- return LocalFiles(config=config, src_dir=join_upath(ds_path, d["src_dir"]))
439
+ self.importer: Importer
440
+ if self.layer_type == LayerType.RASTER:
441
+ self.importer = RasterImporter(item_specs=raster_item_specs)
442
+ elif self.layer_type == LayerType.VECTOR:
443
+ self.importer = VectorImporter()
444
+ else:
445
+ raise ValueError(f"unknown layer type {self.layer_type}")
446
+
447
+ @functools.cache
448
+ def list_items(self) -> list[Item]:
449
+ """Lists items from the source directory while maintaining a cache file."""
450
+ cache_fname = self.src_dir / "summary.json"
451
+ if not cache_fname.exists():
452
+ logger.debug("cache at %s does not exist, listing items", cache_fname)
453
+ items = self.importer.list_items(self.src_dir)
454
+ serialized_items = [item.serialize() for item in items]
455
+ with cache_fname.open("w") as f:
456
+ json.dump(serialized_items, f)
457
+ return items
458
+
459
+ logger.debug("loading item list from cache at %s", cache_fname)
460
+ with cache_fname.open() as f:
461
+ serialized_items = json.load(f)
462
+ return [
463
+ self.deserialize_item(serialized_item)
464
+ for serialized_item in serialized_items
465
+ ]
397
466
 
398
467
  def get_items(
399
468
  self, geometries: list[STGeometry], query_config: QueryConfig
@@ -410,7 +479,7 @@ class LocalFiles(DataSource):
410
479
  groups = []
411
480
  for geometry in geometries:
412
481
  cur_items = []
413
- for item in self.items:
482
+ for item in self.list_items():
414
483
  if not item.geometry.intersects(geometry):
415
484
  continue
416
485
  cur_items.append(item)
@@ -421,17 +490,18 @@ class LocalFiles(DataSource):
421
490
  groups.append(cur_groups)
422
491
  return groups
423
492
 
424
- def deserialize_item(self, serialized_item: Any) -> Item:
493
+ def deserialize_item(self, serialized_item: Any) -> RasterItem | VectorItem:
425
494
  """Deserializes an item from JSON-decoded data."""
426
- assert isinstance(serialized_item, dict)
427
- if self.config.layer_type == LayerType.RASTER:
495
+ if self.layer_type == LayerType.RASTER:
428
496
  return RasterItem.deserialize(serialized_item)
429
- elif self.config.layer_type == LayerType.VECTOR:
497
+ elif self.layer_type == LayerType.VECTOR:
430
498
  return VectorItem.deserialize(serialized_item)
499
+ else:
500
+ raise ValueError(f"Unknown layer type: {self.layer_type}")
431
501
 
432
502
  def ingest(
433
503
  self,
434
- tile_store: TileStore,
504
+ tile_store: TileStoreWithLayer,
435
505
  items: list[Item],
436
506
  geometries: list[list[STGeometry]],
437
507
  ) -> None:
@@ -443,5 +513,4 @@ class LocalFiles(DataSource):
443
513
  geometries: a list of geometries needed for each item
444
514
  """
445
515
  for item, cur_geometries in zip(items, geometries):
446
- cur_tile_store = PrefixedTileStore(tile_store, (item.name,))
447
- self.importer.ingest_item(self.config, cur_tile_store, item, cur_geometries)
516
+ self.importer.ingest_item(tile_store, item, cur_geometries)