rslearn 0.0.18__py3-none-any.whl → 0.0.20__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (55) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/dataset.py +15 -16
  3. rslearn/dataset/dataset.py +28 -22
  4. rslearn/lightning_cli.py +22 -11
  5. rslearn/main.py +1 -1
  6. rslearn/models/anysat.py +35 -33
  7. rslearn/models/attention_pooling.py +177 -0
  8. rslearn/models/clip.py +5 -2
  9. rslearn/models/component.py +12 -0
  10. rslearn/models/croma.py +11 -3
  11. rslearn/models/dinov3.py +2 -1
  12. rslearn/models/faster_rcnn.py +2 -1
  13. rslearn/models/galileo/galileo.py +58 -31
  14. rslearn/models/module_wrapper.py +6 -1
  15. rslearn/models/molmo.py +4 -2
  16. rslearn/models/olmoearth_pretrain/model.py +206 -51
  17. rslearn/models/olmoearth_pretrain/norm.py +5 -3
  18. rslearn/models/panopticon.py +3 -1
  19. rslearn/models/presto/presto.py +45 -15
  20. rslearn/models/prithvi.py +9 -7
  21. rslearn/models/sam2_enc.py +3 -1
  22. rslearn/models/satlaspretrain.py +4 -1
  23. rslearn/models/simple_time_series.py +43 -17
  24. rslearn/models/ssl4eo_s12.py +19 -14
  25. rslearn/models/swin.py +3 -1
  26. rslearn/models/terramind.py +5 -4
  27. rslearn/train/all_patches_dataset.py +96 -28
  28. rslearn/train/dataset.py +102 -53
  29. rslearn/train/model_context.py +35 -1
  30. rslearn/train/scheduler.py +15 -0
  31. rslearn/train/tasks/classification.py +8 -2
  32. rslearn/train/tasks/detection.py +3 -2
  33. rslearn/train/tasks/multi_task.py +2 -3
  34. rslearn/train/tasks/per_pixel_regression.py +14 -5
  35. rslearn/train/tasks/regression.py +8 -2
  36. rslearn/train/tasks/segmentation.py +13 -4
  37. rslearn/train/tasks/task.py +2 -2
  38. rslearn/train/transforms/concatenate.py +45 -5
  39. rslearn/train/transforms/crop.py +22 -8
  40. rslearn/train/transforms/flip.py +13 -5
  41. rslearn/train/transforms/mask.py +11 -2
  42. rslearn/train/transforms/normalize.py +46 -15
  43. rslearn/train/transforms/pad.py +15 -3
  44. rslearn/train/transforms/resize.py +83 -0
  45. rslearn/train/transforms/select_bands.py +11 -2
  46. rslearn/train/transforms/sentinel1.py +18 -3
  47. rslearn/utils/geometry.py +73 -0
  48. rslearn/utils/jsonargparse.py +66 -0
  49. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/METADATA +1 -1
  50. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/RECORD +55 -53
  51. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/WHEEL +0 -0
  52. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/entry_points.txt +0 -0
  53. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/LICENSE +0 -0
  54. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/licenses/NOTICE +0 -0
  55. {rslearn-0.0.18.dist-info → rslearn-0.0.20.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,8 @@ from typing import Any
5
5
  import torch
6
6
  import torchvision
7
7
 
8
+ from rslearn.train.model_context import RasterImage
9
+
8
10
  from .transform import Transform, read_selector
9
11
 
10
12
 
@@ -69,7 +71,9 @@ class Crop(Transform):
69
71
  "remove_from_top": remove_from_top,
70
72
  }
71
73
 
72
- def apply_image(self, image: torch.Tensor, state: dict[str, Any]) -> torch.Tensor:
74
+ def apply_image(
75
+ self, image: RasterImage | torch.Tensor, state: dict[str, Any]
76
+ ) -> RasterImage | torch.Tensor:
73
77
  """Apply the sampled state on the specified image.
74
78
 
75
79
  Args:
@@ -80,13 +84,23 @@ class Crop(Transform):
80
84
  crop_size = state["crop_size"] * image.shape[-1] // image_shape[1]
81
85
  remove_from_left = state["remove_from_left"] * image.shape[-1] // image_shape[1]
82
86
  remove_from_top = state["remove_from_top"] * image.shape[-2] // image_shape[0]
83
- return torchvision.transforms.functional.crop(
84
- image,
85
- top=remove_from_top,
86
- left=remove_from_left,
87
- height=crop_size,
88
- width=crop_size,
89
- )
87
+ if isinstance(image, RasterImage):
88
+ image.image = torchvision.transforms.functional.crop(
89
+ image.image,
90
+ top=remove_from_top,
91
+ left=remove_from_left,
92
+ height=crop_size,
93
+ width=crop_size,
94
+ )
95
+ else:
96
+ image = torchvision.transforms.functional.crop(
97
+ image,
98
+ top=remove_from_top,
99
+ left=remove_from_left,
100
+ height=crop_size,
101
+ width=crop_size,
102
+ )
103
+ return image
90
104
 
91
105
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
92
106
  """Apply the sampled state on the specified image.
