rslearn 0.0.11__py3-none-any.whl → 0.0.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (34) hide show
  1. rslearn/config/dataset.py +23 -4
  2. rslearn/data_sources/planetary_computer.py +52 -0
  3. rslearn/dataset/handler_summaries.py +1 -0
  4. rslearn/dataset/manage.py +16 -2
  5. rslearn/models/anysat.py +5 -1
  6. rslearn/models/dinov3.py +6 -1
  7. rslearn/models/feature_center_crop.py +50 -0
  8. rslearn/models/olmoearth_pretrain/model.py +88 -27
  9. rslearn/models/prithvi.py +9 -1
  10. rslearn/train/lightning_module.py +0 -3
  11. rslearn/train/prediction_writer.py +25 -8
  12. rslearn/train/tasks/classification.py +2 -2
  13. rslearn/train/tasks/detection.py +5 -5
  14. rslearn/train/tasks/embedding.py +116 -0
  15. rslearn/train/tasks/per_pixel_regression.py +5 -4
  16. rslearn/train/tasks/regression.py +5 -5
  17. rslearn/train/transforms/pad.py +3 -3
  18. rslearn/utils/raster_format.py +38 -0
  19. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/METADATA +3 -2
  20. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/RECORD +25 -31
  21. rslearn-0.0.13.dist-info/licenses/NOTICE +115 -0
  22. rslearn/models/copernicusfm.py +0 -228
  23. rslearn/models/copernicusfm_src/__init__.py +0 -1
  24. rslearn/models/copernicusfm_src/aurora/area.py +0 -50
  25. rslearn/models/copernicusfm_src/aurora/fourier.py +0 -134
  26. rslearn/models/copernicusfm_src/dynamic_hypernetwork.py +0 -523
  27. rslearn/models/copernicusfm_src/flexivit/patch_embed.py +0 -260
  28. rslearn/models/copernicusfm_src/flexivit/utils.py +0 -69
  29. rslearn/models/copernicusfm_src/model_vit.py +0 -348
  30. rslearn/models/copernicusfm_src/util/pos_embed.py +0 -216
  31. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/WHEEL +0 -0
  32. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/entry_points.txt +0 -0
  33. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/licenses/LICENSE +0 -0
  34. {rslearn-0.0.11.dist-info → rslearn-0.0.13.dist-info}/top_level.txt +0 -0
rslearn/config/dataset.py CHANGED
@@ -125,7 +125,8 @@ class BandSetConfig:
125
125
  self,
126
126
  config_dict: dict[str, Any],
127
127
  dtype: DType,
128
- bands: list[str],
128
+ bands: list[str] | None = None,
129
+ num_bands: int | None = None,
129
130
  format: dict[str, Any] | None = None,
130
131
  zoom_offset: int = 0,
131
132
  remap: dict[str, Any] | None = None,
@@ -137,7 +138,10 @@ class BandSetConfig:
137
138
  Args:
138
139
  config_dict: the config dict used to configure this BandSetConfig
139
140
  dtype: the pixel value type to store tiles in
140
- bands: list of band names in this BandSetConfig
141
+ bands: list of band names in this BandSetConfig. One of bands or num_bands
142
+ must be set.
143
+ num_bands: the number of bands in this band set. The bands will be named
144
+ B00, B01, B02, etc.
141
145
  format: the format to store tiles in, defaults to geotiff
142
146
  zoom_offset: store images at a resolution higher or lower than the window
143
147
  resolution. This enables keeping source data at its native resolution,
@@ -155,6 +159,14 @@ class BandSetConfig:
155
159
  materialization when creating mosaics, to determine which parts of the
156
160
  source images should be copied.
