rslearn 0.0.15__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. rslearn/config/__init__.py +2 -10
  2. rslearn/config/dataset.py +420 -420
  3. rslearn/data_sources/__init__.py +8 -31
  4. rslearn/data_sources/aws_landsat.py +13 -24
  5. rslearn/data_sources/aws_open_data.py +21 -46
  6. rslearn/data_sources/aws_sentinel1.py +3 -14
  7. rslearn/data_sources/climate_data_store.py +21 -40
  8. rslearn/data_sources/copernicus.py +30 -91
  9. rslearn/data_sources/data_source.py +26 -0
  10. rslearn/data_sources/earthdaily.py +13 -38
  11. rslearn/data_sources/earthdata_srtm.py +14 -32
  12. rslearn/data_sources/eurocrops.py +5 -9
  13. rslearn/data_sources/gcp_public_data.py +46 -43
  14. rslearn/data_sources/google_earth_engine.py +31 -44
  15. rslearn/data_sources/local_files.py +91 -100
  16. rslearn/data_sources/openstreetmap.py +21 -51
  17. rslearn/data_sources/planet.py +12 -30
  18. rslearn/data_sources/planet_basemap.py +4 -25
  19. rslearn/data_sources/planetary_computer.py +58 -141
  20. rslearn/data_sources/usda_cdl.py +15 -26
  21. rslearn/data_sources/usgs_landsat.py +4 -29
  22. rslearn/data_sources/utils.py +9 -0
  23. rslearn/data_sources/worldcereal.py +47 -54
  24. rslearn/data_sources/worldcover.py +16 -14
  25. rslearn/data_sources/worldpop.py +15 -18
  26. rslearn/data_sources/xyz_tiles.py +11 -30
  27. rslearn/dataset/dataset.py +6 -6
  28. rslearn/dataset/manage.py +14 -20
  29. rslearn/dataset/materialize.py +9 -45
  30. rslearn/lightning_cli.py +377 -1
  31. rslearn/main.py +3 -3
  32. rslearn/models/concatenate_features.py +93 -0
  33. rslearn/models/olmoearth_pretrain/model.py +2 -5
  34. rslearn/tile_stores/__init__.py +0 -11
  35. rslearn/train/dataset.py +4 -12
  36. rslearn/train/prediction_writer.py +16 -32
  37. rslearn/train/tasks/classification.py +2 -1
  38. rslearn/utils/fsspec.py +20 -0
  39. rslearn/utils/jsonargparse.py +79 -0
  40. rslearn/utils/raster_format.py +1 -41
  41. rslearn/utils/vector_format.py +1 -38
  42. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/METADATA +58 -25
  43. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/RECORD +48 -49
  44. rslearn/data_sources/geotiff.py +0 -1
  45. rslearn/data_sources/raster_source.py +0 -23
  46. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/WHEEL +0 -0
  47. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/entry_points.txt +0 -0
  48. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/LICENSE +0 -0
  49. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/licenses/NOTICE +0 -0
  50. {rslearn-0.0.15.dist-info → rslearn-0.0.17.dist-info}/top_level.txt +0 -0
@@ -12,10 +12,8 @@ from lightning.pytorch.callbacks import BasePredictionWriter
12
12
  from upath import UPath
13
13
 
14
14
  from rslearn.config import (
15
+ LayerConfig,
15
16
  LayerType,
16
- RasterFormatConfig,
17
- RasterLayerConfig,
18
- VectorLayerConfig,
19
17
  )
20
18
  from rslearn.dataset import Dataset, Window
21
19
  from rslearn.log_utils import get_logger
@@ -25,9 +23,8 @@ from rslearn.utils.geometry import PixelBounds
25
23
  from rslearn.utils.raster_format import (
26
24
  RasterFormat,
27
25
  adjust_projection_and_bounds_for_array,
28
- load_raster_format,
29
26
  )
30
- from rslearn.utils.vector_format import VectorFormat, load_vector_format
27
+ from rslearn.utils.vector_format import VectorFormat
31
28
 
32
29
  from .lightning_module import RslearnLightningModule
33
30
  from .tasks.task import Task
@@ -150,7 +147,7 @@ class RslearnWriter(BasePredictionWriter):
150
147
  selector: list[str] | None = None,
151
148
  merger: PatchPredictionMerger | None = None,
152
149
  output_path: str | Path | None = None,
