rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,9 @@
1
- Metadata-Version: 2.1
1
+ Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.1
3
+ Version: 0.0.21
4
4
  Summary: A library for developing remote sensing datasets and models
5
- Author-email: Favyen Bastani <favyenb@allenai.org>, Yawen Zhang <yawenz@allenai.org>, Patrick Beukema <patrickb@allenai.org>, Henry Herzog <henryh@allenai.org>, Piper Wolters <piperw@allenai.org>
6
- License: Apache License
5
+ Author: OlmoEarth Team
6
+ License: Apache License
7
7
  Version 2.0, January 2004
8
8
  http://www.apache.org/licenses/
9
9
 
@@ -205,36 +205,62 @@ License: Apache License
205
205
  See the License for the specific language governing permissions and
206
206
  limitations under the License.
207
207
 
208
- Requires-Python: >=3.10
208
+ Project-URL: homepage, https://github.com/allenai/rslearn
209
+ Project-URL: issues, https://github.com/allenai/rslearn/issues
210
+ Project-URL: repository, https://github.com/allenai/rslearn
211
+ Requires-Python: >=3.11
209
212
  Description-Content-Type: text/markdown
210
213
  License-File: LICENSE
211
- Requires-Dist: boto3
212
- Requires-Dist: class-registry
213
- Requires-Dist: python-dateutil
214
- Requires-Dist: pytimeparse
215
- Requires-Dist: fiona
216
- Requires-Dist: fsspec[gcs,s3]
217
- Requires-Dist: Pillow
218
- Requires-Dist: pyproj
219
- Requires-Dist: rasterio
220
- Requires-Dist: shapely
221
- Requires-Dist: tqdm
222
- Requires-Dist: torch
223
- Requires-Dist: torchvision
224
- Requires-Dist: universal-pathlib
225
- Requires-Dist: lightning[pytorch-extra]
214
+ License-File: NOTICE
215
+ Requires-Dist: boto3>=1.39
216
+ Requires-Dist: fiona>=1.10
217
+ Requires-Dist: fsspec>=2025.10.0
218
+ Requires-Dist: jsonargparse>=4.35.0
219
+ Requires-Dist: lightning>=2.5.1.post0
220
+ Requires-Dist: Pillow>=11.3
221
+ Requires-Dist: pyproj>=3.7
222
+ Requires-Dist: python-dateutil>=2.9
223
+ Requires-Dist: pytimeparse>=1.1
224
+ Requires-Dist: rasterio>=1.4
225
+ Requires-Dist: shapely>=2.1
226
+ Requires-Dist: torch>=2.7.0
227
+ Requires-Dist: torchvision>=0.22.0
228
+ Requires-Dist: tqdm>=4.67
229
+ Requires-Dist: universal_pathlib>=0.2.6
226
230
  Provides-Extra: extra
227
- Requires-Dist: earthengine-api ; extra == 'extra'
228
- Requires-Dist: gcsfs ; extra == 'extra'
229
- Requires-Dist: google-cloud-storage ; extra == 'extra'
230
- Requires-Dist: mgrs ; extra == 'extra'
231
- Requires-Dist: osmium ; extra == 'extra'
232
- Requires-Dist: planet ; extra == 'extra'
233
- Requires-Dist: pycocotools ; extra == 'extra'
234
- Requires-Dist: rtree ; extra == 'extra'
235
- Requires-Dist: satlaspretrain-models ; extra == 'extra'
236
- Requires-Dist: scipy ; extra == 'extra'
237
- Requires-Dist: wandb ; extra == 'extra'
231
+ Requires-Dist: accelerate>=1.10; extra == "extra"
232
+ Requires-Dist: cdsapi>=0.7.6; extra == "extra"
233
+ Requires-Dist: earthdaily[platform]>=1.0.7; extra == "extra"
234
+ Requires-Dist: earthengine-api>=1.6.3; extra == "extra"
235
+ Requires-Dist: einops>=0.8; extra == "extra"
236
+ Requires-Dist: fsspec[gcs,s3]; extra == "extra"
237
+ Requires-Dist: google-cloud-bigquery>=3.35; extra == "extra"
238
+ Requires-Dist: google-cloud-storage>=2.18; extra == "extra"
239
+ Requires-Dist: huggingface_hub>=0.34.4; extra == "extra"
240
+ Requires-Dist: netCDF4>=1.7.2; extra == "extra"
241
+ Requires-Dist: osmium>=4.0.2; extra == "extra"
242
+ Requires-Dist: planet>=3.1; extra == "extra"
243
+ Requires-Dist: planetary_computer>=1.0; extra == "extra"
244
+ Requires-Dist: pycocotools>=2.0; extra == "extra"
245
+ Requires-Dist: pystac_client>=0.9; extra == "extra"
246
+ Requires-Dist: rtree>=1.4; extra == "extra"
247
+ Requires-Dist: termcolor>=3.0; extra == "extra"
248
+ Requires-Dist: satlaspretrain_models>=0.3; extra == "extra"
249
+ Requires-Dist: scipy>=1.16; extra == "extra"
250
+ Requires-Dist: terratorch>=1.0.2; extra == "extra"
251
+ Requires-Dist: transformers>=4.55; extra == "extra"
252
+ Requires-Dist: wandb>=0.21; extra == "extra"
253
+ Requires-Dist: timm>=0.9.7; extra == "extra"
254
+ Provides-Extra: dev
255
+ Requires-Dist: interrogate>=1.7.0; extra == "dev"
256
+ Requires-Dist: mypy<2,>=1.17.1; extra == "dev"
257
+ Requires-Dist: pre-commit>=4.3.0; extra == "dev"
258
+ Requires-Dist: pytest>=8.0; extra == "dev"
259
+ Requires-Dist: pytest_httpserver; extra == "dev"
260
+ Requires-Dist: ruff>=0.12.9; extra == "dev"
261
+ Requires-Dist: pytest-dotenv; extra == "dev"
262
+ Requires-Dist: pytest-xdist; extra == "dev"
263
+ Dynamic: license-file
238
264
 