157
161
  """
162
+ if (bands is None and num_bands is None) or (
163
+ bands is not None and num_bands is not None
164
+ ):
165
+ raise ValueError("exactly one of bands and num_bands must be set")
166
+ if bands is None:
167
+ assert num_bands is not None
168
+ bands = [f"B{idx}" for idx in range(num_bands)]
169
+
158
170
  if class_names is not None and len(bands) != len(class_names):
159
171
  raise ValueError(
160
172
  f"the number of class lists ({len(class_names)}) does not match the number of bands ({len(bands)})"
@@ -187,9 +199,16 @@ class BandSetConfig:
187
199
  kwargs = dict(
188
200
  config_dict=config,
189
201
  dtype=DType(config["dtype"]),
190
- bands=config["bands"],
191
202
  )
192
- for k in ["format", "zoom_offset", "remap", "class_names", "nodata_vals"]:
203
+ for k in [
204
+ "bands",
205
+ "num_bands",
206
+ "format",
207
+ "zoom_offset",
208
+ "remap",
209
+ "class_names",
210
+ "nodata_vals",
211
+ ]:
193
212
  if k in config:
194
213
  kwargs[k] = config[k]
195
214
  return BandSetConfig(**kwargs) # type: ignore
@@ -827,3 +827,55 @@ class Sentinel1(PlanetaryComputer):
827
827
  kwargs[k] = d[k]
828
828
 
829
829
  return Sentinel1(**kwargs)
830
+
831
+
832
+ class Naip(PlanetaryComputer):
833
+ """A data source for NAIP data on Microsoft Planetary Computer.
834
+
835
+ See https://planetarycomputer.microsoft.com/dataset/naip.
836
+ """
837
+
838
+ COLLECTION_NAME = "naip"
839
+ ASSET_BANDS = {"image": ["R", "G", "B", "NIR"]}
840
+
841
+ def __init__(
842
+ self,
843
+ **kwargs: Any,
844
+ ):
845
+ """Initialize a new Naip instance.
846
+
847
+ Args:
848
+ band_names: list of bands to try to ingest.
849
+ kwargs: additional arguments to pass to PlanetaryComputer.
850
+ """
851
+ super().__init__(
852
+ collection_name=self.COLLECTION_NAME,
853
+ asset_bands=self.ASSET_BANDS,
854
+ **kwargs,
855
+ )
856
+
857
+ @staticmethod
858
+ def from_config(config: RasterLayerConfig, ds_path: UPath) -> "Naip":
859
+ """Creates a new Naip instance from a configuration dictionary."""
860
+ if config.data_source is None:
861
+ raise ValueError("config.data_source is required")
862
+ d = config.data_source.config_dict
863
+ kwargs = {}
864
+
865
+ if "timeout_seconds" in d:
866
+ kwargs["timeout"] = timedelta(seconds=d["timeout_seconds"])
867
+
868
+ if "cache_dir" in d:
869
+ kwargs["cache_dir"] = join_upath(ds_path, d["cache_dir"])
870
+
871
+ simple_optionals = [
872
+ "query",
873
+ "sort_by",
874
+ "sort_ascending",
875
+ "max_items_per_client",
876
+ ]
877
+ for k in simple_optionals:
878
+ if k in d:
879
+ kwargs[k] = d[k]
880
+
881
+ return Naip(**kwargs)
@@ -20,6 +20,7 @@ class LayerPrepareSummary:
20
20
  # Counts
21
21
  windows_prepared: int
22
22
  windows_skipped: int
23
+ windows_rejected: int
23
24
  get_items_attempts: int
24
25
 
25
26
 
rslearn/dataset/manage.py CHANGED
@@ -118,6 +118,7 @@ def prepare_dataset_windows(
118
118
  duration_seconds=time.monotonic() - layer_start_time,
119
119
  windows_prepared=0,
120
120
  windows_skipped=len(windows),
121
+ windows_rejected=0,
121
122
  get_items_attempts=0,
122
123
  )
123
124
  )
@@ -141,6 +142,7 @@ def prepare_dataset_windows(
141
142
  duration_seconds=time.monotonic() - layer_start_time,
142
143
  windows_prepared=0,
143
144
  windows_skipped=len(windows),
145
+ windows_rejected=0,
144
146
  get_items_attempts=0,
145
147
  )
146
148
  )
@@ -181,6 +183,9 @@ def prepare_dataset_windows(
181
183
  attempts_counter=attempts_counter,
182
184
  )
183
185
 
186
+ windows_prepared = 0
187
+ windows_rejected = 0
188
+ min_matches = data_source_cfg.query_config.min_matches
184
189
  for window, result in zip(needed_windows, results):
185
190
  layer_datas = window.load_layer_datas()
186
191
  layer_datas[layer_name] = WindowLayerData(
@@ -191,13 +196,22 @@ def prepare_dataset_windows(
191
196
  )
192
197
  window.save_layer_datas(layer_datas)
193
198
 
199
+ # If result is empty and min_matches > 0, window was rejected due to min_matches
200
+ if len(result) == 0 and min_matches > 0:
201
+ windows_rejected += 1
202
+ else:
203
+ windows_prepared += 1
204
+
205
+ windows_skipped = len(windows) - len(needed_windows)
206
+
194
207
  layer_summaries.append(
195
208
  LayerPrepareSummary(
196
209
  layer_name=layer_name,
197
210
  data_source_name=data_source_cfg.name,
198
211
  duration_seconds=time.monotonic() - layer_start_time,
199
- windows_prepared=len(needed_windows), # we assume all have succeeded
200
- windows_skipped=len(windows) - len(needed_windows),
212
+ windows_prepared=windows_prepared,
213
+ windows_skipped=windows_skipped,
214
+ windows_rejected=windows_rejected,
201
215
  get_items_attempts=attempts_counter.value,
202
216
  )
203
217
  )
rslearn/models/anysat.py CHANGED
@@ -1,4 +1,8 @@
1
- """AnySat model."""
1
+ """AnySat model.
2
+
3
+ This code loads the AnySat model from torch hub. See
4
+ https://github.com/gastruc/AnySat for applicable license and copyright information.
5
+ """
2
6
 
3
7
  from typing import Any
4
8
 
rslearn/models/dinov3.py CHANGED
@@ -1,4 +1,9 @@
1
- """DinoV3 model."""
1
+ """DinoV3 model.
2
+
3
+ This code loads the DINOv3 model. You must obtain the model separately from Meta to use
4
+ it. See https://github.com/facebookresearch/dinov3 for applicable license and copyright
5
+ information.
6
+ """
2
7
 
3
8
  from enum import StrEnum
4
9
  from pathlib import Path
@@ -0,0 +1,50 @@
1
+ """Apply center cropping on a feature map."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+
7
+
8
+ class FeatureCenterCrop(torch.nn.Module):
9
+ """Apply center cropping on the input feature maps."""
10
+
11
+ def __init__(
12
+ self,
13
+ sizes: list[tuple[int, int]],
14
+ ) -> None:
15
+ """Create a new FeatureCenterCrop.
16
+
17
+ Only the center of each feature map will be retained and passed to the next
18
+ module.
19
+
20
+ Args:
21
+ sizes: a list of (height, width) tuples, with one tuple for each input
22
+ feature map.
23
+ """
24
+ super().__init__()
25
+ self.sizes = sizes
26
+
27
+ def forward(
28
+ self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
29
+ ) -> list[torch.Tensor]:
30
+ """Apply center cropping on the feature maps.
31
+
32
+ Args:
33
+ features: list of feature maps at different resolutions.
34
+ inputs: original inputs (ignored).
35
+
36
+ Returns:
37
+ center cropped feature maps.
38
+ """
39
+ new_features = []
40
+ for i, feat in enumerate(features):
41
+ height, width = self.sizes[i]
42
+ if feat.shape[2] < height or feat.shape[3] < width:
43
+ raise ValueError(
44
+ "feature map is smaller than the desired height and width"
45
+ )
46
+ start_h = feat.shape[2] // 2 - height // 2
47
+ start_w = feat.shape[3] // 2 - width // 2
48
+ feat = feat[:, :, start_h : start_h + height, start_w : start_w + width]
49
+ new_features.append(feat)
50
+ return new_features
@@ -9,6 +9,11 @@ from einops import rearrange
9
9
  from olmo_core.config import Config