153
- layer_config: RasterLayerConfig | VectorLayerConfig | None = None,
150
+ layer_config: LayerConfig | None = None,
154
151
  ):
155
152
  """Create a new RslearnWriter.
156
153
 
@@ -178,43 +175,31 @@ class RslearnWriter(BasePredictionWriter):
178
175
  )
179
176
 
180
177
  # Handle dataset and layer config
181
- self.layer_config: RasterLayerConfig | VectorLayerConfig
178
+ self.layer_config: LayerConfig
182
179
  if layer_config:
183
180
  self.layer_config = layer_config
184
- self.dataset = None if self.output_path else Dataset(self.path)
185
181
  else:
186
- self.dataset = Dataset(self.path)
187
- if self.output_layer not in self.dataset.layers:
182
+ dataset = Dataset(self.path)
183
+ if self.output_layer not in dataset.layers:
188
184
  raise KeyError(
189
185
  f"Output layer '{self.output_layer}' not found in dataset layers."
190
186
  )
191
- raw_layer_config = self.dataset.layers[self.output_layer]
192
- # Type narrowing to ensure compatibility
193
- if isinstance(raw_layer_config, (RasterLayerConfig | VectorLayerConfig)):
194
- self.layer_config = raw_layer_config
195
- else:
196
- raise ValueError(
197
- f"Layer config must be RasterLayerConfig or VectorLayerConfig, got {type(raw_layer_config)}"
198
- )
187
+ self.layer_config = dataset.layers[self.output_layer]
199
188
 
200
189
  self.format: RasterFormat | VectorFormat
201
- if self.layer_config.layer_type == LayerType.RASTER:
202
- assert isinstance(self.layer_config, RasterLayerConfig)
190
+ if self.layer_config.type == LayerType.RASTER:
203
191
  band_cfg = self.layer_config.band_sets[0]
204
- self.format = load_raster_format(
205
- RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
206
- )
207
- elif self.layer_config.layer_type == LayerType.VECTOR:
208
- assert isinstance(self.layer_config, VectorLayerConfig)
209
- self.format = load_vector_format(self.layer_config.format)
192
+ self.format = band_cfg.instantiate_raster_format()
193
+ elif self.layer_config.type == LayerType.VECTOR:
194
+ self.format = self.layer_config.instantiate_vector_format()
210
195
  else:
211
- raise ValueError(f"invalid layer type {self.layer_config.layer_type}")
196
+ raise ValueError(f"invalid layer type {self.layer_config.type}")
212
197
 
213
198
  if merger is not None:
214
199
  self.merger = merger
215
- elif self.layer_config.layer_type == LayerType.RASTER:
200
+ elif self.layer_config.type == LayerType.RASTER:
216
201
  self.merger = RasterMerger()
217
- elif self.layer_config.layer_type == LayerType.VECTOR:
202
+ elif self.layer_config.type == LayerType.VECTOR:
218
203
  self.merger = VectorMerger()
219
204
 
220
205
  # Map from window name to pending data to write.
@@ -337,8 +322,7 @@ class RslearnWriter(BasePredictionWriter):
337
322
  logger.debug(f"Merging and writing for window {window.name}")
338
323
  merged_output = self.merger.merge(window, pending_output)
339
324
 
340
- if self.layer_config.layer_type == LayerType.RASTER:
341
- assert isinstance(self.layer_config, RasterLayerConfig)
325
+ if self.layer_config.type == LayerType.RASTER:
342
326
  raster_dir = window.get_raster_dir(
343
327
  self.output_layer, self.layer_config.band_sets[0].bands
344
328
  )
@@ -351,7 +335,7 @@ class RslearnWriter(BasePredictionWriter):
351
335
  )
352
336
  self.format.encode_raster(raster_dir, projection, bounds, merged_output)
353
337
 
354
- elif self.layer_config.layer_type == LayerType.VECTOR:
338
+ elif self.layer_config.type == LayerType.VECTOR:
355
339
  layer_dir = window.get_layer_dir(self.output_layer)
356
340
  assert isinstance(self.format, VectorFormat)
357
341
  self.format.encode_vector(layer_dir, merged_output)
@@ -26,7 +26,7 @@ class ClassificationTask(BasicTask):
26
26
  def __init__(
27
27
  self,
28
28
  property_name: str,
29
- classes: list, # TODO: Should this be a list of str or int or can it be both?
29
+ classes: list[str],
30
30
  filters: list[tuple[str, str]] = [],
31
31
  read_class_id: bool = False,
32
32
  allow_invalid: bool = False,
@@ -176,6 +176,7 @@ class ClassificationTask(BasicTask):
176
176
  # For multiclass classification or when using the default threshold
177
177
  class_idx = probs.argmax().item()
178
178
 
179
+ value: str | int
179
180
  if not self.read_class_id:
180
181
  value = self.classes[class_idx] # type: ignore
181
182
  else:
rslearn/utils/fsspec.py CHANGED
@@ -156,3 +156,23 @@ def open_rasterio_upath_writer(
156
156
  with path.open("wb") as f:
157
157
  with rasterio.open(f, "w", **kwargs) as raster:
158
158
  yield raster
159
+
160
+
161
+ def get_relative_suffix(base_dir: UPath, fname: UPath) -> str:
162
+ """Get the suffix of fname relative to base_dir.
163
+
164
+ Args:
165
+ base_dir: the base directory.
166
+ fname: a filename within the base directory.
167
+
168
+ Returns:
169
+ the suffix on base_dir that would yield the given filename.
170
+ """
171
+ if not fname.path.startswith(base_dir.path):
172
+ raise ValueError(
173
+ f"filename {fname.path} must start with base directory {base_dir.path}"
174
+ )
175
+ suffix = fname.path[len(base_dir.path) :]
176
+ if suffix.startswith("/"):
177
+ suffix = suffix[1:]
178
+ return suffix
@@ -1,7 +1,18 @@
1
1
  """Custom serialization for jsonargparse."""
2
2
 
3
+ from datetime import datetime
4
+ from typing import TYPE_CHECKING, Any
5
+
3
6
  import jsonargparse
4
7
  from rasterio.crs import CRS
8
+ from upath import UPath
9
+
10
+ from rslearn.config.dataset import LayerConfig
11
+
12
+ if TYPE_CHECKING:
13
+ from rslearn.data_sources.data_source import DataSourceContext
14
+
15
+ INITIALIZED = False
5
16
 
6
17
 
7
18
  def crs_serializer(v: CRS) -> str:
@@ -28,6 +39,74 @@ def crs_deserializer(v: str) -> CRS:
28
39
  return CRS.from_string(v)
29
40
 
30
41
 
42
+ def datetime_serializer(v: datetime) -> str:
43
+ """Serialize datetime for jsonargparse.
44
+
45
+ Args:
46
+ v: the datetime object.
47
+
48
+ Returns:
49
+ the datetime encoded to string
50
+ """
51
+ return v.isoformat()
52
+
53
+
54
+ def datetime_deserializer(v: str) -> datetime:
55
+ """Deserialize datetime for jsonargparse.
56
+
57
+ Args:
58
+ v: the encoded datetime.
59
+
60
+ Returns:
61
+ the decoded datetime object
62
+ """
63
+ return datetime.fromisoformat(v)
64
+
65
+
66
+ def data_source_context_serializer(v: "DataSourceContext") -> dict[str, Any]:
67
+ """Serialize DataSourceContext for jsonargparse."""
68
+ x = {
69
+ "ds_path": (str(v.ds_path) if v.ds_path is not None else None),
70
+ "layer_config": (
71
+ v.layer_config.model_dump(mode="json")
72
+ if v.layer_config is not None
73
+ else None
74
+ ),
75
+ }
76
+ return x
77
+
78
+
79
+ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
80
+ """Deserialize DataSourceContext for jsonargparse."""
81
+ # We lazily import these to avoid cyclic dependency.
82
+ from rslearn.data_sources.data_source import DataSourceContext
83
+
84
+ return DataSourceContext(
85
+ ds_path=(UPath(v["ds_path"]) if v["ds_path"] is not None else None),
86
+ layer_config=(
87
+ LayerConfig.model_validate(v["layer_config"])
88
+ if v["layer_config"] is not None
89
+ else None
90
+ ),
91
+ )
92
+
93
+
31
94
  def init_jsonargparse() -> None:
32
95
  """Initialize custom jsonargparse serializers."""
96
+ global INITIALIZED
97
+ if INITIALIZED:
98
+ return
33
99
  jsonargparse.typing.register_type(CRS, crs_serializer, crs_deserializer)
100
+ jsonargparse.typing.register_type(
101
+ datetime, datetime_serializer, datetime_deserializer
102
+ )
103
+
104
+ from rslearn.data_sources.data_source import DataSourceContext
105
+
106
+ jsonargparse.typing.register_type(
107
+ DataSourceContext,
108
+ data_source_context_serializer,
109
+ data_source_context_deserializer,
110
+ )
111
+
112
+ INITIALIZED = True
@@ -2,8 +2,7 @@
2
2
 
3
3
  import hashlib
4
4
  import json
5
- from collections.abc import Callable
6
- from typing import Any, BinaryIO, TypeVar
5
+ from typing import Any, BinaryIO
7
6
 
8
7
  import affine
9
8
  import numpy as np
@@ -14,34 +13,12 @@ from rasterio.crs import CRS
14
13
  from rasterio.enums import Resampling
15
14
  from upath import UPath
16
15
 
17
- from rslearn.config import RasterFormatConfig
18
16
  from rslearn.const import TILE_SIZE
19
17
  from rslearn.log_utils import get_logger
20
18
  from rslearn.utils.fsspec import open_rasterio_upath_reader, open_rasterio_upath_writer
21
19
 
22
20
  from .geometry import PixelBounds, Projection
23
21
 
24
- _RasterFormatT = TypeVar("_RasterFormatT", bound="RasterFormat")
25
-
26
-
27
- class _RasterFormatRegistry(dict[str, type["RasterFormat"]]):
28
- """Registry for RasterFormat classes."""
29
-
30
- def register(
31
- self, name: str
32
- ) -> Callable[[type[_RasterFormatT]], type[_RasterFormatT]]:
33
- """Decorator to register a raster format class."""
34
-
35
- def decorator(cls: type[_RasterFormatT]) -> type[_RasterFormatT]:
36
- self[name] = cls
37
- return cls
38
-
39
- return decorator
40
-
41
-
42
- RasterFormats = _RasterFormatRegistry()
43
-
44
-
45
22
  logger = get_logger(__name__)
46
23
 
47
24
 
@@ -219,7 +196,6 @@ class RasterFormat:
219
196
  raise NotImplementedError
220
197
 
221
198
 
222
- @RasterFormats.register("image_tile")
223
199
  class ImageTileRasterFormat(RasterFormat):
224
200
  """A RasterFormat that stores data in image tiles corresponding to grid cells.
