rslearn 0.0.1__py3-none-any.whl → 0.0.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. rslearn/config/dataset.py +22 -13
  2. rslearn/data_sources/__init__.py +8 -0
  3. rslearn/data_sources/aws_landsat.py +27 -18
  4. rslearn/data_sources/aws_open_data.py +41 -42
  5. rslearn/data_sources/copernicus.py +148 -2
  6. rslearn/data_sources/data_source.py +17 -10
  7. rslearn/data_sources/gcp_public_data.py +177 -100
  8. rslearn/data_sources/geotiff.py +1 -0
  9. rslearn/data_sources/google_earth_engine.py +17 -15
  10. rslearn/data_sources/local_files.py +59 -32
  11. rslearn/data_sources/openstreetmap.py +27 -23
  12. rslearn/data_sources/planet.py +10 -9
  13. rslearn/data_sources/planet_basemap.py +303 -0
  14. rslearn/data_sources/raster_source.py +23 -13
  15. rslearn/data_sources/usgs_landsat.py +56 -27
  16. rslearn/data_sources/utils.py +13 -6
  17. rslearn/data_sources/vector_source.py +1 -0
  18. rslearn/data_sources/xyz_tiles.py +8 -9
  19. rslearn/dataset/add_windows.py +1 -1
  20. rslearn/dataset/dataset.py +16 -5
  21. rslearn/dataset/manage.py +9 -4
  22. rslearn/dataset/materialize.py +26 -5
  23. rslearn/dataset/window.py +5 -0
  24. rslearn/log_utils.py +24 -0
  25. rslearn/main.py +123 -59
  26. rslearn/models/clip.py +62 -0
  27. rslearn/models/conv.py +56 -0
  28. rslearn/models/faster_rcnn.py +2 -19
  29. rslearn/models/fpn.py +1 -1
  30. rslearn/models/module_wrapper.py +43 -0
  31. rslearn/models/molmo.py +65 -0
  32. rslearn/models/multitask.py +1 -1
  33. rslearn/models/pooling_decoder.py +4 -2
  34. rslearn/models/satlaspretrain.py +4 -7
  35. rslearn/models/simple_time_series.py +61 -55
  36. rslearn/models/ssl4eo_s12.py +9 -9
  37. rslearn/models/swin.py +22 -21
  38. rslearn/models/unet.py +4 -2
  39. rslearn/models/upsample.py +35 -0
  40. rslearn/tile_stores/file.py +6 -3
  41. rslearn/tile_stores/tile_store.py +19 -7
  42. rslearn/train/callbacks/freeze_unfreeze.py +3 -3
  43. rslearn/train/data_module.py +5 -4
  44. rslearn/train/dataset.py +79 -36
  45. rslearn/train/lightning_module.py +15 -11
  46. rslearn/train/prediction_writer.py +22 -11
  47. rslearn/train/tasks/classification.py +9 -8
  48. rslearn/train/tasks/detection.py +94 -37
  49. rslearn/train/tasks/multi_task.py +1 -1
  50. rslearn/train/tasks/regression.py +8 -4
  51. rslearn/train/tasks/segmentation.py +23 -19
  52. rslearn/train/transforms/__init__.py +1 -1
  53. rslearn/train/transforms/concatenate.py +6 -2
  54. rslearn/train/transforms/crop.py +6 -2
  55. rslearn/train/transforms/flip.py +5 -1
  56. rslearn/train/transforms/normalize.py +9 -5
  57. rslearn/train/transforms/pad.py +1 -1
  58. rslearn/train/transforms/transform.py +3 -3
  59. rslearn/utils/__init__.py +4 -5
  60. rslearn/utils/array.py +2 -2
  61. rslearn/utils/feature.py +1 -1
  62. rslearn/utils/fsspec.py +70 -1
  63. rslearn/utils/geometry.py +155 -3
  64. rslearn/utils/grid_index.py +5 -5
  65. rslearn/utils/mp.py +4 -3
  66. rslearn/utils/raster_format.py +81 -73
  67. rslearn/utils/rtree_index.py +64 -17
  68. rslearn/utils/sqlite_index.py +7 -1
  69. rslearn/utils/utils.py +11 -3
  70. rslearn/utils/vector_format.py +113 -17
  71. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/METADATA +32 -27
  72. rslearn-0.0.2.dist-info/RECORD +94 -0
  73. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/WHEEL +1 -1
  74. rslearn/utils/mgrs.py +0 -24
  75. rslearn-0.0.1.dist-info/RECORD +0 -88
  76. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/LICENSE +0 -0
  77. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/entry_points.txt +0 -0
  78. {rslearn-0.0.1.dist-info → rslearn-0.0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,65 @@
1
+ """Molmo model."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ from transformers import AutoModelForCausalLM, AutoProcessor
7
+
8
+
9
+ class Molmo(torch.nn.Module):
10
+ """Molmo image encoder."""
11
+
12
+ def __init__(
13
+ self,
14
+ model_name: str,
15
+ ):
16
+ """Instantiate a new Molmo instance.
17
+
18
+ Args:
19
+ model_name: the model name like "allenai/Molmo-7B-D-0924".
20
+ """
21
+ super().__init__()
22
+
23
+ self.processor = AutoProcessor.from_pretrained(
24
+ model_name,
25
+ trust_remote_code=True,
26
+ torch_dtype="auto",
27
+ device_map="cpu",
28
+ )
29
+ model = AutoModelForCausalLM.from_pretrained(
30
+ model_name,
31
+ trust_remote_code=True,
32
+ torch_dtype="auto",
33
+ device_map="cpu",
34
+ )
35
+ self.encoder = model.model.vision_backbone
36
+
37
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
38
+ """Compute outputs from the backbone.
39
+
40
+ Inputs:
41
+ inputs: input dicts that must include "image" key containing the image to
42
+ process. The images should have values 0-255.
43
+
44
+ Returns:
45
+ list of feature maps. Molmo produces features at one scale, so the list
46
+ contains a single Bx24x24x2048 tensor.
47
+ """
48
+ device = inputs[0]["image"].device
49
+ molmo_inputs_list = []
50
+ # Process each one so we can isolate just the full image without any crops.
51
+ for inp in inputs:
52
+ image = inp["image"].cpu().numpy().transpose(1, 2, 0)
53
+ processed = self.processor.process(
54
+ images=[image],
55
+ text="",
56
+ )
57
+ molmo_inputs_list.append(processed["images"][0])
58
+ molmo_inputs: torch.Tensor = torch.stack(molmo_inputs_list, dim=0).unsqueeze(1)
59
+
60
+ image_features, _ = self.encoder.encode_image(molmo_inputs.to(device))
61
+
62
+ # 576x2048 -> 24x24x2048
63
+ return [
64
+ image_features[:, 0, :, :].reshape(-1, 24, 24, 2048).permute(0, 3, 1, 2)
65
+ ]
@@ -44,7 +44,7 @@ class MultiTaskModel(torch.nn.Module):
44
44
  tuple (outputs, loss_dict) from the last module.
45
45
  """
46
46
  features = self.encoder(inputs)
47
- outputs = [{} for _ in inputs]
47
+ outputs: list[dict[str, Any]] = [{} for _ in inputs]
48
48
  losses = {}
49
49
  for name, decoder in self.decoders.items():
50
50
  cur = features
@@ -21,7 +21,7 @@ class PoolingDecoder(torch.nn.Module):
21
21
  num_fc_layers: int = 0,
22
22
  conv_channels: int = 128,
23
23
  fc_channels: int = 512,
24
- ):
24
+ ) -> None:
25
25
  """Initialize a PoolingDecoder.
26
26
 
27
27
  Args:
@@ -57,7 +57,9 @@ class PoolingDecoder(torch.nn.Module):
57
57
 
58
58
  self.output_layer = torch.nn.Linear(prev_channels, out_channels)
59
59
 
60
- def forward(self, features: list[torch.Tensor], inputs: list[dict[str, Any]]):
60
+ def forward(
61
+ self, features: list[torch.Tensor], inputs: list[dict[str, Any]]
62
+ ) -> torch.Tensor:
61
63
  """Compute flat output vector from multi-scale feature map.
