rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. rslearn/config/dataset.py +22 -13
  2. rslearn/data_sources/__init__.py +8 -0
  3. rslearn/data_sources/aws_landsat.py +27 -18
  4. rslearn/data_sources/aws_open_data.py +41 -42
  5. rslearn/data_sources/copernicus.py +148 -2
  6. rslearn/data_sources/data_source.py +17 -10
  7. rslearn/data_sources/gcp_public_data.py +177 -100
  8. rslearn/data_sources/geotiff.py +1 -0
  9. rslearn/data_sources/google_earth_engine.py +17 -15
  10. rslearn/data_sources/local_files.py +59 -32
  11. rslearn/data_sources/openstreetmap.py +27 -23
  12. rslearn/data_sources/planet.py +10 -9
  13. rslearn/data_sources/planet_basemap.py +303 -0
  14. rslearn/data_sources/raster_source.py +23 -13
  15. rslearn/data_sources/usgs_landsat.py +56 -27
  16. rslearn/data_sources/utils.py +13 -6
  17. rslearn/data_sources/vector_source.py +1 -0
  18. rslearn/data_sources/xyz_tiles.py +8 -9
  19. rslearn/dataset/add_windows.py +1 -1
  20. rslearn/dataset/dataset.py +16 -5
  21. rslearn/dataset/manage.py +9 -4
  22. rslearn/dataset/materialize.py +26 -5
  23. rslearn/dataset/window.py +5 -0
  24. rslearn/log_utils.py +24 -0
  25. rslearn/main.py +123 -59
  26. rslearn/models/clip.py +62 -0
  27. rslearn/models/conv.py +56 -0
  28. rslearn/models/faster_rcnn.py +2 -19
  29. rslearn/models/fpn.py +1 -1
  30. rslearn/models/module_wrapper.py +43 -0
  31. rslearn/models/molmo.py +65 -0
  32. rslearn/models/multitask.py +1 -1
  33. rslearn/models/pooling_decoder.py +4 -2
  34. rslearn/models/satlaspretrain.py +4 -7
  35. rslearn/models/simple_time_series.py +61 -55
  36. rslearn/models/ssl4eo_s12.py +9 -9
  37. rslearn/models/swin.py +22 -21
  38. rslearn/models/unet.py +4 -2
  39. rslearn/models/upsample.py +35 -0
  40. rslearn/tile_stores/file.py +6 -3
  41. rslearn/tile_stores/tile_store.py +19 -7
  42. rslearn/train/callbacks/freeze_unfreeze.py +3 -3
  43. rslearn/train/data_module.py +5 -4
  44. rslearn/train/dataset.py +79 -36
  45. rslearn/train/lightning_module.py +15 -11
  46. rslearn/train/prediction_writer.py +22 -11
  47. rslearn/train/tasks/classification.py +9 -8
  48. rslearn/train/tasks/detection.py +94 -37
  49. rslearn/train/tasks/multi_task.py +1 -1
  50. rslearn/train/tasks/regression.py +8 -4
  51. rslearn/train/tasks/segmentation.py +23 -19
  52. rslearn/train/transforms/__init__.py +1 -1
  53. rslearn/train/transforms/concatenate.py +6 -2
  54. rslearn/train/transforms/crop.py +6 -2
  55. rslearn/train/transforms/flip.py +5 -1
  56. rslearn/train/transforms/normalize.py +9 -5
  57. rslearn/train/transforms/pad.py +1 -1
  58. rslearn/train/transforms/transform.py +3 -3
  59. rslearn/utils/__init__.py +4 -5
  60. rslearn/utils/array.py +2 -2
  61. rslearn/utils/feature.py +1 -1
  62. rslearn/utils/fsspec.py +70 -1
  63. rslearn/utils/geometry.py +155 -3
  64. rslearn/utils/grid_index.py +5 -5
  65. rslearn/utils/mp.py +4 -3
  66. rslearn/utils/raster_format.py +81 -73
  67. rslearn/utils/rtree_index.py +64 -17
  68. rslearn/utils/sqlite_index.py +7 -1
  69. rslearn/utils/utils.py +11 -3
  70. rslearn/utils/vector_format.py +113 -17
  71. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
  72. rslearn-0.0.2.dist-info/RECORD +94 -0
  73. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
  74. rslearn/utils/mgrs.py +0 -24
  75. rslearn-0.0.1.dist-info/RECORD +0 -88
  76. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
  77. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
  78. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