10
10
  from olmo_core.distributed.checkpoint import load_model_and_optim_state
11
11
  from olmoearth_pretrain.data.constants import Modality
12
+ from olmoearth_pretrain.model_loader import (
13
+ ModelID,
14
+ load_model_from_id,
15
+ load_model_from_path,
16
+ )
12
17
  from olmoearth_pretrain.nn.flexihelios import Encoder, TokensAndMasks
13
18
  from olmoearth_pretrain.train.masking import MaskedOlmoEarthSample, MaskValue
14
19
  from upath import UPath
@@ -31,54 +36,115 @@ AUTOCAST_DTYPE_MAP = {
31
36
  "float32": torch.float32,
32
37
  }
33
38
 
39
+ EMBEDDING_SIZES = {
40
+ ModelID.OLMOEARTH_V1_NANO: 128,
41
+ ModelID.OLMOEARTH_V1_TINY: 192,
42
+ ModelID.OLMOEARTH_V1_BASE: 768,
43
+ ModelID.OLMOEARTH_V1_LARGE: 1024,
44
+ }
45
+
34
46
 
35
47
  class OlmoEarth(torch.nn.Module):
36
48
  """A wrapper to support the OlmoEarth model."""
37
49
 
38
50
  def __init__(
39
51
  self,
40
- # TODO: we should accept model ID instead of checkpoint_path once we are closer
41
- # to being ready for release.
42
- checkpoint_path: str,
43
- selector: list[str | int] = [],
52
+ patch_size: int,
53
+ model_id: ModelID | None = None,
54
+ model_path: str | None = None,
55
+ checkpoint_path: str | None = None,
56
+ selector: list[str | int] = ["encoder"],
44
57
  forward_kwargs: dict[str, Any] = {},
45
58
  random_initialization: bool = False,
46
59
  embedding_size: int | None = None,
47
- patch_size: int | None = None,
48
60
  autocast_dtype: str | None = "bfloat16",
49
61
  ):
