rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,10 @@
1
- """Contains a SpatialIndex implementation that uses an sqlite database."""
1
+ """Contains a SpatialIndex implementation that uses an sqlite database.
2
+
3
+ # TODO: This is not yet complete decide to either complete it or remove this file.
4
+ """
5
+
6
+ # Ignoring Mypy until we determine if we want to keep this file.
7
+ # mypy: ignore-errors
2
8
 
3
9
  import json
4
10
  import sqlite3
@@ -1,20 +1,21 @@
1
1
  """Classes for writing vector data to a UPath."""
2
2
 
3
3
  import json
4
+ from enum import Enum
4
5
  from typing import Any
5
6
 
6
- import numpy as np
7
7
  import shapely
8
- from class_registry import ClassRegistry
8
+ from rasterio.crs import CRS
9
9
  from upath import UPath
10
10
 
11
- from rslearn.config import VectorFormatConfig
12
11
  from rslearn.const import WGS84_PROJECTION
12
+ from rslearn.log_utils import get_logger
13
+ from rslearn.utils.fsspec import open_atomic
13
14
 
14
15
  from .feature import Feature
15
- from .geometry import PixelBounds, Projection, STGeometry
16
+ from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
16
17
 
17
- VectorFormats = ClassRegistry()
18
+ logger = get_logger(__name__)
18
19
 
19
20
 
20
21
  class VectorFormat:
@@ -24,32 +25,45 @@ class VectorFormat:
24
25
  a UPath. Vector data is a list of GeoJSON-like features.
25
26
  """
26
27
 
27
- def encode_vector(
28
- self, path: UPath, projection: Projection, features: list[Feature]
29
- ) -> None:
28
+ def encode_vector(self, path: UPath, features: list[Feature]) -> None:
30
29
  """Encodes vector data.
31
30
 
32
31
  Args:
33
32
  path: the directory to write to
34
- projection: the projection of the raster data
35
33
  features: the vector data
36
34
  """
37
35
  raise NotImplementedError
38
36
 
39
- def decode_vector(self, path: UPath, bounds: PixelBounds) -> list[Feature]:
37
+ def decode_vector(
38
+ self, path: UPath, projection: Projection, bounds: PixelBounds
39
+ ) -> list[Feature]:
40
40
  """Decodes vector data.
41
41
 
42
42
  Args:
43
43
  path: the directory to read from
44
- bounds: the bounds of the vector data to read
44
+ projection: the projection to read the data in
45
+ bounds: the bounds to read under the given projection. Only features that
46
+ intersect the bounds should be returned.
45
47
 
46
48
  Returns:
47
49
  the vector data
48
50
  """
49
51
  raise NotImplementedError
50
52
 
53
+ @staticmethod
54
+ def from_config(name: str, config: dict[str, Any]) -> "VectorFormat":
55
+ """Create a VectorFormat from a config dict.
56
+
57
+ Args:
58
+ name: the name of this format
59
+ config: the config dict
60
+
61
+ Returns:
62
+ the VectorFormat instance
63
+ """
64
+ raise NotImplementedError
65
+
51
66
 
52
- @VectorFormats.register("tile")
53
67
  class TileVectorFormat(VectorFormat):
54
68
  """TileVectorFormat stores data in GeoJSON files corresponding to grid cells.
55
69
 
@@ -58,29 +72,62 @@ class TileVectorFormat(VectorFormat):
58
72
  intersect.
59
73
  """
60
74
 
61
- def __init__(self, tile_size: int = 512):
75
+ def __init__(
76
+ self,
77
+ tile_size: int = 512,
78
+ projection: Projection | None = None,
79
+ index_property_name: str = "tvf_index",
80
+ ):
62
81
  """Initialize a new TileVectorFormat instance.
63
82
 
64
83
  Args:
65
84
  tile_size: the tile size (grid size in pixels), default 512