225
201
 
@@ -468,7 +444,6 @@ class ImageTileRasterFormat(RasterFormat):
468
444
  )
469
445
 
470
446
 
471
- @RasterFormats.register("geotiff")
472
447
  class GeotiffRasterFormat(RasterFormat):
473
448
  """A raster format that uses one big, tiled GeoTIFF with small block size."""
474
449
 
@@ -623,7 +598,6 @@ class GeotiffRasterFormat(RasterFormat):
623
598
  return GeotiffRasterFormat(**kwargs)
624
599
 
625
600
 
626
- @RasterFormats.register("single_image")
627
601
  class SingleImageRasterFormat(RasterFormat):
628
602
  """A raster format that produces a single image called image.png/jpg.
629
603
 
@@ -775,17 +749,3 @@ class SingleImageRasterFormat(RasterFormat):
775
749
  if "format" in config:
776
750
  kwargs["format"] = config["format"]
777
751
  return SingleImageRasterFormat(**kwargs)
778
-
779
-
780
- def load_raster_format(config: RasterFormatConfig) -> RasterFormat:
781
- """Loads a RasterFormat from a RasterFormatConfig.
782
-
783
- Args:
784
- config: the RasterFormatConfig configuration object specifying the
785
- RasterFormat.
786
-
787
- Returns:
788
- the loaded RasterFormat implementation
789
- """
790
- cls = RasterFormats[config.name]
791
- return cls.from_config(config.name, config.config_dict)
@@ -1,15 +1,13 @@
1
1
  """Classes for writing vector data to a UPath."""
2
2
 
3
3
  import json
4
- from collections.abc import Callable
5
4
  from enum import Enum
6
- from typing import Any, TypeVar
5
+ from typing import Any
7
6
 
8
7
  import shapely
9
8
  from rasterio.crs import CRS
10
9
  from upath import UPath
11
10
 
12
- from rslearn.config import VectorFormatConfig
13
11
  from rslearn.const import WGS84_PROJECTION
14
12
  from rslearn.log_utils import get_logger
15
13
  from rslearn.utils.fsspec import open_atomic
@@ -18,25 +16,6 @@ from .feature import Feature
18
16
  from .geometry import PixelBounds, Projection, STGeometry, safely_reproject_and_clip
19
17
 
20
18
  logger = get_logger(__name__)
21
- _VectorFormatT = TypeVar("_VectorFormatT", bound="VectorFormat")
22
-
23
-
24
- class _VectorFormatRegistry(dict[str, type["VectorFormat"]]):
25
- """Registry for VectorFormat classes."""
26
-
27
- def register(
28
- self, name: str
29
- ) -> Callable[[type[_VectorFormatT]], type[_VectorFormatT]]:
30
- """Decorator to register a vector format class."""
31
-
32
- def decorator(cls: type[_VectorFormatT]) -> type[_VectorFormatT]:
33
- self[name] = cls
34
- return cls
35
-
36
- return decorator
37
-
38
-
39
- VectorFormats = _VectorFormatRegistry()
40
19
 
41
20
 
42
21
  class VectorFormat:
@@ -85,7 +64,6 @@ class VectorFormat:
85
64
  raise NotImplementedError
86
65
 
87
66
 
88
- @VectorFormats.register("tile")
89
67
  class TileVectorFormat(VectorFormat):
90
68
  """TileVectorFormat stores data in GeoJSON files corresponding to grid cells.