62
64
 
63
65
  Args:
@@ -13,7 +13,7 @@ class SatlasPretrain(torch.nn.Module):
13
13
  self,
14
14
  model_identifier: str,
15
15
  fpn: bool = False,
16
- ):
16
+ ) -> None:
17
17
  """Instantiate a new SatlasPretrain instance.
18
18
 
19
19
  Args:
@@ -25,7 +25,7 @@ class SatlasPretrain(torch.nn.Module):
25
25
  super().__init__()
26
26
  weights_manager = satlaspretrain_models.Weights()
27
27
  self.model = weights_manager.get_pretrained_model(
28
- model_identifier=model_identifier, fpn=fpn
28
+ model_identifier=model_identifier, fpn=fpn, device="cpu"
29
29
  )
30
30
 
31
31
  if "SwinB" in model_identifier:
@@ -50,20 +50,17 @@ class SatlasPretrain(torch.nn.Module):
50
50
  [32, 2048],
51
51
  ]
52
52
 
53
- def forward(
54
- self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
55
- ):
53
+ def forward(self, inputs: list[dict[str, Any]]) -> list[torch.Tensor]:
56
54
  """Compute feature maps from the SatlasPretrain backbone.
57
55
 
58
56
  Inputs:
59
57
  inputs: input dicts that must include "image" key containing the image to