@@ -15,7 +15,7 @@ from .dataset import DataInput, ModelDataset, RetryDataset, SplitConfig
15
15
 
16
16
  def collate_fn(
17
17
  batch: list[tuple[dict[str, Any], dict[str, Any]]],
18
- ) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
18
+ ) -> tuple:
19
19
  """Collate batch of training examples.
20
20
 
21
21
  We just make list of the inputs and another of the targets.
@@ -48,7 +48,7 @@ class RslearnDataModule(L.LightningDataModule):
48
48
  val_config: SplitConfig = SplitConfig(),
49
49
  test_config: SplitConfig = SplitConfig(),
50
50
  predict_config: SplitConfig = SplitConfig(),
51
- ):
51
+ ) -> None:
52
52
  """Initialize a new RslearnDataModule.
53
53
 
54
54
  Args:
@@ -79,7 +79,7 @@ class RslearnDataModule(L.LightningDataModule):
79
79
  "predict": default_config.update(predict_config),
80
80
  }
81
81
 
82
- def setup(self, stage: str):
82
+ def setup(self, stage: str) -> None:
83
83
  """Set up datasets and samplers.
84
84
 
85
85
  Args:
@@ -106,12 +106,13 @@ class RslearnDataModule(L.LightningDataModule):
106
106
 
107
107
  def _get_dataloader(self, split: str) -> DataLoader[dict[str, torch.Tensor]]:
108
108
  dataset = self.datasets[split]
109
+ persistent_workers = self.num_workers > 0
109
110
  kwargs = dict(
110
111
  dataset=dataset,
111
112
  batch_size=self.batch_size,
112
113
  num_workers=self.num_workers,
113
114
  collate_fn=collate_fn,
114
- persistent_workers=True,
115
+ persistent_workers=persistent_workers,
115
116
  )
116
117
  sampler_factory = self.split_configs[split].sampler
117
118
  if sampler_factory:
rslearn/train/dataset.py CHANGED
@@ -1,5 +1,6 @@
1
1
  """Default Dataset for rslearn."""
2
2
 
3
+ import hashlib
3
4
  import multiprocessing
4
5
  import os
5
6
  import random
@@ -47,7 +48,9 @@ class SamplerFactory:
47
48
  class RandomSamplerFactory(SamplerFactory):
48
49
  """A sampler factory for RandomSampler."""
49
50
 
50
- def __init__(self, replacement: bool = False, num_samples: int | None = None):
51
+ def __init__(
52
+ self, replacement: bool = False, num_samples: int | None = None
53
+ ) -> None:
51
54
  """Initialize a RandomSamplerFactory.
52
55
 
53
56
  Args:
@@ -75,7 +78,9 @@ class RandomSamplerFactory(SamplerFactory):
75
78
  class WeightedRandomSamplerFactory(SamplerFactory):
76
79
  """A sampler factory for WeightedRandomSampler."""
77
80
 
78
- def __init__(self, option_key: str, num_samples: int, replacement: bool = True):
81
+ def __init__(
82
+ self, option_key: str, num_samples: int, replacement: bool = True
83
+ ) -> None:
79
84
  """Initialize a WeightedRandomSamplerFactory.
80
85
 
81
86
  Args:
@@ -119,7 +124,7 @@ class DataInput:
119
124
  passthrough: bool = False,
120
125
  is_target: bool = False,
121
126
  dtype: DType = DType.FLOAT32,
122
- ):
127
+ ) -> None:
123
128
  """Initialize a new DataInput.
124
129
 
125
130
  Args:
@@ -157,7 +162,7 @@ class SplitConfig:
157
162
  overlap_ratio: float | None = None,