50
62
  """Create a new OlmoEarth model.
51
63
 
52
64
  Args:
53
- checkpoint_path: the checkpoint directory to load. It should contain
54
- config.json file as well as model_and_optim folder.
65
+ patch_size: token spatial patch size to use.
66
+ model_id: the model ID to load. One of model_id or model_path or checkpoint_path must be
67
+ set.
68
+ model_path: the path to load the model from. One of model_id or model_path or checkpoint_path must be
69
+ set. Same structure as the HF-hosted `model_id` models: bundle with a config.json and weights.pth.
70
+ checkpoint_path: the checkpoint directory to load from, if model_id or model_path is not
71
+ set. It should contain a distributed checkpoint with a config.json file as well as model_and_optim
72
+ folder.
55
73
  selector: an optional sequence of attribute names or list indices to select
56
- the sub-module that should be applied on the input images.
74
+ the sub-module that should be applied on the input images. Defaults to
75
+ ["encoder"] to select only the transformer encoder.
57
76
  forward_kwargs: additional arguments to pass to forward pass besides the
58
77
  MaskedOlmoEarthSample.
59
78
  random_initialization: whether to skip loading the checkpoint so the
60
79
  weights are randomly initialized. In this case, the checkpoint is only
61
80
  used to define the model architecture.
62
81
  embedding_size: optional embedding size to report via
63
- get_backbone_channels.
64
- patch_size: optional patch size to report via get_backbone_channels.
82
+ get_backbone_channels (if model_id is not set).
65
83
  autocast_dtype: which dtype to use for autocasting, or set None to disable.
66
84
  """
85
+ if (
86
+ sum(
87
+ [
88
+ model_id is not None,
89
+ model_path is not None,
90
+ checkpoint_path is not None,
91
+ ]
92
+ )
93
+ != 1
94
+ ):
95
+ raise ValueError(
96
+ "exactly one of model_id, model_path, or checkpoint_path must be set"
97
+ )
98
+
67
99
  super().__init__()
68
- _checkpoint_path = UPath(checkpoint_path)
100
+ self.patch_size = patch_size
69
101
  self.forward_kwargs = forward_kwargs