60
58
  process.
61
- targets: target dicts that are ignored
62
59
  """
63
60
  images = torch.stack([inp["image"] for inp in inputs], dim=0)
64
61
  return self.model(images)
65
62
 
66
- def get_backbone_channels(self):
63
+ def get_backbone_channels(self) -> list:
67
64
  """Returns the output channels of this model when used as a backbone.
68
65
 
69
66
  The output channels is a list of (downsample_factor, depth) that corresponds
@@ -24,7 +24,7 @@ class SimpleTimeSeries(torch.nn.Module):
24
24
  op: str = "max",
25
25
  groups: list[list[int]] | None = None,
26
26
  num_layers: int | None = None,
27
- ):
27
+ ) -> None:
28
28
  """Create a new SimpleTimeSeries.
29
29
 
30
30
  Args:
@@ -55,63 +55,69 @@ class SimpleTimeSeries(torch.nn.Module):
55
55
  else:
56
56
  self.num_groups = 1
57
57
 
58
- if self.op == "convrnn":
59
- rnn_kernel_size = 3
60
- self.rnn_layers = []
61
- for _, count in out_channels:
62
- cur_layer = [
63
- torch.nn.Sequential(
64
- torch.nn.Conv2d(
65
- 2 * count, count, rnn_kernel_size, padding="same"
66
- ),
67
- torch.nn.ReLU(inplace=True),
68
- )
69
- ]
70
- for _ in range(num_layers - 1):
71
- cur_layer.append(
58
+ if self.op in ["convrnn", "conv3d", "conv1d"]:
59
+ if num_layers is None:
60
+ raise ValueError(f"num_layers must be specified for {self.op} op")
61
+
62
+ if self.op == "convrnn":
63
+ rnn_kernel_size = 3
64
+ rnn_layers = []
65
+ for _, count in out_channels:
66
+ cur_layer = [
72
67
  torch.nn.Sequential(
73
68
  torch.nn.Conv2d(
74
- count, count, rnn_kernel_size, padding="same"
69
+ 2 * count, count, rnn_kernel_size, padding="same"
75
70
  ),
76
71
  torch.nn.ReLU(inplace=True),
77
72
  )
78
- )
79
- cur_layer = torch.nn.Sequential(*cur_layer)
80
- self.rnn_layers.append(cur_layer)
81
- self.rnn_layers = torch.nn.ModuleList(self.rnn_layers)
82
-
83
- elif self.op == "conv3d":
84
- self.conv3d_layers = []
85
- for _, count in out_channels:
86
- cur_layer = [
87
- torch.nn.Sequential(
88
- torch.nn.Conv3d(count, count, 3, padding=1, stride=(2, 1, 1)),
89
- torch.nn.ReLU(inplace=True),
90
- )
91
- for _ in range(num_layers)
92
- ]
93
- cur_layer = torch.nn.Sequential(*cur_layer)
94
- self.conv3d_layers.append(cur_layer)
95
- self.conv3d_layers = torch.nn.ModuleList(self.conv3d_layers)
96
-
97
- elif self.op == "conv1d":
98
- self.conv1d_layers = []
99
- for _, count in out_channels:
100
- cur_layer = [
101
- torch.nn.Sequential(
102
- torch.nn.Conv1d(count, count, 3, padding=1, stride=2),
103
- torch.nn.ReLU(inplace=True),
104
- )
105
- for _ in range(num_layers)
106
- ]
107
- cur_layer = torch.nn.Sequential(*cur_layer)
108
- self.conv1d_layers.append(cur_layer)
109
- self.conv1d_layers = torch.nn.ModuleList(self.conv1d_layers)
73
+ ]
74
+ for _ in range(num_layers - 1):
75
+ cur_layer.append(
76
+ torch.nn.Sequential(
77
+ torch.nn.Conv2d(
78
+ count, count, rnn_kernel_size, padding="same"
79
+ ),
80
+ torch.nn.ReLU(inplace=True),
81
+ )
82
+ )
83
+ cur_layer = torch.nn.Sequential(*cur_layer)
84
+ rnn_layers.append(cur_layer)
85
+ self.rnn_layers = torch.nn.ModuleList(rnn_layers)
86
+
87
+ elif self.op == "conv3d":
88
+ conv3d_layers = []
89
+ for _, count in out_channels:
90
+ cur_layer = [
91
+ torch.nn.Sequential(
92
+ torch.nn.Conv3d(
93
+ count, count, 3, padding=1, stride=(2, 1, 1)
94
+ ),
95
+ torch.nn.ReLU(inplace=True),
96
+ )
97
+ for _ in range(num_layers)
98
+ ]
99
+ cur_layer = torch.nn.Sequential(*cur_layer)
100
+ conv3d_layers.append(cur_layer)
101
+ self.conv3d_layers = torch.nn.ModuleList(conv3d_layers)
102
+
103
+ elif self.op == "conv1d":
104
+ conv1d_layers = []
105
+ for _, count in out_channels:
106
+ cur_layer = [
107
+ torch.nn.Sequential(
108
+ torch.nn.Conv1d(count, count, 3, padding=1, stride=2),
109
+ torch.nn.ReLU(inplace=True),
110
+ )
111
+ for _ in range(num_layers)
112
+ ]
113
+ cur_layer = torch.nn.Sequential(*cur_layer)
114
+ conv1d_layers.append(cur_layer)
115
+ self.conv1d_layers = torch.nn.ModuleList(conv1d_layers)
110
116
 
111
117
  else:
112
118
  assert self.op in ["max", "mean"]
113
119
 
114
- def get_backbone_channels(self):
120
+ def get_backbone_channels(self) -> list:
115
121
  """Returns the output channels of this model when used as a backbone.
