rslearn 0.0.13__py3-none-any.whl → 0.0.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
rslearn/train/dataset.py CHANGED
@@ -1,7 +1,6 @@
1
1
  """Default Dataset for rslearn."""
2
2
 
3
3
  import hashlib
4
- import itertools
5
4
  import json
6
5
  import multiprocessing
7
6
  import os
@@ -9,10 +8,8 @@ import random
9
8
  import tempfile
10
9
  import time
11
10
  import uuid
12
- from collections.abc import Iterable, Iterator
13
11
  from typing import Any
14
12
 
15
- import shapely
16
13
  import torch
17
14
  import tqdm
18
15
  from rasterio.warp import Resampling
@@ -29,7 +26,7 @@ from rslearn.dataset.window import Window, get_layer_and_group_from_dir_name
29
26
  from rslearn.log_utils import get_logger
30
27
  from rslearn.train.tasks import Task
31
28
  from rslearn.utils.feature import Feature
32
- from rslearn.utils.geometry import PixelBounds, STGeometry
29
+ from rslearn.utils.geometry import PixelBounds
33
30
  from rslearn.utils.mp import star_imap_unordered
34
31
  from rslearn.utils.raster_format import load_raster_format
35
32
  from rslearn.utils.vector_format import load_vector_format
@@ -39,70 +36,14 @@ from .transforms import Sequential
39
36
  logger = get_logger(__name__)
40
37
 
41
38
 
42
- def get_window_patch_options(
43
- patch_size: tuple[int, int],
44
- overlap_size: tuple[int, int],
45
- bounds: PixelBounds,
46
- ) -> list[PixelBounds]:
47
- """Get the bounds of each patch within the overall bounds.
48
-
49
- Args:
50
- patch_size: the size of the patches to extract.
51
- overlap_size: the size of the overlap between patches.
52
- bounds: the window bounds to divide up into smaller patches.
53
-
54
- Returns:
55
- a list of patch bounds within the overall bounds. The rightmost and
56
- bottommost patches may extend beyond the provided bounds.
57
- """
58
- # We stride the patches by patch_size - overlap_size until the last patch.
59
- # We handle the last patch with a special case to ensure it does not exceed the
60
- # window bounds. Instead, it may overlap the previous patch.
61
- cols = list(
62
- range(
63
- bounds[0],
64
- bounds[2] - patch_size[0],
65
- patch_size[0] - overlap_size[0],
66
- )
67
- ) + [bounds[2] - patch_size[0]]
68
- rows = list(
69
- range(
70
- bounds[1],
71
- bounds[3] - patch_size[1],
72
- patch_size[1] - overlap_size[1],
73
- )
74
- ) + [bounds[3] - patch_size[1]]
75
-
76
- patch_bounds: list[PixelBounds] = []
77
- for col in cols:
78
- for row in rows:
79
- patch_bounds.append((col, row, col + patch_size[0], row + patch_size[1]))
80
- return patch_bounds
81
-
82
-
83
- def pad_slice_protect(
84
- raw_inputs: dict[str, Any],
85
- passthrough_inputs: dict[str, Any],
86
- patch_size: tuple[int, int],
87
- ) -> tuple[dict[str, Any], dict[str, Any]]:
88
- """Pad tensors in-place by patch size to protect slicing near right/bottom edges.
89
-
90
- Args:
91
- raw_inputs: the raw inputs to pad.
92
- passthrough_inputs: the passthrough inputs to pad.
93
- patch_size: the size of the patches to extract.
94
-
95
- Returns:
96
- a tuple of (raw_inputs, passthrough_inputs).
97
- """
98
- for d in [raw_inputs, passthrough_inputs]:
99
- for input_name, value in list(d.items()):
100
- if not isinstance(value, torch.Tensor):
101
- continue
102
- d[input_name] = torch.nn.functional.pad(
103
- value, pad=(0, patch_size[0], 0, patch_size[1])
104
- )
105
- return raw_inputs, passthrough_inputs
39
+ def get_torch_dtype(dtype: DType) -> torch.dtype:
40
+ """Convert rslearn DType to torch dtype."""
41
+ if dtype == DType.INT32:
42
+ return torch.int32
43
+ elif dtype == DType.FLOAT32:
44
+ return torch.float32
45
+ else:
46
+ raise ValueError(f"unable to handle {dtype} as a torch dtype")
106
47
 
107
48
 
108
49
  class SamplerFactory:
@@ -296,7 +237,7 @@ def read_raster_layer_for_data_input(
296
237
 
297
238
  image = torch.zeros(
298
239
  (len(needed_bands), bounds[3] - bounds[1], bounds[2] - bounds[0]),
299
- dtype=data_input.dtype.get_torch_dtype(),
240
+ dtype=get_torch_dtype(data_input.dtype),
300
241
  )
301
242
 
302
243
  for band_set, src_indexes, dst_indexes in needed_sets_and_indexes:
@@ -893,383 +834,6 @@ class ModelDataset(torch.utils.data.Dataset):
893
834
  self.name = name
894
835
 
895
836
 
896
- class IterableAllPatchesDataset(torch.utils.data.IterableDataset):
897
- """This wraps a ModelDataset to iterate over all patches in that dataset.
898
-
899
- This should be used when SplitConfig.load_all_patches is enabled. The ModelDataset
900
- is configured with no patch size (load entire windows), and the dataset is wrapped
901
- in an AllPatchesDataset.
902
-
903
- Similar to DistributedSampler, we add extra samples at each rank to ensure
904
- consistent number of batches across all ranks.
905
- """
906
-
907
- def __init__(
908
- self,
909
- dataset: ModelDataset,
910
- patch_size: tuple[int, int],
911
- overlap_ratio: float = 0.0,
912
- rank: int = 0,
913
- world_size: int = 1,
914
- ):
915
- """Create a new IterableAllPatchesDataset.
916
-
917
- Args:
918
- dataset: the ModelDataset to wrap.
919
- patch_size: the size of the patches to extract.
920
- overlap_ratio: whether to include overlap between the patches. Note that
921
- the right/bottom-most patches may still overlap since we ensure that
922
- all patches are contained in the window bounds.
923
- rank: the global rank of this train worker process.
924
- world_size: the total number of train worker processes.
925
- """
926
- super().__init__()
927
- self.dataset = dataset
928
- self.patch_size = patch_size
929
- self.overlap_size = (
930
- round(self.patch_size[0] * overlap_ratio),
931
- round(self.patch_size[1] * overlap_ratio),
932
- )
933
- self.rank = rank
934
- self.world_size = world_size
935
- self.windows = self.dataset.get_dataset_examples()
936
-
937
- def set_name(self, name: str) -> None:
938
- """Sets dataset name.
939
-
940
- Args:
941
- name: dataset name
942
- """
943
- self.dataset.set_name(name)
944
-
945
- def get_window_num_patches(self, bounds: PixelBounds) -> int:
946
- """Get the number of patches for these bounds.
947
-
948
- This corresponds to the length of the list returned by get_patch_options.
949
- """
950
- num_cols = (
951
- len(
952
- range(
953
- bounds[0],
954
- bounds[2] - self.patch_size[0],
955
- self.patch_size[0] - self.overlap_size[0],
956
- )
957
- )
958
- + 1
959
- )
960
- num_rows = (
961
- len(
962
- range(
963
- bounds[1],
964
- bounds[3] - self.patch_size[1],
965
- self.patch_size[1] - self.overlap_size[1],
966
- )
967
- )
968
- + 1
969
- )
970
- return num_cols * num_rows
971
-
972
- def _get_worker_iteration_data(self) -> tuple[Iterable[int], int]:
973
- """Get the windows we should iterate over.
974
-
975
- This is split both by training worker (self.rank) and data loader worker (via
976
- get_worker_info).
977
-
978
- We also compute the total number of samples that each data loader worker should
979
- yield. This is important for DDP to ensure that all ranks see the same number
980
- of batches.
981
-
982
- Returns:
983
- a tuple (window_ids, num_samples_per_worker).
984
- """
985
- # Figure out the total number of data loader workers and our worker ID.
986
- worker_info = torch.utils.data.get_worker_info()
987
- if worker_info is None:
988
- worker_id = 0
989
- num_workers = 1
990
- else:
991
- worker_id = worker_info.id
992
- num_workers = worker_info.num_workers
993
- global_worker_id = self.rank * num_workers + worker_id
994
- global_num_workers = self.world_size * num_workers
995
-
996
- # Split up the windows evenly among the workers.
997
- # We compute this for all workers since we will need to see the maximum number
998
- # of samples under this assignment across workers.
999
- window_indexes = range(len(self.windows))
1000
- windows_by_worker = [
1001
- window_indexes[cur_rank :: self.world_size][cur_worker_id::num_workers]
1002
- for cur_rank in range(self.world_size)
1003
- for cur_worker_id in range(num_workers)
1004
- ]
1005
-
1006
- # Now compute the maximum number of samples across workers.
1007
- max_num_patches = 0
1008
- for worker_windows in windows_by_worker:
1009
- worker_num_patches = 0
1010
- for window_id in worker_windows:
1011
- worker_num_patches += self.get_window_num_patches(
1012
- self.windows[window_id].bounds
1013
- )
1014
- max_num_patches = max(max_num_patches, worker_num_patches)
1015
-
1016
- # Each worker needs at least one window, otherwise it won't be able to pad.
1017
- # Unless there are zero windows total, which is fine.
1018
- # Previously we would address this by borrowing the windows from another
1019
- # worker, but this causes issues with RslearnWriter: if we yield the same
1020
- # window from parallel workers, it may end up writing an empty output for that
1021
- # window in the end.
1022
- # So now we raise an error instead, and require the number of workers to be
1023
- # less than the number of windows.
1024
- if len(windows_by_worker[global_worker_id]) == 0 and max_num_patches > 0:
1025
- raise ValueError(
1026
- f"the number of workers {global_num_workers} must be <= the number of windows {len(self.windows)}"
1027
- )
1028
-
1029
- return (windows_by_worker[global_worker_id], max_num_patches)
1030
-
1031
- def __iter__(
1032
- self,
1033
- ) -> Iterator[tuple[dict[str, Any], dict[str, Any], dict[str, Any]]]:
1034
- """Iterate over all patches in each element of the underlying ModelDataset."""
1035
- # Iterate over the window IDs until we have returned enough samples.
1036
- window_ids, num_samples_needed = self._get_worker_iteration_data()
1037
- num_samples_returned = 0
1038
-
1039
- for iteration_idx in itertools.count():
1040
- for window_id in window_ids:
1041
- raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(
1042
- window_id
1043
- )
1044
- bounds = metadata["bounds"]
1045
-
1046
- # For simplicity, pad tensors by patch size to ensure that any patch bounds
1047
- # extending outside the window bounds will not have issues when we slice
1048
- # the tensors later.
1049
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
1050
-
1051
- # Now iterate over the patches and extract/yield the crops.
1052
- # Note that, in case user is leveraging RslearnWriter, it is important that
1053
- # the patch_idx be increasing (as we iterate) within one window.
1054
- patches = get_window_patch_options(
1055
- self.patch_size, self.overlap_size, bounds
1056
- )
1057
- for patch_idx, patch_bounds in enumerate(patches):
1058
- cur_geom = STGeometry(
1059
- metadata["projection"], shapely.box(*patch_bounds), None
1060
- )
1061
- start_offset = (
1062
- patch_bounds[0] - bounds[0],
1063
- patch_bounds[1] - bounds[1],
1064
- )
1065
- end_offset = (
1066
- patch_bounds[2] - bounds[0],
1067
- patch_bounds[3] - bounds[1],
1068
- )
1069
-
1070
- # Define a helper function to handle each input dict.
1071
- def crop_input_dict(d: dict[str, Any]) -> dict[str, Any]:
1072
- cropped = {}
1073
- for input_name, value in d.items():
1074
- if isinstance(value, torch.Tensor):
1075
- # Crop the CHW tensor.
1076
- cropped[input_name] = value[
1077
- :,
1078
- start_offset[1] : end_offset[1],
1079
- start_offset[0] : end_offset[0],
1080
- ].clone()
1081
- elif isinstance(value, list):
1082
- cropped[input_name] = [
1083
- feat
1084
- for feat in value
1085
- if cur_geom.intersects(feat.geometry)
1086
- ]
1087
- else:
1088
- raise ValueError(
1089
- "got input that is neither tensor nor feature list"
1090
- )
1091
- return cropped
1092
-
1093
- cur_raw_inputs = crop_input_dict(raw_inputs)
1094
- cur_passthrough_inputs = crop_input_dict(passthrough_inputs)
1095
-
1096
- # Adjust the metadata as well.
1097
- cur_metadata = metadata.copy()
1098
- cur_metadata["bounds"] = patch_bounds
1099
- cur_metadata["patch_idx"] = patch_idx
1100
- cur_metadata["num_patches"] = len(patches)
1101
-
1102
- # Now we can compute input and target dicts via the task.
1103
- input_dict, target_dict = self.dataset.task.process_inputs(
1104
- cur_raw_inputs,
1105
- metadata=cur_metadata,
1106
- load_targets=not self.dataset.split_config.get_skip_targets(),
1107
- )
1108
- input_dict.update(cur_passthrough_inputs)
1109
- input_dict, target_dict = self.dataset.transforms(
1110
- input_dict, target_dict
1111
- )
1112
- input_dict["dataset_source"] = self.dataset.name
1113
-
1114
- if num_samples_returned < num_samples_needed:
1115
- yield input_dict, target_dict, cur_metadata
1116
- num_samples_returned += 1
1117
- else:
1118
- assert iteration_idx > 0
1119
-
1120
- if num_samples_returned >= num_samples_needed:
1121
- break
1122
-
1123
- def get_dataset_examples(self) -> list[Window]:
1124
- """Returns a list of windows in this dataset."""
1125
- return self.dataset.get_dataset_examples()
1126
-
1127
-
1128
- class InMemoryAllPatchesDataset(torch.utils.data.Dataset):
1129
- """This wraps a ModelDataset to iterate over all patches in that dataset.
1130
-
1131
- This should be used when SplitConfig.load_all_patches is enabled.
1132
-
1133
- This is a simpler version of IterableAllPatchesDataset that caches all windows in memory.
1134
- This is useful for small datasets that fit in memory.
1135
- """
1136
-
1137
- def __init__(
1138
- self,
1139
- dataset: ModelDataset,
1140
- patch_size: tuple[int, int],
1141
- overlap_ratio: float = 0.0,
1142
- ):
1143
- """Create a new InMemoryAllPatchesDataset.
1144
-
1145
- Args:
1146
- dataset: the ModelDataset to wrap.
1147
- patch_size: the size of the patches to extract.
1148
- overlap_ratio: whether to include overlap between the patches. Note that
1149
- the right/bottom-most patches may still overlap since we ensure that
1150
- all patches are contained in the window bounds.
1151
- """
1152
- super().__init__()
1153
- self.dataset = dataset
1154
- self.patch_size = patch_size
1155
- self.overlap_size = (
1156
- round(self.patch_size[0] * overlap_ratio),
1157
- round(self.patch_size[1] * overlap_ratio),
1158
- )
1159
- self.windows = self.dataset.get_dataset_examples()
1160
- self.window_cache: dict[
1161
- int, tuple[dict[str, Any], dict[str, Any], dict[str, Any]]
1162
- ] = {}
1163
-
1164
- # Precompute the batch boundaries for each window
1165
- self.patches = []
1166
- for window_id, window in enumerate(self.windows):
1167
- patch_bounds = get_window_patch_options(
1168
- self.patch_size, self.overlap_size, window.bounds
1169
- )
1170
- for i, patch_bound in enumerate(patch_bounds):
1171
- self.patches.append((window_id, patch_bound, (i, len(patch_bounds))))
1172
-
1173
- def get_raw_inputs(
1174
- self, index: int
1175
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
1176
- """Get the raw inputs for a single patch. Retrieve from cache if possible.
1177
-
1178
- Also crops/pads the tensors by patch size to protect slicing near right/bottom edges.
1179
-
1180
- Args:
1181
- index: the index of the patch.
1182
-
1183
- Returns:
1184
- a tuple of (raw_inputs, passthrough_inputs, metadata).
1185
- """
1186
- if index in self.window_cache:
1187
- return self.window_cache[index]
1188
-
1189
- raw_inputs, passthrough_inputs, metadata = self.dataset.get_raw_inputs(index)
1190
- pad_slice_protect(raw_inputs, passthrough_inputs, self.patch_size)
1191
-
1192
- self.window_cache[index] = (raw_inputs, passthrough_inputs, metadata)
1193
- return self.window_cache[index]
1194
-
1195
- @staticmethod
1196
- def _crop_input_dict(
1197
- d: dict[str, Any],
1198
- start_offset: tuple[int, int],
1199
- end_offset: tuple[int, int],
1200
- cur_geom: STGeometry,
1201
- ) -> dict[str, Any]:
1202
- """Crop a dictionary of inputs to the given bounds."""
1203
- cropped = {}
1204
- for input_name, value in d.items():
1205
- if isinstance(value, torch.Tensor):
1206
- cropped[input_name] = value[
1207
- :,
1208
- start_offset[1] : end_offset[1],
1209
- start_offset[0] : end_offset[0],
1210
- ].clone()
1211
- elif isinstance(value, list):
1212
- cropped[input_name] = [
1213
- feat for feat in value if cur_geom.intersects(feat.geometry)
1214
- ]
1215
- else:
1216
- raise ValueError("got input that is neither tensor nor feature list")
1217
- return cropped
1218
-
1219
- def __len__(self) -> int:
1220
- """Return the total number of patches in the dataset."""
1221
- return len(self.patches)
1222
-
1223
- def __getitem__(
1224
- self, index: int
1225
- ) -> tuple[dict[str, Any], dict[str, Any], dict[str, Any]]:
1226
- """Return (input_dict, target_dict, metadata) for a single flattened patch."""
1227
- (window_id, patch_bounds, (patch_idx, num_patches)) = self.patches[index]
1228
- raw_inputs, passthrough_inputs, metadata = self.get_raw_inputs(window_id)
1229
- bounds = metadata["bounds"]
1230
-
1231
- cur_geom = STGeometry(metadata["projection"], shapely.box(*patch_bounds), None)
1232
- start_offset = (patch_bounds[0] - bounds[0], patch_bounds[1] - bounds[1])
1233
- end_offset = (patch_bounds[2] - bounds[0], patch_bounds[3] - bounds[1])
1234
-
1235
- cur_raw_inputs = self._crop_input_dict(
1236
- raw_inputs, start_offset, end_offset, cur_geom
1237
- )
1238
- cur_passthrough_inputs = self._crop_input_dict(
1239
- passthrough_inputs, start_offset, end_offset, cur_geom
1240
- )
1241
-
1242
- # Adjust the metadata as well.
1243
- cur_metadata = metadata.copy()
1244
- cur_metadata["bounds"] = patch_bounds
1245
- cur_metadata["patch_idx"] = patch_idx
1246
- cur_metadata["num_patches"] = num_patches
1247
-
1248
- # Now we can compute input and target dicts via the task.
1249
- input_dict, target_dict = self.dataset.task.process_inputs(
1250
- cur_raw_inputs,
1251
- metadata=cur_metadata,
1252
- load_targets=not self.dataset.split_config.get_skip_targets(),
1253
- )
1254
- input_dict.update(cur_passthrough_inputs)
1255
- input_dict, target_dict = self.dataset.transforms(input_dict, target_dict)
1256
- input_dict["dataset_source"] = self.dataset.name
1257
-
1258
- return input_dict, target_dict, cur_metadata
1259
-
1260
- def get_dataset_examples(self) -> list[Window]:
1261
- """Returns a list of windows in this dataset."""
1262
- return self.dataset.get_dataset_examples()
1263
-
1264
- def set_name(self, name: str) -> None:
1265
- """Sets dataset name.
1266
-
1267
- Args:
1268
- name: dataset name
1269
- """
1270
- self.dataset.set_name(name)
1271
-
1272
-
1273
837
  class RetryDataset(torch.utils.data.Dataset):
1274
838
  """A dataset wrapper that retries getitem upon encountering error."""
1275
839
 
rslearn/utils/array.py CHANGED
@@ -1,14 +1,16 @@
1
1
  """Array util functions."""
2
2
 
3
- from typing import Any
3
+ from typing import TYPE_CHECKING, Any
4
4
 
5
5
  import numpy.typing as npt
6
- import torch
6
+
7
+ if TYPE_CHECKING:
8
+ import torch
7
9
 
8
10
 
9
11
  def copy_spatial_array(
10
- src: torch.Tensor | npt.NDArray[Any],
11
- dst: torch.Tensor | npt.NDArray[Any],
12
+ src: "torch.Tensor | npt.NDArray[Any]",
13
+ dst: "torch.Tensor | npt.NDArray[Any]",
12
14
  src_offset: tuple[int, int],
13
15
  dst_offset: tuple[int, int],
14
16
  ) -> None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: rslearn
3
- Version: 0.0.13
3
+ Version: 0.0.15
4
4
  Summary: A library for developing remote sensing datasets and models
5
5
  Author: OlmoEarth Team
6
6
  License: Apache License
@@ -1,12 +1,13 @@
1
1
  rslearn/__init__.py,sha256=fFmAen3vxZyosEfPbG0W46IttujYGVxzrGkJ0YutmmY,73
2
2
  rslearn/arg_parser.py,sha256=GNlJncO6Ck_dCNrcg7z_SSG61I-2gKn3Ix2tAxIk9CI,1428
3
3
  rslearn/const.py,sha256=FUCfsvFAs-QarEDJ0grdy0C1HjUjLpNFYGo5I2Vpc5Y,449
4
+ rslearn/lightning_cli.py,sha256=io1Agb2fr-fUu9yOODNJhP8-vJp_v9UbJJA2hkLubKA,2435
4
5
  rslearn/log_utils.py,sha256=unD9gShiuO7cx5Nnq8qqVQ4qrbOOwFVgcHxN5bXuiAo,941
5
- rslearn/main.py,sha256=fLYmm2ZsUTCaJBKZvxu3pc4fB2thaf-p2Qv0AifDlXM,31292
6
+ rslearn/main.py,sha256=JMNMhAHqpb9bDUoKzj6kN659Ft_-gZv_rKUieJcJNwI,29087
6
7
  rslearn/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
8
  rslearn/template_params.py,sha256=Vop0Ha-S44ctCa9lvSZRjrMETznJZlR5y_gJrVIwrPg,791
8
9
  rslearn/config/__init__.py,sha256=Bhf2VVncdMYRC8Wfb4GsJJ13OAJYNCO_ODLSNTmBOHM,638
9
- rslearn/config/dataset.py,sha256=VpXUGKCr45kzE-W27rgF4tPQuyICfwQkJTxb2z9aXQM,21685
10
+ rslearn/config/dataset.py,sha256=lIuFgJG0Hz7nxacFIpbwOyNJqjlkOlaMfWt91Chjb_M,21338
10
11
  rslearn/data_sources/__init__.py,sha256=8_7Pi3agKsatNoxXw74-U5G-QAP-rbdfcH8EkZfJbH4,1449
11
12
  rslearn/data_sources/aws_landsat.py,sha256=GA9H04KagBDm-N37jFdh_aHCX2ZneVdnqT1SNOyAwTs,20829
12
13
  rslearn/data_sources/aws_open_data.py,sha256=nU_D5cqc-wibxq4uyUNb0z-XD0Puf1gZ8v5FMiMAN5w,30258
@@ -39,7 +40,7 @@ rslearn/dataset/add_windows.py,sha256=pwCEvwLE1jQCoqQxw6CJ-sP46ayWppFa2hGYIB6VVk
39
40
  rslearn/dataset/dataset.py,sha256=bjf9nI55j-MF0bIQWSNPjNbpfqnLK4jy-96TAcwO0MM,5214
40
41
  rslearn/dataset/handler_summaries.py,sha256=wI99RDk5erCWkzl1A7Uc4chatQ9KWIr4F_0Hxr9Co6s,2607
41
42
  rslearn/dataset/index.py,sha256=Wni5m6h4gisRB54fPLnCfUrRTEsJ5EvwS0fs9sYc2wg,6025
42
- rslearn/dataset/manage.py,sha256=owelBiBqvoIQYLhFMDK4ULzcoGBNE27JV8kl68jf3wg,18563
43
+ rslearn/dataset/manage.py,sha256=IURlbCtm9a5f4d52AXfte1yyodlf6MgjfYn3__GdIL4,19062
43
44
  rslearn/dataset/materialize.py,sha256=-z47svc_JqGhzkp8kq5Hd9fykWNqFEUCQezo887TWBw,22056
44
45
  rslearn/dataset/remap.py,sha256=6MaImsY02GNACpvRM81RvWmjZWRfAHxo_R3Ox6XLF6A,2723
45
46
  rslearn/dataset/window.py,sha256=I5RqZ12jlIXhohw4qews1x_I4tSDpml709DZRtLiN24,12546
@@ -47,7 +48,7 @@ rslearn/models/__init__.py,sha256=_vWoF9d2Slah8-6XhYhdU4SRsy_CNxXjCGQTD2yvu3Q,22
47
48
  rslearn/models/anysat.py,sha256=3Oh2gWxicVdUzOjevBEZf0PuolmCy0KC5Ad7JY-0Plc,7949
48
49
  rslearn/models/clip.py,sha256=u5aqYnVB4Jag7o1h8EzPDAc1t2BAHeALA9FcUwP5tfo,2238
49
50
  rslearn/models/conv.py,sha256=fWyByeswIOKKzyPmP3erYUlZaKEV0huWHA4CyKTBbfY,1703
50
- rslearn/models/croma.py,sha256=cOazTp3l2PNJltKrmPqD5Gy4pi3CI03-X9G4T10cX2k,9529
51
+ rslearn/models/croma.py,sha256=n7yunpT7lo8vWWaOpx4yt8jZSXjgWqfgZcZKFW5zuEQ,10591
51
52
  rslearn/models/dinov3.py,sha256=9k9kNlXCorQQwKjLGptooANd48TUBsITQ1e4fUomlM4,6337
52
53
  rslearn/models/faster_rcnn.py,sha256=uaxX6-E1f0BibaA9sorEg3be83C7kTdTc39pC5jRqwE,8286
53
54
  rslearn/models/feature_center_crop.py,sha256=24eOrvLEGGVWPw7kPHyUes5HtYNAX7GZ_NpqDGMILEY,1553
@@ -62,18 +63,18 @@ rslearn/models/prithvi.py,sha256=AIzcO5xk1ggR0MjbfhIzqPVgUKFN7odxygmgyAelfW8,401
62
63
  rslearn/models/registry.py,sha256=yCcrOvLkbn07Xtln1j7hAB_kmGw0MGsiR2TloJq9Bmk,504
63
64
  rslearn/models/resize_features.py,sha256=asKXWrLHIBrU6GaAV0Ory9YuK7IK104XjhkB4ljzI3A,1289
64
65
  rslearn/models/sam2_enc.py,sha256=gNlPokr7eNxO2KvnzDMXNxYM2WRO0YkQPjR4110n6cw,3508
65
- rslearn/models/satlaspretrain.py,sha256=YpjXl-uClhTZMDmyhN64Fg3AszzT-ymZgJB0fO9RyoY,2419
66
+ rslearn/models/satlaspretrain.py,sha256=b6FR_il6MnWU4UpB9OxInZSK9n0IS0PcQuLrWH4YD8g,3046
66
67
  rslearn/models/simple_time_series.py,sha256=oTg_akabYFBExJu7JCpbuM211-ZgQS4WerG2nEYrIZY,12774
67
68
  rslearn/models/singletask.py,sha256=z4vN9Yvzz0I-U4KJdVZxLJK2ZV-MIv9tzwCGcOWoUPY,1604
68
69
  rslearn/models/ssl4eo_s12.py,sha256=sOGEHcDo-rNdmEyoLu2AVEqfxRM_cv6zpfAmyn5c6tw,3553
69
70
  rslearn/models/swin.py,sha256=bMlGePXMFou4A_YSUZzjHgN9NniGXaCWdGQ31xHDKis,5511
70
71
  rslearn/models/task_embedding.py,sha256=Z6sf61BLCtvdrdnvjh8500b-KiFp3GeWbT4mOqpaCKk,9100
71
- rslearn/models/terramind.py,sha256=kipar8sMaHJJ3b8vIgL0-s4qhHcA0Vb854vmlZ9cWh4,7524
72
+ rslearn/models/terramind.py,sha256=5POVk_y29LlbVswa6ojd9gdB70iO41yB9Y2aqVY4WdQ,8327
72
73
  rslearn/models/trunk.py,sha256=H1QPQGAKsmocq3OiF66GW8MQI4LffupTDrgZR4Ta7QM,4708
73
74
  rslearn/models/unet.py,sha256=WUgLgvvlgV8l_6MIDBl6aX1HNOkb24DfnVRIyYXHCjo,6865
74
75
  rslearn/models/upsample.py,sha256=3kWbyWZIk56JJxj8en9pieitbrk3XnbIsTKlEkiDQQY,938
75
76
  rslearn/models/use_croma.py,sha256=OSBqMuLp-pDtqPNWAVBfmX4wckmyYCKtUDdGCjJk_K8,17966
76
- rslearn/models/clay/clay.py,sha256=5RO5H8EM0tKjCwWMQ4xDkKkUCwKpm2K_Yw1alnhvVhU,7773
77
+ rslearn/models/clay/clay.py,sha256=29CGCOysx9duEX4Y6LUNHXck_sHjCFrlV4w8CP_hKmI,8460
77
78
  rslearn/models/clay/configs/metadata.yaml,sha256=rZTFh4Yb9htEfbQNOPl4HTbFogEhzwIRqFzG-1uT01Y,4652
78
79
  rslearn/models/detr/__init__.py,sha256=GGAnTIhyuvl34IRrJ_4gXjm_01OlM5rbQQ3c3TGfbK8,84
79
80
  rslearn/models/detr/box_ops.py,sha256=ORCF6EwMpMBB_VgQT05SjR47dCR2rN2gPhL_gsuUWJs,3236
@@ -107,8 +108,9 @@ rslearn/tile_stores/__init__.py,sha256=o_tWVKu6UwFzZbO9jn_3cmIDqc_Q3qDd6tA9If0T_
107
108
  rslearn/tile_stores/default.py,sha256=PYaDNvBxhJTDKJGw0EjDTSE1OKajR7_iJpMbOjj-mE8,15054
108
109
  rslearn/tile_stores/tile_store.py,sha256=9AeYduDYPp_Ia2NMlq6osptpz_AFGIOQcLJrqZ_m-z0,10469
109
110
  rslearn/train/__init__.py,sha256=fnJyY4aHs5zQqbDKSfXsJZXY_M9fbTsf7dRYaPwZr2M,30
110
- rslearn/train/data_module.py,sha256=K-nQgnOZn-KGq_G2pVOQFtWRrlWih0212i_bkXZ2bEE,23515
111
- rslearn/train/dataset.py,sha256=YiskNlYYcKqZxyw0Xzop1RGLbjMc-oK_rmhrSMVbTQg,51857
111
+ rslearn/train/all_patches_dataset.py,sha256=xFJ96HU3CodrUBzXTsgrmEShosKH79T2SxI0xDVSH3Q,18217
112
+ rslearn/train/data_module.py,sha256=pgut8rEWHIieZ7RR8dUvhtlNqk0egEdznYF3tCvqdHg,23552
113
+ rslearn/train/dataset.py,sha256=8F3bpus25g_NG0-CwMCuznwKxOvBDClNBCOEvDbMyN8,34312
112
114
  rslearn/train/lightning_module.py,sha256=ZLBiId3secUlVs2yzkN-mwVv4rMdh5TkdZYl4vv_Cw0,14466
113
115
  rslearn/train/optimizer.py,sha256=EKSqkmERalDA0bF32Gey7n6z69KLyaUWKlRsGJfKBmE,927
114
116
  rslearn/train/prediction_writer.py,sha256=mDvREwEB5k5_tNuBnYIvAGnxS3sYFWQYvV07V3UEe2k,14106
@@ -138,7 +140,7 @@ rslearn/train/transforms/select_bands.py,sha256=uDfD9G8Z4VTt88QZsjj1FB20QEmzSefh
138
140
  rslearn/train/transforms/sentinel1.py,sha256=FrLaYZs2AjqWQCun8DTFtgo1l0xLxqaFKtDNIehtpDg,1913
139
141
  rslearn/train/transforms/transform.py,sha256=n1Qzqix2dVvej-Q7iPzHeOQbqH79IBlvqPoymxhNVpE,4446
140
142
  rslearn/utils/__init__.py,sha256=GNvdTUmXakiEMnLdje7k1fe5aC7SFVqP757kbpN6Fzw,558
141
- rslearn/utils/array.py,sha256=JwZi7o0uj-dftREzJmqrRVR2joIwBikm3Er9KeHVIZU,2402
143
+ rslearn/utils/array.py,sha256=RC7ygtPnQwU6Lb9kwORvNxatJcaJ76JPsykQvndAfes,2444
142
144
  rslearn/utils/feature.py,sha256=lsg0WThZDJzo1mrbaL04dXYI5G3x-n5FG9aEjj7uUaI,1649
143
145
  rslearn/utils/fsspec.py,sha256=9QwN46heBhjUnth3qFeRNE3W6Wlr6dM3twYVswPnS9o,5300
144
146
  rslearn/utils/geometry.py,sha256=oZllq1aBFcDewTTDYAMnTeP1xR0EdB5Xz3ILmfASo-8,18455
@@ -152,10 +154,10 @@ rslearn/utils/spatial_index.py,sha256=eomJAUgzmjir8j9HZnSgQoJHwN9H0wGTjmJkMkLLfs
152
154
  rslearn/utils/sqlite_index.py,sha256=YGOJi66544e6JNtfSft6YIlHklFdSJO2duxQ4TJ2iu4,2920
153
155
  rslearn/utils/time.py,sha256=2ilSLG94_sxLP3y5RSV5L5CG8CoND_dbdzYEHVtN-I8,387
154
156
  rslearn/utils/vector_format.py,sha256=EIChYCL6GLOILS2TO2JBkca1TuaWsSubWv6iRS3P2ds,16139
155
- rslearn-0.0.13.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
156
- rslearn-0.0.13.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
157
- rslearn-0.0.13.dist-info/METADATA,sha256=44oDmbvkIrjJ0unVNaYeO5OypD6RavmG7l5HUz9Re48,36319
158
- rslearn-0.0.13.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
159
- rslearn-0.0.13.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
160
- rslearn-0.0.13.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
161
- rslearn-0.0.13.dist-info/RECORD,,
157
+ rslearn-0.0.15.dist-info/licenses/LICENSE,sha256=_99ZWPoLdlUbqZoSC5DF4ihiNwl5rTEmBaq2fACecdg,11352
158
+ rslearn-0.0.15.dist-info/licenses/NOTICE,sha256=wLPr6rwV_jCg-xEknNGwhnkfRfuoOE9MZ-lru2yZyLI,5070
159
+ rslearn-0.0.15.dist-info/METADATA,sha256=HRkJjQfvxCEosmCBvLcLd9nZnXjbmfAgPIknMy_ORBo,36319
160
+ rslearn-0.0.15.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
161
+ rslearn-0.0.15.dist-info/entry_points.txt,sha256=doTBQ57NT7nq-dgYGgTTw6mafcGWb_4PWYtYR4rGm50,46
162
+ rslearn-0.0.15.dist-info/top_level.txt,sha256=XDKo90WBH8P9RQumHxo0giLJsoufT4r9odv-WE6Ahk4,8
163
+ rslearn-0.0.15.dist-info/RECORD,,