70
102
  self.embedding_size = embedding_size
71
- self.patch_size = patch_size
72
103
 
73
104
  if autocast_dtype is not None:
74
105
  self.autocast_dtype = AUTOCAST_DTYPE_MAP[autocast_dtype]
75
106
  else:
76
107
  self.autocast_dtype = None
77
108
 
109
+ if model_id is not None:
110
+ # Load from Hugging Face.
111
+ model = load_model_from_id(model_id, load_weights=not random_initialization)
112
+ if self.embedding_size is None and model_id in EMBEDDING_SIZES:
113
+ self.embedding_size = EMBEDDING_SIZES[model_id]
114
+
115
+ elif model_path is not None:
116
+ # Load from path.
117
+ model = load_model_from_path(
118
+ UPath(model_path), load_weights=not random_initialization
119
+ )
120
+
121
+ else:
122
+ # Load the distributed model checkpoint by path through Olmo Core
123
+ model = self._load_model_from_checkpoint(
124
+ UPath(checkpoint_path), random_initialization
125
+ )
126
+
127
+ # Select just the portion of the model that we actually want to use.
128
+ for part in selector:
129
+ if isinstance(part, str):
130
+ model = getattr(model, part)
131
+ else:
132
+ model = model[part]
133
+ self.model = model
134
+
135
+ def _load_model_from_checkpoint(
136
+ self, checkpoint_upath: UPath, random_initialization: bool
137
+ ) -> torch.nn.Module:
138
+ """Load the OlmoEarth pre-trained model from a distributed checkpoint folder.
139
+
140
+ The folder should contain config.json as well as the model_and_optim folder
141
+ that contains the distributed checkpoint. This is the format produced by
142
+ pre-training runs in olmoearth_pretrain.
143
+ """
78
144
  # Load the model config and initialize it.
79
145
  # We avoid loading the train module here because it depends on running within
80
146
  # olmo_core.
81
- with (_checkpoint_path / "config.json").open() as f:
147
+ with (checkpoint_upath / "config.json").open() as f:
82
148
  config_dict = json.load(f)
83
149
  model_config = Config.from_dict(config_dict["model"])
84
150
 
@@ -86,22 +152,14 @@ class OlmoEarth(torch.nn.Module):
86
152
 
87
153
  # Load the checkpoint.
88
154
  if not random_initialization:
89
- train_module_dir = _checkpoint_path / "model_and_optim"
155
+ train_module_dir = checkpoint_upath / "model_and_optim"
90
156
  if train_module_dir.exists():
91
157
  load_model_and_optim_state(str(train_module_dir), model)
92
158
  logger.info(f"loaded OlmoEarth encoder from {train_module_dir}")
93
159
  else:
94
160
  logger.info(f"could not find OlmoEarth encoder at {train_module_dir}")
95
- else:
96
- logger.info("skipping loading OlmoEarth encoder")
97
161
 
98
- # Select just the portion of the model that we actually want to use.
99
- for part in selector:
100
- if isinstance(part, str):
101
- model = getattr(model, part)
102
- else:
103
- model = model[part]
104
- self.model = model
162
+ return model
105
163
 
106
164
  def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