@@ -4,6 +4,8 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import RasterImage
8
+
7
9
  from .transform import Transform
8
10
 
9
11
 
@@ -48,17 +50,23 @@ class Flip(Transform):
48
50
  "vertical": vertical,
49
51
  }
50
52
 
51
- def apply_image(self, image: torch.Tensor, state: dict[str, bool]) -> torch.Tensor:
53
+ def apply_image(self, image: RasterImage, state: dict[str, bool]) -> RasterImage:
52
54
  """Apply the sampled state on the specified image.
53
55
 
54
56
  Args:
55
57
  image: the image to transform.
56
58
  state: the sampled state.
57
59
  """
58
- if state["horizontal"]:
59
- image = torch.flip(image, dims=[-1])
60
- if state["vertical"]:
61
- image = torch.flip(image, dims=[-2])
60
+ if isinstance(image, RasterImage):
61
+ if state["horizontal"]:
62
+ image.image = torch.flip(image.image, dims=[-1])
63
+ if state["vertical"]:
64
+ image.image = torch.flip(image.image, dims=[-2])
65
+ elif isinstance(image, torch.Tensor):
66
+ if state["horizontal"]:
67
+ image = torch.flip(image, dims=[-1])
68
+ if state["vertical"]:
69
+ image = torch.flip(image, dims=[-2])
62
70
  return image
63
71
 