91
69
 
@@ -275,7 +253,6 @@ class GeojsonCoordinateMode(Enum):
275
253
  WGS84 = "wgs84"
276
254
 
277
255
 
278
- @VectorFormats.register("geojson")
279
256
  class GeojsonVectorFormat(VectorFormat):
280
257
  """A vector format that uses one big GeoJSON."""
281
258
 
@@ -429,17 +406,3 @@ class GeojsonVectorFormat(VectorFormat):
429
406
  if "coordinate_mode" in config:
430
407
  kwargs["coordinate_mode"] = GeojsonCoordinateMode(config["coordinate_mode"])
431
408
  return GeojsonVectorFormat(**kwargs)
432
-
433
-
434
- def load_vector_format(config: VectorFormatConfig) -> VectorFormat:
435
- """Loads a VectorFormat from a VectorFormatConfig.
436
-
437
- Args:
438
- config: the VectorFormatConfig configuration object specifying the
439
- VectorFormat.
440
-
441
- Returns:
442
- the loaded VectorFormat implementation
443
- """
444
- cls = VectorFormats[config.name]
445
- return cls.from_config(config.name, config.config_dict)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.15
3
+ Version: 0.0.17
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -343,10 +343,12 @@ directory `/path/to/dataset` and corresponding configuration file at
343
343
  "bands": ["R", "G", "B"]