107
165
  """Compute feature maps from the OlmoEarth backbone.
@@ -167,13 +225,16 @@ class OlmoEarth(torch.nn.Module):
167
225
  if isinstance(self.model, Encoder):
168
226
  # Encoder has a fast_pass argument to indicate mask is not needed.
169
227
  tokens_and_masks = self.model(
170
- sample, fast_pass=True, **self.forward_kwargs
228
+ sample,
229
+ fast_pass=True,
230
+ patch_size=self.patch_size,
231
+ **self.forward_kwargs,
171
232
  )["tokens_and_masks"]
172
233
  else:
173
234
  # Other models like STEncoder do not have this option supported.
174
- tokens_and_masks = self.model(sample, **self.forward_kwargs)[
175
- "tokens_and_masks"
176
- ]
235
+ tokens_and_masks = self.model(
236
+ sample, patch_size=self.patch_size, **self.forward_kwargs
237
+ )["tokens_and_masks"]
177
238
 
178
239
  # Apply temporal/modality pooling so we just have one feature per patch.
179
240
  features = []
rslearn/models/prithvi.py CHANGED
@@ -1,4 +1,12 @@
1
- """Prithvi V2."""
1
+ """Prithvi V2.
2
+
3
+ This code is adapted from https://github.com/NASA-IMPACT/Prithvi-WxC
4
+
5
+ The code is released under:
6
+
7
+ MIT License
8
+ Copyright (c) 2024 Inter Agency Implementation and Advanced Concepts
9
+ """
2
10
 
3
11
  import json
4
12
  import logging
@@ -94,7 +94,6 @@ class RslearnLightningModule(L.LightningModule):
94
94
  restore_config: RestoreConfig | None = None,
95
95
  print_parameters: bool = False,
96
96
  print_model: bool = False,
97
- strict_loading: bool = True,
98
97
  # Deprecated options.
99
98
  lr: float = 1e-3,
100
99
  plateau: bool = False,
@@ -118,7 +117,6 @@ class RslearnLightningModule(L.LightningModule):
118
117
  print_parameters: whether to print the list of model parameters after model
119
118
  initialization
120
119
  print_model: whether to print the model after model initialization
121
- strict_loading: whether to strictly load the model parameters.
122
120
  lr: deprecated.
123
121
  plateau: deprecated.
124
122
  plateau_factor: deprecated.
@@ -132,7 +130,6 @@ class RslearnLightningModule(L.LightningModule):
132
130
  self.visualize_dir = visualize_dir
133
131
  self.metrics_file = metrics_file
134
132
  self.restore_config = restore_config
135
- self.strict_loading = strict_loading
136
133
 
137
134
  self.scheduler_factory: SchedulerFactory | None = None
138
135
  if scheduler:
@@ -22,7 +22,11 @@ from rslearn.log_utils import get_logger
22
22
  from rslearn.utils.array import copy_spatial_array
23
23
  from rslearn.utils.feature import Feature
24
24
  from rslearn.utils.geometry import PixelBounds
25
- from rslearn.utils.raster_format import RasterFormat, load_raster_format
25
+ from rslearn.utils.raster_format import (
26
+ RasterFormat,
27
+ adjust_projection_and_bounds_for_array,
28
+ load_raster_format,
29
+ )
26
30
  from rslearn.utils.vector_format import VectorFormat, load_vector_format
27
31
 
28
32
  from .lightning_module import RslearnLightningModule
@@ -68,15 +72,18 @@ class VectorMerger(PatchPredictionMerger):
68
72
  class RasterMerger(PatchPredictionMerger):
69
73
  """Merger for raster data that copies the rasters to the output."""
70
74
 
71
- def __init__(self, padding: int | None = None):
75
+ def __init__(self, padding: int | None = None, downsample_factor: int = 1):
72
76
  """Create a new RasterMerger.
73
77
 
74
78
  Args:
75
79
  padding: the padding around the individual patch outputs to remove. This is
76
80
  typically used when leveraging overlapping patches. Portions of outputs
77
81
  at the border of the window will still be retained.