64
72
  def apply_boxes(
@@ -2,6 +2,7 @@
2
2
 
3
3
  import torch
4
4
 
5
+ from rslearn.train.model_context import RasterImage
5
6
  from rslearn.train.transforms.transform import Transform, read_selector
6
7
 
7
8
 
@@ -31,7 +32,9 @@ class Mask(Transform):
31
32
  self.mask_selector = mask_selector
32
33
  self.mask_value = mask_value
33
34
 
34
- def apply_image(self, image: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
35
+ def apply_image(
36
+ self, image: torch.Tensor | RasterImage, mask: torch.Tensor | RasterImage
37
+ ) -> torch.Tensor | RasterImage:
35
38
  """Apply the mask on the image.
36
39
 
37
40
  Args:
@@ -42,6 +45,9 @@ class Mask(Transform):
42
45
  masked image
43
46
  """
44
47
  # Tile the mask to have same number of bands as the image.
48
+ if isinstance(mask, RasterImage):
49
+ mask = mask.image
50
+
45
51
  if image.shape[0] != mask.shape[0]:
46
52
  if mask.shape[0] != 1:
47
53
  raise ValueError(
@@ -49,7 +55,10 @@ class Mask(Transform):
49
55
  )
50
56
  mask = mask.repeat(image.shape[0], 1, 1)
51
57
 
52
- image[mask == 0] = self.mask_value
58
+ if isinstance(image, torch.Tensor):
59
+ image[mask == 0] = self.mask_value
60
+ else:
61
+ image.image[mask == 0] = self.mask_value
53
62
  return image
54
63
 
55
64
  def forward(self, input_dict: dict, target_dict: dict) -> tuple[dict, dict]:
@@ -4,6 +4,8 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import RasterImage
8
+
7
9
  from .transform import Transform
8
10
 
9
11
 
@@ -55,7 +57,9 @@ class Normalize(Transform):
55
57
  self.bands = torch.tensor(bands) if bands is not None else None
56
58
  self.num_bands = num_bands
57
59
 
58
- def apply_image(self, image: torch.Tensor) -> torch.Tensor:
60
+ def apply_image(
61
+ self, image: torch.Tensor | RasterImage
62
+ ) -> torch.Tensor | RasterImage:
59
63
  """Normalize the specified image.
60
64
 
61
65
  Args:
@@ -63,7 +67,7 @@ class Normalize(Transform):
63
67
  """
64
68
 
65
69
  def _repeat_mean_and_std(
66
- image_channels: int, num_bands: int | None
70
+ image_channels: int, num_bands: int | None, is_raster_image: bool
67
71
  ) -> tuple[torch.Tensor, torch.Tensor]:
68
72
  """Get mean and std tensor that are suitable for applying on the image."""
69
73
  # We only need to repeat the tensor if both of these are true:
@@ -74,9 +78,16 @@ class Normalize(Transform):
74
78
  if num_bands is None:
75
79
  return self.mean, self.std
76
80
  num_images = image_channels // num_bands
77
- return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
78
- num_images
79
- )[:, None, None]
81
+ if is_raster_image:
82
+ # add an extra T dimension, CTHW
83
+ return self.mean.repeat(num_images)[
84
+ :, None, None, None
85
+ ], self.std.repeat(num_images)[:, None, None, None]
86
+ else:
87
+ # add an extra T dimension, CTHW
88
+ return self.mean.repeat(num_images)[:, None, None], self.std.repeat(
89
+ num_images
90
+ )[:, None, None]
80
91
 
81
92
  if self.bands is not None:
82
93
  # User has provided band indices to normalize.
@@ -96,20 +107,40 @@ class Normalize(Transform):
96
107
  # We use len(self.bands) here because that is how many bands per timestep
97
108
  # we are actually processing with the mean/std.
98
109
  mean, std = _repeat_mean_and_std(
99
- image_channels=len(band_indices), num_bands=len(self.bands)
110
+ image_channels=len(band_indices),
111
+ num_bands=len(self.bands),
112
+ is_raster_image=isinstance(image, RasterImage),
100
113
  )
101
- image[band_indices] = (image[band_indices] - mean) / std
102
- if self.valid_min is not None:
103
- image[band_indices] = torch.clamp(
104
- image[band_indices], min=self.valid_min, max=self.valid_max
105
- )
114
+ if isinstance(image, torch.Tensor):
115
+ image[band_indices] = (image[band_indices] - mean) / std
116
+ if self.valid_min is not None:
117
+ image[band_indices] = torch.clamp(
118
+ image[band_indices], min=self.valid_min, max=self.valid_max
119
+ )
120
+ else:
121
+ image.image[band_indices] = (image.image[band_indices] - mean) / std
122
+ if self.valid_min is not None:
123
+ image.image[band_indices] = torch.clamp(
124
+ image.image[band_indices],
125
+ min=self.valid_min,
126
+ max=self.valid_max,
127
+ )
106
128
  else:
107
129
  mean, std = _repeat_mean_and_std(
108
- image_channels=image.shape[0], num_bands=self.num_bands
130
+ image_channels=image.shape[0],
131
+ num_bands=self.num_bands,
132
+ is_raster_image=isinstance(image, RasterImage),
109
133
  )
110
- image = (image - mean) / std
111
- if self.valid_min is not None:
112
- image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
134
+ if isinstance(image, torch.Tensor):
135
+ image = (image - mean) / std
136
+ if self.valid_min is not None:
137
+ image = torch.clamp(image, min=self.valid_min, max=self.valid_max)
138
+ else:
139
+ image.image = (image.image - mean) / std
140
+ if self.valid_min is not None:
141
+ image.image = torch.clamp(
142
+ image.image, min=self.valid_min, max=self.valid_max
143
+ )
113
144
  return image
114
145
 
115
146
  def forward(
@@ -5,6 +5,8 @@ from typing import Any
5
5
  import torch
6
6
  import torchvision
7
7
 
8
+ from rslearn.train.model_context import RasterImage
9
+
8
10
  from .transform import Transform
9
11
 
10
12
 
@@ -48,7 +50,9 @@ class Pad(Transform):
48
50
  """
49
51
  return {"size": torch.randint(low=self.size[0], high=self.size[1], size=())}
50
52
 
51
- def apply_image(self, image: torch.Tensor, state: dict[str, bool]) -> torch.Tensor:
53
+ def apply_image(
54
+ self, image: RasterImage | torch.Tensor, state: dict[str, bool]
55
+ ) -> RasterImage | torch.Tensor:
52
56
  """Apply the sampled state on the specified image.
53
57
 
54
58
  Args:
@@ -101,8 +105,16 @@ class Pad(Transform):
101
105
  horizontal_pad = (horizontal_half, horizontal_extra - horizontal_half)
102
106
  vertical_pad = (vertical_half, vertical_extra - vertical_half)
103
107
 
104
- image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
105
- image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
108
+ if isinstance(image, RasterImage):
109
+ image.image = apply_padding(
110
+ image.image, True, horizontal_pad[0], horizontal_pad[1]
111
+ )
112
+ image.image = apply_padding(
113
+ image.image, False, vertical_pad[0], vertical_pad[1]
114
+ )
115
+ else:
116
+ image = apply_padding(image, True, horizontal_pad[0], horizontal_pad[1])
117
+ image = apply_padding(image, False, vertical_pad[0], vertical_pad[1])
106
118
  return image
107
119
 
108
120
  def apply_boxes(self, boxes: Any, state: dict[str, bool]) -> torch.Tensor:
@@ -0,0 +1,83 @@
1
+ """Resize transform."""
2
+
3
+ from typing import Any
4
+
5
+ import torch
6
+ import torchvision
7
+ from torchvision.transforms import InterpolationMode
8
+
9
+ from rslearn.train.model_context import RasterImage
10
+
11
+ from .transform import Transform
12
+
13
+ INTERPOLATION_MODES = {
14
+ "nearest": InterpolationMode.NEAREST,
15
+ "nearest_exact": InterpolationMode.NEAREST_EXACT,
16
+ "bilinear": InterpolationMode.BILINEAR,
17
+ "bicubic": InterpolationMode.BICUBIC,
18
+ }
19
+
20
+
21
+ class Resize(Transform):
22
+ """Resizes inputs to a target size."""
23
+
24
+ def __init__(
25
+ self,
26
+ target_size: tuple[int, int],
27
+ selectors: list[str] = [],
28
+ interpolation: str = "nearest",
29
+ ):
30
+ """Initialize a resize transform.
31
+
32
+ Args:
33
+ target_size: the (height, width) to resize to.
34
+ selectors: items to transform.
35
+ interpolation: the interpolation mode to use for resizing.
36
+ Must be one of "nearest", "nearest_exact", "bilinear", or "bicubic".
37
+ """
38
+ super().__init__()
39
+ self.target_size = target_size
40
+ self.selectors = selectors
41
+ self.interpolation = INTERPOLATION_MODES[interpolation]
42
+
43
+ def apply_resize(
44
+ self, image: torch.Tensor | RasterImage
45
+ ) -> torch.Tensor | RasterImage:
46
+ """Apply resizing on the specified image.
47
+
48
+ If the image is 2D, it is unsqueezed to 3D and then squeezed
49
+ back after resizing.
50
+
51
+ Args:
52
+ image: the image to transform.
53
+ """
54
+ if isinstance(image, torch.Tensor):
55
+ if image.dim() == 2:
56
+ image = image.unsqueeze(0) # (H, W) -> (1, H, W)
57
+ result = torchvision.transforms.functional.resize(
58
+ image, self.target_size, self.interpolation
59
+ )
60
+ return result.squeeze(0) # (1, H, W) -> (H, W)
61
+ return torchvision.transforms.functional.resize(
62
+ image, self.target_size, self.interpolation
63
+ )
64
+ else:
65
+ image.image = torchvision.transforms.functional.resize(
66
+ image.image, self.target_size, self.interpolation
67
+ )
68
+ return image
69
+
70
+ def forward(
71
+ self, input_dict: dict[str, Any], target_dict: dict[str, Any]
72
+ ) -> tuple[dict[str, Any], dict[str, Any]]:
73
+ """Apply transform over the inputs and targets.
74
+
75
+ Args:
76
+ input_dict: the input
77
+ target_dict: the target
78
+
79
+ Returns:
80
+ transformed (input_dicts, target_dicts) tuple
81
+ """
82
+ self.apply_fn(self.apply_resize, input_dict, target_dict, self.selectors)
83
+ return input_dict, target_dict
@@ -2,6 +2,8 @@
2
2
 
3
3
  from typing import Any
4
4
 
5
+ from rslearn.train.model_context import RasterImage
6
+
5
7
  from .transform import Transform, read_selector, write_selector
6
8
 
7
9
 
@@ -49,6 +51,10 @@ class SelectBands(Transform):
49
51
  if self.num_bands_per_timestep is not None
50
52
  else image.shape[0]
51
53
  )
54
+ if isinstance(image, RasterImage):
55
+ assert num_bands_per_timestep == image.shape[0], (
56
+ "Expect a seperate dimension for timesteps in RasterImages."
57
+ )
52
58
 
53
59
  if image.shape[0] % num_bands_per_timestep != 0:
54
60
  raise ValueError(
@@ -62,6 +68,9 @@ class SelectBands(Transform):
62
68
  [(start_channel_idx + band_idx) for band_idx in self.band_indices]
63
69
  )
64
70
 
65
- result = image[wanted_bands]
66
- write_selector(input_dict, target_dict, self.output_selector, result)
71
+ if isinstance(image, RasterImage):
72
+ image.image = image.image[wanted_bands]
73
+ else:
74
+ image = image[wanted_bands]
75
+ write_selector(input_dict, target_dict, self.output_selector, image)
67
76
  return input_dict, target_dict
@@ -4,6 +4,8 @@ from typing import Any
4
4
 
5
5
  import torch
6
6
 
7
+ from rslearn.train.model_context import RasterImage
8
+
7
9
  from .transform import Transform
8
10
 
9
11
 
@@ -31,18 +33,31 @@ class Sentinel1ToDecibels(Transform):
31
33
  self.from_decibels = from_decibels
32
34
  self.epsilon = epsilon
33
35
 
34
- def apply_image(self, image: torch.Tensor) -> torch.Tensor:
36
+ def apply_image(
37
+ self, image: torch.Tensor | RasterImage
38
+ ) -> torch.Tensor | RasterImage:
35
39
  """Normalize the specified image.
36
40
 
37
41
  Args:
38
42
  image: the image to transform.
39
43
  """
44
+ if isinstance(image, torch.Tensor):
45
+ image_to_process = image
46
+ else:
47
+ image_to_process = image.image
40
48
  if self.from_decibels:
41
49
  # Decibels to linear scale.
42
- return torch.pow(10.0, image / 10.0)
50
+ image_to_process = torch.pow(10.0, image_to_process / 10.0)
43
51
  else:
44
52
  # Linear scale to decibels.
45
- return 10 * torch.log10(torch.clamp(image, min=self.epsilon))
53
+ image_to_process = 10 * torch.log10(
54
+ torch.clamp(image_to_process, min=self.epsilon)
55
+ )
56
+ if isinstance(image, torch.Tensor):
57
+ return image_to_process
58
+ else:
59
+ image.image = image_to_process
60
+ return image
46
61
 
47
62
  def forward(
48
63
  self, input_dict: dict[str, Any], target_dict: dict[str, Any]
rslearn/utils/geometry.py CHANGED
@@ -116,6 +116,79 @@ class Projection:
116
116
  WGS84_PROJECTION = Projection(CRS.from_epsg(WGS84_EPSG), 1, 1)
117
117
 
118
118
 
119
+ class ResolutionFactor:
120
+ """Multiplier for the resolution in a Projection.
121
+
122
+ The multiplier is either an integer x, or the inverse of an integer (1/x).
123
+
124
+ Factors greater than 1 increase the projection_units/pixel resolution, increasing
125
+ the resolution (more pixels per projection unit). Factors less than 1 make it coarser
126
+ (less pixels).
127
+ """
128
+
129
+ def __init__(self, numerator: int = 1, denominator: int = 1):
130
+ """Create a new ResolutionFactor.
131
+
132
+ Args:
133
+ numerator: the numerator of the fraction.
134
+ denominator: the denominator of the fraction. If set, numerator must be 1.
135
+ """
136
+ if numerator != 1 and denominator != 1:
137
+ raise ValueError("one of numerator or denominator must be 1")
138
+ if not isinstance(numerator, int) or not isinstance(denominator, int):
139
+ raise ValueError("numerator and denominator must be integers")
140
+ if numerator < 1 or denominator < 1:
141
+ raise ValueError("numerator and denominator must be >= 1")
142
+ self.numerator = numerator
143
+ self.denominator = denominator
144
+
145
+ def multiply_projection(self, projection: Projection) -> Projection:
146
+ """Multiply the projection by this factor."""
147
+ if self.denominator > 1:
148
+ return Projection(
149
+ projection.crs,
150
+ projection.x_resolution * self.denominator,
151
+ projection.y_resolution * self.denominator,
152
+ )
153
+ else:
154
+ return Projection(
155
+ projection.crs,
156
+ projection.x_resolution // self.numerator,
157
+ projection.y_resolution // self.numerator,
158
+ )
159
+
160
+ def multiply_bounds(self, bounds: PixelBounds) -> PixelBounds:
161
+ """Multiply the bounds by this factor.
162
+
163
+ When coarsening, the width and height of the given bounds must be a multiple of
164
+ the denominator.
165
+ """
166
+ if self.denominator > 1:
167
+ # Verify the width and height are multiples of the denominator.
168
+ # Otherwise the new width and height is not an integer.
169
+ width = bounds[2] - bounds[0]
170
+ height = bounds[3] - bounds[1]
171
+ if width % self.denominator != 0 or height % self.denominator != 0:
172
+ raise ValueError(
173
+ f"width {width} or height {height} is not a multiple of the resolution factor {self.denominator}"
174
+ )
175
+ # TODO: an offset could be introduced by bounds not being a multiple
176
+ # of the denominator -> will need to decide how to handle that.
177
+ return (
178
+ bounds[0] // self.denominator,
179
+ bounds[1] // self.denominator,
180
+ bounds[2] // self.denominator,
181
+ bounds[3] // self.denominator,
182
+ )
183
+ else:
184
+ return (
185
+ bounds[0] * self.numerator,
186
+ bounds[1] * self.numerator,
187
+ bounds[2] * self.numerator,
188
+ bounds[3] * self.numerator,
189
+ )
190
+
191
+
119
192
  class STGeometry:
120
193
  """A spatiotemporal geometry.
121
194
 
@@ -8,6 +8,7 @@ from rasterio.crs import CRS
8
8
  from upath import UPath
9
9
 
10
10
  from rslearn.config.dataset import LayerConfig
11
+ from rslearn.utils.geometry import ResolutionFactor
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from rslearn.data_sources.data_source import DataSourceContext
@@ -91,6 +92,68 @@ def data_source_context_deserializer(v: dict[str, Any]) -> "DataSourceContext":
91
92
  )
92
93
 
93
94
 
95
+ def resolution_factor_serializer(v: ResolutionFactor) -> str:
96
+ """Serialize ResolutionFactor for jsonargparse.
97
+
98
+ Args:
99
+ v: the ResolutionFactor object.
100
+
101
+ Returns:
102
+ the ResolutionFactor encoded to string
103
+ """
104
+ if hasattr(v, "init_args"):
105
+ init_args = v.init_args
106
+ return f"{init_args.numerator}/{init_args.denominator}"
107
+
108
+ return f"{v.numerator}/{v.denominator}"
109
+
110
+
111
+ def resolution_factor_deserializer(v: int | str | dict) -> ResolutionFactor:
112
+ """Deserialize ResolutionFactor for jsonargparse.
113
+
114
+ Args:
115
+ v: the encoded ResolutionFactor.
116
+
117
+ Returns:
118
+ the decoded ResolutionFactor object
119
+ """
120
+ # Handle already-instantiated ResolutionFactor
121
+ if isinstance(v, ResolutionFactor):
122
+ return v
123
+
124
+ # Handle Namespace from class_path syntax (used during config save/validation)
125
+ if hasattr(v, "init_args"):
126
+ init_args = v.init_args
127
+ return ResolutionFactor(
128
+ numerator=init_args.numerator,
129
+ denominator=init_args.denominator,
130
+ )
131
+
132
+ # Handle dict from class_path syntax in YAML config
133
+ if isinstance(v, dict) and "init_args" in v:
134
+ init_args = v["init_args"]
135
+ return ResolutionFactor(
136
+ numerator=init_args.get("numerator", 1),
137
+ denominator=init_args.get("denominator", 1),
138
+ )
139
+
140
+ if isinstance(v, int):
141
+ return ResolutionFactor(numerator=v)
142
+ elif isinstance(v, str):
143
+ parts = v.split("/")
144
+ if len(parts) == 1:
145
+ return ResolutionFactor(numerator=int(parts[0]))
146
+ elif len(parts) == 2:
147
+ return ResolutionFactor(
148
+ numerator=int(parts[0]),
149
+ denominator=int(parts[1]),
150
+ )
151
+ else:
152
+ raise ValueError("expected resolution factor to be of the form x or 1/x")
153
+ else:
154
+ raise ValueError("expected resolution factor to be str or int")
155
+
156
+
94
157
  def init_jsonargparse() -> None:
95
158
  """Initialize custom jsonargparse serializers."""
96
159
  global INITIALIZED
@@ -100,6 +163,9 @@ def init_jsonargparse() -> None:
100
163
  jsonargparse.typing.register_type(
101
164
  datetime, datetime_serializer, datetime_deserializer
102
165
  )
166
+ jsonargparse.typing.register_type(
167
+ ResolutionFactor, resolution_factor_serializer, resolution_factor_deserializer
168
+ )
103
169
 
104
170
  from rslearn.data_sources.data_source import DataSourceContext
105
171
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.18
3
+ Version: 0.0.20
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License