158
163
  load_all_patches: bool | None = None,
159
164
  skip_targets: bool | None = None,
160
- ):
165
+ ) -> None:
161
166
  """Initialize a new SplitConfig.
162
167
 
163
168
  Args:
@@ -242,7 +247,7 @@ class SplitConfig:
242
247
  return True if self.skip_targets is True else False
243
248
 
244
249
 
245
- def check_window(inputs: dict[str, DataInput], window: Window) -> bool:
250
+ def check_window(inputs: dict[str, DataInput], window: Window) -> Window | None:
246
251
  """Verify that the window has the required layers based on the specified inputs.
247
252
 
248
253
  Args:
@@ -254,7 +259,7 @@ def check_window(inputs: dict[str, DataInput], window: Window) -> bool:
254
259
  """
255
260
 
256
261
  # Make sure window has all the needed layers.
257
- def is_any_layer_available(data_input):
262
+ def is_any_layer_available(data_input: DataInput) -> bool:
258
263
  for layer_name in data_input.layers:
259
264
  completed_fname = window.path / "layers" / layer_name / "completed"
260
265
  if completed_fname.exists():
@@ -285,7 +290,7 @@ class ModelDataset(torch.utils.data.Dataset):
285
290
  inputs: dict[str, DataInput],
286
291
  task: Task,
287
292
  workers: int,
288
- ):
293
+ ) -> None:
289
294
  """Instantiate a new ModelDataset.
290
295
 
291
296
  Args:
@@ -347,37 +352,53 @@ class ModelDataset(torch.utils.data.Dataset):
347
352
 
348
353
  # Eliminate windows that are missing either a requisite input layer, or missing
349
354
  # all target layers.
350
- p = multiprocessing.Pool(workers)
351
- outputs = star_imap_unordered(
352
- p,
353
- check_window,
354
- [
355
- dict(
356
- inputs=self.inputs,
357
- window=window,
358
- )
359
- for window in windows
360
- ],
361
- )
362
355
  new_windows = []
363
- for window in tqdm.tqdm(
364
- outputs, total=len(windows), desc="Checking available layers in windows"
365
- ):
366
- if window is None:
367
- continue
368
- new_windows.append(window)
369
- p.close()
356
+ if workers == 0:
357
+ for window in windows:
358
+ if check_window(self.inputs, window) is None:
359
+ continue
360
+ new_windows.append(window)
361
+ else:
362
+ p = multiprocessing.Pool(workers)
363
+ outputs = star_imap_unordered(
364
+ p,
365
+ check_window,
366
+ [
367
+ dict(
368
+ inputs=self.inputs,
369
+ window=window,
370
+ )
371
+ for window in windows
372
+ ],
373
+ )
374
+ for window in tqdm.tqdm(
375
+ outputs, total=len(windows), desc="Checking available layers in windows"
376
+ ):
377
+ if window is None:
378
+ continue
379
+ new_windows.append(window)
380
+ p.close()
370
381
  windows = new_windows
371
382
 
383
+ # Sort the windows to ensure that the dataset is consistent across GPUs.
384
+ # Inconsistent ordering can lead to a subset of windows being processed during
385
+ # "model test" / "model predict" when using multiple GPUs.
386
+ # We use a hash so that functionality like num_samples limit gets a random
387
+ # subset of windows (with respect to the hash function choice).
388
+ windows.sort(
389
+ key=lambda window: hashlib.sha256(window.name.encode()).hexdigest()
390
+ )
391
+
372
392
  # Limit windows to num_samples if requested.
373
393
  if split_config.num_samples:
374
- # TODO: use hash of window names so this is deterministic and not arbitrarily ordered according to load_windows
394
+ # The windows are sorted by hash of window name so this distribution should
395
+ # be representative of the population.
375
396
  windows = windows[0 : split_config.num_samples]
376
397
 
377
- self.windows = windows
398
+ self.windows: list = windows
378
399
 
379
400
  # If we're loading all patches, we need to include the patch details.
