rslearn 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -229,7 +229,13 @@ class SimpleTimeSeries(FeatureExtractor):
229
229
 
230
230
  # Now we can apply the underlying FeatureExtractor.
231
231
  # Its output must be a FeatureMaps.
232
- encoder_output = self.encoder(batched_inputs)
232
+ assert batched_inputs is not None
233
+ encoder_output = self.encoder(
234
+ ModelContext(
235
+ inputs=batched_inputs,
236
+ metadatas=context.metadatas,
237
+ )
238
+ )
233
239
  if not isinstance(encoder_output, FeatureMaps):
234
240
  raise ValueError(
235
241
  "output of underlying FeatureExtractor in SimpleTimeSeries must be a FeatureMaps"
@@ -9,7 +9,7 @@ import shapely
9
9
  import torch
10
10
 
11
11
  from rslearn.dataset import Window
12
- from rslearn.train.dataset import ModelDataset
12
+ from rslearn.train.dataset import DataInput, ModelDataset
13
13
  from rslearn.train.model_context import SampleMetadata
14
14
  from rslearn.utils.geometry import PixelBounds, STGeometry
15
15
 
@@ -34,22 +34,28 @@ def get_window_patch_options(
34
34
  bottommost patches may extend beyond the provided bounds.
35
35
  """
36
36
  # We stride the patches by patch_size - overlap_size until the last patch.
37
+ # We handle the first patch with a special case to ensure it is always used.
37
38
  # We handle the last patch with a special case to ensure it does not exceed the
38
39
  # window bounds. Instead, it may overlap the previous patch.
39
- cols = list(
40
+ cols = [bounds[0]] + list(
40
41
  range(
41
- bounds[0],
42
+ bounds[0] + patch_size[0],
42
43
  bounds[2] - patch_size[0],
43
44
  patch_size[0] - overlap_size[0],
44
45
  )
45
- ) + [bounds[2] - patch_size[0]]
46
- rows = list(
46
+ )
47
+ rows = [bounds[1]] + list(
47
48
  range(
48
- bounds[1],
49
+ bounds[1] + patch_size[1],
49
50
  bounds[3] - patch_size[1],
50
51
  patch_size[1] - overlap_size[1],
51
52
  )
52
- ) + [bounds[3] - patch_size[1]]
53
+ )
54
+ # Add last patches only if the input is larger than one patch.
55
+ if bounds[2] - patch_size[0] > bounds[0]:
56
+ cols.append(bounds[2] - patch_size[0])
57
+ if bounds[3] - patch_size[1] > bounds[1]:
58
+ rows.append(bounds[3] - patch_size[1])
53
59
 
54
60
  patch_bounds: list[PixelBounds] = []
55
61
  for col in cols:
@@ -62,13 +68,17 @@ def pad_slice_protect(
62
68
  raw_inputs: dict[str, Any],
63
69
  passthrough_inputs: dict[str, Any],
64
70
  patch_size: tuple[int, int],
71
+ inputs: dict[str, DataInput],
65
72
  ) -> tuple[dict[str, Any], dict[str, Any]]:
66
73
  """Pad tensors in-place by patch size to protect slicing near right/bottom edges.
67
74
 
75
+ The padding is scaled based on each input's resolution_factor.
76
+
68
77
  Args:
69
78
  raw_inputs: the raw inputs to pad.
70
79
  passthrough_inputs: the passthrough inputs to pad.
71
- patch_size: the size of the patches to extract.
80
+ patch_size: the size of the patches to extract (at window resolution).
81
+ inputs: the DataInput definitions, used to get resolution_factor per input.
72
82
 
73
83
  Returns:
74
84
  a tuple of (raw_inputs, passthrough_inputs).
@@ -77,8 +87,14 @@ def pad_slice_protect(
77
87
  for input_name, value in list(d.items()):
78
88
  if not isinstance(value, torch.Tensor):
79
89
  continue
90
+ # Get resolution scale for this input
91
+ rf = inputs[input_name].resolution_factor
92
+ scale = rf.numerator / rf.denominator
93
+ # Scale the padding amount
94
+ scaled_pad_x = int(patch_size[0] * scale)
95
+ scaled_pad_y = int(patch_size[1] * scale)
80
96
  d[input_name] = torch.nn.functional.pad(
81
- value, pad=(0, patch_size[0], 0, patch_size[1])
97
+ value, pad=(0, scaled_pad_x, 0, scaled_pad_y)
82
98
  )
83
99
  return raw_inputs, passthrough_inputs
84
100
 
@@ -123,6 +139,7 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
123
139
  self.rank = rank
124
140
  self.world_size = world_size
125
141
  self.windows = self.dataset.get_dataset_examples()
142
+ self.inputs = dataset.inputs
126
143
 
127
144
  def set_name(self, name: str) -> None:
128
145
  """Sets dataset name.
@@ -235,8 +252,10 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
235
252
 
236
253
  # For simplicity, pad tensors by patch size to ensure that any patch bounds
237
254
  # extending outside the window bounds will not have issues when we slice
238
- # the tensors later.
239
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
255
+ # the tensors later. Padding is scaled per-input based on resolution_factor.
256
+ pad_slice_protect(
257
+ raw_inputs, passthrough_inputs, self.patch_size, self.inputs
258
+ )
240
259
 
241
260
  # Now iterate over the patches and extract/yield the crops.
242
261
  # Note that, in case user is leveraging RslearnWriter, it is important that
@@ -258,15 +277,28 @@ class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
258
277
  )
259
278
 
260
279
  # Define a helper function to handle each input dict.
280
+ # Crop coordinates are scaled based on each input's resolution_factor.
261
281
  def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
262
282
  cropped = {}
263
283
  for input_name, value in d.items():
264
284
  if isinstance(value, torch.Tensor):
265
- # Crop the CHW tensor.
285
+ # Get resolution scale for this input
286
+ rf = self.inputs[input_name].resolution_factor
287
+ scale = rf.numerator / rf.denominator
288
+ # Scale the crop coordinates
289
+ scaled_start = (
290
+ int(start_offset[0] * scale),
291
+ int(start_offset[1] * scale),
292
+ )
293
+ scaled_end = (
294
+ int(end_offset[0] * scale),
295
+ int(end_offset[1] * scale),
296
+ )
297
+ # Crop the CHW tensor with scaled coordinates.
266
298
  cropped[input_name] = value[
267
299
  :,
268
- start_offset[1] : end_offset[1],
269
- start_offset[0] : end_offset[0],
300
+ scaled_start[1] : scaled_end[1],
301
+ scaled_start[0] : scaled_end[0],
270
302
  ].clone()
271
303
  elif isinstance(value, list):
272
304
  cropped[input_name] = [
@@ -348,6 +380,7 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
348
380
  round(self.patch_size[1] * overlap_ratio),
349
381
  )
350
382
  self.windows = self.dataset.get_dataset_examples()
383
+ self.inputs = dataset.inputs
351
384
  self.window_cache: dict[
352
385
  int, tuple[dict[str, Any], dict[str, Any], SampleMetadata]
353
386
  ] = {}
@@ -378,26 +411,41 @@ class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
378
411
  return self.window_cache[index]
379
412
 
380
413
  raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
381
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
414
+ pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size, self.inputs)
382
415
 
383
416
  self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
384
417
  return self.window_cache[index]
385
418
 
386
- @staticmethod
387
419
  def _crop_input_dict(
420
+ self,
388
421
  d: dict[str, Any],
389
422
  start_offset: tuple[int, int],
390
423
  end_offset: tuple[int, int],
391
424
  cur_geom: STGeometry,
392
425
  ) -> dict[str, Any]:
393
- """Crop a dictionary of inputs to the given bounds."""
426
+ """Crop a dictionary of inputs to the given bounds.
427
+
428
+ Crop coordinates are scaled based on each input's resolution_factor.
429
+ """
394
430
  cropped = {}
395
431
  for input_name, value in d.items():
396
432
  if isinstance(value, torch.Tensor):
433
+ # Get resolution scale for this input
434
+ rf = self.inputs[input_name].resolution_factor
435
+ scale = rf.numerator / rf.denominator
436
+ # Scale the crop coordinates
437
+ scaled_start = (
438
+ int(start_offset[0] * scale),
439
+ int(start_offset[1] * scale),
440
+ )
441
+ scaled_end = (
442
+ int(end_offset[0] * scale),
443
+ int(end_offset[1] * scale),
444
+ )
397
445
  cropped[input_name] = value[
398
446
  :,
399
- start_offset[1] : end_offset[1],
400
- start_offset[0] : end_offset[0],
447
+ scaled_start[1] : scaled_end[1],
448
+ scaled_start[0] : scaled_end[0],
401
449
  ].clone()
402
450
  elif isinstance(value, list):
403
451
  cropped[input_name] = [
rslearn/train/dataset.py CHANGED
@@ -24,7 +24,7 @@ from rslearn.dataset.storage.file import FileWindowStorage
24
24
  from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
25
25
  from rslearn.log_utils import get_logger
26
26
  from rslearn.utils.feature import Feature
27
- from rslearn.utils.geometry import PixelBounds
27
+ from rslearn.utils.geometry import PixelBounds, ResolutionFactor
28
28
  from rslearn.utils.mp import star_imap_unordered
29
29
 
30
30
  from .model_context import SampleMetadata
@@ -130,6 +130,10 @@ class DataInput:
130
130
  """Specification of a piece of data from a window that is needed for training.
131
131
 
132
132
  The DataInput includes which layer(s) the data can be obtained from for each window.
133
+
134
+ Note that this class is not a dataclass because jsonargparse does not play well
135
+ with dataclasses without enabling specialized options which we have not validated
136
+ will work with the rest of our code.
133
137
  """
134
138
 
135
139
  def __init__(
@@ -143,7 +147,9 @@ class DataInput:
143
147
  dtype: DType = DType.FLOAT32,
144
148
  load_all_layers: bool = False,
145
149
  load_all_item_groups: bool = False,
146
- ) -> None:
150
+ resolution_factor: ResolutionFactor = ResolutionFactor(),
151
+ resampling: Resampling = Resampling.nearest,
152
+ ):
147
153
  """Initialize a new DataInput.
148
154
 
149
155
  Args:
@@ -166,6 +172,11 @@ class DataInput:
166
172
  are reading from. By default, we assume the specified layer name is of
167
173
  the form "{layer_name}.{group_idx}" and read that item group only. With
168
174
  this option enabled, we ignore the group_idx and read all item groups.
175
+ resolution_factor: controls the resolution at which raster data is loaded for training.
176
+ By default (factor=1), data is loaded at the window resolution.
177
+ E.g. for a 64x64 window at 10 m/pixel with resolution_factor=1/2,
178
+ the resulting tensor is 32x32 (covering the same geographic area at 20 m/pixel).
179
+ resampling: resampling method (default nearest neighbor).
169
180
  """
170
181
  self.data_type = data_type
171
182
  self.layers = layers
@@ -176,6 +187,8 @@ class DataInput:
176
187
  self.dtype = dtype
177
188
  self.load_all_layers = load_all_layers
178
189
  self.load_all_item_groups = load_all_item_groups
190
+ self.resolution_factor = resolution_factor
191
+ self.resampling = resampling
179
192
 
180
193
 
181
194
  def read_raster_layer_for_data_input(
@@ -233,15 +246,23 @@ def read_raster_layer_for_data_input(
233
246
  + f"window {window.name} layer {layer_name} group {group_idx}"
234
247
  )
235
248
 
249
+ # Get the projection and bounds to read under (multiply window resolution # by
250
+ # the specified resolution factor).
251
+ final_projection = data_input.resolution_factor.multiply_projection(
252
+ window.projection
253
+ )
254
+ final_bounds = data_input.resolution_factor.multiply_bounds(bounds)
255
+
236
256
  image = torch.zeros(
237
- (len(needed_bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
257
+ (
258
+ len(needed_bands),
259
+ final_bounds[3] - final_bounds[1],
260
+ final_bounds[2] - final_bounds[0],
261
+ ),
238
262
  dtype=get_torch_dtype(data_input.dtype),
239
263
  )
240
264
 
241
265
  for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
242
- final_projection, final_bounds = band_set.get_final_projection_and_bounds(
243
- window.projection, bounds
244
- )
245
266
  if band_set.format is None:
246
267
  raise ValueError(f"No format specified for {layer_name}")
247
268
  raster_format = band_set.instantiate_raster_format()
@@ -249,44 +270,16 @@ def read_raster_layer_for_data_input(
249
270
  layer_name, band_set.bands, group_idx=group_idx
250
271
  )
251
272
 
252
- # Previously we always read in the native projection of the data, and then
253
- # zoom in or out (the resolution must be a power of two off) to match the
254
- # window's resolution.
255
- # However, this fails if the bounds are not multiples of the resolution factor.
256
- # So we fallback to reading directly in the window projection if that is the
257
- # case (which may be a bit slower).
258
- is_bounds_zoomable = True
259
- if band_set.zoom_offset < 0:
260
- zoom_factor = 2 ** (-band_set.zoom_offset)
261
- is_bounds_zoomable = (final_bounds[2] - final_bounds[0]) * zoom_factor == (
262
- bounds[2] - bounds[0]
263
- ) and (final_bounds[3] - final_bounds[1]) * zoom_factor == (
264
- bounds[3] - bounds[1]
265
- )
266
-
267
- if is_bounds_zoomable:
268
- src = raster_format.decode_raster(
269
- raster_dir, final_projection, final_bounds
270
- )
271
-
272
- # Resize to patch size if needed.
273
- # This is for band sets that are stored at a lower resolution.
274
- # Here we assume that it is a multiple.
275
- if src.shape[1:3] != image.shape[1:3]:
276
- if src.shape[1] < image.shape[1]:
277
- factor = image.shape[1] // src.shape[1]
278
- src = src.repeat(repeats=factor, axis=1).repeat(
279
- repeats=factor, axis=2
280
- )
281
- else:
282
- factor = src.shape[1] // image.shape[1]
283
- src = src[:, ::factor, ::factor]
284
-
285
- else:
286
- src = raster_format.decode_raster(
287
- raster_dir, window.projection, bounds, resampling=Resampling.nearest
288
- )
273
+ # TODO: previously we try to read based on band_set.zoom_offset when possible,
274
+ # and handle zooming in with torch.repeat (if resampling method is nearest
275
+ # neighbor). However, we have not benchmarked whether this actually improves
276
+ # data loading speed, so for simplicity, for now we let rasterio handle the
277
+ # resampling. If it really is much faster to handle it via torch, then it may
278
+ # make sense to bring back that functionality.
289
279
 
280
+ src = raster_format.decode_raster(
281
+ raster_dir, final_projection, final_bounds, resampling=Resampling.nearest
282
+ )
290
283
  image[dst_indexes, :, :] = torch.as_tensor(
291
284
  src[src_indexes, :, :].astype(data_input.dtype.get_numpy_dtype())
292
285
  )
@@ -8,6 +8,7 @@ from torch.optim.lr_scheduler import (
8
8
  CosineAnnealingLR,
9
9
  CosineAnnealingWarmRestarts,
10
10
  LRScheduler,
11
+ MultiStepLR,
11
12
  ReduceLROnPlateau,
12
13
  )
13
14
 
@@ -50,6 +51,20 @@ class PlateauScheduler(SchedulerFactory):
50
51
  return ReduceLROnPlateau(optimizer, **self.get_kwargs())
51
52
 
52
53
 
54
+ @dataclass
55
+ class MultiStepScheduler(SchedulerFactory):
56
+ """Step learning rate scheduler."""
57
+
58
+ milestones: list[int]
59
+ gamma: float | None = None
60
+ last_epoch: int | None = None
61
+
62
+ def build(self, optimizer: Optimizer) -> LRScheduler:
63
+ """Build the ReduceLROnPlateau scheduler."""
64
+ super().build(optimizer)
65
+ return MultiStepLR(optimizer, **self.get_kwargs())
66
+
67
+
53
68
  @dataclass
54
69
  class CosineAnnealingScheduler(SchedulerFactory):
55
70
  """Cosine annealing learning rate scheduler."""
@@ -0,0 +1,74 @@
1
+ """Resize transform."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ import torchvision
7
+ from torchvision.transforms import InterpolationMode
8
+
9
+ from .transform import Transform
10
+
11
+ INTERPOLATION_MODES = {
12
+ "nearest": InterpolationMode.NEAREST,
13
+ "nearest_exact": InterpolationMode.NEAREST_EXACT,
14
+ "bilinear": InterpolationMode.BILINEAR,
15
+ "bicubic": InterpolationMode.BICUBIC,
16
+ }
17
+
18
+
19
+ class Resize(Transform):
20
+ """Resizes inputs to a target size."""
21
+
22
+ def __init__(
23
+ self,
24
+ target_size: tuple[int, int],
25
+ selectors: list[str] = [],
26
+ interpolation: str = "nearest",
27
+ ):
28
+ """Initialize a resize transform.
29
+
30
+ Args:
31
+ target_size: the (height, width) to resize to.
32
+ selectors: items to transform.
33
+ interpolation: the interpolation mode to use for resizing.
34
+ Must be one of "nearest", "nearest_exact", "bilinear", or "bicubic".
35
+ """
36
+ super().__init__()
37
+ self.target_size = target_size
38
+ self.selectors = selectors
39
+ self.interpolation = INTERPOLATION_MODES[interpolation]
40
+
41
+ def apply_resize(self, image: torch.Tensor) -> torch.Tensor:
42
+ """Apply resizing on the specified image.
43
+
44
+ If the image is 2D, it is unsqueezed to 3D and then squeezed
45
+ back after resizing.
46
+
47
+ Args:
48
+ image: the image to transform.
49
+ """
50
+ if image.dim() == 2:
51
+ image = image.unsqueeze(0) # (H, W) -> (1, H, W)
52
+ result = torchvision.transforms.functional.resize(
53
+ image, self.target_size, self.interpolation
54
+ )
55
+ return result.squeeze(0) # (1, H, W) -> (H, W)
56
+
57
+ return torchvision.transforms.functional.resize(
58
+ image, self.target_size, self.interpolation
59
+ )
60
+
61
+ def forward(
62
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
63
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
64
+ """Apply transform over the inputs and targets.
65
+
66
+ Args:
67
+ input_dict: the input
68
+ target_dict: the target
69
+
70
+ Returns:
71
+ transformed (input_dicts, target_dicts) tuple
72
+ """
73
+ self.apply_fn(self.apply_resize, input_dict, target_dict, self.selectors)
74
+ return input_dict, target_dict
rslearn/utils/geometry.py CHANGED
@@ -116,6 +116,79 @@ class Projection:
116
116
  WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
117
117
 
118
118
 
119
+ class ResolutionFactor:
120
+ """Multiplier for the resolution in a Projection.
121
+
122
+ The multiplier is either an integer x, or the inverse of an integer (1/x).
123
+
124
+ Factors greater than 1 increase the projection_units/pixel resolution, increasing
125
+ the resolution (more pixels per projection unit). Factors less than 1 make it coarser
126
+ (less pixels).
127
+ """
128
+
129
+ def __init__(self, numerator: int = 1, denominator: int = 1):
130
+ """Create a new ResolutionFactor.
131
+
132
+ Args:
133
+ numerator: the numerator of the fraction.
134
+ denominator: the denominator of the fraction. If set, numerator must be 1.
135
+ """
136
+ if numerator != 1 and denominator != 1:
137
+ raise ValueError("one of numerator or denominator must be 1")
138
+ if not isinstance(numerator, int) or not isinstance(denominator, int):
139
+ raise ValueError("numerator and denominator must be integers")
140
+ if numerator < 1 or denominator < 1:
141
+ raise ValueError("numerator and denominator must be >= 1")
142
+ self.numerator = numerator
143
+ self.denominator = denominator
144
+
145
+ def multiply_projection(self, projection: Projection) -> Projection:
146
+ """Multiply the projection by this factor."""
147
+ if self.denominator > 1:
148
+ return Projection(
149
+ projection.crs,
150
+ projection.x_resolution * self.denominator,
151
+ projection.y_resolution * self.denominator,
152
+ )
153
+ else:
154
+ return Projection(
155
+ projection.crs,
156
+ projection.x_resolution // self.numerator,
157
+ projection.y_resolution // self.numerator,
158
+ )
159
+
160
+ def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
161
+ """Multiply the bounds by this factor.
162
+
163
+ When coarsening, the width and height of the given bounds must be a multiple of
164
+ the denominator.
165
+ """
166
+ if self.denominator > 1:
167
+ # Verify the width and height are multiples of the denominator.
168
+ # Otherwise the new width and height is not an integer.
169
+ width = bounds[2] - bounds[0]
170
+ height = bounds[3] - bounds[1]
171
+ if width % self.denominator != 0 or height % self.denominator != 0:
172
+ raise ValueError(
173
+ f"width {width} or height {height} is not a multiple of the resolution factor {self.denominator}"
174
+ )
175
+ # TODO: an offset could be introduced by bounds not being a multiple
176
+ # of the denominator -> will need to decide how to handle that.
177
+ return (
178
+ bounds[0] // self.denominator,
179
+ bounds[1] // self.denominator,
180
+ bounds[2] // self.denominator,
181
+ bounds[3] // self.denominator,
182
+ )
183
+ else:
184
+ return (
185
+ bounds[0] * self.numerator,
186
+ bounds[1] * self.numerator,
187
+ bounds[2] * self.numerator,
188
+ bounds[3] * self.numerator,
189
+ )
190
+
191
+
119
192
  class STGeometry:
120
193
  """A spatiotemporal geometry.
121
194
 
@@ -8,6 +8,7 @@ from rasterio.crs import CRS
8
8
  from upath import UPath
9
9
 
10
10
  from rslearn.config.dataset import LayerConfig
11
+ from rslearn.utils.geometry import ResolutionFactor
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from rslearn.data_sources.data_source import DataSourceContext
@@ -91,6 +92,68 @@ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
91
92
  )
92
93
 
93
94
 
95
+ def resolution_factor_serializer(v: ResolutionFactor) -> str:
96
+ """Serialize ResolutionFactor for jsonargparse.
97
+
98
+ Args:
99
+ v: the ResolutionFactor object.
100
+
101
+ Returns:
102
+ the ResolutionFactor encoded to string
103
+ """
104
+ if hasattr(v, "init_args"):
105
+ init_args = v.init_args
106
+ return f"{init_args.numerator}/{init_args.denominator}"
107
+
108
+ return f"{v.numerator}/{v.denominator}"
109
+
110
+
111
+ def resolution_factor_deserializer(v: int | str | dict) -> ResolutionFactor:
112
+ """Deserialize ResolutionFactor for jsonargparse.
113
+
114
+ Args:
115
+ v: the encoded ResolutionFactor.
116
+
117
+ Returns:
118
+ the decoded ResolutionFactor object
119
+ """
120
+ # Handle already-instantiated ResolutionFactor
121
+ if isinstance(v, ResolutionFactor):
122
+ return v
123
+
124
+ # Handle Namespace from class_path syntax (used during config save/validation)
125
+ if hasattr(v, "init_args"):
126
+ init_args = v.init_args
127
+ return ResolutionFactor(
128
+ numerator=init_args.numerator,
129
+ denominator=init_args.denominator,
130
+ )
131
+
132
+ # Handle dict from class_path syntax in YAML config
133
+ if isinstance(v, dict) and "init_args" in v:
134
+ init_args = v["init_args"]
135
+ return ResolutionFactor(
136
+ numerator=init_args.get("numerator", 1),
137
+ denominator=init_args.get("denominator", 1),
138
+ )
139
+
140
+ if isinstance(v, int):
141
+ return ResolutionFactor(numerator=v)
142
+ elif isinstance(v, str):
143
+ parts = v.split("/")
144
+ if len(parts) == 1:
145
+ return ResolutionFactor(numerator=int(parts[0]))
146
+ elif len(parts) == 2:
147
+ return ResolutionFactor(
148
+ numerator=int(parts[0]),
149
+ denominator=int(parts[1]),
150
+ )
151
+ else:
152
+ raise ValueError("expected resolution factor to be of the form x or 1/x")
153
+ else:
154
+ raise ValueError("expected resolution factor to be str or int")
155
+
156
+
94
157
  def init_jsonargparse() -> None:
95
158
  """Initialize custom jsonargparse serializers."""
96
159
  global INITIALIZED
@@ -100,6 +163,9 @@ def init_jsonargparse() -> None:
100
163
  jsonargparse.typing.register_type(
101
164
  datetime, datetime_serializer, datetime_deserializer
102
165
  )
166
+ jsonargparse.typing.register_type(
167
+ ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer
168
+ )
103
169
 
104
170
  from rslearn.data_sources.data_source import DataSourceContext
105
171
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.18
3
+ Version: 0.0.19
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License