85
+ projection: if set, store features under this projection. Otherwise, the
86
+ output projection is taken from the first feature in an encode_vector
87
+ call.
88
+ index_property_name: property name used to store an index integer that
89
+ identifies the same feature across different tiles.
66
90
  """
67
91
  self.tile_size = tile_size
92
+ self.projection = projection
93
+ self.index_property_name = index_property_name
68
94
 
69
- def encode_vector(
70
- self, path: UPath, projection: Projection, features: list[Feature]
71
- ) -> None:
95
+ def encode_vector(self, path: UPath, features: list[Feature]) -> None:
72
96
  """Encodes vector data.
73
97
 
74
98
  Args:
75
99
  path: the directory to write to
76
- projection: the projection of the raster data
77
100
  features: the vector data
78
101
  """
79
- tile_data = {}
80
- for feat in features:
102
+ # Determine the output projection to write in.
103
+ if len(features) == 0:
104
+ # We won't actually write any features but still setting output_projection
105
+ # to write to projection.json.
106
+ # We just fallback to WGS84 here.
107
+ output_projection = WGS84_PROJECTION
108
+ elif self.projection is not None:
109
+ output_projection = self.projection
110
+ else:
111
+ output_projection = features[0].geometry.projection
112
+
113
+ # Save metadata file containing the serialized projection so we can load it
114
+ # when decoding.
115
+ with open_atomic(path / "projection.json", "w") as f:
116
+ json.dump(output_projection.serialize(), f)
117
+
118
+ # Dictionary from tile (col, row) to the list of features intersecting that
119
+ # tile. We iterate over the features to populate tile_data, then write each
120
+ # tile as a separate file.
121
+ tile_data: dict[tuple[int, int], list[dict]] = {}
122
+
123
+ for feat_idx, feat in enumerate(features):
124
+ # Skip invalid features since they can cause errors.
81
125
  if not feat.geometry.shp.is_valid:
82
126
  continue
83
- bounds = feat.geometry.shp.bounds
127
+
128
+ # Identify each grid cell that this feature intersects.
129
+ geometry = feat.geometry.to_projection(output_projection)
130
+ bounds = geometry.shp.bounds
84
131
  start_tile = (
85
132
  int(bounds[0]) // self.tile_size,
86
133
  int(bounds[1]) // self.tile_size,
@@ -89,58 +136,73 @@ class TileVectorFormat(VectorFormat):
89
136
  int(bounds[2]) // self.tile_size + 1,
90
137
  int(bounds[3]) // self.tile_size + 1,
91
138
  )
139
+
140
+ # We add an index property to the features so when reading we can
141
+ # de-duplicate (in case we read multiple tiles that contain the same
142
+ # feature).
143
+ properties = {self.index_property_name: feat_idx}
144
+ properties.update(feat.properties)
145
+ # Use the re-projected geometry here.
146
+ output_feat = Feature(geometry, properties)
147
+ output_geojson = output_feat.to_geojson()
148
+
149
+ # Now we add the feature to each tile that it intersects.
92
150
  for col in range(start_tile[0], end_tile[0]):
93
151
  for row in range(start_tile[1], end_tile[1]):
94
- cur_shp = feat.geometry.shp.intersection(
95
- shapely.box(
96
- col * self.tile_size,
97
- row * self.tile_size,
98
- (col + 1) * self.tile_size,
99
- (row + 1) * self.tile_size,
100
- )
152
+ tile_box = shapely.box(
153
+ col * self.tile_size,
154
+ row * self.tile_size,
155
+ (col + 1) * self.tile_size,
156
+ (row + 1) * self.tile_size,
101
157
  )
102
- cur_shp = shapely.transform(
103
- cur_shp,
104
- lambda array: array
105
- - np.array([[col * self.tile_size, row * self.tile_size]]),
106
- )
107
- cur_feat = Feature(
108
- STGeometry(projection, cur_shp, None), feat.properties
109
- )
110
- try:
111
- cur_geojson = cur_feat.to_geojson()
112
- except Exception as e:
113
- print(e)
158
+ if not geometry.shp.intersects(tile_box):
114
159
  continue
115
160
  tile = (col, row)
116
161
  if tile not in tile_data:
117
162
  tile_data[tile] = []
118
- tile_data[tile].append(cur_geojson)
163
+ tile_data[tile].append(output_geojson)
119
164
 
120
165
  path.mkdir(parents=True, exist_ok=True)
166
+
167
+ # Now save each tile.
121
168
  for (col, row), geojson_features in tile_data.items():
122
169
  fc = {
123
170
  "type": "FeatureCollection",
124
171
  "features": [geojson_feat for geojson_feat in geojson_features],
125
- "properties": projection.serialize(),
172
+ "properties": output_projection.serialize(),
126
173
  }
127
- with (path / f"{col}_{row}.geojson").open("w") as f:
174
+ cur_fname = path / f"{col}_{row}.geojson"
175
+ logger.debug("writing tile (%d, %d) to %s", col, row, cur_fname)
176
+ with open_atomic(cur_fname, "w") as f:
128
177
  json.dump(fc, f)
129
178
 
130
- def decode_vector(self, path: UPath, bounds: PixelBounds) -> list[Feature]:
179
+ def decode_vector(
180
+ self, path: UPath, projection: Projection, bounds: PixelBounds
181
+ ) -> list[Feature]:
131
182
  """Decodes vector data.
132
183
 
133
184
  Args:
134
185
  path: the directory to read from
135
- bounds: the bounds of the vector data to read
186
+ projection: the projection to read the data in
187
+ bounds: the bounds to read under the given projection. Only features that
188
+ intersect the bounds should be returned.
136
189
 
137
190
  Returns:
138
191
  the vector data
139
192
  """
140
- start_tile = (bounds[0] // self.tile_size, bounds[1] // self.tile_size)
193
+ # Convert the bounds to the projection of the stored data.
194
+ with (path / "projection.json").open() as f:
195
+ storage_projection = Projection.deserialize(json.load(f))
196
+ bounds_geom = STGeometry(projection, shapely.box(*bounds), None)
197
+ storage_bounds = bounds_geom.to_projection(storage_projection).shp.bounds
198
+
199
+ start_tile = (
200
+ int(storage_bounds[0]) // self.tile_size,
201
+ int(storage_bounds[1]) // self.tile_size,
202
+ )
141
203
  end_tile = (
142
- (bounds[2] - 1) // self.tile_size + 1,
143
- (bounds[3] - 1) // self.tile_size + 1,
204
+ (int(storage_bounds[2]) - 1) // self.tile_size + 1,
205
+ (int(storage_bounds[3]) - 1) // self.tile_size + 1,
144
206
  )
145
207
  features = []
146
208
  for col in range(start_tile[0], end_tile[0]):
@@ -148,22 +210,12 @@ class TileVectorFormat(VectorFormat):
148
210
  cur_fname = path / f"{col}_{row}.geojson"
149
211
  if not cur_fname.exists():
150
212
  continue
151
- with cur_fname.open("r") as f:
213
+ with cur_fname.open() as f:
152
214
  fc = json.load(f)
153
- if "properties" in fc and "crs" in fc["properties"]:
154
- projection = Projection.deserialize(fc["properties"])
155
- else:
156
- projection = WGS84_PROJECTION
157
-
158
- for feat in fc["features"]:
159
- shp = shapely.geometry.shape(feat["geometry"])
160
- shp = shapely.transform(
161
- shp,
162
- lambda array: array
163
- + np.array([[col * self.tile_size, row * self.tile_size]]),
164
- )
165
- feat["geometry"] = json.loads(shapely.to_geojson(shp))
166
- features.append(Feature.from_geojson(projection, feat))
215
+
216
+ for geojson_feat in fc["features"]:
217
+ feat = Feature.from_geojson(storage_projection, geojson_feat)
218
+ features.append(feat.to_projection(projection))
167
219
  return features
168
220
 
169
221
  @staticmethod
@@ -177,54 +229,168 @@ class TileVectorFormat(VectorFormat):
177
229
  Returns:
178
230
  the TileVectorFormat
179
231
  """