116
122
 
117
123
  The output channels is a list of (downsample_factor, depth) that corresponds
@@ -129,14 +135,14 @@ class SimpleTimeSeries(torch.nn.Module):
129
135
  return out_channels
130
136
 
131
137
  def forward(
132
- self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
133
- ):
138
+ self,
139
+ inputs: list[dict[str, Any]],
140
+ ) -> list[torch.Tensor]:
134
141
  """Compute outputs from the backbone.
135
142
 
136
143
  Inputs:
137
144
  inputs: input dicts that must include "image" key containing the image time
138
145
  series to process (with images concatenated on the channel dimension).
139
- targets: target dicts that are ignored unless
140
146
  """
141
147
  # First get features of each image.
142
148
  # To do so, we need to split up each grouped image into its component images (which have had their channels stacked).
@@ -171,13 +177,13 @@ class SimpleTimeSeries(torch.nn.Module):
171
177
  for feature_idx in range(len(all_features)):
172
178
  aggregated_features = []
173
179
  for group in groups:
174
- group_features = []
180
+ group_features_list = []
175
181
  for image_idx in group:
176
- group_features.append(
182
+ group_features_list.append(
177
183
  all_features[feature_idx][:, image_idx, :, :, :]
178
184
  )
179
185
  # Resulting group features are (depth, batch, C, height, width).
180
- group_features = torch.stack(group_features, dim=0)
186
+ group_features = torch.stack(group_features_list, dim=0)
181
187
 
182
188
  if self.op == "max":
183
189
  group_features = torch.amax(group_features, dim=0)
@@ -14,7 +14,7 @@ class Ssl4eoS12(torch.nn.Module):
14
14
  backbone_ckpt_path: str,
15
15
  arch: str = "resnet50",
16
16
  output_layers: list[int] = [0, 1, 2, 3],
17
- ):
17
+ ) -> None:
18
18
  """Instantiate a new Swin instance.
19
19
 
20
20
  Args:
@@ -51,7 +51,7 @@ class Ssl4eoS12(torch.nn.Module):
51
51
  f"warning: got missing_keys={missing_keys}, unexpected_keys={unexpected_keys} when loading SSL4EO-S12 state dict"
52
52
  )
53
53
 
54
- def get_backbone_channels(self):
54
+ def get_backbone_channels(self) -> list[tuple[int, int]]:
55
55
  """Returns the output channels of this model when used as a backbone.