239
265
  Overview
240
266
  --------
@@ -254,10 +280,12 @@ rslearn helps with:
254
280
 
255
281
 
256
282
  Quick links:
257
- - [CoreConcepts](CoreConcepts.md) summarizes key concepts in rslearn, including
283
+ - [CoreConcepts](docs/CoreConcepts.md) summarizes key concepts in rslearn, including
258
284
  datasets, windows, layers, and data sources.
259
- - [Examples](Examples.md) contains more examples, including customizing different
285
+ - [Examples](docs/Examples.md) contains more examples, including customizing different
260
286
  stages of rslearn with additional code.
287
+ - [DatasetConfig](docs/DatasetConfig.md) documents the dataset configuration file.
288
+ - [ModelConfig](docs/ModelConfig.md) documents the model configuration file.
261
289
 
262
290
 
263
291
  Setup
@@ -265,9 +293,33 @@ Setup
265
293
 
266
294
  rslearn requires Python 3.10+ (Python 3.12 is recommended).
267
295
 
268
- git clone https://github.com/allenai/rslearn.git
269
- cd rslearn
270
- pip install .[extra]
296
+ ```
297
+ git clone https://github.com/allenai/rslearn.git
298
+ cd rslearn
299
+ pip install .[extra]
300
+ ```
301
+
302
+
303
+ Supported Data Sources
304
+ ----------------------
305
+
306
+ rslearn supports ingesting raster and vector data from the following data sources. Even
307
+ if you don't plan to train models within rslearn, you can still use it to easily
308
+ download, crop, and re-project data based on spatiotemporal rectangles (windows) that
309
+ you define. See [Examples](docs/Examples.md) and [DatasetConfig](docs/DatasetConfig.md)
310
+ for how to setup these data sources.
311
+
312
+ - Sentinel-1
313
+ - Sentinel-2 L1C and L2A
314
+ - Landsat 8/9 OLI-TIRS
315
+ - National Agriculture Imagery Program
316
+ - OpenStreetMap
317
+ - Xyz (Slippy) Tiles (e.g., Mapbox tiles)
318
+ - Planet Labs (PlanetScope, SkySat)
319
+ - ESA WorldCover 2021
320
+
321
+ rslearn can also be used to easily mosaic, crop, and re-project any sets of local
322
+ raster and vector files you may have.
271
323
 
272
324
 
273
325
  Example Usage
@@ -281,28 +333,27 @@ Let's start by defining a region of interest and obtaining Sentinel-2 images. Cr
281
333
  directory `/path/to/dataset` and corresponding configuration file at
282
334
  `/path/to/dataset/config.json` as follows:
283
335
 
284
- {
285
- "layers": {
286
- "sentinel2": {
287
- "type": "raster",
288
- "band_sets": [{
289
- "dtype": "uint8",
290
- "bands": ["R", "G", "B"]
291
- }],
292
- "data_source": {
293
- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
336
+ ```json
337
+ {
338
+ "layers": {
339
+ "sentinel2": {
340
+ "type": "raster",
341
+ "band_sets": [{
342
+ "dtype": "uint8",
343
+ "bands": ["R", "G", "B"]
344
+ }],
345
+ "data_source": {
346
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
347
+ "init_args": {
294
348
  "index_cache_dir": "cache/sentinel2/",
295
- "max_time_delta": "1d",
296
349
  "sort_by": "cloud_cover",
297
350
  "use_rtree_index": false
298
351
  }
299
352
  }
300
- },
301
- "tile_store": {
302
- "name": "file",
303
- "root_dir": "tiles"
304
353
  }
305
354
  }
355
+ }
356
+ ```
306
357
 
307
358
  Here, we have initialized an empty dataset and defined a raster layer called
308
359
  `sentinel2`. Because it specifies a data source, it will be populated automatically. In
@@ -314,8 +365,10 @@ choosing the scenes with minimal cloud cover.
314
365
  Next, let's create our spatiotemporal windows. These will correspond to training
315
366
  examples.
316
367
 
317
- export DATASET_PATH=/path/to/dataset
318
- rslearn dataset add_windows --root $DATASET_PATH --group default --utm --resolution 10 --grid_size 128 --src_crs EPSG:4326 --box=-122.6901,47.2079,-121.4955,47.9403 --start 2024-06-01T00:00:00+00:00 --end 2024-08-01T00:00:00+00:00 --name seattle
368
+ ```
369
+ export DATASET_PATH=/path/to/dataset
370
+ rslearn dataset add_windows --root $DATASET_PATH --group default --utm --resolution 10 --grid_size 128 --src_crs EPSG:4326 --box=-122.6901,47.2079,-121.4955,47.9403 --start 2024-06-01T00:00:00+00:00 --end 2024-08-01T00:00:00+00:00 --name seattle
371
+ ```
319
372
 
320
373
  This creates windows along a 128x128 grid in the specified projection (i.e.,
321
374
  appropriate UTM zone for the location with 10 m/pixel resolution) covering the
@@ -327,9 +380,11 @@ We can now obtain the Sentinel-2 images by running prepare, ingest, and material
327
380
  * Ingest: retrieve those items. This step populates the `tiles` directory within the dataset.
328
381
  * Materialize: crop/mosaic the items to align with the windows. This populates the `layers` folder in each window directory.
329
382
 
330
- rslearn dataset prepare --root $DATASET_PATH --workers 32 --batch-size 8
331
- rslearn dataset ingest --root $DATASET_PATH --workers 32 --no-use-initial-job --jobs-per-process 1
332
- rslearn dataset materialize --root $DATASET_PATH --workers 32 --no-use-initial-job
383
+ ```
384
+ rslearn dataset prepare --root $DATASET_PATH --workers 32 --batch-size 8
385
+ rslearn dataset ingest --root $DATASET_PATH --workers 32 --no-use-initial-job --jobs-per-process 1
386
+ rslearn dataset materialize --root $DATASET_PATH --workers 32 --no-use-initial-job
387
+ ```
333
388
 
334
389
  For ingestion, you may need to reduce the number of workers depending on the available
335
390
  memory on your system.
@@ -337,32 +392,36 @@ memory on your system.
337
392
  You should now be able to open the GeoTIFF images. Let's find the window that
338
393
  corresponds to downtown Seattle:
339
394
 
340
- import shapely
341
- from rslearn.const import WGS84_PROJECTION
342
- from rslearn.dataset import Dataset
343
- from rslearn.utils import Projection, STGeometry
344
- from upath import UPath
345
-
346
- # Define longitude and latitude for downtown Seattle.
347
- downtown_seattle = shapely.Point(-122.333, 47.606)
348
-
349
- # Iterate over the windows and find the closest one.
350
- dataset = Dataset(path=UPath("/path/to/dataset"))
351
- best_window_name = None
352
- best_distance = None
353
- for window in dataset.load_windows(workers=32):
354
- shp = window.get_geometry().to_projection(WGS84_PROJECTION).shp
355
- distance = shp.distance(downtown_seattle)
356
- if best_distance is None or distance < best_distance:
357
- best_window_name = window.name
358
- best_distance = distance
359
-
360
- print(best_window_name)
395
+ ```python
396
+ import shapely
397
+ from rslearn.const import WGS84_PROJECTION
398
+ from rslearn.dataset import Dataset
399
+ from rslearn.utils import Projection, STGeometry
400
+ from upath import UPath
401
+
402
+ # Define longitude and latitude for downtown Seattle.
403
+ downtown_seattle = shapely.Point(-122.333, 47.606)
404
+
405
+ # Iterate over the windows and find the closest one.
406
+ dataset = Dataset(path=UPath("/path/to/dataset"))
407
+ best_window_name = None
408
+ best_distance = None
409
+ for window in dataset.load_windows(workers=32):
410
+ shp = window.get_geometry().to_projection(WGS84_PROJECTION).shp
411
+ distance = shp.distance(downtown_seattle)
412
+ if best_distance is None or distance < best_distance:
413
+ best_window_name = window.name
414
+ best_distance = distance
415
+
416
+ print(best_window_name)
417
+ ```
361
418
 
362
419
  It should be `seattle_54912_-527360`, so let's open it in qgis (or your favorite GIS
363
420
  software):
364
421
 
365
- qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/sentinel2/R_G_B/geotiff.tif
422
+ ```
423
+ qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/sentinel2/R_G_B/geotiff.tif
424
+ ```
366
425
 
367
426
 
368
427
  ### Adding Land Cover Labels
@@ -372,152 +431,166 @@ the ESA WorldCover land cover map as labels.
372
431
 
373
432
  Start by downloading the WorldCover data from https://worldcover2021.esa.int
374
433
 
375
- wget https://worldcover2021.esa.int/data/archive/ESA_WorldCover_10m_2021_v200_60deg_macrotile_N30W180.zip
376
- mkdir world_cover_tifs
377
- unzip ESA_WorldCover_10m_2021_v200_60deg_macrotile_N30W180.zip -d world_cover_tifs/
434
+ ```
435
+ wget https://worldcover2021.esa.int/data/archive/ESA_WorldCover_10m_2021_v200_60deg_macrotile_N30W180.zip
436
+ mkdir world_cover_tifs
437
+ unzip ESA_WorldCover_10m_2021_v200_60deg_macrotile_N30W180.zip -d world_cover_tifs/
438
+ ```
378
439
 
379
440
  It would require some work to write a script to re-project and crop these GeoTIFFs so
380
441
  that they align with the windows we have previously defined (and the Sentinel-2 images
381
442
  we have already ingested). We can use the LocalFiles data source to have rslearn
382
443
  automate this process. Update the dataset `config.json` with a new layer:
383
444
 
384
- "layers": {
385
- "sentinel2": {
386
- ...
387
- },
388
- "worldcover": {
389
- "type": "raster",
390
- "band_sets": [{
391
- "dtype": "uint8",
392
- "bands": ["B1"]
393
- }],
394
- "resampling_method": "nearest",
395
- "data_source": {
396
- "name": "rslearn.data_sources.local_files.LocalFiles",
445
+ ```jsonc
446
+ "layers": {
447
+ "sentinel2": {
448
+ # ...
449
+ },
450
+ "worldcover": {
451
+ "type": "raster",
452
+ "band_sets": [{
453
+ "dtype": "uint8",
454
+ "bands": ["B1"]
455
+ }],
456
+ "resampling_method": "nearest",
457
+ "data_source": {
458
+ "class_path": "rslearn.data_sources.local_files.LocalFiles",
459
+ "init_args": {
397
460
  "src_dir": "file:///path/to/world_cover_tifs/"
398
461
  }
399
462
  }
400
- },
401
- ...
463
+ }
464
+ },
465
+ # ...
466
+ ```
402
467
 
403
468
  Repeat the materialize process so we populate the data for this new layer:
404
469
 
405
- rslearn dataset prepare --root $DATASET_PATH --workers 32 --batch-size 8
406
- rslearn dataset ingest --root $DATASET_PATH --workers 32 --no-use-initial-job --jobs-per-process 1
407
- rslearn dataset materialize --root $DATASET_PATH --workers 32 --no-use-initial-job
470
+ ```
471
+ rslearn dataset prepare --root $DATASET_PATH --workers 32 --batch-size 8
472
+ rslearn dataset ingest --root $DATASET_PATH --workers 32 --no-use-initial-job --jobs-per-process 1
473
+ rslearn dataset materialize --root $DATASET_PATH --workers 32 --no-use-initial-job
474
+ ```
408
475
 
409
476
  We can visualize both the GeoTIFFs together in qgis:
410
477
 
411
- qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/*/*/geotiff.tif
478
+ ```
479
+ qgis $DATASET_PATH/windows/default/seattle_54912_-527360/layers/*/*/geotiff.tif
480
+ ```
412
481
 
413
482
 
414
483
  ### Training a Model
415
484
 
416
485
  Create a model configuration file `land_cover_model.yaml`:
417
486
 
487
+ ```yaml
488
+ model:
489
+ class_path: rslearn.train.lightning_module.RslearnLightningModule
490
+ init_args:
491
+ # This part defines the model architecture.
492
+ # Essentially we apply the SatlasPretrain Sentinel-2 backbone with a UNet decoder
493
+ # that terminates at a segmentation prediction head.
494
+ # The backbone outputs four feature maps at different scales, and the UNet uses
495
+ # these to compute a feature map at the input scale.
496
+ # Finally the segmentation head applies per-pixel softmax to compute the land
497
+ # cover class.
418
498
  model:
419
- class_path: rslearn.train.lightning_module.RslearnLightningModule
499
+ class_path: rslearn.models.singletask.SingleTaskModel
420
500
  init_args:
421
- # This part defines the model architecture.
422
- # Essentially we apply the SatlasPretrain Sentinel-2 backbone with a UNet decoder
423
- # that terminates at a segmentation prediction head.
424
- # The backbone outputs four feature maps at different scales, and the UNet uses
425
- # these to compute a feature map at the input scale.
426
- # Finally the segmentation head applies per-pixel softmax to compute the land
427
- # cover class.
428
- model:
429
- class_path: rslearn.models.singletask.SingleTaskModel
430
- init_args:
431
- encoder:
432
- - class_path: rslearn.models.satlaspretrain.SatlasPretrain
433
- init_args:
434
- model_identifier: "Sentinel2_SwinB_SI_RGB"
435
- decoder:
436
- - class_path: rslearn.models.unet.UNetDecoder
437
- init_args:
438
- in_channels: [[4, 128], [8, 256], [16, 512], [32, 1024]]
439
- # We use 101 classes because the WorldCover classes are 10, 20, 30, 40
440
- # 50, 60, 70, 80, 90, 95, 100.
441
- # We could process the GeoTIFFs to collapse them to 0-10 (the 11 actual
442
- # classes) but the model will quickly learn that the intermediate
443
- # values are never used.
444
- out_channels: 101
445
- conv_layers_per_resolution: 2
446
- - class_path: rslearn.train.tasks.segmentation.SegmentationHead
447
- # Remaining parameters in RslearnLightningModule define different aspects of the
448
- # training process like initial learning rate.
449
- lr: 0.0001
450
- data:
451
- class_path: rslearn.train.data_module.RslearnDataModule
501
+ encoder:
502
+ - class_path: rslearn.models.satlaspretrain.SatlasPretrain
503
+ init_args:
504
+ model_identifier: "Sentinel2_SwinB_SI_RGB"
505
+ decoder:
506
+ - class_path: rslearn.models.unet.UNetDecoder
507
+ init_args:
508
+ in_channels: [[4, 128], [8, 256], [16, 512], [32, 1024]]
509
+ # We use 101 classes because the WorldCover classes are 10, 20, 30, 40
510
+ # 50, 60, 70, 80, 90, 95, 100.
511
+ # We could process the GeoTIFFs to collapse them to 0-10 (the 11 actual
512
+ # classes) but the model will quickly learn that the intermediate
513
+ # values are never used.
514
+ out_channels: 101
515
+ conv_layers_per_resolution: 2
516
+ - class_path: rslearn.train.tasks.segmentation.SegmentationHead
517
+ # Remaining parameters in RslearnLightningModule define different aspects of the
518
+ # training process like initial learning rate.
519
+ lr: 0.0001
520
+ data:
521
+ class_path: rslearn.train.data_module.RslearnDataModule
522
+ init_args:
523
+ path: ${DATASET_PATH}
524
+ # This defines the layers that should be read for each window.
525
+ # The key ("image" / "targets") is what the data will be called in the model,
526
+ # while the layers option specifies which layers will be read.
527
+ inputs:
528
+ image:
529
+ data_type: "raster"
530
+ layers: ["sentinel2"]
531
+ bands: ["R", "G", "B"]
532
+ passthrough: true
533
+ targets:
534
+ data_type: "raster"
535
+ layers: ["worldcover"]
536
+ bands: ["B1"]
537
+ is_target: true
538
+ task:
539
+ # Train for semantic segmentation.
540
+ # The remap option is only used when visualizing outputs during testing.
541
+ class_path: rslearn.train.tasks.segmentation.SegmentationTask
452
542
  init_args:
453
- # Replace this with the dataset path.
454
- path: /path/to/dataset/
455
- # This defines the layers that should be read for each window.
456
- # The key ("image" / "targets") is what the data will be called in the model,
457
- # while the layers option specifies which layers will be read.
458
- inputs:
459
- image:
460
- data_type: "raster"
461
- layers: ["sentinel2"]
462
- bands: ["R", "G", "B"]
463
- passthrough: true
464
- targets:
465
- data_type: "raster"
466
- layers: ["worldcover"]
467
- bands: ["B1"]
468
- is_target: true
469
- task:
470
- # Train for semantic segmentation.
471
- # The remap option is only used when visualizing outputs during testing.
472
- class_path: rslearn.train.tasks.segmentation.SegmentationTask
543
+ num_classes: 101
544
+ remap_values: [[0, 1], [0, 255]]
545
+ batch_size: 8
546
+ num_workers: 32
547
+ # These define different options for different phases/splits, like training,
548
+ # validation, and testing.
549
+ # Here we use the same transform across splits except training where we add a
550
+ # flipping augmentation.
551
+ # For now we are using the same windows for training and validation.
552
+ default_config:
553
+ transforms:
554
+ - class_path: rslearn.train.transforms.normalize.Normalize
473
555
  init_args:
474
- num_classes: 101
475
- remap_values: [[0, 1], [0, 255]]
476
- batch_size: 8
477
- num_workers: 32
478
- # These define different options for different phases/splits, like training,
479
- # validation, and testing.
480
- # Here we use the same transform across splits except training where we add a
481
- # flipping augmentation.
482
- # For now we are using the same windows for training and validation.
483
- default_config:
484
- transforms:
485
- - class_path: rslearn.train.transforms.normalize.Normalize
486
- init_args:
487
- mean: 0
488
- std: 255
489
- train_config:
490
- transforms:
491
- - class_path: rslearn.train.transforms.normalize.Normalize
492
- init_args:
493
- mean: 0
494
- std: 255
495
- - class_path: rslearn.train.transforms.flip.Flip
496
- init_args:
497
- image_selectors: ["image", "target/classes", "target/valid"]
498
- groups: ["default"]
499
- val_config:
500
- groups: ["default"]
501
- test_config:
502
- groups: ["default"]
503
- predict_config:
504
- groups: ["predict"]
505
- load_all_patches: true
506
- skip_targets: true
507
- patch_size: 512
508
- trainer:
509
- max_epochs: 10
510
- callbacks:
511
- - class_path: lightning.pytorch.callbacks.ModelCheckpoint
556
+ mean: 0
557
+ std: 255
558
+ train_config:
559
+ transforms:
560
+ - class_path: rslearn.train.transforms.normalize.Normalize
561
+ init_args:
562
+ mean: 0
563
+ std: 255
564
+ - class_path: rslearn.train.transforms.flip.Flip
512
565
  init_args:
513
- save_top_k: 1
514
- save_last: true
515
- monitor: val_accuracy
516
- mode: max
566
+ image_selectors: ["image", "target/classes", "target/valid"]
567
+ groups: ["default"]
568
+ val_config:
569
+ groups: ["default"]
570
+ test_config:
571
+ groups: ["default"]
572
+ predict_config:
573
+ groups: ["predict"]
574
+ load_all_patches: true
575
+ skip_targets: true
576
+ patch_size: 512
577
+ trainer:
578
+ max_epochs: 10
579
+ callbacks:
580
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
581
+ init_args:
582
+ save_top_k: 1
583
+ save_last: true
584
+ monitor: val_accuracy
585
+ mode: max
586
+ dirpath: ./land_cover_model_checkpoints/
587
+ ```
517
588
 
518
589
  Now we can train the model:
519
590
 
520
- rslearn model fit --config land_cover_model.yaml
591
+ ```
592
+ rslearn model fit --config land_cover_model.yaml
593
+ ```
521
594
 
522
595
 
523
596
  ### Apply the Model
@@ -528,22 +601,28 @@ windows along a grid, we just create one big window. This is because we are just
528
601
  to run the prediction over the whole window rather than use different windows as
529
602
  different training examples.
530
603
 
531
- rslearn dataset add_windows --root $DATASET_PATH --group predict --utm --resolution 10 --src_crs EPSG:4326 --box=-122.712,45.477,-122.621,45.549 --start 2024-06-01T00:00:00+00:00 --end 2024-08-01T00:00:00+00:00 --name portland
532
- rslearn dataset prepare --root $DATASET_PATH --workers 32 --batch-size 8
533
- rslearn dataset ingest --root $DATASET_PATH --workers 32 --no-use-initial-job --jobs-per-process 1
534
- rslearn dataset materialize --root $DATASET_PATH --workers 32 --no-use-initial-job
604
+ ```
605
+ rslearn dataset add_windows --root $DATASET_PATH --group predict --utm --resolution 10 --src_crs EPSG:4326 --box=-122.712,45.477,-122.621,45.549 --start 2024-06-01T00:00:00+00:00 --end 2024-08-01T00:00:00+00:00 --name portland
606
+ rslearn dataset prepare --root $DATASET_PATH --workers 32 --batch-size 8
607
+ rslearn dataset ingest --root $DATASET_PATH --workers 32 --no-use-initial-job --jobs-per-process 1
608
+ rslearn dataset materialize --root $DATASET_PATH --workers 32 --no-use-initial-job
609
+ ```
535
610
 
536
611
  We also need to add an RslearnPredictionWriter to the trainer callbacks in the model
537
612
  configuration file, as it will handle writing the outputs from the model to a GeoTIFF.
538
613
 
539
- trainer:
540
- callbacks:
541
- - class_path: lightning.pytorch.callbacks.ModelCheckpoint
542
- ...
543
- - class_path: rslearn.train.prediction_writer.RslearnWriter
544
- init_args:
545
- path: /path/to/dataset/
546
- output_layer: output
614
+ ```yaml
615
+ trainer:
616
+ callbacks:
617
+ - class_path: lightning.pytorch.callbacks.ModelCheckpoint
618
+ ...
619
+ - class_path: rslearn.train.prediction_writer.RslearnWriter
620
+ init_args:
621
+ # We need to include this argument, but it will be overridden with the dataset
622
+ # path from data.init_args.path.
623
+ path: placeholder
624
+ output_layer: output
625
+ ```
547
626
 
548
627
  Because of our `predict_config`, when we run `model predict` it will apply the model on
549
628
  windows in the "predict" group, which is where we added the Portland window.
@@ -551,39 +630,46 @@ windows in the "predict" group, which is where we added the Portland window.
551
630
  And it will be written in a new output_layer called "output". But we have to update the
552
631
  dataset configuration so it specifies the layer:
553
632
 
554
-
555
- "layers": {
556
- "sentinel2": {
557
- ...
558
- },
559
- "worldcover": {
560
- ...
561
- },
562
- "output": {
563
- "type": "raster",
564
- "band_sets": [{
565
- "dtype": "uint8",
566
- "bands": ["output"]
567
- }]
568
- }
633
+ ```jsonc
634
+ "layers": {
635
+ "sentinel2": {
636
+ # ...
569
637
  },
638
+ "worldcover": {
639
+ # ...
640
+ },
641
+ "output": {
642
+ "type": "raster",
643
+ "band_sets": [{
644
+ "dtype": "uint8",
645
+ "bands": ["output"]
646
+ }]
647
+ }
648
+ },
649
+ ```
570
650
 
571
651
  Now we can apply the model:
572
652
 
573
- # Find model checkpoint in lightning_logs dir.
574
- ls lightning_logs/*/checkpoints/last.ckpt
575
- rslearn model predict --config land_cover_model.yaml --ckpt_path lightning_logs/version_0/checkpoints/last.ckpt
653
+ ```
654
+ # Find model checkpoint in lightning_logs dir.
655
+ ls lightning_logs/*/checkpoints/last.ckpt
656
+ rslearn model predict --config land_cover_model.yaml --ckpt_path land_cover_model_checkpoints/last.ckpt
657
+ ```
576
658
 
577
659
  And visualize the Sentinel-2 image and output in qgis:
578
660
 
579
- qgis $DATASET_PATH/windows/predict/portland/layers/*/*/geotiff.tif
661
+ ```
662
+ qgis $DATASET_PATH/windows/predict/portland/layers/*/*/geotiff.tif
663
+ ```
580
664
 
581
665
 
582
666
  ### Defining Train and Validation Splits
583
667
 
584
668
  We can visualize the logged metrics using Tensorboard:
585
669
 
586
- tensorboard --logdir=lightning_logs/
670
+ ```
671
+ tensorboard --logdir=lightning_logs/
672
+ ```
587
673
 
588
674
  However, because our training and validation data are identical, the validation metrics
589
675
  are not meaningful.
@@ -597,57 +683,61 @@ We will use the second approach. The script below sets a "split" key in the opti
597
683
  dict (which is stored in each window's `metadata.json` file) to "train" or "val"
598
684
  based on the SHA-256 hash of the window name.
599
685
 
600
- import hashlib
601
- import tqdm
602
- from rslearn.dataset import Dataset, Window
603
- from upath import UPath
604
-
605
- ds_path = UPath("/path/to/dataset/")
606
- dataset = Dataset(ds_path)
607
- windows = dataset.load_windows(show_progress=True, workers=32)
608
- for window in tqdm.tqdm(windows):
609
- if hashlib.sha256(window.name.encode()).hexdigest()[0] in ["0", "1"]:
610
- split = "val"
611
- else:
612
- split = "train"
613
- if "split" in window.options and window.options["split"] == split:
614
- continue
615
- window.options["split"] = split
616
- window.save()
686
+ ```python
687
+ import hashlib
688
+ import tqdm
689
+ from rslearn.dataset import Dataset, Window
690
+ from upath import UPath
691
+
692
+ ds_path = UPath("/path/to/dataset/")
693
+ dataset = Dataset(ds_path)
694
+ windows = dataset.load_windows(show_progress=True, workers=32)
695
+ for window in tqdm.tqdm(windows):
696
+ if hashlib.sha256(window.name.encode()).hexdigest()[0] in ["0", "1"]:
697
+ split = "val"
698
+ else:
699
+ split = "train"
700
+ if "split" in window.options and window.options["split"] == split:
701
+ continue
702
+ window.options["split"] = split
703
+ window.save()
704
+ ```
617
705
 
618
706
  Now we can update the model configuration file to use these splits:
619
707
 
620
- default_config:
621
- transforms:
622
- - class_path: rslearn.train.transforms.normalize.Normalize
623
- init_args:
624
- mean: 0
625
- std: 255
626
- train_config:
627
- transforms:
628
- - class_path: rslearn.train.transforms.normalize.Normalize
629
- init_args:
630
- mean: 0
631
- std: 255
632
- - class_path: rslearn.train.transforms.flip.Flip
633
- init_args:
634
- image_selectors: ["image", "target/classes", "target/valid"]
635
- groups: ["default"]
636
- tags:
637
- split: train
638
- val_config:
639
- groups: ["default"]
640
- tags:
641
- split: val
642
- test_config:
643
- groups: ["default"]
644
- tags:
645
- split: val
646
- predict_config:
647
- groups: ["predict"]
648
- load_all_patches: true
649
- skip_targets: true
650
- patch_size: 512
708
+ ```yaml
709
+ default_config:
710
+ transforms:
711
+ - class_path: rslearn.train.transforms.normalize.Normalize
712
+ init_args:
713
+ mean: 0
714
+ std: 255
715
+ train_config:
716
+ transforms:
717
+ - class_path: rslearn.train.transforms.normalize.Normalize
718
+ init_args:
719
+ mean: 0
720
+ std: 255
721
+ - class_path: rslearn.train.transforms.flip.Flip
722
+ init_args:
723
+ image_selectors: ["image", "target/classes", "target/valid"]
724
+ groups: ["default"]
725
+ tags:
726
+ split: train
727
+ val_config:
728
+ groups: ["default"]
729
+ tags:
730
+ split: val
731
+ test_config:
732
+ groups: ["default"]
733
+ tags:
734
+ split: val
735
+ predict_config:
736
+ groups: ["predict"]
737
+ load_all_patches: true
738
+ skip_targets: true
739
+ patch_size: 512
740
+ ```
651
741
 
652
742
  The `tags` option that we are adding here tells rslearn to only load windows with a
653
743
  matching key and value in the window options.
@@ -655,28 +745,187 @@ matching key and value in the window options.
655
745
  Previously when we run `model fit`, it should show the same number of windows for
656
746
  training and validation:
657
747
 
658
- got 4752 examples in split train
659
- got 4752 examples in split val
748
+ ```
749
+ got 4752 examples in split train
750
+ got 4752 examples in split val
751
+ ```
660
752
 
661
753
  With the updates, it should show different numbers like this:
662
754
 
663
- got 4167 examples in split train
664
- got 585 examples in split val
755
+ ```
756
+ got 4167 examples in split train
757
+ got 585 examples in split val
758
+ ```
665
759
 
666
760
 
667
761
  ### Visualizing with `model test`
668
762
 
669
- Coming soon
763
+ We can visualize the ground truth labels and model predictions in the test set using
764
+ the `model test` command:
765
+
766
+ ```
767
+ mkdir ./vis
768
+ rslearn model test --config land_cover_model.yaml --ckpt_path land_cover_model_checkpoints/last.ckpt --model.init_args.visualize_dir=./vis/
769
+ ```
770
+
771
+ This will produce PNGs in the vis directory. The visualizations are produced by the
772
+ `Task.visualize` function, so we could customize the visualization by subclassing
773
+ SegmentationTask and overriding the visualize function.
774
+
775
+
776
+ ### Checkpoint and Logging Management
777
+
778
+ Above, we needed to configure the checkpoint directory in the model config (the
779
+ `dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
780
+ specify the checkpoint path when applying the model. Additionally, metrics are logged
781
+ to the local filesystem and not well organized.
782
+
783
+ We can instead let rslearn automatically manage checkpoints, along with logging to
784
+ Weights & Biases. To do so, we add project_name, run_name, and management_dir options
785
+ to the model config. The project_name corresponds to the W&B project, and the run name
786
+ corresponds to the W&B name. The management_dir is a directory to store project data;
787
+ rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
788
+ and uses it to store checkpoints.
789
+
790
+ ```yaml
791
+ model:
792
+ # ...
793
+ data:
794
+ # ...
795
+ trainer:
796
+ # ...
797
+ project_name: land_cover_model
798
+ run_name: version_00
799
+ # This sets the option via the MANAGEMENT_DIR environment variable.
800
+ management_dir: ${MANAGEMENT_DIR}
801
+ ```
802
+
803
+ Now, set the `MANAGEMENT_DIR` environment variable and run `model fit`:
804
+
805
+ ```
806
+ export MANAGEMENT_DIR=./project_data
807
+ rslearn model fit --config land_cover_model.yaml
808
+ ```
809
+
810
+ The training and validation loss and accuracy metric should now be logged to W&B. The
811
+ accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
812
+ by passing the relevant init_args to the task, e.g. mean IoU and F1:
813
+
814
+ ```yaml
815
+ class_path: rslearn.train.tasks.segmentation.SegmentationTask
816
+ init_args:
817
+ num_classes: 101
818
+ remap_values: [[0, 1], [0, 255]]
819
+ enable_miou_metric: true
820
+ enable_f1_metric: true
821
+ ```
822
+
823
+ When calling `model test` and `model predict` with management_dir set, rslearn will
824
+ automatically load the best checkpoint from the project directory, or raise an error if
825
+ no existing checkpoint exists. This behavior can be overridden with the
826
+ `--load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
827
+ details). Logging will be enabled during fit but not test/predict, and this can also
828
+ be overridden, using `--log_mode`.
670
829
 
671
830
 
672
831
  ### Inputting Multiple Sentinel-2 Images
673
832
 
674
- Coming soon
833
+ Currently our model inputs a single Sentinel-2 image. However, for most tasks where
834
+ labels are not expected to change from week to week, we find that accuracy can be
835
+ significantly improved by inputting multiple images, regardless of the pre-trained
836
+ model used. Multiple images makes the model more resilient to clouds and image
837
+ artifacts, and allows the model to synthesize information across different views that
838
+ may come from different seasons or weather conditions.
839
+
840
+ We first update our dataset configuration to obtain three images, by customizing the
841
+ query_config section. This can replace the sentinel2 layer:
842
+
843
+ ```jsonc
844
+ "layers": {
845
+ "sentinel2_multi": {
846
+ "type": "raster",
847
+ "band_sets": [{
848
+ "dtype": "uint8",
849
+ "bands": ["R", "G", "B"]
850
+ }],
851
+ "data_source": {
852
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
853
+ "init_args": {
854
+ "index_cache_dir": "cache/sentinel2/",
855
+ "sort_by": "cloud_cover",
856
+ "use_rtree_index": false
857
+ },
858
+ "query_config": {
859
+ "max_matches": 3
860
+ }
861
+ }
862
+ },
863
+ "worldcover": {
864
+ # ...
865
+ },
866
+ "output": {
867
+ # ...
868
+ }
869
+ }
870
+ ```
871
+
872
+ Repeat the steps from earlier to prepare, ingest, and materialize the dataset.
675
873
 
874
+ Now we update our model configuration file. First, we modify the model architecture to
875
+ be able to input an image time series. We use the SimpleTimeSeries model, which takes
876
+ an encoder that expects a single-image input, and applies that encoder on each image in
877
+ the time series. It then applies max temporal pooling to combine the per-image feature
878
+ maps extracted by the encoder.
676
879
 
677
- ### Logging to Weights & Biases
880
+ Image time series in rslearn are currently stored as [T*C, H, W] tensors. So we pass
881
+ the `image_channels` to SimpleTimeSeries so it knows how to slice up the tensor to
882
+ recover the per-timestep images.
678
883
 
679
- Coming soon
884
+ ```yaml
885
+ model:
886
+ class_path: rslearn.train.lightning_module.RslearnLightningModule
887
+ init_args:
888
+ model:
889
+ class_path: rslearn.models.singletask.SingleTaskModel
890
+ init_args:
891
+ encoder:
892
+ - class_path: rslearn.models.simple_time_series.SimpleTimeSeries
893
+ init_args:
894
+ encoder:
895
+ class_path: rslearn.models.satlaspretrain.SatlasPretrain
896
+ init_args:
897
+ model_identifier: "Sentinel2_SwinB_SI_RGB"
898
+ image_channels: 3
899
+ decoder:
900
+ # ...
901
+ ```
902
+
903
+ Next, we update the data module section so that the dataset loads the image time series
904
+ rather than a single image. The `load_all_layers` option tells the dataset to stack the
905
+ rasters from all of the layers specified, and also to ignore windows where any of those
906
+ layers are missing.
907
+
908
+ ```yaml
909
+ data:
910
+ class_path: rslearn.train.data_module.RslearnDataModule
911
+ init_args:
912
+ path: # ...
913
+ inputs:
914
+ image:
915
+ data_type: "raster"
916
+ layers: ["sentinel2_multi", "sentinel2_multi.1", "sentinel2_multi.2"]
917
+ bands: ["R", "G", "B"]
918
+ passthrough: true
919
+ load_all_layers: true
920
+ targets:
921
+ # ...
922
+ ```
923
+
924
+ Now we can train an updated model:
925
+
926
+ ```
927
+ rslearn model fit --config land_cover_model.yaml
928
+ ```
680
929
 
681
930
 
682
931
  Contact