180
- return TileVectorFormat(tile_size=config.get("tile_size", 512))
232
+ kwargs = {}
233
+ if "tile_size" in config:
234
+ kwargs["tile_size"] = config["tile_size"]
235
+ if "projection" in config:
236
+ kwargs["projection"] = Projection.deserialize(config["projection"])
237
+ if "index_property_name" in config:
238
+ kwargs["index_property_name"] = config["index_property_name"]
239
+ return TileVectorFormat(**kwargs)
240
+
241
+
242
+ class GeojsonCoordinateMode(Enum):
243
+ """The projection to use when writing GeoJSON file."""
244
+
245
+ # Write the features as is.
246
+ PIXEL = "pixel"
247
+
248
+ # Write the features in CRS coordinates (i.e., a projection with x_resolution=1 and
249
+ # y_resolution=1).
250
+ CRS = "crs"
251
+
252
+ # Write in WGS84 (longitude, latitude) coordinates.
253
+ WGS84 = "wgs84"
181
254
 
182
255
 
183
- @VectorFormats.register("geojson")
184
256
  class GeojsonVectorFormat(VectorFormat):
185
257
  """A vector format that uses one big GeoJSON."""
186
258
 
187
259
  fname = "data.geojson"
188
260
 
189
- def encode_vector(
190
- self, path: UPath, projection: Projection, features: list[Feature]
191
- ) -> None:
261
+ def __init__(
262
+ self, coordinate_mode: GeojsonCoordinateMode = GeojsonCoordinateMode.PIXEL
263
+ ):
264
+ """Create a new GeojsonVectorFormat.
265
+
266
+ Args:
267
+ coordinate_mode: the projection to use for coordinates written to the
268
+ GeoJSON files. PIXEL means we write them as is, CRS means we just undo
269
+ the resolution in the Projection so they are in CRS coordinates, and
270
+ WGS84 means we always write longitude/latitude. When using PIXEL, the
271
+ GeoJSON will not be readable by GIS tools since it relies on a custom
272
+ encoding.
273
+ """
274
+ self.coordinate_mode = coordinate_mode
275
+
276
+ def encode_to_file(self, fname: UPath, features: list[Feature]) -> None:
277
+ """Encode vector data to a specific file.
278
+
279
+ Args:
280
+ fname: the file to write to
281
+ features: the vector data
282
+ """
283
+ fc: dict[str, Any] = {"type": "FeatureCollection"}
284
+
285
+ # Identify target projection and convert features.
286
+ # Also set the target projection in the FeatureCollection.
287
+ # For PIXEL mode, we need to use a custom encoding so the resolution is stored.
288
+ output_projection: Projection
289
+ if len(features) > 0 and self.coordinate_mode != GeojsonCoordinateMode.WGS84:
290
+ if self.coordinate_mode == GeojsonCoordinateMode.PIXEL:
291
+ output_projection = features[0].geometry.projection
292
+ fc["properties"] = output_projection.serialize()
293
+ elif self.coordinate_mode == GeojsonCoordinateMode.CRS:
294
+ output_projection = Projection(
295
+ features[0].geometry.projection.crs, 1, 1
296
+ )
297
+ fc["crs"] = {
298
+ "type": "name",
299
+ "properties": {
300
+ "name": output_projection.crs.to_wkt(),
301
+ },
302
+ }
303
+ else:
304
+ # Either there are no features so we need to fallback to WGS84, or the
305
+ # coordinate mode is WGS84.
306
+ output_projection = WGS84_PROJECTION
307
+ fc["crs"] = {
308
+ "type": "name",
309
+ "properties": {
310
+ "name": output_projection.crs.to_wkt(),
311
+ },
312
+ }
313
+
314
+ fc["features"] = []
315
+ for feat in features:
316
+ feat = feat.to_projection(output_projection)
317
+ fc["features"].append(feat.to_geojson())
318
+
319
+ logger.debug(
320
+ "writing features to %s with coordinate mode %s",
321
+ fname,
322
+ self.coordinate_mode,
323
+ )
324
+ with open_atomic(fname, "w") as f:
325
+ json.dump(fc, f)
326
+
327
+ def encode_vector(self, path: UPath, features: list[Feature]) -> None:
192
328
  """Encodes vector data.
193
329
 
194
330
  Args:
195
331
  path: the directory to write to
196
- projection: the projection of the raster data
197
332
  features: the vector data
198
333
  """