56
56
 
57
57
  The output channels is a list of (downsample_factor, depth) that corresponds
@@ -65,16 +65,17 @@ class Ssl4eoS12(torch.nn.Module):
65
65
  """
66
66
  if self.arch == "resnet50":
67
67
  all_out_channels = [
68
- [4, 256],
69
- [8, 512],
70
- [16, 1024],
71
- [32, 2048],
68
+ (4, 256),
69
+ (8, 512),
70
+ (16, 1024),
71
+ (32, 2048),
72
72
  ]
73
73
  return [all_out_channels[idx] for idx in self.output_layers]
74
74
 
75
75
  def forward(
76
- self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
77
- ):
76
+ self,
77
+ inputs: list[dict[str, Any]],
78
+ ) -> list[torch.Tensor]:
78
79
  """Compute outputs from the backbone.
79
80
 
80
81
  If output_layers is set, then the outputs are multi-scale feature maps;
@@ -84,7 +85,6 @@ class Ssl4eoS12(torch.nn.Module):
84
85
  Inputs:
85
86
  inputs: input dicts that must include "image" key containing the image to
86
87
  process.
87
- targets: target dicts that are ignored unless
88
88
  """
89
89
  x = torch.stack([inp["image"] for inp in inputs], dim=0)
90
90
  x = self.model.conv1(x)
rslearn/models/swin.py CHANGED
@@ -28,7 +28,7 @@ class Swin(torch.nn.Module):
28
28
  input_channels: int = 3,
29
29
  output_layers: list[int] | None = None,
30
30
  num_outputs: int = 1000,
31
- ):
31
+ ) -> None:
32
32
  """Instantiate a new Swin instance.
33
33
 
34
34
  Args:
@@ -89,7 +89,7 @@ class Swin(torch.nn.Module):
89
89
  if num_outputs != self.model.head.out_features:
90
90
  self.model.head = torch.nn.Linear(self.model.head.in_features, num_outputs)
91
91
 
92
- def get_backbone_channels(self):
92
+ def get_backbone_channels(self) -> list[tuple[int, int]]:
93
93
  """Returns the output channels of this model when used as a backbone.
94
94
 
95
95
  The output channels is a list of (downsample_factor, depth) that corresponds
@@ -105,31 +105,33 @@ class Swin(torch.nn.Module):
105
105
 
106
106
  if self.arch in ["swin_b", "swin_v2_b"]:
107
107
  all_out_channels = [
108
- [4, 128],
109
- [4, 128],
110
- [8, 256],
111
- [8, 256],
112
- [16, 512],
113
- [16, 512],
114
- [32, 1024],
115
- [32, 1024],
108
+ (4, 128),
109
+ (4, 128),
110
+ (4, 128),
111
+ (8, 256),
112
+ (8, 256),
113
+ (16, 512),
114
+ (16, 512),
115
+ (32, 1024),
116
+ (32, 1024),
116
117
  ]
117
118
  elif self.arch in ["swin_s", "swin_v2_s", "swin_t", "swin_v2_t"]:
118
119
  all_out_channels = [
119
- [4, 96],
120
- [4, 96],
121
- [8, 192],
122
- [8, 192],
123
- [16, 384],
124
- [16, 384],
125
- [32, 768],
126
- [32, 768],
120
+ (4, 96),
121
+ (4, 96),
122
+ (8, 192),
123
+ (8, 192),
124
+ (16, 384),
125
+ (16, 384),
126
+ (32, 768),
127
+ (32, 768),
127
128
  ]
128
129
  return [all_out_channels[idx] for idx in self.output_layers]
129
130
 
130
131
  def forward(
131
- self, inputs: list[dict[str, Any]], targets: list[dict[str, Any]] = None
132
- ):
132
+ self,
133
+ inputs: list[dict[str, Any]],
134
+ ) -> list[torch.Tensor]:
133
135
  """Compute outputs from the backbone.
134
136
 
135
137
  If output_layers is set, then the outputs are multi-scale feature maps;
@@ -139,7 +141,6 @@ class Swin(torch.nn.Module):
139
141
  Inputs:
140
142
  inputs: input dicts that must include "image" key containing the image to
141
143
  process.
142
- targets: target dicts that are ignored unless
143
144
  """
144
145
  images = torch.stack([inp["image"] for inp in inputs], dim=0)
145
146
 
rslearn/models/unet.py CHANGED
@@ -19,7 +19,7 @@ class UNetDecoder(torch.nn.Module):
19
19
  out_channels: int,
20
20
  conv_layers_per_resolution: int = 1,
21
21
  kernel_size: int = 3,
22
- ):
22
+ ) -> None:
23
23
  """Initialize a UNetDecoder.
24
24
 
25
25
  Args:
@@ -110,7 +110,9 @@ class UNetDecoder(torch.nn.Module):
110
110
  layers.append(torch.nn.Sequential(*cur_layers))
111
111
  self.layers = torch.nn.ModuleList(layers)
112
112
 
113
- def forward(self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]]):
113
+ def forward(
114
+ self, in_features: list[torch.Tensor], inputs: list[dict[str, Any]]
115
+ ) -> torch.Tensor:
114
116
  """Compute output from multi-scale feature map.
115
117
 
116
118
  Args:
@@ -0,0 +1,35 @@
1
+ """An upsampling layer."""
2
+
3
+ import torch
4
+
5
+
6
+ class Upsample(torch.nn.Module):
7
+ """Upsamples each input feature map by the same factor."""
8
+
9
+ def __init__(
10
+ self,
11
+ scale_factor: int,
12
+ mode: str = "bilinear",
13
+ ):
14
+ """Initialize an Upsample.
15
+
16
+ Args:
17
+ scale_factor: the upsampling factor, e.g. 2 to double the size.
18
+ mode: "nearest" or "bilinear".
19
+ """
20
+ super().__init__()
21
+ self.layer = torch.nn.Upsample(scale_factor=scale_factor, mode=mode)
22
+
23
+ def forward(
24
+ self, features: list[torch.Tensor], inputs: list[torch.Tensor]
25
+ ) -> list[torch.Tensor]:
26
+ """Compute flat output vector from multi-scale feature map.
27
+
28
+ Args:
29
+ features: list of feature maps at different resolutions.
30
+ inputs: original inputs (ignored).
31
+
32
+ Returns:
33
+ flat feature vector
34
+ """
35
+ return [self.layer(feat_map) for feat_map in features]
@@ -11,7 +11,6 @@ from rslearn.utils import Feature, PixelBounds, Projection
11
11
  from rslearn.utils.fsspec import open_atomic
12
12
  from rslearn.utils.raster_format import (
13
13
  GeotiffRasterFormat,
14
- RasterFormat,
15
14
  load_raster_format,
16
15
  )
17
16
  from rslearn.utils.vector_format import (
@@ -30,7 +29,7 @@ class FileTileStoreLayer(TileStoreLayer):
30
29
  self,
31
30
  path: UPath,
32
31
  projection: Projection | None = None,
33
- raster_format: RasterFormat = GeotiffRasterFormat(),
32
+ raster_format: GeotiffRasterFormat = GeotiffRasterFormat(),
34
33
  vector_format: VectorFormat = GeojsonVectorFormat(),
35
34
  ):
36
35
  """Creates a new FileTileStoreLayer.
@@ -70,6 +69,8 @@ class FileTileStoreLayer(TileStoreLayer):
70
69
  bounds: the bounds of the raster
71
70
  array: the raster data
72
71
  """
72
+ if self.projection is None:
73
+ raise ValueError("need a projection to encode and write raster data")
73
74
  self.raster_format.encode_raster(self.path, self.projection, bounds, array)
74
75
 
75
76
  def get_raster_bounds(self) -> PixelBounds:
@@ -93,6 +94,8 @@ class FileTileStoreLayer(TileStoreLayer):
93
94
  Args:
94
95
  data: the vector data
95
96
  """
97
+ if self.projection is None:
98
+ raise ValueError("need a projection to encode and write vector data")
96
99
  self.vector_format.encode_vector(self.path, self.projection, data)
97
100
 
98
101
  def get_metadata(self) -> LayerMetadata:
@@ -125,7 +128,7 @@ class FileTileStore(TileStore):
125
128
  def __init__(
126
129
  self,
127
130
  path: UPath,
128
- raster_format: RasterFormat = GeotiffRasterFormat(),
131
+ raster_format: GeotiffRasterFormat = GeotiffRasterFormat(),
129
132
  vector_format: VectorFormat = GeojsonVectorFormat(),
130
133
  ):
131
134
  """Initialize a new FileTileStore.
@@ -1,5 +1,6 @@
1
1
  """Base class for tile stores."""
2
2
 
3
+ from abc import ABC, abstractmethod
3
4
  from datetime import datetime
4
5
  from typing import Any
5
6
 
@@ -52,12 +53,13 @@ class LayerMetadata:
52
53
  )
53
54
 
54
55
 
55
- class TileStoreLayer:
56
+ class TileStoreLayer(ABC):
56
57
  """An abstract class for a layer in a tile store.
57
58
 
58
59
  The layer can store one or more raster and vector datas.
59
60
  """
60
61
 
62
+ @abstractmethod
61
63
  def read_raster(self, bounds: PixelBounds) -> npt.NDArray[Any] | None:
62
64
  """Read raster data from the store.
63
65
 
@@ -67,8 +69,9 @@ class TileStoreLayer:
67
69
  Returns:
68
70
  the raster data
69
71
  """
70
- raise NotImplementedError
72
+ pass
71
73
 
74
+ @abstractmethod
72
75
  def write_raster(self, bounds: PixelBounds, array: npt.NDArray[Any]) -> None:
73
76
  """Write raster data to the store.
74
77
 
@@ -76,8 +79,9 @@ class TileStoreLayer:
76
79
  bounds: the bounds of the raster
77
80
  array: the raster data
78
81
  """
79
- raise NotImplementedError
82
+ pass
80
83
 
84
+ @abstractmethod
81
85
  def read_vector(self, bounds: PixelBounds) -> list[Feature]:
82
86
  """Read vector data from the store.
83
87
 
@@ -87,20 +91,23 @@ class TileStoreLayer:
87
91
  Returns:
88
92
  the vector data
89
93
  """
90
- raise NotImplementedError
94
+ pass
91
95
 
96
+ @abstractmethod
92
97
  def write_vector(self, data: list[Feature]) -> None:
93
98
  """Save vector tiles to the store.
94
99
 
95
100
  Args:
96
101
  data: the vector data
97
102
  """
98
- raise NotImplementedError
103
+ pass
99
104
 
105
+ @abstractmethod
100
106
  def get_metadata(self) -> LayerMetadata:
101
107
  """Get the LayerMetadata associated with this layer."""
102
- raise NotImplementedError
108
+ pass
103
109
 
110
+ @abstractmethod
104
111
  def set_property(self, key: str, value: Any) -> None:
105
112
  """Set a property in the metadata for this layer.
106
113
 
@@ -108,7 +115,12 @@ class TileStoreLayer:
108
115
  key: the property key
109
116
  value: the property value
110
117
  """
111
- raise NotImplementedError
118
+ pass
119
+
120
+ @abstractmethod
121
+ def get_raster_bounds(self) -> PixelBounds:
122
+ """Get the bounds of the raster data in the store."""
123
+ pass
112
124
 
113
125
 
114
126
  class TileStore:
@@ -14,7 +14,7 @@ class FreezeUnfreeze(BaseFinetuning):
14
14
  module_selector: list[str | int],
15
15
  unfreeze_at_epoch: int | None = None,
16
16
  unfreeze_lr_factor: float = 1,
17
- ):
17
+ ) -> None:
18
18
  """Creates a new FreezeUnfreeze.
19
19
 
20
20
  Args:
@@ -40,7 +40,7 @@ class FreezeUnfreeze(BaseFinetuning):
40
40
  target_module = getattr(target_module, k)
41
41
  return target_module
42
42
 
43
- def freeze_before_training(self, pl_module: LightningModule):
43
+ def freeze_before_training(self, pl_module: LightningModule) -> None:
44
44
  """Freeze the model at the beginning of training.
45
45
 
46
46
  Args:
@@ -51,7 +51,7 @@ class FreezeUnfreeze(BaseFinetuning):
51
51
 
52
52
  def finetune_function(
53
53
  self, pl_module: LightningModule, current_epoch: int, optimizer: Optimizer
54
- ):
54
+ ) -> None:
55
55
  """Check whether we should unfreeze the model on each epoch.
56
56
 
57
57
  Args: