rslearn 0.0.12__py3-none-any.whl → 0.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rslearn/config/dataset.py CHANGED
@@ -8,7 +8,6 @@ from typing import Any
8
8
  import numpy as np
9
9
  import numpy.typing as npt
10
10
  import pytimeparse
11
- import torch
12
11
  from rasterio.enums import Resampling
13
12
 
14
13
  from rslearn.utils import PixelBounds, Projection
@@ -49,15 +48,6 @@ class DType(Enum):
49
48
  return np.float32
50
49
  raise ValueError(f"unable to handle numpy dtype {self}")
51
50
 
52
- def get_torch_dtype(self) -> torch.dtype:
53
- """Returns pytorch dtype object corresponding to this DType."""
54
- if self == DType.INT32:
55
- return torch.int32
56
- elif self == DType.FLOAT32:
57
- return torch.float32
58
- else:
59
- raise ValueError(f"unable to handle torch dtype {self}")
60
-
61
51
 
62
52
  RESAMPLING_METHODS = {
63
53
  "nearest": Resampling.nearest,
@@ -125,7 +115,8 @@ class BandSetConfig:
125
115
  self,
126
116
  config_dict: dict[str, Any],
127
117
  dtype: DType,
128
- bands: list[str],
118
+ bands: list[str] | None = None,
119
+ num_bands: int | None = None,
129
120
  format: dict[str, Any] | None = None,
130
121
  zoom_offset: int = 0,
131
122
  remap: dict[str, Any] | None = None,
@@ -137,7 +128,10 @@ class BandSetConfig:
137
128
  Args:
138
129
  config_dict: the config dict used to configure this BandSetConfig
139
130
  dtype: the pixel value type to store tiles in
140
- bands: list of band names in this BandSetConfig
131
+ bands: list of band names in this BandSetConfig. One of bands or num_bands
132
+ must be set.
133
+ num_bands: the number of bands in this band set. The bands will be named
134
+ B00, B01, B02, etc.
141
135
  format: the format to store tiles in, defaults to geotiff
142
136
  zoom_offset: store images at a resolution higher or lower than the window
143
137
  resolution. This enables keeping source data at its native resolution,
@@ -155,6 +149,14 @@ class BandSetConfig:
155
149
  materialization when creating mosaics, to determine which parts of the
156
150
  source images should be copied.
157
151
  """
152
+ if (bands is None and num_bands is None) or (
153
+ bands is not None and num_bands is not None
154
+ ):
155
+ raise ValueError("exactly one of bands and num_bands must be set")
156
+ if bands is None:
157
+ assert num_bands is not None
158
+ bands = [f"B{idx}" for idx in range(num_bands)]
159
+
158
160
  if class_names is not None and len(bands) != len(class_names):
159
161
  raise ValueError(
160
162
  f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
@@ -187,9 +189,16 @@ class BandSetConfig:
187
189
  kwargs = dict(
188
190
  config_dict=config,
189
191
  dtype=DType(config["dtype"]),
190
- bands=config["bands"],
191
192
  )
192
- for k in ["format", "zoom_offset", "remap", "class_names", "nodata_vals"]:
193
+ for k in [
194
+ "bands",
195
+ "num_bands",
196
+ "format",
197
+ "zoom_offset",
198
+ "remap",
199
+ "class_names",
200
+ "nodata_vals",
201
+ ]:
193
202
  if k in config:
194
203
  kwargs[k] = config[k]
195
204
  return BandSetConfig(**kwargs) # type: ignore
@@ -827,3 +827,55 @@ class Sentinel1(PlanetaryComputer):
827
827
  kwargs[k] = d[k]
828
828
 
829
829
  return Sentinel1(**kwargs)
830
+
831
+
832
+ class Naip(PlanetaryComputer):
833
+ """A data source for NAIP data on Microsoft Planetary Computer.
834
+
835
+ See https://planetarycomputer.microsoft.com/dataset/naip.
836
+ """
837
+
838
+ COLLECTION_NAME = "naip"
839
+ ASSET_BANDS = {"image": ["R", "G", "B", "NIR"]}
840
+
841
+ def __init__(
842
+ self,
843
+ **kwargs: Any,
844
+ ):
845
+ """Initialize a new Naip instance.
846
+
847
+ Args:
848
+ band_names: list of bands to try to ingest.
849
+ kwargs: additional arguments to pass to PlanetaryComputer.
850
+ """
851
+ super().__init__(
852
+ collection_name=self.COLLECTION_NAME,
853
+ asset_bands=self.ASSET_BANDS,
854
+ **kwargs,
855
+ )
856
+
857
+ @staticmethod
858
+ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
859
+ """Creates a new Naip instance from a configuration dictionary."""
860
+ if config.data_source is None:
861
+ raise ValueError("config.data_source is required")
862
+ d = config.data_source.config_dict
863
+ kwargs = {}
864
+
865
+ if "timeout_seconds" in d:
866
+ kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
867
+
868
+ if "cache_dir" in d:
869
+ kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
870
+
871
+ simple_optionals = [
872
+ "query",
873
+ "sort_by",
874
+ "sort_ascending",
875
+ "max_items_per_client",
876
+ ]
877
+ for k in simple_optionals:
878
+ if k in d:
879
+ kwargs[k] = d[k]
880
+
881
+ return Naip(**kwargs)
@@ -20,6 +20,7 @@ class LayerPrepareSummary:
20
20
  # Counts
21
21
  windows_prepared: int
22
22
  windows_skipped: int
23
+ windows_rejected: int
23
24
  get_items_attempts: int
24
25
 
25
26
 
rslearn/dataset/manage.py CHANGED
@@ -118,6 +118,7 @@ def prepare_dataset_windows(
118
118
  duration_seconds=time.monotonic() - layer_start_time,
119
119
  windows_prepared=0,
120
120
  windows_skipped=len(windows),
121
+ windows_rejected=0,
121
122
  get_items_attempts=0,
122
123
  )
123
124
  )
@@ -141,6 +142,7 @@ def prepare_dataset_windows(
141
142
  duration_seconds=time.monotonic() - layer_start_time,
142
143
  windows_prepared=0,
143
144
  windows_skipped=len(windows),
145
+ windows_rejected=0,
144
146
  get_items_attempts=0,
145
147
  )
146
148
  )
@@ -181,6 +183,9 @@ def prepare_dataset_windows(
181
183
  attempts_counter=attempts_counter,
182
184
  )
183
185
 
186
+ windows_prepared = 0
187
+ windows_rejected = 0
188
+ min_matches = data_source_cfg.query_config.min_matches
184
189
  for window, result in zip(needed_windows, results):
185
190
  layer_datas = window.load_layer_datas()
186
191
  layer_datas[layer_name] = WindowLayerData(
@@ -191,13 +196,22 @@ def prepare_dataset_windows(
191
196
  )
192
197
  window.save_layer_datas(layer_datas)
193
198
 
199
+ # If result is empty and min_matches > 0, window was rejected due to min_matches
200
+ if len(result) == 0 and min_matches > 0:
201
+ windows_rejected += 1
202
+ else:
203
+ windows_prepared += 1
204
+
205
+ windows_skipped = len(windows) - len(needed_windows)
206
+
194
207
  layer_summaries.append(
195
208
  LayerPrepareSummary(
196
209
  layer_name=layer_name,
197
210
  data_source_name=data_source_cfg.name,
198
211
  duration_seconds=time.monotonic() - layer_start_time,
199
- windows_prepared=len(needed_windows), # we assume all have succeeded
200
- windows_skipped=len(windows) - len(needed_windows),
212
+ windows_prepared=windows_prepared,
213
+ windows_skipped=windows_skipped,
214
+ windows_rejected=windows_rejected,
201
215
  get_items_attempts=attempts_counter.value,
202
216
  )
203
217
  )
@@ -0,0 +1,67 @@
1
+ """LightningCLI for rslearn."""
2
+
3
+ import sys
4
+
5
+ from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
6
+
7
+ from rslearn.arg_parser import RslearnArgumentParser
8
+ from rslearn.train.data_module import RslearnDataModule
9
+ from rslearn.train.lightning_module import RslearnLightningModule
10
+
11
+
12
+ class RslearnLightningCLI(LightningCLI):
13
+ """LightningCLI that links data.tasks to model.tasks and supports environment variables."""
14
+
15
+ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
16
+ """Link data.tasks to model.tasks.
17
+
18
+ Args:
19
+ parser: the argument parser
20
+ """
21
+ # Link data.tasks to model.tasks
22
+ parser.link_arguments(
23
+ "data.init_args.task", "model.init_args.task", apply_on="instantiate"
24
+ )
25
+
26
+ def before_instantiate_classes(self) -> None:
27
+ """Called before Lightning class initialization.
28
+
29
+ Sets the dataset path for any configured RslearnPredictionWriter callbacks.
30
+ """
31
+ subcommand = self.config.subcommand
32
+ c = self.config[subcommand]
33
+
34
+ # If there is a RslearnPredictionWriter, set its path.
35
+ prediction_writer_callback = None
36
+ if "callbacks" in c.trainer:
37
+ for existing_callback in c.trainer.callbacks:
38
+ if (
39
+ existing_callback.class_path
40
+ == "rslearn.train.prediction_writer.RslearnWriter"
41
+ ):
42
+ prediction_writer_callback = existing_callback
43
+ if prediction_writer_callback:
44
+ prediction_writer_callback.init_args.path = c.data.init_args.path
45
+
46
+ # Disable the sampler replacement, since the rslearn data module will set the
47
+ # sampler as needed.
48
+ c.trainer.use_distributed_sampler = False
49
+
50
+ # For predict, make sure that return_predictions is False.
51
+ # Otherwise all the predictions would be stored in memory which can lead to
52
+ # high memory consumption.
53
+ if subcommand == "predict":
54
+ c.return_predictions = False
55
+
56
+
57
+ def model_handler() -> None:
58
+ """Handler for any rslearn model X commands."""
59
+ RslearnLightningCLI(
60
+ model_class=RslearnLightningModule,
61
+ datamodule_class=RslearnDataModule,
62
+ args=sys.argv[2:],
63
+ subclass_mode_model=True,
64
+ subclass_mode_data=True,
65
+ save_config_kwargs={"overwrite": True},
66
+ parser_class=RslearnArgumentParser,
67
+ )
rslearn/main.py CHANGED
@@ -10,11 +10,9 @@ from datetime import UTC, datetime, timedelta
10
10
  from typing import Any, TypeVar
11
11
 
12
12
  import tqdm
13
- from lightning.pytorch.cli import LightningArgumentParser, LightningCLI
14
13
  from rasterio.crs import CRS
15
14
  from upath import UPath
16
15
 
17
- from rslearn.arg_parser import RslearnArgumentParser
18
16
  from rslearn.config import LayerConfig
19
17
  from rslearn.const import WGS84_EPSG
20
18
  from rslearn.data_sources import Item, data_source_from_config
@@ -38,8 +36,6 @@ from rslearn.dataset.manage import (
38
36
  )
39
37
  from rslearn.log_utils import get_logger
40
38
  from rslearn.tile_stores import get_tile_store_with_layer
41
- from rslearn.train.data_module import RslearnDataModule
42
- from rslearn.train.lightning_module import RslearnLightningModule
43
39
  from rslearn.utils import Projection, STGeometry
44
40
 
45
41
  logger = get_logger(__name__)
@@ -831,85 +827,35 @@ def dataset_build_index() -> None:
831
827
  index.save_index(ds_path)
832
828
 
833
829
 
834
- class RslearnLightningCLI(LightningCLI):
835
- """LightningCLI that links data.tasks to model.tasks and supports environment variables."""
836
-
837
- def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:
838
- """Link data.tasks to model.tasks.
839
-
840
- Args:
841
- parser: the argument parser
842
- """
843
- # Link data.tasks to model.tasks
844
- parser.link_arguments(
845
- "data.init_args.task", "model.init_args.task", apply_on="instantiate"
846
- )
847
-
848
- def before_instantiate_classes(self) -> None:
849
- """Called before Lightning class initialization.
850
-
851
- Sets the dataset path for any configured RslearnPredictionWriter callbacks.
852
- """
853
- subcommand = self.config.subcommand
854
- c = self.config[subcommand]
855
-
856
- # If there is a RslearnPredictionWriter, set its path.
857
- prediction_writer_callback = None
858
- if "callbacks" in c.trainer:
859
- for existing_callback in c.trainer.callbacks:
860
- if (
861
- existing_callback.class_path
862
- == "rslearn.train.prediction_writer.RslearnWriter"
863
- ):
864
- prediction_writer_callback = existing_callback
865
- if prediction_writer_callback:
866
- prediction_writer_callback.init_args.path = c.data.init_args.path
867
-
868
- # Disable the sampler replacement, since the rslearn data module will set the
869
- # sampler as needed.
870
- c.trainer.use_distributed_sampler = False
871
-
872
- # For predict, make sure that return_predictions is False.
873
- # Otherwise all the predictions would be stored in memory which can lead to
874
- # high memory consumption.
875
- if subcommand == "predict":
876
- c.return_predictions = False
877
-
878
-
879
- def model_handler() -> None:
880
- """Handler for any rslearn model X commands."""
881
- RslearnLightningCLI(
882
- model_class=RslearnLightningModule,
883
- datamodule_class=RslearnDataModule,
884
- args=sys.argv[2:],
885
- subclass_mode_model=True,
886
- subclass_mode_data=True,
887
- save_config_kwargs={"overwrite": True},
888
- parser_class=RslearnArgumentParser,
889
- )
890
-
891
-
892
830
  @register_handler("model", "fit")
893
831
  def model_fit() -> None:
894
832
  """Handler for rslearn model fit."""
833
+ from .lightning_cli import model_handler
834
+
895
835
  model_handler()
896
836
 
897
837
 
898
838
  @register_handler("model", "validate")
899
839
  def model_validate() -> None:
900
840
  """Handler for rslearn model validate."""
841
+ from .lightning_cli import model_handler
842
+
901
843
  model_handler()
902
844
 
903
845
 
904
846
  @register_handler("model", "test")
905
847
  def model_test() -> None:
906
848
  """Handler for rslearn model test."""
849
+ from .lightning_cli import model_handler
850
+
907
851
  model_handler()
908
852
 
909
853
 
910
854
  @register_handler("model", "predict")
911
855
  def model_predict() -> None:
912
856
  """Handler for rslearn model predict."""
857
+ from .lightning_cli import model_handler
858
+
913
859
  model_handler()
914
860
 
915
861
 
@@ -40,6 +40,7 @@ EMBEDDING_SIZES = {
40
40
  ModelID.OLMOEARTH_V1_NANO: 128,
41
41
  ModelID.OLMOEARTH_V1_TINY: 192,
42
42
  ModelID.OLMOEARTH_V1_BASE: 768,
43
+ ModelID.OLMOEARTH_V1_LARGE: 1024,
43
44
  }
44
45
 
45
46