82
+ downsample_factor: the factor by which the rasters output by the task are
83
+ lower in resolution relative to the window resolution.
78
84
  """
79
85
  self.padding = padding
86
+ self.downsample_factor = downsample_factor
80
87
 
81
88
  def merge(
82
89
  self, window: Window, outputs: Sequence[PendingPatchOutput]
@@ -87,8 +94,8 @@ class RasterMerger(PatchPredictionMerger):
87
94
  merged_image = np.zeros(
88
95
  (
89
96
  num_channels,
90
- window.bounds[3] - window.bounds[1],
91
- window.bounds[2] - window.bounds[0],
97
+ (window.bounds[3] - window.bounds[1]) // self.downsample_factor,
98
+ (window.bounds[2] - window.bounds[0]) // self.downsample_factor,
92
99
  ),
93
100
  dtype=dtype,
94
101
  )
@@ -104,7 +111,10 @@ class RasterMerger(PatchPredictionMerger):
104
111
  # If the output is not on the left or top boundary, then we should apply
105
112
  # the padding (if set).
106
113
  src = output.output
107
- src_offset = (output.bounds[0], output.bounds[1])
114
+ src_offset = (
115
+ output.bounds[0] // self.downsample_factor,
116
+ output.bounds[1] // self.downsample_factor,
117
+ )
108
118
  if self.padding is not None and output.bounds[0] != window.bounds[0]:
109
119
  src = src[:, :, self.padding :]
110
120
  src_offset = (src_offset[0] + self.padding, src_offset[1])
@@ -116,7 +126,10 @@ class RasterMerger(PatchPredictionMerger):
116
126
  src=src,
117
127
  dst=merged_image,
118
128
  src_offset=src_offset,
119
- dst_offset=(window.bounds[0], window.bounds[1]),
129
+ dst_offset=(
130
+ window.bounds[0] // self.downsample_factor,
131
+ window.bounds[1] // self.downsample_factor,
132
+ ),
120
133
  )
121
134
 
122
135
  return merged_image
@@ -330,9 +343,13 @@ class RslearnWriter(BasePredictionWriter):
330
343
  self.output_layer, self.layer_config.band_sets[0].bands
331
344
  )
332
345
  assert isinstance(self.format, RasterFormat)
333
- self.format.encode_raster(
334
- raster_dir, window.projection, window.bounds, merged_output
346
+
347
+ # In case the merged_output is at a different resolution than the window,
348
+ # get adjusted projection and bounds for writing it.
349
+ projection, bounds = adjust_projection_and_bounds_for_array(
350
+ window.projection, window.bounds, merged_output
335
351
  )
352
+ self.format.encode_raster(raster_dir, projection, bounds, merged_output)
336
353
 
337
354
  elif self.layer_config.layer_type == LayerType.VECTOR:
338
355
  layer_dir = window.get_layer_dir(self.output_layer)
@@ -49,8 +49,8 @@ class ClassificationTask(BasicTask):
49
49
  features with matching properties.
50
50
  read_class_id: whether to read an integer class ID instead of the class
51
51
  name.
52
- allow_invalid: instead of throwing error when no regression label is found
53
- at a window, simply mark the example invalid for this task
52
+ allow_invalid: instead of throwing error when no classification label is
53
+ found at a window, simply mark the example invalid for this task
54
54
  skip_unknown_categories: whether to skip examples with categories that are
55
55
  not passed via classes, instead of throwing error
56
56
  prob_property: when predicting, write probabilities in addition to class ID
@@ -72,11 +72,11 @@ class DetectionTask(BasicTask):
72
72
  f1_metric_kwargs: dict[str, Any] = {},
73
73
  **kwargs: Any,
74
74
  ) -> None:
75
- """Initialize a new SegmentationTask.
75
+ """Initialize a new DetectionTask.
76
76
 
77
77
  Args:
78
- property_name: the property from which to extract the class name. The class
79
- is read from the first matching feature.
78
+ property_name: the property from which to extract the class name. Features
79
+ without this property name are ignored.
80
80
  classes: a list of class names.
81
81
  filters: optional list of (property_name, property_value) to only consider
82
82
  features with matching properties.
@@ -86,8 +86,8 @@ class DetectionTask(BasicTask):
86
86
  not passed via classes, instead of throwing error
87
87
  skip_empty_examples: whether to skip examples with zero labels.
88
88
  colors: optional colors for each class
89
- box_size: force all boxes to be this size, centered at the centroid of the
90
- geometry. Required for Point geometries.
89
+ box_size: force all boxes to be two times this size, centered at the
90
+ centroid of the geometry. Required for Point geometries.
91
91
  clip_boxes: whether to clip boxes to the image bounds.
92
92
  exclude_by_center: before optionally clipping boxes, exclude boxes if the
93
93
  center is outside the image bounds.