199
334
  path.mkdir(parents=True, exist_ok=True)
200
- with (path / self.fname).open("w") as f:
201
- json.dump(
202
- {
203
- "type": "FeatureCollection",
204
- "features": [feat.to_geojson() for feat in features],
205
- "properties": projection.serialize(),
206
- },
207
- f,
208
- )
335
+ self.encode_to_file(path / self.fname, features)
209
336
 
210
- def decode_vector(self, path: UPath, bounds: PixelBounds) -> list[Feature]:
211
- """Decodes vector data.
337
+ def decode_from_file(self, fname: UPath) -> list[Feature]:
338
+ """Decodes vector data from a filename.
212
339
 
213
340
  Args:
214
- path: the directory to read from
215
- bounds: the bounds of the vector data to read
341
+ fname: the filename to read.
216
342
 
217
343
  Returns:
218
344
  the vector data
219
345
  """
220
- with (path / self.fname).open("r") as f:
346
+ with fname.open() as f:
221
347
  fc = json.load(f)
348
+
349
+ # Detect the projection that the features are stored under.
222
350
  if "properties" in fc and "crs" in fc["properties"]:
351
+ # Means it uses our custom Projection encoding.
223
352
  projection = Projection.deserialize(fc["properties"])
353
+ elif "crs" in fc:
354
+ # Means it uses standard GeoJSON CRS encoding.
355
+ crs = CRS.from_string(fc["crs"]["properties"]["name"])
356
+ projection = Projection(crs, 1, 1)
224
357
  else:
358
+ # Otherwise it should be WGS84 (GeoJSONs created in rslearn should include
359
+ # the "crs" attribute, but maybe it was created externally).
225
360
  projection = WGS84_PROJECTION
361
+
226
362
  return [Feature.from_geojson(projection, feat) for feat in fc["features"]]
227
363
 
364
+ def decode_vector(
365
+ self, path: UPath, projection: Projection, bounds: PixelBounds
366
+ ) -> list[Feature]:
367
+ """Decodes vector data.
368
+
369
+ Args:
370
+ path: the directory to read from
371
+ projection: the projection to read the data in
372
+ bounds: the bounds to read under the given projection. Only features that
373
+ intersect the bounds should be returned.
374
+
375
+ Returns:
376
+ the vector data
377
+ """
378
+ features = self.decode_from_file(path / self.fname)
379
+
380
+ # Re-project to the desired projection and clip to bounds.
381
+ dst_geom = STGeometry(projection, shapely.box(*bounds), None)
382
+ reprojected_geoms = safely_reproject_and_clip(
383
+ [feat.geometry for feat in features], dst_geom
384
+ )
385
+ reprojected_features = []
386
+ for feat, geom in zip(features, reprojected_geoms):
387
+ if geom is None:
388
+ # None value means that it did not intersect the provided bounds.
389
+ continue
390
+ reprojected_features.append(Feature(geom, feat.properties))
391
+
392
+ return reprojected_features
393
+
228
394
  @staticmethod
229
395
  def from_config(name: str, config: dict[str, Any]) -> "GeojsonVectorFormat":
230
396
  """Create a GeojsonVectorFormat from a config dict.
@@ -236,18 +402,7 @@ class GeojsonVectorFormat(VectorFormat):
236
402
  Returns:
237
403
  the GeojsonVectorFormat
238
404
  """
239
- return GeojsonVectorFormat()
240
-
241
-
242
- def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
243
- """Loads a VectorFormat from a VectorFormatConfig.
244
-
245
- Args:
246
- config: the VectorFormatConfig configuration object specifying the
247
- VectorFormat.
248
-
249
- Returns:
250
- the loaded VectorFormat implementation
251
- """
252
- cls = VectorFormats.get_class(config.name)
253
- return cls.from_config(config.name, config.config_dict)
405
+ kwargs = {}
406
+ if "coordinate_mode" in config:
407
+ kwargs["coordinate_mode"] = GeojsonCoordinateMode(config["coordinate_mode"])
408
+ return GeojsonVectorFormat(**kwargs)