344
344
  }],
345
345
  "data_source": {
346
- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
347
- "index_cache_dir": "cache/sentinel2/",
348
- "sort_by": "cloud_cover",
349
- "use_rtree_index": false
346
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
347
+ "init_args": {
348
+ "index_cache_dir": "cache/sentinel2/",
349
+ "sort_by": "cloud_cover",
350
+ "use_rtree_index": false
351
+ }
350
352
  }
351
353
  }
352
354
  }
@@ -453,8 +455,10 @@ automate this process. Update the dataset `config.json` with a new layer:
453
455
  }],
454
456
  "resampling_method": "nearest",
455
457
  "data_source": {
456
- "name": "rslearn.data_sources.local_files.LocalFiles",
457
- "src_dir": "file:///path/to/world_cover_tifs/"
458
+ "class_path": "rslearn.data_sources.local_files.LocalFiles",
459
+ "init_args": {
460
+ "src_dir": "file:///path/to/world_cover_tifs/"
461
+ }
458
462
  }
459
463
  }
460
464
  },
@@ -516,8 +520,7 @@ model:
516
520
  data:
517
521
  class_path: rslearn.train.data_module.RslearnDataModule
518
522
  init_args:
519
- # Replace this with the dataset path.
520
- path: /path/to/dataset/
523
+ path: ${DATASET_PATH}
521
524
  # This defines the layers that should be read for each window.
522
525
  # The key ("image" / "targets") is what the data will be called in the model,
523
526
  # while the layers option specifies which layers will be read.
@@ -615,7 +618,9 @@ trainer:
615
618
  ...
616
619
  - class_path: rslearn.train.prediction_writer.RslearnWriter
617
620
  init_args:
618
- path: /path/to/dataset/
621
+ # We need to include this argument, but it will be overridden with the dataset
622
+ # path from data.init_args.path.
623
+ path: placeholder
619
624
  output_layer: output
620
625
  ```
621
626
 
@@ -768,24 +773,43 @@ This will produce PNGs in the vis directory. The visualizations are produced by
768
773
  SegmentationTask and overriding the visualize function.
769
774
 
770
775
 
771
- ### Logging to Weights & Biases
776
+ ### Checkpoint and Logging Management
777
+
778
+ Above, we needed to configure the checkpoint directory in the model config (the
779
+ `dirpath` option under `lightning.pytorch.callbacks.ModelCheckpoint`), and explicitly
780
+ specify the checkpoint path when applying the model. Additionally, metrics are logged
781
+ to the local filesystem and not well organized.
772
782
 
773
- We can log to W&B by setting the logger under trainer in the model configuration file:
783
+ We can instead let rslearn automatically manage checkpoints, along with logging to
784
+ Weights & Biases. To do so, we add project_name, run_name, and management_dir options
785
+ to the model config. The project_name corresponds to the W&B project, and the run name
786
+ corresponds to the W&B name. The management_dir is a directory to store project data;
787
+ rslearn determines a per-project directory at `{management_dir}/{project_name}/{run_name}/`
788
+ and uses it to store checkpoints.
774
789
 
775
790
  ```yaml
791
+ model:
792
+ # ...
793
+ data:
794
+ # ...
776
795
  trainer:
777
796
  # ...
778
- logger:
779
- class_path: lightning.pytorch.loggers.WandbLogger
780
- init_args:
781
- project: land_cover_model
782
- name: version_00
797
+ project_name: land_cover_model
798
+ run_name: version_00
799
+ # This sets the option via the MANAGEMENT_DIR environment variable.
800
+ management_dir: ${MANAGEMENT_DIR}
783
801
  ```
784
802
 
785
- Now, runs with this model configuration should show on W&B. For `model fit` runs,
786
- the training and validation loss and accuracy metric will be logged. The accuracy
787
- metric is provided by SegmentationTask, and additional metrics can be enabled by
788
- passing the relevant init_args to the task, e.g. mean IoU and F1:
803
+ Now, set the `MANAGEMENT_DIR` environment variable and run `model fit`:
804
+
805
+ ```
806
+ export MANAGEMENT_DIR=./project_data
807
+ rslearn model fit --config land_cover_model.yaml
808
+ ```
809
+
810
+ The training and validation loss and accuracy metric should now be logged to W&B. The
811
+ accuracy metric is provided by SegmentationTask, and additional metrics can be enabled
812
+ by passing the relevant init_args to the task, e.g. mean IoU and F1:
789
813
 
790
814
  ```yaml
791
815
  class_path: rslearn.train.tasks.segmentation.SegmentationTask
@@ -796,6 +820,13 @@ passing the relevant init_args to the task, e.g. mean IoU and F1:
796
820
  enable_f1_metric: true
797
821
  ```
798
822
 
823
+ When calling `model test` and `model predict` with management_dir set, rslearn will
824
+ automatically load the best checkpoint from the project directory, or raise an error if
825
+ no existing checkpoint exists. This behavior can be overridden with the
826
+ `--load_checkpoint_mode` and `--load_checkpoint_required` options (see `--help` for
827
+ details). Logging will be enabled during fit but not test/predict, and this can also
828
+ be overridden, using `--log_mode`.
829
+
799
830
 
800
831
  ### Inputting Multiple Sentinel-2 Images
801
832
 
@@ -818,10 +849,12 @@ query_config section. This can replace the sentinel2 layer:
818
849
  "bands": ["R", "G", "B"]
819
850
  }],
820
851
  "data_source": {
821
- "name": "rslearn.data_sources.gcp_public_data.Sentinel2",
822
- "index_cache_dir": "cache/sentinel2/",
823
- "sort_by": "cloud_cover",
824
- "use_rtree_index": false,
852
+ "class_path": "rslearn.data_sources.gcp_public_data.Sentinel2",
853
+ "init_args": {
854
+ "index_cache_dir": "cache/sentinel2/",
855
+ "sort_by": "cloud_cover",
856
+ "use_rtree_index": false
857
+ },
825
858
  "query_config": {
826
859
  "max_matches": 3
827
860
  }