rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- rslearn/arg_parser.py +31 -0
- rslearn/config/__init__.py +6 -12
- rslearn/config/dataset.py +520 -401
- rslearn/const.py +9 -15
- rslearn/data_sources/__init__.py +8 -23
- rslearn/data_sources/aws_landsat.py +242 -98
- rslearn/data_sources/aws_open_data.py +111 -151
- rslearn/data_sources/aws_sentinel1.py +131 -0
- rslearn/data_sources/climate_data_store.py +471 -0
- rslearn/data_sources/copernicus.py +884 -12
- rslearn/data_sources/data_source.py +43 -12
- rslearn/data_sources/earthdaily.py +484 -0
- rslearn/data_sources/earthdata_srtm.py +282 -0
- rslearn/data_sources/eurocrops.py +242 -0
- rslearn/data_sources/gcp_public_data.py +578 -222
- rslearn/data_sources/google_earth_engine.py +461 -135
- rslearn/data_sources/local_files.py +219 -150
- rslearn/data_sources/openstreetmap.py +51 -89
- rslearn/data_sources/planet.py +24 -60
- rslearn/data_sources/planet_basemap.py +275 -0
- rslearn/data_sources/planetary_computer.py +798 -0
- rslearn/data_sources/usda_cdl.py +195 -0
- rslearn/data_sources/usgs_landsat.py +115 -83
- rslearn/data_sources/utils.py +249 -61
- rslearn/data_sources/vector_source.py +1 -0
- rslearn/data_sources/worldcereal.py +449 -0
- rslearn/data_sources/worldcover.py +144 -0
- rslearn/data_sources/worldpop.py +153 -0
- rslearn/data_sources/xyz_tiles.py +150 -107
- rslearn/dataset/__init__.py +8 -2
- rslearn/dataset/add_windows.py +2 -2
- rslearn/dataset/dataset.py +40 -51
- rslearn/dataset/handler_summaries.py +131 -0
- rslearn/dataset/manage.py +313 -74
- rslearn/dataset/materialize.py +431 -107
- rslearn/dataset/remap.py +29 -4
- rslearn/dataset/storage/__init__.py +1 -0
- rslearn/dataset/storage/file.py +202 -0
- rslearn/dataset/storage/storage.py +140 -0
- rslearn/dataset/window.py +181 -44
- rslearn/lightning_cli.py +454 -0
- rslearn/log_utils.py +24 -0
- rslearn/main.py +384 -181
- rslearn/models/anysat.py +215 -0
- rslearn/models/attention_pooling.py +177 -0
- rslearn/models/clay/clay.py +231 -0
- rslearn/models/clay/configs/metadata.yaml +295 -0
- rslearn/models/clip.py +68 -0
- rslearn/models/component.py +111 -0
- rslearn/models/concatenate_features.py +103 -0
- rslearn/models/conv.py +63 -0
- rslearn/models/croma.py +306 -0
- rslearn/models/detr/__init__.py +5 -0
- rslearn/models/detr/box_ops.py +103 -0
- rslearn/models/detr/detr.py +504 -0
- rslearn/models/detr/matcher.py +107 -0
- rslearn/models/detr/position_encoding.py +114 -0
- rslearn/models/detr/transformer.py +429 -0
- rslearn/models/detr/util.py +24 -0
- rslearn/models/dinov3.py +177 -0
- rslearn/models/faster_rcnn.py +30 -28
- rslearn/models/feature_center_crop.py +53 -0
- rslearn/models/fpn.py +19 -8
- rslearn/models/galileo/__init__.py +5 -0
- rslearn/models/galileo/galileo.py +595 -0
- rslearn/models/galileo/single_file_galileo.py +1678 -0
- rslearn/models/module_wrapper.py +65 -0
- rslearn/models/molmo.py +69 -0
- rslearn/models/multitask.py +384 -28
- rslearn/models/olmoearth_pretrain/__init__.py +1 -0
- rslearn/models/olmoearth_pretrain/model.py +421 -0
- rslearn/models/olmoearth_pretrain/norm.py +86 -0
- rslearn/models/panopticon.py +170 -0
- rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
- rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
- rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
- rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
- rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
- rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
- rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
- rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
- rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
- rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
- rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
- rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
- rslearn/models/pick_features.py +17 -10
- rslearn/models/pooling_decoder.py +60 -7
- rslearn/models/presto/__init__.py +5 -0
- rslearn/models/presto/presto.py +297 -0
- rslearn/models/presto/single_file_presto.py +926 -0
- rslearn/models/prithvi.py +1147 -0
- rslearn/models/resize_features.py +59 -0
- rslearn/models/sam2_enc.py +13 -9
- rslearn/models/satlaspretrain.py +38 -18
- rslearn/models/simple_time_series.py +188 -77
- rslearn/models/singletask.py +24 -13
- rslearn/models/ssl4eo_s12.py +40 -30
- rslearn/models/swin.py +44 -32
- rslearn/models/task_embedding.py +250 -0
- rslearn/models/terramind.py +256 -0
- rslearn/models/trunk.py +139 -0
- rslearn/models/unet.py +68 -22
- rslearn/models/upsample.py +48 -0
- rslearn/models/use_croma.py +508 -0
- rslearn/template_params.py +26 -0
- rslearn/tile_stores/__init__.py +41 -18
- rslearn/tile_stores/default.py +409 -0
- rslearn/tile_stores/tile_store.py +236 -132
- rslearn/train/all_patches_dataset.py +530 -0
- rslearn/train/callbacks/adapters.py +53 -0
- rslearn/train/callbacks/freeze_unfreeze.py +348 -17
- rslearn/train/callbacks/gradients.py +129 -0
- rslearn/train/callbacks/peft.py +116 -0
- rslearn/train/data_module.py +444 -20
- rslearn/train/dataset.py +588 -235
- rslearn/train/lightning_module.py +192 -62
- rslearn/train/model_context.py +88 -0
- rslearn/train/optimizer.py +31 -0
- rslearn/train/prediction_writer.py +319 -84
- rslearn/train/scheduler.py +92 -0
- rslearn/train/tasks/classification.py +55 -28
- rslearn/train/tasks/detection.py +132 -76
- rslearn/train/tasks/embedding.py +120 -0
- rslearn/train/tasks/multi_task.py +28 -14
- rslearn/train/tasks/per_pixel_regression.py +291 -0
- rslearn/train/tasks/regression.py +161 -44
- rslearn/train/tasks/segmentation.py +428 -53
- rslearn/train/tasks/task.py +6 -5
- rslearn/train/transforms/__init__.py +1 -1
- rslearn/train/transforms/concatenate.py +54 -10
- rslearn/train/transforms/crop.py +29 -11
- rslearn/train/transforms/flip.py +18 -6
- rslearn/train/transforms/mask.py +78 -0
- rslearn/train/transforms/normalize.py +101 -17
- rslearn/train/transforms/pad.py +19 -7
- rslearn/train/transforms/resize.py +83 -0
- rslearn/train/transforms/select_bands.py +76 -0
- rslearn/train/transforms/sentinel1.py +75 -0
- rslearn/train/transforms/transform.py +89 -70
- rslearn/utils/__init__.py +2 -6
- rslearn/utils/array.py +8 -6
- rslearn/utils/feature.py +2 -2
- rslearn/utils/fsspec.py +90 -1
- rslearn/utils/geometry.py +347 -7
- rslearn/utils/get_utm_ups_crs.py +2 -3
- rslearn/utils/grid_index.py +5 -5
- rslearn/utils/jsonargparse.py +178 -0
- rslearn/utils/mp.py +4 -3
- rslearn/utils/raster_format.py +268 -116
- rslearn/utils/rtree_index.py +64 -17
- rslearn/utils/sqlite_index.py +7 -1
- rslearn/utils/vector_format.py +252 -97
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
- rslearn-0.0.21.dist-info/RECORD +167 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
- rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
- rslearn/data_sources/raster_source.py +0 -309
- rslearn/models/registry.py +0 -5
- rslearn/tile_stores/file.py +0 -242
- rslearn/utils/mgrs.py +0 -24
- rslearn/utils/utils.py +0 -22
- rslearn-0.0.1.dist-info/RECORD +0 -88
- /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
- {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
|
@@ -1,9 +1,9 @@
|
|
|
1
|
-
Metadata-Version: 2.
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
2
|
Name: rslearn
|
|
3
|
-
Version: 0.0.
|
|
3
|
+
Version: 0.0.21
|
|
4
4
|
Summary: A library for developing remote sensing datasets and models
|
|
5
|
-
Author
|
|
6
|
-
License:
|
|
5
|
+
Author: OlmoEarth Team
|
|
6
|
+
License: Apache License
|
|
7
7
|
Version 2.0, January 2004
|
|
8
8
|
http://www.apache.org/licenses/
|
|
9
9
|
|
|
@@ -205,36 +205,62 @@ License: Apache License
|
|
|
205
205
|
See the License for the specific language governing permissions and
|
|
206
206
|
limitations under the License.
|
|
207
207
|
|
|
208
|
-
|
|
208
|
+
Project-URL: homepage, https://github.com/allenai/rslearn
|
|
209
|
+
Project-URL: issues, https://github.com/allenai/rslearn/issues
|
|
210
|
+
Project-URL: repository, https://github.com/allenai/rslearn
|
|
211
|
+
Requires-Python: >=3.11
|
|
209
212
|
Description-Content-Type: text/markdown
|
|
210
213
|
License-File: LICENSE
|
|
211
|
-
|
|
212
|
-
Requires-Dist:
|
|
213
|
-
Requires-Dist:
|
|
214
|
-
Requires-Dist:
|
|
215
|
-
Requires-Dist:
|
|
216
|
-
Requires-Dist:
|
|
217
|
-
Requires-Dist: Pillow
|
|
218
|
-
Requires-Dist: pyproj
|
|
219
|
-
Requires-Dist:
|
|
220
|
-
Requires-Dist:
|
|
221
|
-
Requires-Dist:
|
|
222
|
-
Requires-Dist:
|
|
223
|
-
Requires-Dist:
|
|
224
|
-
Requires-Dist:
|
|
225
|
-
Requires-Dist:
|
|
214
|
+
License-File: NOTICE
|
|
215
|
+
Requires-Dist: boto3>=1.39
|
|
216
|
+
Requires-Dist: fiona>=1.10
|
|
217
|
+
Requires-Dist: fsspec>=2025.10.0
|
|
218
|
+
Requires-Dist: jsonargparse>=4.35.0
|
|
219
|
+
Requires-Dist: lightning>=2.5.1.post0
|
|
220
|
+
Requires-Dist: Pillow>=11.3
|
|
221
|
+
Requires-Dist: pyproj>=3.7
|
|
222
|
+
Requires-Dist: python-dateutil>=2.9
|
|
223
|
+
Requires-Dist: pytimeparse>=1.1
|
|
224
|
+
Requires-Dist: rasterio>=1.4
|
|
225
|
+
Requires-Dist: shapely>=2.1
|
|
226
|
+
Requires-Dist: torch>=2.7.0
|
|
227
|
+
Requires-Dist: torchvision>=0.22.0
|
|
228
|
+
Requires-Dist: tqdm>=4.67
|
|
229
|
+
Requires-Dist: universal_pathlib>=0.2.6
|
|
226
230
|
Provides-Extra: extra
|
|
227
|
-
Requires-Dist:
|
|
228
|
-
Requires-Dist:
|
|
229
|
-
Requires-Dist:
|
|
230
|
-
Requires-Dist:
|
|
231
|
-
Requires-Dist:
|
|
232
|
-
Requires-Dist:
|
|
233
|
-
Requires-Dist:
|
|
234
|
-
Requires-Dist:
|
|
235
|
-
Requires-Dist:
|
|
236
|
-
Requires-Dist:
|
|
237
|
-
Requires-Dist:
|
|
231
|
+
Requires-Dist: accelerate>=1.10; extra == "extra"
|
|
232
|
+
Requires-Dist: cdsapi>=0.7.6; extra == "extra"
|
|
233
|
+
Requires-Dist: earthdaily[platform]>=1.0.7; extra == "extra"
|
|
234
|
+
Requires-Dist: earthengine-api>=1.6.3; extra == "extra"
|
|
235
|
+
Requires-Dist: einops>=0.8; extra == "extra"
|
|
236
|
+
Requires-Dist: fsspec[gcs,s3]; extra == "extra"
|
|
237
|
+
Requires-Dist: google-cloud-bigquery>=3.35; extra == "extra"
|
|
238
|
+
Requires-Dist: google-cloud-storage>=2.18; extra == "extra"
|
|
239
|
+
Requires-Dist: huggingface_hub>=0.34.4; extra == "extra"
|
|
240
|
+
Requires-Dist: netCDF4>=1.7.2; extra == "extra"
|
|
241
|
+
Requires-Dist: osmium>=4.0.2; extra == "extra"
|
|
242
|
+
Requires-Dist: planet>=3.1; extra == "extra"
|
|
243
|
+
Requires-Dist: planetary_computer>=1.0; extra == "extra"
|
|
244
|
+
Requires-Dist: pycocotools>=2.0; extra == "extra"
|
|
245
|
+
Requires-Dist: pystac_client>=0.9; extra == "extra"
|
|
246
|
+
Requires-Dist: rtree>=1.4; extra == "extra"
|
|
247
|
+
Requires-Dist: termcolor>=3.0; extra == "extra"
|
|
248
|
+
Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
|
|
249
|
+
Requires-Dist: scipy>=1.16; extra == "extra"
|
|
250
|
+
Requires-Dist: terratorch>=1.0.2; extra == "extra"
|
|
251
|
+
Requires-Dist: transformers>=4.55; extra == "extra"
|
|
252
|
+
Requires-Dist: wandb>=0.21; extra == "extra"
|
|
253
|
+
Requires-Dist: timm>=0.9.7; extra == "extra"
|
|
254
|
+
Provides-Extra: dev
|
|
255
|
+
Requires-Dist: interrogate>=1.7.0; extra == "dev"
|
|
256
|
+
Requires-Dist: mypy<2,>=1.17.1; extra == "dev"
|
|
257
|
+
Requires-Dist: pre-commit>=4.3.0; extra == "dev"
|
|
258
|
+
Requires-Dist: pytest>=8.0; extra == "dev"
|
|
259
|
+
Requires-Dist: pytest_httpserver; extra == "dev"
|
|
260
|
+
Requires-Dist: ruff>=0.12.9; extra == "dev"
|
|
261
|
+
Requires-Dist: pytest-dotenv; extra == "dev"
|
|
262
|
+
Requires-Dist: pytest-xdist; extra == "dev"
|
|
263
|
+
Dynamic: license-file
|
|
238
264
|
|
|
239
265
|
Overview
|
|
240
266
|
--------
|
|
@@ -254,10 +280,12 @@ rslearn helps with:
|
|
|
254
280
|
|
|
255
281
|
|
|
256
282
|
Quick links:
|
|
257
|
-
- [CoreConcepts](CoreConcepts.md) summarizes key concepts in rslearn, including
|
|
283
|
+
- [CoreConcepts](docs/CoreConcepts.md) summarizes key concepts in rslearn, including
|
|
258
284
|
datasets, windows, layers, and data sources.
|
|
259
|
-
- [Examples](Examples.md) contains more examples, including customizing different
|
|
285
|
+
- [Examples](docs/Examples.md) contains more examples, including customizing different
|
|
260
286
|
stages of rslearn with additional code.
|
|
287
|
+
- [DatasetConfig](docs/DatasetConfig.md) documents the dataset configuration file.
|
|
288
|
+
- [ModelConfig](docs/ModelConfig.md) documents the model configuration file.
|
|
261
289
|
|
|
262
290
|
|
|
263
291
|
Setup
|
|
@@ -265,9 +293,33 @@ Setup
|
|
|
265
293
|
|
|
266
294
|
rslearn requires Python 3.10+ (Python 3.12 is recommended).
|
|
267
295
|
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
296
|
+
```
|
|
297
|
+
git clone https://github.com/allenai/rslearn.git
|
|
298
|
+
cd rslearn
|
|
299
|
+
pip install .[extra]
|
|
300
|
+
```
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
Supported Data Sources
|
|
304
|
+
----------------------
|
|
305
|
+
|
|
306
|
+
rslearn supports ingesting raster and vector data from the following data sources. Even
|
|
307
|
+
if you don't plan to train models within rslearn, you can still use it to easily
|
|
308
|
+
download, crop, and re-project data based on spatiotemporal rectangles (windows) that
|
|
309
|
+
you define. See [Examples](docs/Examples.md) and [DatasetConfig](docs/DatasetConfig.md)
|
|
310
|
+
for how to setup these data sources.
|
|
311
|
+
|
|
312
|
+
- Sentinel-1
|
|
313
|
+
- Sentinel-2 L1C and L2A
|
|
314
|
+
- Landsat 8/9 OLI-TIRS
|
|
315
|
+
- National Agriculture Imagery Program
|
|
316
|
+
- OpenStreetMap
|
|
317
|
+
- Xyz (Slippy) Tiles (e.g., Mapbox tiles)
|
|
318
|
+
- Planet Labs (PlanetScope, SkySat)
|
|
319
|
+
- ESA WorldCover 2021
|
|
320
|
+
|
|
321
|
+
rslearn can also be used to easily mosaic, crop, and re-project any sets of local
|
|
322
|
+
raster and vector files you may have.
|
|
271
323
|
|
|
272
324
|
|
|
273
325
|
Example Usage
|
|
@@ -281,28 +333,27 @@ Let's start by defining a region of interest and obtaining Sentinel-2 images. Cr
|
|
|
281
333
|
directory `/path/to/dataset` and corresponding configuration file at
|
|
282
334
|
`/path/to/dataset/config.json` as follows:
|
|
283
335
|
|
|
284
|
-
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
|
|
336
|
+
```json
|
|
337
|
+
{
|
|
338
|
+
"layers": {
|
|
339
|
+
"sentinel2": {
|
|
340
|
+
"type": "raster",
|
|
341
|
+
"band_sets": [{
|
|
342
|
+
"dtype": "uint8",
|
|
343
|
+
"bands": ["R", "G", "B"]
|
|
344
|
+
}],
|
|
345
|
+
"data_source": {
|
|
346
|
+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
|
|
347
|
+
"init_args": {
|
|
294
348
|
"index_cache_dir": "cache/sentinel2/",
|
|
295
|
-
"max_time_delta": "1d",
|
|
296
349
|
"sort_by": "cloud_cover",
|
|
297
350
|
"use_rtree_index": false
|
|
298
351
|
}
|
|
299
352
|
}
|
|
300
|
-
},
|
|
301
|
-
"tile_store": {
|
|
302
|
-
"name": "file",
|
|
303
|
-
"root_dir": "tiles"
|
|
304
353
|
}
|
|
305
354
|
}
|
|
355
|
+
}
|
|
356
|
+
```
|
|
306
357
|
|
|
307
358
|
Here, we have initialized an empty dataset and defined a raster layer called
|
|
308
359
|
`sentinel2`. Because it specifies a data source, it will be populated automatically. In
|
|
@@ -314,8 +365,10 @@ choosing the scenes with minimal cloud cover.
|
|
|
314
365
|
Next, let's create our spatiotemporal windows. These will correspond to training
|
|
315
366
|
examples.
|
|
316
367
|
|
|
317
|
-
|
|
318
|
-
|
|
368
|
+
```
|
|
369
|
+
export DATASET_PATH=/path/to/dataset
|
|
370
|
+
rslearn dataset add_windows --root $DATASET_PATH --group default --utm --resolution 10 --grid_size 128 --src_crs EPSG:4326 --box=-122.6901,47.2079,-121.4955,47.9403 --start 2024-06-01T00:00:00+00:00 --end 2024-08-01T00:00:00+00:00 --name seattle
|
|
371
|
+
```
|
|
319
372
|
|
|
320
373
|
This creates windows along a 128x128 grid in the specified projection (i.e.,
|
|
321
374
|
appropriate UTM zone for the location with 10 m/pixel resolution) covering the
|
|
@@ -327,9 +380,11 @@ We can now obtain the Sentinel-2 images by running prepare, ingest, and material
|
|
|
327
380
|
* Ingest: retrieve those items. This step populates the `tiles` directory within the dataset.
|
|
328
381
|
* Materialize: crop/mosaic the items to align with the windows. This populates the `layers` folder in each window directory.
|
|
329
382
|
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
383
|
+
```
|
|
384
|
+
rslearn dataset prepare --root $DATASET_PATH --workers 32 --batch-size 8
|
|
385
|
+
rslearn dataset ingest --root $DATASET_PATH --workers 32 --no-use-initial-job --jobs-per-process 1
|
|
386
|
+
rslearn dataset materialize --root $DATASET_PATH --workers 32 --no-use-initial-job
|
|
387
|
+
```
|
|
333
388
|
|
|
334
389
|
For ingestion, you may need to reduce the number of workers depending on the available
|
|
335
390
|
memory on your system.
|
|
@@ -337,32 +392,36 @@ memory on your system.
|
|
|
337
392
|
You should now be able to open the GeoTIFF images. Let's find the window that
|
|
338
393
|
corresponds to downtown Seattle:
|
|
339
394
|
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
|
|
343
|
-
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
|
|
351
|
-
|
|
352
|
-
|
|
353
|
-
|
|
354
|
-
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
395
|
+
```python
|
|
396
|
+
import shapely
|
|
397
|
+
from rslearn.const import WGS84_PROJECTION
|
|
398
|
+
from rslearn.dataset import Dataset
|
|
399
|
+
from rslearn.utils import Projection, STGeometry
|
|
400
|
+
from upath import UPath
|
|
401
|
+
|
|
402
|
+
# Define longitude and latitude for downtown Seattle.
|
|
403
|
+
downtown_seattle = shapely.Point(-122.333, 47.606)
|
|
404
|
+
|
|
405
|
+
# Iterate over the windows and find the closest one.
|
|
406
|
+
dataset = Dataset(path=UPath("/path/to/dataset"))
|
|
407
|
+
best_window_name = None
|
|
408
|
+
best_distance = None
|
|
409
|
+
for window in dataset.load_windows(workers=32):
|
|
410
|
+
shp = window.get_geometry().to_projection(WGS84_PROJECTION).shp
|
|
411
|
+
distance = shp.distance(downtown_seattle)
|
|
412
|
+
if best_distance is None or distance < best_distance:
|
|
413
|
+
best_window_name = window.name
|
|
414
|
+
best_distance = distance
|
|
415
|
+
|
|
416
|
+
print(best_window_name)
|
|
417
|
+
```
|
|
361
418
|
|
|
362
419
|
It should be `seattle_54912_-527360`, so let's open it in qgis (or your favorite GIS
|
|
363
420
|
software):
|
|
364
421
|
|
|
365
|
-
|
|
422
|
+
```
|
|
423
|
+
qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/sentinel2/R_G_B/geotiff.tif
|
|
424
|
+
```
|
|
366
425
|
|
|
367
426
|
|
|
368
427
|
### Adding Land Cover Labels
|
|
@@ -372,152 +431,166 @@ the ESA WorldCover land cover map as labels.
|
|
|
372
431
|
|
|
373
432
|
Start by downloading the WorldCover data from https://worldcover2021.esa.int
|
|
374
433
|
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
434
|
+
```
|
|
435
|
+
wget https://worldcover2021.esa.int/data/archive/ESA_WorldCover_10m_2021_v200_60deg_macrotile_N30W180.zip
|
|
436
|
+
mkdir world_cover_tifs
|
|
437
|
+
unzip ESA_WorldCover_10m_2021_v200_60deg_macrotile_N30W180.zip -d world_cover_tifs/
|
|
438
|
+
```
|
|
378
439
|
|
|
379
440
|
It would require some work to write a script to re-project and crop these GeoTIFFs so
|
|
380
441
|
that they align with the windows we have previously defined (and the Sentinel-2 images
|
|
381
442
|
we have already ingested). We can use the LocalFiles data source to have rslearn
|
|
382
443
|
automate this process. Update the dataset `config.json` with a new layer:
|
|
383
444
|
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
445
|
+
```jsonc
|
|
446
|
+
"layers": {
|
|
447
|
+
"sentinel2": {
|
|
448
|
+
# ...
|
|
449
|
+
},
|
|
450
|
+
"worldcover": {
|
|
451
|
+
"type": "raster",
|
|
452
|
+
"band_sets": [{
|
|
453
|
+
"dtype": "uint8",
|
|
454
|
+
"bands": ["B1"]
|
|
455
|
+
}],
|
|
456
|
+
"resampling_method": "nearest",
|
|
457
|
+
"data_source": {
|
|
458
|
+
"class_path": "rslearn.data_sources.local_files.LocalFiles",
|
|
459
|
+
"init_args": {
|
|
397
460
|
"src_dir": "file:///path/to/world_cover_tifs/"
|
|
398
461
|
}
|
|
399
462
|
}
|
|
400
|
-
}
|
|
401
|
-
|
|
463
|
+
}
|
|
464
|
+
},
|
|
465
|
+
# ...
|
|
466
|
+
```
|
|
402
467
|
|
|
403
468
|
Repeat the materialize process so we populate the data for this new layer:
|
|
404
469
|
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
470
|
+
```
|
|
471
|
+
rslearn dataset prepare --root $DATASET_PATH --workers 32 --batch-size 8
|
|
472
|
+
rslearn dataset ingest --root $DATASET_PATH --workers 32 --no-use-initial-job --jobs-per-process 1
|
|
473
|
+
rslearn dataset materialize --root $DATASET_PATH --workers 32 --no-use-initial-job
|
|
474
|
+
```
|
|
408
475
|
|
|
409
476
|
We can visualize both the GeoTIFFs together in qgis:
|
|
410
477
|
|
|
411
|
-
|
|
478
|
+
```
|
|
479
|
+
qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/*/*/geotiff.tif
|
|
480
|
+
```
|
|
412
481
|
|
|
413
482
|
|
|
414
483
|
### Training a Model
|
|
415
484
|
|
|
416
485
|
Create a model configuration file `land_cover_model.yaml`:
|
|
417
486
|
|
|
487
|
+
```yaml
|
|
488
|
+
model:
|
|
489
|
+
class_path: rslearn.train.lightning_module.RslearnLightningModule
|
|
490
|
+
init_args:
|
|
491
|
+
# This part defines the model architecture.
|
|
492
|
+
# Essentially we apply the SatlasPretrain Sentinel-2 backbone with a UNet decoder
|
|
493
|
+
# that terminates at a segmentation prediction head.
|
|
494
|
+
# The backbone outputs four feature maps at different scales, and the UNet uses
|
|
495
|
+
# these to compute a feature map at the input scale.
|
|
496
|
+
# Finally the segmentation head applies per-pixel softmax to compute the land
|
|
497
|
+
# cover class.
|
|
418
498
|
model:
|
|
419
|
-
class_path: rslearn.
|
|
499
|
+
class_path: rslearn.models.singletask.SingleTaskModel
|
|
420
500
|
init_args:
|
|
421
|
-
|
|
422
|
-
|
|
423
|
-
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
432
|
-
|
|
433
|
-
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
|
|
438
|
-
|
|
439
|
-
|
|
440
|
-
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
|
|
451
|
-
|
|
501
|
+
encoder:
|
|
502
|
+
- class_path: rslearn.models.satlaspretrain.SatlasPretrain
|
|
503
|
+
init_args:
|
|
504
|
+
model_identifier: "Sentinel2_SwinB_SI_RGB"
|
|
505
|
+
decoder:
|
|
506
|
+
- class_path: rslearn.models.unet.UNetDecoder
|
|
507
|
+
init_args:
|
|
508
|
+
in_channels: [[4, 128], [8, 256], [16, 512], [32, 1024]]
|
|
509
|
+
# We use 101 classes because the WorldCover classes are 10, 20, 30, 40
|
|
510
|
+
# 50, 60, 70, 80, 90, 95, 100.
|
|
511
|
+
# We could process the GeoTIFFs to collapse them to 0-10 (the 11 actual
|
|
512
|
+
# classes) but the model will quickly learn that the intermediate
|
|
513
|
+
# values are never used.
|
|
514
|
+
out_channels: 101
|
|
515
|
+
conv_layers_per_resolution: 2
|
|
516
|
+
- class_path: rslearn.train.tasks.segmentation.SegmentationHead
|
|
517
|
+
# Remaining parameters in RslearnLightningModule define different aspects of the
|
|
518
|
+
# training process like initial learning rate.
|
|
519
|
+
lr: 0.0001
|
|
520
|
+
data:
|
|
521
|
+
class_path: rslearn.train.data_module.RslearnDataModule
|
|
522
|
+
init_args:
|
|
523
|
+
path: ${DATASET_PATH}
|
|
524
|
+
# This defines the layers that should be read for each window.
|
|
525
|
+
# The key ("image" / "targets") is what the data will be called in the model,
|
|
526
|
+
# while the layers option specifies which layers will be read.
|
|
527
|
+
inputs:
|
|
528
|
+
image:
|
|
529
|
+
data_type: "raster"
|
|
530
|
+
layers: ["sentinel2"]
|
|
531
|
+
bands: ["R", "G", "B"]
|
|
532
|
+
passthrough: true
|
|
533
|
+
targets:
|
|
534
|
+
data_type: "raster"
|
|
535
|
+
layers: ["worldcover"]
|
|
536
|
+
bands: ["B1"]
|
|
537
|
+
is_target: true
|
|
538
|
+
task:
|
|
539
|
+
# Train for semantic segmentation.
|
|
540
|
+
# The remap option is only used when visualizing outputs during testing.
|
|
541
|
+
class_path: rslearn.train.tasks.segmentation.SegmentationTask
|
|
452
542
|
init_args:
|
|
453
|
-
|
|
454
|
-
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
|
|
458
|
-
|
|
459
|
-
|
|
460
|
-
|
|
461
|
-
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
|
|
465
|
-
data_type: "raster"
|
|
466
|
-
layers: ["worldcover"]
|
|
467
|
-
bands: ["B1"]
|
|
468
|
-
is_target: true
|
|
469
|
-
task:
|
|
470
|
-
# Train for semantic segmentation.
|
|
471
|
-
# The remap option is only used when visualizing outputs during testing.
|
|
472
|
-
class_path: rslearn.train.tasks.segmentation.SegmentationTask
|
|
543
|
+
num_classes: 101
|
|
544
|
+
remap_values: [[0, 1], [0, 255]]
|
|
545
|
+
batch_size: 8
|
|
546
|
+
num_workers: 32
|
|
547
|
+
# These define different options for different phases/splits, like training,
|
|
548
|
+
# validation, and testing.
|
|
549
|
+
# Here we use the same transform across splits except training where we add a
|
|
550
|
+
# flipping augmentation.
|
|
551
|
+
# For now we are using the same windows for training and validation.
|
|
552
|
+
default_config:
|
|
553
|
+
transforms:
|
|
554
|
+
- class_path: rslearn.train.transforms.normalize.Normalize
|
|
473
555
|
init_args:
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
default_config:
|
|
484
|
-
transforms:
|
|
485
|
-
- class_path: rslearn.train.transforms.normalize.Normalize
|
|
486
|
-
init_args:
|
|
487
|
-
mean: 0
|
|
488
|
-
std: 255
|
|
489
|
-
train_config:
|
|
490
|
-
transforms:
|
|
491
|
-
- class_path: rslearn.train.transforms.normalize.Normalize
|
|
492
|
-
init_args:
|
|
493
|
-
mean: 0
|
|
494
|
-
std: 255
|
|
495
|
-
- class_path: rslearn.train.transforms.flip.Flip
|
|
496
|
-
init_args:
|
|
497
|
-
image_selectors: ["image", "target/classes", "target/valid"]
|
|
498
|
-
groups: ["default"]
|
|
499
|
-
val_config:
|
|
500
|
-
groups: ["default"]
|
|
501
|
-
test_config:
|
|
502
|
-
groups: ["default"]
|
|
503
|
-
predict_config:
|
|
504
|
-
groups: ["predict"]
|
|
505
|
-
load_all_patches: true
|
|
506
|
-
skip_targets: true
|
|
507
|
-
patch_size: 512
|
|
508
|
-
trainer:
|
|
509
|
-
max_epochs: 10
|
|
510
|
-
callbacks:
|
|
511
|
-
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
|
556
|
+
mean: 0
|
|
557
|
+
std: 255
|
|
558
|
+
train_config:
|
|
559
|
+
transforms:
|
|
560
|
+
- class_path: rslearn.train.transforms.normalize.Normalize
|
|
561
|
+
init_args:
|
|
562
|
+
mean: 0
|
|
563
|
+
std: 255
|
|
564
|
+
- class_path: rslearn.train.transforms.flip.Flip
|
|
512
565
|
init_args:
|
|
513
|
-
|
|
514
|
-
|
|
515
|
-
|
|
516
|
-
|
|
566
|
+
image_selectors: ["image", "target/classes", "target/valid"]
|
|
567
|
+
groups: ["default"]
|
|
568
|
+
val_config:
|
|
569
|
+
groups: ["default"]
|
|
570
|
+
test_config:
|
|
571
|
+
groups: ["default"]
|
|
572
|
+
predict_config:
|
|
573
|
+
groups: ["predict"]
|
|
574
|
+
load_all_patches: true
|
|
575
|
+
skip_targets: true
|
|
576
|
+
patch_size: 512
|
|
577
|
+
trainer:
|
|
578
|
+
max_epochs: 10
|
|
579
|
+
callbacks:
|
|
580
|
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
|
581
|
+
init_args:
|
|
582
|
+
save_top_k: 1
|
|
583
|
+
save_last: true
|
|
584
|
+
monitor: val_accuracy
|
|
585
|
+
mode: max
|
|
586
|
+
dirpath: ./land_cover_model_checkpoints/
|
|
587
|
+
```
|
|
517
588
|
|
|
518
589
|
Now we can train the model:
|
|
519
590
|
|
|
520
|
-
|
|
591
|
+
```
|
|
592
|
+
rslearn model fit --config land_cover_model.yaml
|
|
593
|
+
```
|
|
521
594
|
|
|
522
595
|
|
|
523
596
|
### Apply the Model
|
|
@@ -528,22 +601,28 @@ windows along a grid, we just create one big window. This is because we are just
|
|
|
528
601
|
to run the prediction over the whole window rather than use different windows as
|
|
529
602
|
different training examples.
|
|
530
603
|
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
604
|
+
```
|
|
605
|
+
rslearn dataset add_windows --root $DATASET_PATH --group predict --utm --resolution 10 --src_crs EPSG:4326 --box=-122.712,45.477,-122.621,45.549 --start 2024-06-01T00:00:00+00:00 --end 2024-08-01T00:00:00+00:00 --name portland
|
|
606
|
+
rslearn dataset prepare --root $DATASET_PATH --workers 32 --batch-size 8
|
|
607
|
+
rslearn dataset ingest --root $DATASET_PATH --workers 32 --no-use-initial-job --jobs-per-process 1
|
|
608
|
+
rslearn dataset materialize --root $DATASET_PATH --workers 32 --no-use-initial-job
|
|
609
|
+
```
|
|
535
610
|
|
|
536
611
|
We also need to add an RslearnPredictionWriter to the trainer callbacks in the model
|
|
537
612
|
configuration file, as it will handle writing the outputs from the model to a GeoTIFF.
|
|
538
613
|
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
542
|
-
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
546
|
-
|
|
614
|
+
```yaml
|
|
615
|
+
trainer:
|
|
616
|
+
callbacks:
|
|
617
|
+
- class_path: lightning.pytorch.callbacks.ModelCheckpoint
|
|
618
|
+
...
|
|
619
|
+
- class_path: rslearn.train.prediction_writer.RslearnWriter
|
|
620
|
+
init_args:
|
|
621
|
+
# We need to include this argument, but it will be overridden with the dataset
|
|
622
|
+
# path from data.init_args.path.
|
|
623
|
+
path: placeholder
|
|
624
|
+
output_layer: output
|
|
625
|
+
```
|
|
547
626
|
|
|
548
627
|
Because of our `predict_config`, when we run `model predict` it will apply the model on
|
|
549
628
|
windows in the "predict" group, which is where we added the Portland window.
|
|
@@ -551,39 +630,46 @@ windows in the "predict" group, which is where we added the Portland window.
|
|
|
551
630
|
And it will be written in a new output_layer called "output". But we have to update the
|
|
552
631
|
dataset configuration so it specifies the layer:
|
|
553
632
|
|
|
554
|
-
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
},
|
|
559
|
-
"worldcover": {
|
|
560
|
-
...
|
|
561
|
-
},
|
|
562
|
-
"output": {
|
|
563
|
-
"type": "raster",
|
|
564
|
-
"band_sets": [{
|
|
565
|
-
"dtype": "uint8",
|
|
566
|
-
"bands": ["output"]
|
|
567
|
-
}]
|
|
568
|
-
}
|
|
633
|
+
```jsonc
|
|
634
|
+
"layers": {
|
|
635
|
+
"sentinel2": {
|
|
636
|
+
# ...
|
|
569
637
|
},
|
|
638
|
+
"worldcover": {
|
|
639
|
+
# ...
|
|
640
|
+
},
|
|
641
|
+
"output": {
|
|
642
|
+
"type": "raster",
|
|
643
|
+
"band_sets": [{
|
|
644
|
+
"dtype": "uint8",
|
|
645
|
+
"bands": ["output"]
|
|
646
|
+
}]
|
|
647
|
+
}
|
|
648
|
+
},
|
|
649
|
+
```
|
|
570
650
|
|
|
571
651
|
Now we can apply the model:
|
|
572
652
|
|
|
573
|
-
|
|
574
|
-
|
|
575
|
-
|
|
653
|
+
```
|
|
654
|
+
# Find model checkpoint in lightning_logs dir.
|
|
655
|
+
ls lightning_logs/*/checkpoints/last.ckpt
|
|
656
|
+
rslearn model predict --config land_cover_model.yaml --ckpt_path land_cover_model_checkpoints/last.ckpt
|
|
657
|
+
```
|
|
576
658
|
|
|
577
659
|
And visualize the Sentinel-2 image and output in qgis:
|
|
578
660
|
|
|
579
|
-
|
|
661
|
+
```
|
|
662
|
+
qgis $DATASET_PATH/windows/predict/portland/layers/*/*/geotiff.tif
|
|
663
|
+
```
|
|
580
664
|
|
|
581
665
|
|
|
582
666
|
### Defining Train and Validation Splits
|
|
583
667
|
|
|
584
668
|
We can visualize the logged metrics using Tensorboard:
|
|
585
669
|
|
|
586
|
-
|
|
670
|
+
```
|
|
671
|
+
tensorboard --logdir=lightning_logs/
|
|
672
|
+
```
|
|
587
673
|
|
|
588
674
|
However, because our training and validation data are identical, the validation metrics
|
|
589
675
|
are not meaningful.
|
|
@@ -597,57 +683,61 @@ We will use the second approach. The script below sets a "split" key in the opti
|
|
|
597
683
|
dict (which is stored in each window's `metadata.json` file) to "train" or "val"
|
|
598
684
|
based on the SHA-256 hash of the window name.
|
|
599
685
|
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
|
|
608
|
-
|
|
609
|
-
|
|
610
|
-
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
|
|
614
|
-
|
|
615
|
-
|
|
616
|
-
|
|
686
|
+
```python
|
|
687
|
+
import hashlib
|
|
688
|
+
import tqdm
|
|
689
|
+
from rslearn.dataset import Dataset, Window
|
|
690
|
+
from upath import UPath
|
|
691
|
+
|
|
692
|
+
ds_path = UPath("/path/to/dataset/")
|
|
693
|
+
dataset = Dataset(ds_path)
|
|
694
|
+
windows = dataset.load_windows(show_progress=True, workers=32)
|
|
695
|
+
for window in tqdm.tqdm(windows):
|
|
696
|
+
if hashlib.sha256(window.name.encode()).hexdigest()[0] in ["0", "1"]:
|
|
697
|
+
split = "val"
|
|
698
|
+
else:
|
|
699
|
+
split = "train"
|
|
700
|
+
if "split" in window.options and window.options["split"] == split:
|
|
701
|
+
continue
|
|
702
|
+
window.options["split"] = split
|
|
703
|
+
window.save()
|
|
704
|
+
```
|
|
617
705
|
|
|
618
706
|
Now we can update the model configuration file to use these splits:
|
|
619
707
|
|
|
620
|
-
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
624
|
-
|
|
625
|
-
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
|
|
629
|
-
|
|
630
|
-
|
|
631
|
-
|
|
632
|
-
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
640
|
-
|
|
641
|
-
|
|
642
|
-
|
|
643
|
-
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
|
|
649
|
-
|
|
650
|
-
|
|
708
|
+
```yaml
|
|
709
|
+
default_config:
|
|
710
|
+
transforms:
|
|
711
|
+
- class_path: rslearn.train.transforms.normalize.Normalize
|
|
712
|
+
init_args:
|
|
713
|
+
mean: 0
|
|
714
|
+
std: 255
|
|
715
|
+
train_config:
|
|
716
|
+
transforms:
|
|
717
|
+
- class_path: rslearn.train.transforms.normalize.Normalize
|
|
718
|
+
init_args:
|
|
719
|
+
mean: 0
|
|
720
|
+
std: 255
|
|
721
|
+
- class_path: rslearn.train.transforms.flip.Flip
|
|
722
|
+
init_args:
|
|
723
|
+
image_selectors: ["image", "target/classes", "target/valid"]
|
|
724
|
+
groups: ["default"]
|
|
725
|
+
tags:
|
|
726
|
+
split: train
|
|
727
|
+
val_config:
|
|
728
|
+
groups: ["default"]
|
|
729
|
+
tags:
|
|
730
|
+
split: val
|
|
731
|
+
test_config:
|
|
732
|
+
groups: ["default"]
|
|
733
|
+
tags:
|
|
734
|
+
split: val
|
|
735
|
+
predict_config:
|
|
736
|
+
groups: ["predict"]
|
|
737
|
+
load_all_patches: true
|
|
738
|
+
skip_targets: true
|
|
739
|
+
patch_size: 512
|
|
740
|
+
```
|
|
651
741
|
|
|
652
742
|
The `tags` option that we are adding here tells rslearn to only load windows with a
|
|
653
743
|
matching key and value in the window options.
|
|
@@ -655,28 +745,187 @@ matching key and value in the window options.
|
|
|
655
745
|
Previously when we run `model fit`, it should show the same number of windows for
|
|
656
746
|
training and validation:
|
|
657
747
|
|
|
658
|
-
|
|
659
|
-
|
|
748
|
+
```
|
|
749
|
+
got 4752 examples in split train
|
|
750
|
+
got 4752 examples in split val
|
|
751
|
+
```
|
|
660
752
|
|
|
661
753
|
With the updates, it should show different numbers like this:
|
|
662
754
|
|
|
663
|
-
|
|
664
|
-
|
|
755
|
+
```
|
|
756
|
+
got 4167 examples in split train
|
|
757
|
+
got 585 examples in split val
|
|
758
|
+
```
|
|
665
759
|
|
|
666
760
|
|
|
667
761
|
### Visualizing with `model test`
|
|
668
762
|
|
|
669
|
-
|
|
763
|
+
We can visualize the ground truth labels and model predictions in the test set using
|
|
764
|
+
the `model test` command:
|
|
765
|
+
|
|
766
|
+
```
|
|
767
|
+
mkdir ./vis
|
|
768
|
+
rslearn model test --config land_cover_model.yaml --ckpt_path land_cover_model_checkpoints/last.ckpt --model.init_args.visualize_dir=./vis/
|
|
769
|
+
```
|
|
770
|
+
|
|
771
|
+
This will produce PNGs in the vis directory. The visualizations are produced by the
|
|
772
|
+
`Task.visualize` function, so we could customize the visualization by subclassing
|
|
773
|
+
SegmentationTask and overriding the visualize function.
|
|
774
|
+
|
|
775
|
+
|
|
776
|
+
### Checkpoint and Logging Management
|
|
777
|
+
|
|
778
|
+
Above, we needed to configure the checkpoint directory in the model config (the
|
|
779
|
+
`dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
|
|
780
|
+
specify the checkpoint path when applying the model. Additionally, metrics are logged
|
|
781
|
+
to the local filesystem and not well organized.
|
|
782
|
+
|
|
783
|
+
We can instead let rslearn automatically manage checkpoints, along with logging to
|
|
784
|
+
Weights & Biases. To do so, we add project_name, run_name, and management_dir options
|
|
785
|
+
to the model config. The project_name corresponds to the W&B project, and the run name
|
|
786
|
+
corresponds to the W&B name. The management_dir is a directory to store project data;
|
|
787
|
+
rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
|
|
788
|
+
and uses it to store checkpoints.
|
|
789
|
+
|
|
790
|
+
```yaml
|
|
791
|
+
model:
|
|
792
|
+
# ...
|
|
793
|
+
data:
|
|
794
|
+
# ...
|
|
795
|
+
trainer:
|
|
796
|
+
# ...
|
|
797
|
+
project_name: land_cover_model
|
|
798
|
+
run_name: version_00
|
|
799
|
+
# This sets the option via the MANAGEMENT_DIR environment variable.
|
|
800
|
+
management_dir: ${MANAGEMENT_DIR}
|
|
801
|
+
```
|
|
802
|
+
|
|
803
|
+
Now, set the `MANAGEMENT_DIR` environment variable and run `model fit`:
|
|
804
|
+
|
|
805
|
+
```
|
|
806
|
+
export MANAGEMENT_DIR=./project_data
|
|
807
|
+
rslearn model fit --config land_cover_model.yaml
|
|
808
|
+
```
|
|
809
|
+
|
|
810
|
+
The training and validation loss and accuracy metric should now be logged to W&B. The
|
|
811
|
+
accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
|
|
812
|
+
by passing the relevant init_args to the task, e.g. mean IoU and F1:
|
|
813
|
+
|
|
814
|
+
```yaml
|
|
815
|
+
class_path: rslearn.train.tasks.segmentation.SegmentationTask
|
|
816
|
+
init_args:
|
|
817
|
+
num_classes: 101
|
|
818
|
+
remap_values: [[0, 1], [0, 255]]
|
|
819
|
+
enable_miou_metric: true
|
|
820
|
+
enable_f1_metric: true
|
|
821
|
+
```
|
|
822
|
+
|
|
823
|
+
When calling `model test` and `model predict` with management_dir set, rslearn will
|
|
824
|
+
automatically load the best checkpoint from the project directory, or raise an error if
|
|
825
|
+
no existing checkpoint exists. This behavior can be overridden with the
|
|
826
|
+
`--load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
|
|
827
|
+
details). Logging will be enabled during fit but not test/predict, and this can also
|
|
828
|
+
be overridden, using `--log_mode`.
|
|
670
829
|
|
|
671
830
|
|
|
672
831
|
### Inputting Multiple Sentinel-2 Images
|
|
673
832
|
|
|
674
|
-
|
|
833
|
+
Currently our model inputs a single Sentinel-2 image. However, for most tasks where
|
|
834
|
+
labels are not expected to change from week to week, we find that accuracy can be
|
|
835
|
+
significantly improved by inputting multiple images, regardless of the pre-trained
|
|
836
|
+
model used. Multiple images makes the model more resilient to clouds and image
|
|
837
|
+
artifacts, and allows the model to synthesize information across different views that
|
|
838
|
+
may come from different seasons or weather conditions.
|
|
839
|
+
|
|
840
|
+
We first update our dataset configuration to obtain three images, by customizing the
|
|
841
|
+
query_config section. This can replace the sentinel2 layer:
|
|
842
|
+
|
|
843
|
+
```jsonc
|
|
844
|
+
"layers": {
|
|
845
|
+
"sentinel2_multi": {
|
|
846
|
+
"type": "raster",
|
|
847
|
+
"band_sets": [{
|
|
848
|
+
"dtype": "uint8",
|
|
849
|
+
"bands": ["R", "G", "B"]
|
|
850
|
+
}],
|
|
851
|
+
"data_source": {
|
|
852
|
+
"class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
|
|
853
|
+
"init_args": {
|
|
854
|
+
"index_cache_dir": "cache/sentinel2/",
|
|
855
|
+
"sort_by": "cloud_cover",
|
|
856
|
+
"use_rtree_index": false
|
|
857
|
+
},
|
|
858
|
+
"query_config": {
|
|
859
|
+
"max_matches": 3
|
|
860
|
+
}
|
|
861
|
+
}
|
|
862
|
+
},
|
|
863
|
+
"worldcover": {
|
|
864
|
+
# ...
|
|
865
|
+
},
|
|
866
|
+
"output": {
|
|
867
|
+
# ...
|
|
868
|
+
}
|
|
869
|
+
}
|
|
870
|
+
```
|
|
871
|
+
|
|
872
|
+
Repeat the steps from earlier to prepare, ingest, and materialize the dataset.
|
|
675
873
|
|
|
874
|
+
Now we update our model configuration file. First, we modify the model architecture to
|
|
875
|
+
be able to input an image time series. We use the SimpleTimeSeries model, which takes
|
|
876
|
+
an encoder that expects a single-image input, and applies that encoder on each image in
|
|
877
|
+
the time series. It then applies max temporal pooling to combine the per-image feature
|
|
878
|
+
maps extracted by the encoder.
|
|
676
879
|
|
|
677
|
-
|
|
880
|
+
Image time series in rslearn are currently stored as [T*C, H, W] tensors. So we pass
|
|
881
|
+
the `image_channels` to SimpleTimeSeries so it knows how to slice up the tensor to
|
|
882
|
+
recover the per-timestep images.
|
|
678
883
|
|
|
679
|
-
|
|
884
|
+
```yaml
|
|
885
|
+
model:
|
|
886
|
+
class_path: rslearn.train.lightning_module.RslearnLightningModule
|
|
887
|
+
init_args:
|
|
888
|
+
model:
|
|
889
|
+
class_path: rslearn.models.singletask.SingleTaskModel
|
|
890
|
+
init_args:
|
|
891
|
+
encoder:
|
|
892
|
+
- class_path: rslearn.models.simple_time_series.SimpleTimeSeries
|
|
893
|
+
init_args:
|
|
894
|
+
encoder:
|
|
895
|
+
class_path: rslearn.models.satlaspretrain.SatlasPretrain
|
|
896
|
+
init_args:
|
|
897
|
+
model_identifier: "Sentinel2_SwinB_SI_RGB"
|
|
898
|
+
image_channels: 3
|
|
899
|
+
decoder:
|
|
900
|
+
# ...
|
|
901
|
+
```
|
|
902
|
+
|
|
903
|
+
Next, we update the data module section so that the dataset loads the image time series
|
|
904
|
+
rather than a single image. The `load_all_layers` option tells the dataset to stack the
|
|
905
|
+
rasters from all of the layers specified, and also to ignore windows where any of those
|
|
906
|
+
layers are missing.
|
|
907
|
+
|
|
908
|
+
```yaml
|
|
909
|
+
data:
|
|
910
|
+
class_path: rslearn.train.data_module.RslearnDataModule
|
|
911
|
+
init_args:
|
|
912
|
+
path: # ...
|
|
913
|
+
inputs:
|
|
914
|
+
image:
|
|
915
|
+
data_type: "raster"
|
|
916
|
+
layers: ["sentinel2_multi", "sentinel2_multi.1", "sentinel2_multi.2"]
|
|
917
|
+
bands: ["R", "G", "B"]
|
|
918
|
+
passthrough: true
|
|
919
|
+
load_all_layers: true
|
|
920
|
+
targets:
|
|
921
|
+
# ...
|
|
922
|
+
```
|
|
923
|
+
|
|
924
|
+
Now we can train an updated model:
|
|
925
|
+
|
|
926
|
+
```
|
|
927
|
+
rslearn model fit --config land_cover_model.yaml
|
|
928
|
+
```
|
|
680
929
|
|
|
681
930
|
|
|
682
931
|
Contact
|