380
- if split_config.get_load_all_patches():
401
+ if split_config.get_load_all_patches() and self.patch_size is not None:
381
402
  patches = []
382
403
  overlap_size = int(
383
404
  self.patch_size[0] * split_config.overlap_ratio
@@ -386,6 +407,8 @@ class ModelDataset(torch.utils.data.Dataset):
386
407
  )
387
408
  for window in self.windows:
388
409
  cur_patches = []
410
+ if window is None:
411
+ raise ValueError("Window is None in load_all_patches")
389
412
  for col in range(
390
413
  window.bounds[0],
391
414
  window.bounds[2],
@@ -412,7 +435,9 @@ class ModelDataset(torch.utils.data.Dataset):
412
435
  """Returns the dataset length."""
413
436
  return len(self.windows)
414
437
 
415
- def __getitem__(self, idx) -> tuple[dict[str, Any], dict[str, Any]]:
438
+ def __getitem__(
439
+ self, idx: int
440
+ ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
416
441
  """Read one training example.
417
442
 
418
443
  Args:
@@ -429,7 +454,7 @@ class ModelDataset(torch.utils.data.Dataset):
429
454
  window, bounds, (patch_idx, num_patches) = window
430
455
  elif self.patch_size:
431
456
 
432
- def get_patch_range(n_patch, n_window):
457
+ def get_patch_range(n_patch: int, n_window: int) -> list[int]:
433
458
  if n_patch > n_window:
434
459
  # Select arbitrary range containing the entire window.
435
460
  # Basically arbitrarily padding the window to get to patch size.
@@ -459,7 +484,7 @@ class ModelDataset(torch.utils.data.Dataset):
459
484
  bounds = window.bounds
460
485
 
461
486
  # Read the inputs and targets.
462
- def read_input(data_input: DataInput):
487
+ def read_input(data_input: DataInput) -> torch.Tensor:
463
488
  # First enumerate all options of individual layers to read.
464
489
  layer_options = []
465
490
  for layer_name in data_input.layers:
@@ -473,7 +498,13 @@ class ModelDataset(torch.utils.data.Dataset):
473
498
  # the options, as well as picking multiple for series inputs.
474
499
  layer = random.choice(layer_options)
475
500
  layer_dir = window.path / "layers" / layer
476
- layer_config = self.dataset.layers[layer]
501
+
502
+ # The model config may reference a specific group within a layer, like
503
+ # "image.2" in a dataset that has a layer "image" with max_matches > 1.
504
+ # So we need to split off the period. Layer names should not contain
505
+ # period.
506
+ layer_ds_key = layer.split(".")[0]
507
+ layer_config = self.dataset.layers[layer_ds_key]
477
508
 
478
509
  if data_input.data_type == "raster":
479
510
  assert isinstance(layer_config, RasterLayerConfig)
@@ -481,6 +512,8 @@ class ModelDataset(torch.utils.data.Dataset):
481
512
  # See what different sets of bands we need to read to get all the
482
513
  # configured bands.
483
514
  needed_bands = data_input.bands
515
+ if needed_bands is None:
516
+ raise ValueError(f"No bands specified for {layer}")
484
517
  needed_band_indexes = {}
485
518
  for i, band in enumerate(needed_bands):
486
519
  needed_band_indexes[band] = i
@@ -488,6 +521,8 @@ class ModelDataset(torch.utils.data.Dataset):
488
521
  for band_set in layer_config.band_sets:
489
522
  needed_src_indexes = []
490
523
  needed_dst_indexes = []
524
+ if band_set.bands is None:
525
+ continue
491
526
  for i, band in enumerate(band_set.bands):
492
527
  if band not in needed_band_indexes:
493
528
  continue
@@ -514,12 +549,20 @@ class ModelDataset(torch.utils.data.Dataset):
514
549
  _, final_bounds = band_set.get_final_projection_and_bounds(
515
550
  window.projection, bounds
516
551
  )
552
+ if band_set.format is None:
553
+ raise ValueError(f"No format specified for {layer}")
517
554
  raster_format = load_raster_format(
518
555
  RasterFormatConfig(band_set.format["name"], band_set.format)
519
556
  )
557
+ if band_set.bands is None:
558
+ # Raising Error as It is unclear the intended behavior here.
559
+ raise ValueError("No bands specified for band set")
520
560
  cur_path = layer_dir / "_".join(band_set.bands)
561
+ if final_bounds is None:
562
+ raise ValueError("Final bounds are None")
521
563
  src = raster_format.decode_raster(cur_path, final_bounds)
522
-
564
+ if src is None:
565
+ raise ValueError(f"Source is None for {data_input}")
523
566
  # Resize to patch size if needed.
524
567
  # This is for band sets that are stored at a lower resolution.
525
568
  # Here we assume that it is a multiple.
@@ -594,7 +637,7 @@ class RetryDataset(torch.utils.data.Dataset):
594
637
 
595
638
  def __init__(
596
639
  self, dataset: torch.utils.data.Dataset, retries: int = 3, delay: float = 5
597
- ):
640
+ ) -> None:
598
641
  """Create a new RetryDataset.
599
642
 
600
643
  Args:
@@ -606,7 +649,7 @@ class RetryDataset(torch.utils.data.Dataset):
606
649
  self.retries = retries
607
650
  self.delay = delay
608
651
 
609
- def __len__(self):
652
+ def __len__(self) -> int:
610
653
  """Return length of the dataset."""
611
654
  return len(self.dataset)
612
655
 
@@ -49,7 +49,7 @@ class RestoreConfig:
49
49
  """Returns the state dict configured in this RestoreConfig."""
50
50
  print(f"loading state dict from {self.restore_path}")
51
51
  with self.restore_path.open("rb") as f:
52
- state_dict = torch.load(f)
52
+ state_dict = torch.load(f, map_location="cpu")
53
53
  for k in self.selector:
54
54
  state_dict = state_dict[k]
55
55
 
@@ -124,6 +124,7 @@ class RslearnLightningModule(L.LightningModule):
124
124
  self.plateau_min_lr = plateau_min_lr
125
125
  self.plateau_cooldown = plateau_cooldown
126
126
  self.visualize_dir = visualize_dir
127
+ self.restore_config = restore_config
127
128
 
128
129
  if print_parameters:
129
130
  for name, param in self.named_parameters():
@@ -132,8 +133,19 @@ class RslearnLightningModule(L.LightningModule):
132
133
  if print_model:
133
134
  print(self.model)
134
135
 
135
- if restore_config:
136
- state_dict = restore_config.get_state_dict()
136
+ self.epochs = 0
137
+
138
+ metrics = self.task.get_metrics()
139
+ self.val_metrics = metrics.clone(prefix="val_")
140
+ self.test_metrics = metrics.clone(prefix="test_")
141
+
142
+ self.schedulers: dict = {}
143
+
144
+ def on_fit_start(self) -> None:
145
+ """Called when the fit begins."""
146
+ # Only restore if doing a fresh fit.
147
+ if self.trainer.ckpt_path is None and self.restore_config:
148
+ state_dict = self.restore_config.get_state_dict()
137
149
  missing_keys, unexpected_keys = self.model.load_state_dict(
138
150
  state_dict, strict=False
139
151
  )
@@ -142,14 +154,6 @@ class RslearnLightningModule(L.LightningModule):
142
154
  f"warning: restore yielded missing_keys={missing_keys} and unexpected_keys={unexpected_keys}"
143
155
  )
144
156
 
145
- self.epochs = 0
146
-
147
- metrics = self.task.get_metrics()
148
- self.val_metrics = metrics.clone(prefix="val_")
149
- self.test_metrics = metrics.clone(prefix="test_")
150
-
151
- self.schedulers = {}
152
-
153
157
  def configure_optimizers(self) -> OptimizerLRSchedulerConfig:
154
158
  """Initialize the optimizer and learning rate scheduler.
155
159
 
@@ -8,7 +8,12 @@ from lightning.pytorch import LightningModule, Trainer
8
8
  from lightning.pytorch.callbacks import BasePredictionWriter
9
9
  from upath import UPath
10
10
 
11
- from rslearn.config import LayerType, RasterFormatConfig
11
+ from rslearn.config import (
12
+ LayerType,
13
+ RasterFormatConfig,
14
+ RasterLayerConfig,
15
+ VectorLayerConfig,
16
+ )
12
17
  from rslearn.dataset import Dataset
13
18
  from rslearn.utils.array import copy_spatial_array
14
19
  from rslearn.utils.raster_format import load_raster_format
@@ -20,17 +25,14 @@ from .lightning_module import RslearnLightningModule
20
25
  class PatchPredictionMerger:
21
26
  """Base class for merging predictions from multiple patches."""
22
27
 
23
- def merge(
24
- self, outputs: Sequence[Any], metadatas: Sequence[Any]
25
- ) -> tuple[Sequence[Any], Sequence[Any]]:
26
- """Merge the outputs and metadatas.
28
+ def merge(self, outputs: Sequence[Any]) -> tuple[Sequence[Any]]:
29
+ """Merge the outputs.
27
30
 
28
31
  Args:
29
32
  outputs: the outputs to process.
30
- metadatas: the metadatas to process.
31
33
 
32
34
  Returns:
33
- the merged outputs and metadatas.
35
+ the merged outputs.
34
36
  """
35
37
  raise NotImplementedError
36
38
 
@@ -57,6 +59,7 @@ class RslearnWriter(BasePredictionWriter):
57
59
  output_layer: which layer to write the outputs under.
58
60
  path_options: additional options for path to pass to fsspec
59
61
  selector: keys to access the desired output in the output dict if needed.
62
+ e.g ["key1", "key2"] gets output["key1"]["key2"]
60
63
  merger: merger to use to merge outputs from overlapped patches.
61
64
  """
62
65
  super().__init__(write_interval="batch")
@@ -65,13 +68,16 @@ class RslearnWriter(BasePredictionWriter):
65
68
  self.path = UPath(path, **path_options)
66
69
  self.dataset = Dataset(self.path)
67
70
  self.layer_config = self.dataset.layers[self.output_layer]
68
-
71
+ # TODO: This is a bit of a hack to get the type checker to be happy.
72
+ self.format: Any
69
73
  if self.layer_config.layer_type == LayerType.RASTER:
74
+ assert isinstance(self.layer_config, RasterLayerConfig)
70
75
  band_cfg = self.layer_config.band_sets[0]
71
76
  self.format = load_raster_format(
72
77
  RasterFormatConfig(band_cfg.format["name"], band_cfg.format)
73
78
  )
74
79
  elif self.layer_config.layer_type == LayerType.VECTOR:
80
+ assert isinstance(self.layer_config, VectorLayerConfig)
75
81
  self.format = load_vector_format(self.layer_config.format)
76
82
  else:
77
83
  raise ValueError(f"invalid layer type {self.layer_config.layer_type}")
@@ -81,7 +87,7 @@ class RslearnWriter(BasePredictionWriter):
81
87
  # Map from window name to pending data to write.
82
88
  # This is used when windows are split up into patches, so the data from all the
83
89
  # patches of each window need to be reconstituted.
84
- self.pending_outputs = {}
90
+ self.pending_outputs: dict[str, Any] = {}
85
91
 
86
92
  def write_on_batch_end(
87
93
  self,
@@ -92,7 +98,7 @@ class RslearnWriter(BasePredictionWriter):
92
98
  batch: Any,
93
99
  batch_idx: int,
94
100
  dataloader_idx: int,
95
- ):
101
+ ) -> None:
96
102
  """Write a batch of predictions into the rslearn dataset.
97
103
 
98
104
  Args:
@@ -112,6 +118,8 @@ class RslearnWriter(BasePredictionWriter):
112
118
  ]
113
119
 
114
120
  for output, metadata in zip(outputs, metadatas):
121
+ if not isinstance(output, dict):
122
+ raise ValueError(f"Unsupported output type {type(output)}")
115
123
  for k in self.selector:
116
124
  output = output[k]
117
125
 
@@ -120,7 +128,9 @@ class RslearnWriter(BasePredictionWriter):
120
128
  window_bounds = metadata["window_bounds"]
121
129
 
122
130
  if self.layer_config.layer_type == LayerType.RASTER:
123
- if window_name not in self.pending_outputs:
131
+ if window_name not in self.pending_outputs and isinstance(
132
+ output, np.ndarray
133
+ ):
124
134
  self.pending_outputs[window_name] = np.zeros(
125
135
  (
126
136
  output.shape[0],
@@ -167,6 +177,7 @@ class RslearnWriter(BasePredictionWriter):
167
177
  )
168
178
 
169
179
  if self.layer_config.layer_type == LayerType.RASTER:
180
+ assert isinstance(self.layer_config, RasterLayerConfig)
170
181
  band_dir = layer_dir / "_".join(self.layer_config.band_sets[0].bands)
171
182
  self.format.encode_raster(
172
183
  band_dir, metadata["projection"], window_bounds, pending_output
@@ -26,8 +26,8 @@ class ClassificationTask(BasicTask):
26
26
  def __init__(
27
27
  self,
28
28
  property_name: str,
29
- classes: list[str],
30
- filters: list[tuple[str, str]] | None = None,
29
+ classes: list, # TODO: Should this be a list of str or int or can it be both?
30
+ filters: list[tuple[str, str]] = [],
31
31
  read_class_id: bool = False,
32
32
  allow_invalid: bool = False,
33
33
  skip_unknown_categories: bool = False,
@@ -37,7 +37,7 @@ class ClassificationTask(BasicTask):
37
37
  f1_metric_kwargs: dict[str, Any] = {},
38
38
  positive_class: str | None = None,
39
39
  positive_class_threshold: float = 0.5,
40
- **kwargs,
40
+ **kwargs: Any,
41
41
  ):
42
42
  """Initialize a new ClassificationTask.
43
43
 
@@ -95,9 +95,6 @@ class ClassificationTask(BasicTask):
95
95
  else:
96
96
  self.positive_class_id = self.classes.index(self.positive_class)
97
97
 
98
- if not self.filters:
99
- self.filters = []
100
-
101
98
  def process_inputs(
102
99
  self,
103
100
  raw_inputs: dict[str, torch.Tensor | list[Feature]],
@@ -120,6 +117,8 @@ class ClassificationTask(BasicTask):
120
117
 
121
118
  data = raw_inputs["targets"]
122
119
  for feat in data:
120
+ if feat.properties is None:
121
+ continue
123
122
  for property_name, property_value in self.filters:
124
123
  if feat.properties.get(property_name) != property_value:
125
124
  continue
@@ -178,7 +177,7 @@ class ClassificationTask(BasicTask):
178
177
  class_idx = probs.argmax()
179
178
 
180
179
  if not self.read_class_id:
181
- value = self.classes[class_idx]
180
+ value = self.classes[class_idx] # type: ignore
182
181
  else:
183
182
  value = class_idx
184
183
 
@@ -192,7 +191,7 @@ class ClassificationTask(BasicTask):
192
191
  self.property_name: value,
193
192
  },
194
193
  )
195
- if self.prob_property:
194
+ if self.prob_property is not None and feature.properties is not None:
196
195
  feature.properties[self.prob_property] = probs.tolist()
197
196
  return [feature]
198
197
 
@@ -215,6 +214,8 @@ class ClassificationTask(BasicTask):
215
214
  image = super().visualize(input_dict, target_dict, output)["image"]
216
215
  image = Image.fromarray(image)
217
216
  draw = ImageDraw.Draw(image)
217
+ if target_dict is None:
218
+ raise ValueError("target_dict is required for visualization")
218
219
  target_class = self.classes[target_dict["class"]]
219
220
  output_class = self.classes[output.argmax()]
220
221
  text = f"Label: {target_class}\nOutput: {output_class}"