rslearn 0.0.1__py3-none-any.whl → 0.0.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. rslearn/arg_parser.py +31 -0
  2. rslearn/config/__init__.py +6 -12
  3. rslearn/config/dataset.py +520 -401
  4. rslearn/const.py +9 -15
  5. rslearn/data_sources/__init__.py +8 -23
  6. rslearn/data_sources/aws_landsat.py +242 -98
  7. rslearn/data_sources/aws_open_data.py +111 -151
  8. rslearn/data_sources/aws_sentinel1.py +131 -0
  9. rslearn/data_sources/climate_data_store.py +471 -0
  10. rslearn/data_sources/copernicus.py +884 -12
  11. rslearn/data_sources/data_source.py +43 -12
  12. rslearn/data_sources/earthdaily.py +484 -0
  13. rslearn/data_sources/earthdata_srtm.py +282 -0
  14. rslearn/data_sources/eurocrops.py +242 -0
  15. rslearn/data_sources/gcp_public_data.py +578 -222
  16. rslearn/data_sources/google_earth_engine.py +461 -135
  17. rslearn/data_sources/local_files.py +219 -150
  18. rslearn/data_sources/openstreetmap.py +51 -89
  19. rslearn/data_sources/planet.py +24 -60
  20. rslearn/data_sources/planet_basemap.py +275 -0
  21. rslearn/data_sources/planetary_computer.py +798 -0
  22. rslearn/data_sources/usda_cdl.py +195 -0
  23. rslearn/data_sources/usgs_landsat.py +115 -83
  24. rslearn/data_sources/utils.py +249 -61
  25. rslearn/data_sources/vector_source.py +1 -0
  26. rslearn/data_sources/worldcereal.py +449 -0
  27. rslearn/data_sources/worldcover.py +144 -0
  28. rslearn/data_sources/worldpop.py +153 -0
  29. rslearn/data_sources/xyz_tiles.py +150 -107
  30. rslearn/dataset/__init__.py +8 -2
  31. rslearn/dataset/add_windows.py +2 -2
  32. rslearn/dataset/dataset.py +40 -51
  33. rslearn/dataset/handler_summaries.py +131 -0
  34. rslearn/dataset/manage.py +313 -74
  35. rslearn/dataset/materialize.py +431 -107
  36. rslearn/dataset/remap.py +29 -4
  37. rslearn/dataset/storage/__init__.py +1 -0
  38. rslearn/dataset/storage/file.py +202 -0
  39. rslearn/dataset/storage/storage.py +140 -0
  40. rslearn/dataset/window.py +181 -44
  41. rslearn/lightning_cli.py +454 -0
  42. rslearn/log_utils.py +24 -0
  43. rslearn/main.py +384 -181
  44. rslearn/models/anysat.py +215 -0
  45. rslearn/models/attention_pooling.py +177 -0
  46. rslearn/models/clay/clay.py +231 -0
  47. rslearn/models/clay/configs/metadata.yaml +295 -0
  48. rslearn/models/clip.py +68 -0
  49. rslearn/models/component.py +111 -0
  50. rslearn/models/concatenate_features.py +103 -0
  51. rslearn/models/conv.py +63 -0
  52. rslearn/models/croma.py +306 -0
  53. rslearn/models/detr/__init__.py +5 -0
  54. rslearn/models/detr/box_ops.py +103 -0
  55. rslearn/models/detr/detr.py +504 -0
  56. rslearn/models/detr/matcher.py +107 -0
  57. rslearn/models/detr/position_encoding.py +114 -0
  58. rslearn/models/detr/transformer.py +429 -0
  59. rslearn/models/detr/util.py +24 -0
  60. rslearn/models/dinov3.py +177 -0
  61. rslearn/models/faster_rcnn.py +30 -28
  62. rslearn/models/feature_center_crop.py +53 -0
  63. rslearn/models/fpn.py +19 -8
  64. rslearn/models/galileo/__init__.py +5 -0
  65. rslearn/models/galileo/galileo.py +595 -0
  66. rslearn/models/galileo/single_file_galileo.py +1678 -0
  67. rslearn/models/module_wrapper.py +65 -0
  68. rslearn/models/molmo.py +69 -0
  69. rslearn/models/multitask.py +384 -28
  70. rslearn/models/olmoearth_pretrain/__init__.py +1 -0
  71. rslearn/models/olmoearth_pretrain/model.py +421 -0
  72. rslearn/models/olmoearth_pretrain/norm.py +86 -0
  73. rslearn/models/panopticon.py +170 -0
  74. rslearn/models/panopticon_data/sensors/drone.yaml +32 -0
  75. rslearn/models/panopticon_data/sensors/enmap.yaml +904 -0
  76. rslearn/models/panopticon_data/sensors/goes.yaml +9 -0
  77. rslearn/models/panopticon_data/sensors/himawari.yaml +9 -0
  78. rslearn/models/panopticon_data/sensors/intuition.yaml +606 -0
  79. rslearn/models/panopticon_data/sensors/landsat8.yaml +84 -0
  80. rslearn/models/panopticon_data/sensors/modis_terra.yaml +99 -0
  81. rslearn/models/panopticon_data/sensors/qb2_ge1.yaml +34 -0
  82. rslearn/models/panopticon_data/sensors/sentinel1.yaml +85 -0
  83. rslearn/models/panopticon_data/sensors/sentinel2.yaml +97 -0
  84. rslearn/models/panopticon_data/sensors/superdove.yaml +60 -0
  85. rslearn/models/panopticon_data/sensors/wv23.yaml +63 -0
  86. rslearn/models/pick_features.py +17 -10
  87. rslearn/models/pooling_decoder.py +60 -7
  88. rslearn/models/presto/__init__.py +5 -0
  89. rslearn/models/presto/presto.py +297 -0
  90. rslearn/models/presto/single_file_presto.py +926 -0
  91. rslearn/models/prithvi.py +1147 -0
  92. rslearn/models/resize_features.py +59 -0
  93. rslearn/models/sam2_enc.py +13 -9
  94. rslearn/models/satlaspretrain.py +38 -18
  95. rslearn/models/simple_time_series.py +188 -77
  96. rslearn/models/singletask.py +24 -13
  97. rslearn/models/ssl4eo_s12.py +40 -30
  98. rslearn/models/swin.py +44 -32
  99. rslearn/models/task_embedding.py +250 -0
  100. rslearn/models/terramind.py +256 -0
  101. rslearn/models/trunk.py +139 -0
  102. rslearn/models/unet.py +68 -22
  103. rslearn/models/upsample.py +48 -0
  104. rslearn/models/use_croma.py +508 -0
  105. rslearn/template_params.py +26 -0
  106. rslearn/tile_stores/__init__.py +41 -18
  107. rslearn/tile_stores/default.py +409 -0
  108. rslearn/tile_stores/tile_store.py +236 -132
  109. rslearn/train/all_patches_dataset.py +530 -0
  110. rslearn/train/callbacks/adapters.py +53 -0
  111. rslearn/train/callbacks/freeze_unfreeze.py +348 -17
  112. rslearn/train/callbacks/gradients.py +129 -0
  113. rslearn/train/callbacks/peft.py +116 -0
  114. rslearn/train/data_module.py +444 -20
  115. rslearn/train/dataset.py +588 -235
  116. rslearn/train/lightning_module.py +192 -62
  117. rslearn/train/model_context.py +88 -0
  118. rslearn/train/optimizer.py +31 -0
  119. rslearn/train/prediction_writer.py +319 -84
  120. rslearn/train/scheduler.py +92 -0
  121. rslearn/train/tasks/classification.py +55 -28
  122. rslearn/train/tasks/detection.py +132 -76
  123. rslearn/train/tasks/embedding.py +120 -0
  124. rslearn/train/tasks/multi_task.py +28 -14
  125. rslearn/train/tasks/per_pixel_regression.py +291 -0
  126. rslearn/train/tasks/regression.py +161 -44
  127. rslearn/train/tasks/segmentation.py +428 -53
  128. rslearn/train/tasks/task.py +6 -5
  129. rslearn/train/transforms/__init__.py +1 -1
  130. rslearn/train/transforms/concatenate.py +54 -10
  131. rslearn/train/transforms/crop.py +29 -11
  132. rslearn/train/transforms/flip.py +18 -6
  133. rslearn/train/transforms/mask.py +78 -0
  134. rslearn/train/transforms/normalize.py +101 -17
  135. rslearn/train/transforms/pad.py +19 -7
  136. rslearn/train/transforms/resize.py +83 -0
  137. rslearn/train/transforms/select_bands.py +76 -0
  138. rslearn/train/transforms/sentinel1.py +75 -0
  139. rslearn/train/transforms/transform.py +89 -70
  140. rslearn/utils/__init__.py +2 -6
  141. rslearn/utils/array.py +8 -6
  142. rslearn/utils/feature.py +2 -2
  143. rslearn/utils/fsspec.py +90 -1
  144. rslearn/utils/geometry.py +347 -7
  145. rslearn/utils/get_utm_ups_crs.py +2 -3
  146. rslearn/utils/grid_index.py +5 -5
  147. rslearn/utils/jsonargparse.py +178 -0
  148. rslearn/utils/mp.py +4 -3
  149. rslearn/utils/raster_format.py +268 -116
  150. rslearn/utils/rtree_index.py +64 -17
  151. rslearn/utils/sqlite_index.py +7 -1
  152. rslearn/utils/vector_format.py +252 -97
  153. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/METADATA +532 -283
  154. rslearn-0.0.21.dist-info/RECORD +167 -0
  155. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/WHEEL +1 -1
  156. rslearn-0.0.21.dist-info/licenses/NOTICE +115 -0
  157. rslearn/data_sources/raster_source.py +0 -309
  158. rslearn/models/registry.py +0 -5
  159. rslearn/tile_stores/file.py +0 -242
  160. rslearn/utils/mgrs.py +0 -24
  161. rslearn/utils/utils.py +0 -22
  162. rslearn-0.0.1.dist-info/RECORD +0 -88
  163. /rslearn/{data_sources/geotiff.py → py.typed} +0 -0
  164. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/entry_points.txt +0 -0
  165. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info/licenses}/LICENSE +0 -0
  166. {rslearn-0.0.1.dist-info → rslearn-0.0.21.dist-info}/top_level.txt +0 -0
rslearn/main.py CHANGED
@@ -1,39 +1,56 @@
1
1
  """Entrypoint for the rslearn command-line interface."""
2
2
 
3
3
  import argparse
4
- import logging
5
4
  import multiprocessing
6
5
  import random
7
6
  import sys
7
+ import time
8
8
  from collections.abc import Callable
9
- from datetime import datetime, timezone
10
- from pathlib import Path
9
+ from datetime import UTC, datetime, timedelta
10
+ from typing import Any, TypeVar
11
11
 
12
12
  import tqdm
13
- import wandb
14
- from lightning.pytorch.cli import LightningCLI
15
13
  from rasterio.crs import CRS
16
14
  from upath import UPath
17
15
 
18
16
  from rslearn.config import LayerConfig
19
17
  from rslearn.const import WGS84_EPSG
20
- from rslearn.data_sources import Item, data_source_from_config
21
- from rslearn.dataset import Dataset, Window
18
+ from rslearn.data_sources import Item
19
+ from rslearn.dataset import Dataset, Window, WindowLayerData
22
20
  from rslearn.dataset.add_windows import add_windows_from_box, add_windows_from_file
23
- from rslearn.dataset.manage import materialize_dataset_windows, prepare_dataset_windows
24
- from rslearn.tile_stores import get_tile_store_for_layer
25
- from rslearn.train.data_module import RslearnDataModule
26
- from rslearn.train.lightning_module import RslearnLightningModule
21
+ from rslearn.dataset.handler_summaries import (
22
+ ErrorOutcome,
23
+ IngestCounts,
24
+ IngestDatasetJobsSummary,
25
+ LayerIngestSummary,
26
+ MaterializeDatasetWindowsSummary,
27
+ PrepareDatasetWindowsSummary,
28
+ UnknownIngestCounts,
29
+ )
30
+ from rslearn.dataset.manage import (
31
+ AttemptsCounter,
32
+ materialize_dataset_windows,
33
+ prepare_dataset_windows,
34
+ retry,
35
+ )
36
+ from rslearn.dataset.storage.file import FileWindowStorage
37
+ from rslearn.log_utils import get_logger
38
+ from rslearn.tile_stores import get_tile_store_with_layer
27
39
  from rslearn.utils import Projection, STGeometry
28
40
 
29
- logging.basicConfig()
41
+ logger = get_logger(__name__)
42
+
30
43
  handler_registry = {}
31
44
 
45
+ ItemType = TypeVar("ItemType", bound="Item")
46
+
47
+ MULTIPROCESSING_CONTEXT = "forkserver"
32
48
 
33
- def register_handler(category, command):
49
+
50
+ def register_handler(category: Any, command: str) -> Callable:
34
51
  """Register a new handler for a command."""
35
52
 
36
- def decorator(f):
53
+ def decorator(f: Callable) -> Callable:
37
54
  handler_registry[(category, command)] = f
38
55
  return f
39
56
 
@@ -47,7 +64,7 @@ def parse_time(time_str: str) -> datetime:
47
64
  """
48
65
  ts = datetime.fromisoformat(time_str)
49
66
  if not ts.tzinfo:
50
- ts = ts.replace(tzinfo=timezone.utc)
67
+ ts = ts.replace(tzinfo=UTC)
51
68
  return ts
52
69
 
53
70
 
@@ -60,8 +77,13 @@ def parse_time_range(
60
77
  return (parse_time(start), parse_time(end))
61
78
 
62
79
 
80
+ def parse_disabled_layers(disabled_layers: str) -> list[str]:
81
+ """Parse the disabled layers string."""
82
+ return disabled_layers.split(",") if disabled_layers else []
83
+
84
+
63
85
  @register_handler("dataset", "add_windows")
64
- def add_windows():
86
+ def add_windows() -> None:
65
87
  """Handler for the rslearn dataset add_windows command."""
66
88
  parser = argparse.ArgumentParser(
67
89
  prog="rslearn dataset add_windows",
@@ -156,7 +178,13 @@ def add_windows():
156
178
  )
157
179
  args = parser.parse_args(args=sys.argv[3:])
158
180
 
159
- def parse_projection(crs_str, resolution, x_res, y_res, default_crs=None):
181
+ def parse_projection(
182
+ crs_str: str | None,
183
+ resolution: float | None,
184
+ x_res: float,
185
+ y_res: float,
186
+ default_crs: CRS | None = None,
187
+ ) -> Projection | None:
160
188
  if not crs_str:
161
189
  if default_crs:
162
190
  crs = default_crs
@@ -197,7 +225,8 @@ def add_windows():
197
225
  box = [float(value) for value in args.box.split(",")]
198
226
 
199
227
  windows = add_windows_from_box(
200
- box=box,
228
+ # TODO: we should have an object for box
229
+ box=box, # type: ignore
201
230
  src_projection=parse_projection(
202
231
  args.src_crs, args.src_resolution, args.src_x_res, args.src_y_res
203
232
  ),
@@ -210,10 +239,10 @@ def add_windows():
210
239
  else:
211
240
  raise Exception("one of box or fname must be specified")
212
241
 
213
- print(f"created {len(windows)} windows")
242
+ logger.info(f"created {len(windows)} windows")
214
243
 
215
244
 
216
- def add_apply_on_windows_args(parser: argparse.ArgumentParser):
245
+ def add_apply_on_windows_args(parser: argparse.ArgumentParser) -> None:
217
246
  """Add arguments for handlers that use the apply_on_windows helper.
218
247
 
219
248
  Args:
@@ -223,10 +252,14 @@ def add_apply_on_windows_args(parser: argparse.ArgumentParser):
223
252
  "--root", type=str, required=True, help="Dataset root directory"
224
253
  )
225
254
  parser.add_argument(
226
- "--group", type=str, default=None, help="Only prepare windows in this group"
255
+ "--group",
256
+ type=str,
257
+ nargs="*",
258
+ default=None,
259
+ help="Only prepare windows in these groups",
227
260
  )
228
261
  parser.add_argument(
229
- "--window", type=str, default=None, help="Only prepare this window"
262
+ "--window", type=str, nargs="*", default=None, help="Only prepare these windows"
230
263
  )
231
264
  parser.add_argument(
232
265
  "--workers",
@@ -234,6 +267,12 @@ def add_apply_on_windows_args(parser: argparse.ArgumentParser):
234
267
  default=0,
235
268
  help="Number of worker processes (default 0 to use main process only)",
236
269
  )
270
+ parser.add_argument(
271
+ "--load-workers",
272
+ type=int,
273
+ default=None,
274
+ help="Number of workers for loading windows (defaults to --workers)",
275
+ )
237
276
  parser.add_argument(
238
277
  "--batch-size",
239
278
  type=int,
@@ -255,25 +294,31 @@ def add_apply_on_windows_args(parser: argparse.ArgumentParser):
255
294
 
256
295
 
257
296
  def apply_on_windows(
258
- f: Callable[[list[Window]], None],
297
+ f: Callable[[list[Window]], Any],
259
298
  dataset: Dataset,
260
- group: str | None = None,
261
- window: str | None = None,
299
+ group: str | list[str] | None = None,
300
+ names: list[str] | None = None,
262
301
  workers: int = 0,
302
+ load_workers: int | None = None,
263
303
  batch_size: int = 1,
264
304
  jobs_per_process: int | None = None,
265
305
  use_initial_job: bool = True,
266
- ):
306
+ ) -> None:
267
307
  """A helper to apply a function on windows in a dataset.
268
308
 
269
309
  Args:
270
310
  f: the function to apply on lists of windows.
271
311
  dataset: the dataset.
272
312
  group: optional, only apply on windows in this group.
273
- window: optional, only apply on windows with this name.
313
+ names: optional, only apply on windows with these names.
274
314
  workers: the number of parallel workers to use, default 0 (main thread only).
315
+ load_workers: optional different number of workers to use for loading the
316
+ windows. If set, workers controls the number of workers to process the
317
+ jobs, while load_workers controls the number of workers to use for reading
318
+ windows from the rslearn dataset. Workers is only passed if the window
319
+ storage is FileWindowStorage.
275
320
  batch_size: if workers > 0, the maximum number of windows to pass to the
276
- function. If workers == 0, all windows are always passed.
321
+ function.
277
322
  jobs_per_process: optional, terminate processes after they have handled this
278
323
  many jobs. This is useful if there is a memory leak in a dependency.
279
324
  use_initial_job: if workers > 0, by default, an initial job is run on the first
@@ -284,30 +329,33 @@ def apply_on_windows(
284
329
  if hasattr(f, "set_dataset"):
285
330
  f.set_dataset(dataset)
286
331
 
287
- groups = None
288
- names = None
289
- if group:
332
+ # Handle group. It can be None (load all groups) or list of groups. But it can also
333
+ # just be group name, in which case we must convert to list.
334
+ groups: list[str] | None
335
+ if isinstance(group, str):
290
336
  groups = [group]
291
- if window:
292
- names = [window]
293
- windows = dataset.load_windows(
294
- groups=groups, names=names, workers=workers, show_progress=True
295
- )
296
- print(f"found {len(windows)} windows")
337
+ else:
338
+ groups = group
339
+
340
+ # Load the windows. We pass workers and show_progress if it is FileWindowStorage.
341
+ kwargs: dict[str, Any] = {}
342
+ if isinstance(dataset.storage, FileWindowStorage):
343
+ if load_workers is None:
344
+ load_workers = workers
345
+ kwargs["workers"] = load_workers
346
+ kwargs["show_progress"] = True
347
+ windows = dataset.load_windows(groups=groups, names=names, **kwargs)
348
+ logger.info(f"found {len(windows)} windows")
297
349
 
298
350
  if hasattr(f, "get_jobs"):
299
- jobs = f.get_jobs(windows, workers)
300
- print(f"got {len(jobs)} jobs")
351
+ jobs = f.get_jobs(windows, load_workers)
352
+ logger.info(f"got {len(jobs)} jobs")
301
353
  else:
302
354
  jobs = windows
303
355
 
304
- if workers == 0:
305
- f(jobs)
306
- return
307
-
308
356
  random.shuffle(jobs)
309
357
 
310
- if use_initial_job:
358
+ if use_initial_job and len(jobs) > 0:
311
359
  # Apply directly on first window to get any initialization out of the way.
312
360
  f([jobs[0]])
313
361
  jobs = jobs[1:]
@@ -316,41 +364,59 @@ def apply_on_windows(
316
364
  for i in range(0, len(jobs), batch_size):
317
365
  batches.append(jobs[i : i + batch_size])
318
366
 
319
- p = multiprocessing.Pool(processes=workers, maxtasksperchild=jobs_per_process)
320
- outputs = p.imap_unordered(f, batches)
321
- for _ in tqdm.tqdm(outputs, total=len(batches)):
322
- pass
323
- p.close()
367
+ num_batches = len(batches)
368
+ if workers == 0:
369
+ # Process batches sequentially but with same error handling as parallel
370
+ for batch in tqdm.tqdm(batches, total=num_batches):
371
+ f(batch)
372
+ else:
373
+ # Process batches in parallel
374
+ p = multiprocessing.Pool(processes=workers, maxtasksperchild=jobs_per_process)
375
+ outputs = p.imap_unordered(f, batches)
376
+ for _ in tqdm.tqdm(outputs, total=num_batches):
377
+ pass
378
+ p.close()
324
379
 
325
380
 
326
- def apply_on_windows_args(f: Callable[[list[Window]], None], args: argparse.Namespace):
381
+ def apply_on_windows_args(f: Callable[..., Any], args: argparse.Namespace) -> None:
327
382
  """Call apply_on_windows with arguments passed via command-line interface."""
328
- dataset = Dataset(UPath(args.root))
383
+ dataset = Dataset(UPath(args.root), disabled_layers=args.disabled_layers)
329
384
  apply_on_windows(
330
- f,
331
- dataset,
332
- args.group,
333
- args.window,
334
- args.workers,
335
- args.batch_size,
336
- args.jobs_per_process,
337
- args.use_initial_job,
385
+ f=f,
386
+ dataset=dataset,
387
+ group=args.group,
388
+ names=args.window,
389
+ workers=args.workers,
390
+ load_workers=args.load_workers,
391
+ batch_size=args.batch_size,
392
+ jobs_per_process=args.jobs_per_process,
393
+ use_initial_job=args.use_initial_job,
338
394
  )
339
395
 
340
396
 
341
397
  class PrepareHandler:
342
398
  """apply_on_windows handler for the rslearn dataset prepare command."""
343
399
 
344
- def __init__(self, force: bool):
400
+ def __init__(
401
+ self,
402
+ force: bool,
403
+ retry_max_attempts: int = 0,
404
+ retry_backoff: timedelta = timedelta(minutes=1),
405
+ ) -> None:
345
406
  """Initialize a new PrepareHandler.
346
407
 
347
408
  Args:
348
409
  force: force prepare
410
+ retry_max_attempts: set greater than zero to retry for this many attempts in
411
+ case of error.
412
+ retry_backoff: how long to wait before retrying (see retry).
349
413
  """
350
414
  self.force = force
351
- self.dataset = None
415
+ self.dataset: Dataset | None = None
416
+ self.retry_max_attempts = retry_max_attempts
417
+ self.retry_backoff = retry_backoff
352
418
 
353
- def set_dataset(self, dataset: Dataset):
419
+ def set_dataset(self, dataset: Dataset) -> None:
354
420
  """Captures the dataset from apply_on_windows_args.
355
421
 
356
422
  Args:
@@ -358,13 +424,22 @@ class PrepareHandler:
358
424
  """
359
425
  self.dataset = dataset
360
426
 
361
- def __call__(self, windows: list[Window]):
427
+ def __call__(self, windows: list[Window]) -> PrepareDatasetWindowsSummary:
362
428
  """Prepares the windows from apply_on_windows."""
363
- prepare_dataset_windows(self.dataset, windows, self.force)
429
+ logger.info(f"Running prepare on {len(windows)} windows")
430
+ if self.dataset is None:
431
+ raise ValueError("dataset not set")
432
+ return prepare_dataset_windows(
433
+ self.dataset,
434
+ windows,
435
+ self.force,
436
+ retry_max_attempts=self.retry_max_attempts,
437
+ retry_backoff=self.retry_backoff,
438
+ )
364
439
 
365
440
 
366
441
  @register_handler("dataset", "prepare")
367
- def dataset_prepare():
442
+ def dataset_prepare() -> None:
368
443
  """Handler for the rslearn dataset prepare command."""
369
444
  parser = argparse.ArgumentParser(
370
445
  prog="rslearn dataset prepare",
@@ -377,14 +452,38 @@ def dataset_prepare():
377
452
  action=argparse.BooleanOptionalAction,
378
453
  help="Prepare windows even if they were previously prepared",
379
454
  )
455
+ parser.add_argument(
456
+ "--disabled-layers",
457
+ type=parse_disabled_layers,
458
+ default="",
459
+ help="List of layers to disable e.g 'layer1,layer2'",
460
+ )
461
+ parser.add_argument(
462
+ "--retry-max-attempts",
463
+ type=int,
464
+ default=0,
465
+ help="Retry for this many attempts",
466
+ )
467
+ parser.add_argument(
468
+ "--retry-backoff-seconds",
469
+ type=int,
470
+ default=0,
471
+ help="Backoff time (seconds) between retries",
472
+ )
380
473
  add_apply_on_windows_args(parser)
381
474
  args = parser.parse_args(args=sys.argv[3:])
382
475
 
383
- fn = PrepareHandler(args.force)
476
+ fn = PrepareHandler(
477
+ args.force,
478
+ retry_max_attempts=args.retry_max_attempts,
479
+ retry_backoff=timedelta(seconds=args.retry_backoff_seconds),
480
+ )
384
481
  apply_on_windows_args(fn, args)
385
482
 
386
483
 
387
- def _load_window_layer_datas(window: Window):
484
+ def _load_window_layer_datas(
485
+ window: Window,
486
+ ) -> tuple[Window, dict[str, WindowLayerData]]:
388
487
  # Helper for IngestHandler to use with multiprocessing.
389
488
  return window, window.load_layer_datas()
390
489
 
@@ -392,11 +491,19 @@ def _load_window_layer_datas(window: Window):
392
491
  class IngestHandler:
393
492
  """apply_on_windows handler for the rslearn dataset ingest command."""
394
493
 
395
- def __init__(self):
494
+ def __init__(
495
+ self,
496
+ ignore_errors: bool = False,
497
+ retry_max_attempts: int = 0,
498
+ retry_backoff: timedelta = timedelta(minutes=1),
499
+ ) -> None:
396
500
  """Initialize a new IngestHandler."""
397
- self.dataset = None
501
+ self.dataset: Dataset | None = None
502
+ self.ignore_errors = ignore_errors
503
+ self.retry_max_attempts = retry_max_attempts
504
+ self.retry_backoff = retry_backoff
398
505
 
399
- def set_dataset(self, dataset: Dataset):
506
+ def set_dataset(self, dataset: Dataset) -> None:
400
507
  """Captures the dataset from apply_on_windows_args.
401
508
 
402
509
  Args:
@@ -404,21 +511,32 @@ class IngestHandler:
404
511
  """
405
512
  self.dataset = dataset
406
513
 
407
- def __call__(self, jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]]):
514
+ def __call__(
515
+ self, jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]]
516
+ ) -> IngestDatasetJobsSummary:
408
517
  """Ingest the specified items.
409
518
 
410
519
  The items are computed from list of windows via IngestHandler.get_jobs.
411
520
 
412
521
  Args:
413
- jobs: list of (layer_name, item, geometries) tuples to ingest.
522
+ jobs: list of (layer_name, layer_cfg, item, geometries) tuples to ingest.
523
+
524
+ Returns:
525
+ summary of the ingest jobs operation fit for telemetry purposes.
414
526
  """
527
+ start_time = time.monotonic()
528
+ layer_summaries: list[LayerIngestSummary] = []
529
+
530
+ logger.info(f"Running ingest for {len(jobs)} jobs")
415
531
  import gc
416
532
 
533
+ if self.dataset is None:
534
+ raise ValueError("dataset not set")
417
535
  tile_store = self.dataset.get_tile_store()
418
536
 
419
537
  # Group jobs by layer name.
420
- jobs_by_layer = {}
421
- configs_by_layer = {}
538
+ jobs_by_layer: dict = {}
539
+ configs_by_layer: dict = {}
422
540
  for layer_name, layer_cfg, item, geometries in jobs:
423
541
  if layer_name not in jobs_by_layer:
424
542
  jobs_by_layer[layer_name] = []
@@ -426,24 +544,81 @@ class IngestHandler:
426
544
  configs_by_layer[layer_name] = layer_cfg
427
545
 
428
546
  for layer_name, items_and_geometries in jobs_by_layer.items():
429
- cur_tile_store = get_tile_store_for_layer(tile_store, layer_name, layer_cfg)
547
+ layer_tile_store = get_tile_store_with_layer(
548
+ tile_store, layer_name, layer_cfg
549
+ )
430
550
  layer_cfg = self.dataset.layers[layer_name]
431
- data_source = data_source_from_config(layer_cfg, self.dataset.path)
551
+ data_source = layer_cfg.instantiate_data_source(self.dataset.path)
432
552
 
553
+ attempts_counter = AttemptsCounter()
554
+ ingest_counts: IngestCounts | UnknownIngestCounts
433
555
  try:
434
- data_source.ingest(
435
- tile_store=cur_tile_store,
436
- items=[item for item, _ in items_and_geometries],
437
- geometries=[geometries for _, geometries in items_and_geometries],
556
+ retry(
557
+ lambda: data_source.ingest(
558
+ tile_store=layer_tile_store,
559
+ items=[item for item, _ in items_and_geometries],
560
+ geometries=[
561
+ geometries for _, geometries in items_and_geometries
562
+ ],
563
+ ),
564
+ retry_max_attempts=self.retry_max_attempts,
565
+ retry_backoff=self.retry_backoff,
566
+ attempts_counter=attempts_counter,
567
+ )
568
+ ingest_counts = IngestCounts(
569
+ items_ingested=len(items_and_geometries),
570
+ geometries_ingested=sum(
571
+ len(geometries) for _, geometries in items_and_geometries
572
+ ),
438
573
  )
439
574
  except Exception as e:
440
- print(
575
+ if not self.ignore_errors:
576
+ raise
577
+
578
+ ingest_counts = UnknownIngestCounts(
579
+ items_attempted=len(items_and_geometries),
580
+ geometries_attempted=sum(
581
+ len(geometries) for _, geometries in items_and_geometries
582
+ ),
583
+ )
584
+ logger.error(
441
585
  "warning: got error while ingesting "
442
586
  + f"{len(items_and_geometries)} items: {e}"
443
587
  )
444
588
 
589
+ layer_summaries.append(
590
+ LayerIngestSummary(
591
+ layer_name=layer_name,
592
+ data_source_name=getattr(layer_cfg.data_source, "name", "N/A"),
593
+ duration_seconds=time.monotonic() - start_time,
594
+ ingest_counts=ingest_counts,
595
+ ingest_attempts=attempts_counter.value,
596
+ )
597
+ )
598
+
445
599
  gc.collect()
446
600
 
601
+ return IngestDatasetJobsSummary(
602
+ duration_seconds=time.monotonic() - start_time,
603
+ num_jobs=len(jobs),
604
+ layer_summaries=layer_summaries,
605
+ )
606
+
607
+ def _load_layer_data_for_windows(
608
+ self, windows: list[Window], workers: int
609
+ ) -> list[tuple[Window, dict[str, WindowLayerData]]]:
610
+ if workers == 0:
611
+ return [(_load_window_layer_datas(window)) for window in windows]
612
+ p = multiprocessing.Pool(workers)
613
+ outputs = p.imap_unordered(_load_window_layer_datas, windows)
614
+ windows_and_layer_datas = []
615
+ for window, layer_datas in tqdm.tqdm(
616
+ outputs, total=len(windows), desc="Loading window layer datas"
617
+ ):
618
+ windows_and_layer_datas.append((window, layer_datas))
619
+ p.close()
620
+ return windows_and_layer_datas
621
+
447
622
  def get_jobs(
448
623
  self, windows: list[Window], workers: int
449
624
  ) -> list[tuple[str, LayerConfig, Item, list[STGeometry]]]:
@@ -455,17 +630,12 @@ class IngestHandler:
455
630
  This makes sure that jobs are grouped by item rather than by window, which
456
631
  makes sense because there's no reason to ingest the same item twice.
457
632
  """
633
+ if self.dataset is None:
634
+ raise ValueError("dataset not set")
458
635
  # TODO: avoid duplicating ingest_dataset_windows...
459
636
 
460
637
  # Load layer datas of each window.
461
- p = multiprocessing.Pool(workers)
462
- outputs = p.imap_unordered(_load_window_layer_datas, windows)
463
- windows_and_layer_datas = []
464
- for window, layer_datas in tqdm.tqdm(
465
- outputs, total=len(windows), desc="Loading window layer datas"
466
- ):
467
- windows_and_layer_datas.append((window, layer_datas))
468
- p.close()
638
+ windows_and_layer_datas = self._load_layer_data_for_windows(windows, workers)
469
639
 
470
640
  jobs: list[tuple[str, LayerConfig, Item, list[STGeometry]]] = []
471
641
  for layer_name, layer_cfg in self.dataset.layers.items():
@@ -474,9 +644,9 @@ class IngestHandler:
474
644
  if not layer_cfg.data_source.ingest:
475
645
  continue
476
646
 
477
- data_source = data_source_from_config(layer_cfg, self.dataset.path)
647
+ data_source = layer_cfg.instantiate_data_source(self.dataset.path)
478
648
 
479
- geometries_by_item = {}
649
+ geometries_by_item: dict = {}
480
650
  for window, layer_datas in windows_and_layer_datas:
481
651
  if layer_name not in layer_datas:
482
652
  continue
@@ -484,7 +654,9 @@ class IngestHandler:
484
654
  layer_data = layer_datas[layer_name]
485
655
  for group in layer_data.serialized_item_groups:
486
656
  for serialized_item in group:
487
- item = data_source.deserialize_item(serialized_item)
657
+ item = data_source.deserialize_item( # type: ignore
658
+ serialized_item
659
+ )
488
660
  if item not in geometries_by_item:
489
661
  geometries_by_item[item] = []
490
662
  geometries_by_item[item].append(geometry)
@@ -492,32 +664,69 @@ class IngestHandler:
492
664
  for item, geometries in geometries_by_item.items():
493
665
  jobs.append((layer_name, layer_cfg, item, geometries))
494
666
 
495
- print(f"computed {len(jobs)} ingest jobs from {len(windows)} windows")
667
+ logger.info(f"computed {len(jobs)} ingest jobs from {len(windows)} windows")
496
668
  return jobs
497
669
 
498
670
 
499
671
  @register_handler("dataset", "ingest")
500
- def dataset_ingest():
672
+ def dataset_ingest() -> None:
501
673
  """Handler for the rslearn dataset ingest command."""
502
674
  parser = argparse.ArgumentParser(
503
675
  prog="rslearn dataset ingest",
504
676
  description="rslearn dataset ingest: ingest items in retrieved data sources",
505
677
  )
678
+ parser.add_argument(
679
+ "--disabled-layers",
680
+ type=parse_disabled_layers,
681
+ default="",
682
+ help="List of layers to disable e.g 'layer1,layer2'",
683
+ )
684
+ parser.add_argument(
685
+ "--ignore-errors",
686
+ type=bool,
687
+ default=False,
688
+ help="Ignore ingestion errors in individual jobs",
689
+ action=argparse.BooleanOptionalAction,
690
+ )
691
+ parser.add_argument(
692
+ "--retry-max-attempts",
693
+ type=int,
694
+ default=0,
695
+ help="Retry for this many attempts",
696
+ )
697
+ parser.add_argument(
698
+ "--retry-backoff-seconds",
699
+ type=int,
700
+ default=0,
701
+ help="Backoff time (seconds) between retries",
702
+ )
506
703
  add_apply_on_windows_args(parser)
507
704
  args = parser.parse_args(args=sys.argv[3:])
508
705
 
509
- fn = IngestHandler()
706
+ fn = IngestHandler(
707
+ ignore_errors=args.ignore_errors,
708
+ retry_max_attempts=args.retry_max_attempts,
709
+ retry_backoff=timedelta(seconds=args.retry_backoff_seconds),
710
+ )
510
711
  apply_on_windows_args(fn, args)
511
712
 
512
713
 
513
714
  class MaterializeHandler:
514
715
  """apply_on_windows handler for the rslearn dataset materialize command."""
515
716
 
516
- def __init__(self):
717
+ def __init__(
718
+ self,
719
+ ignore_errors: bool = False,
720
+ retry_max_attempts: int = 0,
721
+ retry_backoff: timedelta = timedelta(minutes=1),
722
+ ) -> None:
517
723
  """Initialize a MaterializeHandler."""
518
- self.dataset = None
724
+ self.dataset: Dataset | None = None
725
+ self.ignore_errors = ignore_errors
726
+ self.retry_max_attempts = retry_max_attempts
727
+ self.retry_backoff = retry_backoff
519
728
 
520
- def set_dataset(self, dataset: Dataset):
729
+ def set_dataset(self, dataset: Dataset) -> None:
521
730
  """Captures the dataset from apply_on_windows_args.
522
731
 
523
732
  Args:
@@ -525,13 +734,31 @@ class MaterializeHandler:
525
734
  """
526
735
  self.dataset = dataset
527
736
 
528
- def __call__(self, windows: list[Window]):
737
+ def __call__(
738
+ self, windows: list[Window]
739
+ ) -> MaterializeDatasetWindowsSummary | ErrorOutcome:
529
740
  """Materializes the windows from apply_on_windows."""
530
- materialize_dataset_windows(self.dataset, windows)
741
+ logger.info(f"Running Materialize with {len(windows)} windows")
742
+ start_time = time.monotonic()
743
+ if self.dataset is None:
744
+ raise ValueError("dataset not set")
745
+ try:
746
+ return materialize_dataset_windows(
747
+ self.dataset,
748
+ windows,
749
+ retry_max_attempts=self.retry_max_attempts,
750
+ retry_backoff=self.retry_backoff,
751
+ )
752
+ except Exception as e:
753
+ if not self.ignore_errors:
754
+ logger.error(f"Error materializing windows: {e}")
755
+ raise
756
+ logger.warning(f"Ignoring error while materializing windows: {e}")
757
+ return ErrorOutcome(duration_seconds=time.monotonic() - start_time)
531
758
 
532
759
 
533
760
  @register_handler("dataset", "materialize")
534
- def dataset_materialize():
761
+ def dataset_materialize() -> None:
535
762
  """Handler for the rslearn dataset materialize command."""
536
763
  parser = argparse.ArgumentParser(
537
764
  prog="rslearn dataset materialize",
@@ -540,110 +767,87 @@ def dataset_materialize():
540
767
  + "materialize data from retrieved data sources"
541
768
  ),
542
769
  )
770
+ parser.add_argument(
771
+ "--disabled-layers",
772
+ type=parse_disabled_layers,
773
+ default="",
774
+ help="List of layers to disable e.g 'layer1,layer2'",
775
+ )
776
+ parser.add_argument(
777
+ "--ignore-errors",
778
+ type=bool,
779
+ default=False,
780
+ help="Ignore errors in individual jobs",
781
+ action=argparse.BooleanOptionalAction,
782
+ )
783
+ parser.add_argument(
784
+ "--retry-max-attempts",
785
+ type=int,
786
+ default=0,
787
+ help="Retry for this many attempts",
788
+ )
789
+ parser.add_argument(
790
+ "--retry-backoff-seconds",
791
+ type=int,
792
+ default=0,
793
+ help="Backoff time (seconds) between retries",
794
+ )
543
795
  add_apply_on_windows_args(parser)
544
796
  args = parser.parse_args(args=sys.argv[3:])
545
-
546
- fn = MaterializeHandler()
547
- apply_on_windows_args(fn, args)
548
-
549
-
550
- class RslearnLightningCLI(LightningCLI):
551
- """LightningCLI that links data.tasks to model.tasks."""
552
-
553
- def add_arguments_to_parser(self, parser) -> None:
554
- """Link data.tasks to model.tasks.
555
-
556
- Args:
557
- parser: the argument parser
558
- """
559
- parser.link_arguments(
560
- "data.init_args.task", "model.init_args.task", apply_on="instantiate"
561
- )
562
- parser.add_argument(
563
- "--wandb_run_id",
564
- default="",
565
- type=str,
566
- help="W&B run ID to load checkpoint from",
567
- )
568
- parser.add_argument(
569
- "--wandb_resume",
570
- default=False,
571
- type=bool,
572
- help="Whether to resume from specified wandb_run_id",
573
- )
574
-
575
- def before_instantiate_classes(self):
576
- """Called before Lightning class initialization.
577
-
578
- Sets up wandb_run_id / wandb_resume arguments.
579
- """
580
- subcommand = self.config.subcommand
581
- c = self.config[subcommand]
582
-
583
- if c.wandb_run_id:
584
- api = wandb.Api()
585
- artifact_id = (
586
- f"{c.trainer.logger.init_args.project}/model-{c.wandb_run_id}:latest"
587
- )
588
- print(f"restoring from artifact {artifact_id} on wandb")
589
- artifact = api.artifact(artifact_id, type="model")
590
- artifact_dir = artifact.download()
591
- c.ckpt_path = str(Path(artifact_dir) / "model.ckpt")
592
-
593
- if c.wandb_resume:
594
- c.trainer.logger.init_args.id = c.wandb_run_id
595
-
596
- # If there is a RslearnPredictionWriter, set its path.
597
- prediction_writer_callback = None
598
- if "callbacks" in c.trainer:
599
- for existing_callback in c.trainer.callbacks:
600
- if (
601
- existing_callback.class_path
602
- == "rslearn.train.prediction_writer.RslearnWriter"
603
- ):
604
- prediction_writer_callback = existing_callback
605
- if prediction_writer_callback:
606
- prediction_writer_callback.init_args.path = c.data.init_args.path
607
-
608
-
609
- def model_handler():
610
- """Handler for any rslearn model X commands."""
611
- RslearnLightningCLI(
612
- model_class=RslearnLightningModule,
613
- datamodule_class=RslearnDataModule,
614
- args=sys.argv[2:],
615
- subclass_mode_model=True,
616
- subclass_mode_data=True,
617
- save_config_kwargs={"overwrite": True},
797
+ fn = MaterializeHandler(
798
+ ignore_errors=args.ignore_errors,
799
+ retry_max_attempts=args.retry_max_attempts,
800
+ retry_backoff=timedelta(seconds=args.retry_backoff_seconds),
618
801
  )
802
+ apply_on_windows_args(fn, args)
619
803
 
620
804
 
621
805
  @register_handler("model", "fit")
622
- def model_fit():
806
+ def model_fit() -> None:
623
807
  """Handler for rslearn model fit."""
808
+ from .lightning_cli import model_handler
809
+
624
810
  model_handler()
625
811
 
626
812
 
627
813
  @register_handler("model", "validate")
628
- def model_validate():
814
+ def model_validate() -> None:
629
815
  """Handler for rslearn model validate."""
816
+ from .lightning_cli import model_handler
817
+
630
818
  model_handler()
631
819
 
632
820
 
633
821
  @register_handler("model", "test")
634
- def model_test():
822
+ def model_test() -> None:
635
823
  """Handler for rslearn model test."""
824
+ from .lightning_cli import model_handler
825
+
636
826
  model_handler()
637
827
 
638
828
 
639
829
  @register_handler("model", "predict")
640
- def model_predict():
830
+ def model_predict() -> None:
641
831
  """Handler for rslearn model predict."""
832
+ from .lightning_cli import model_handler
833
+
642
834
  model_handler()
643
835
 
644
836
 
645
- def main():
837
+ def main() -> None:
646
838
  """CLI entrypoint."""
839
+ try:
840
+ multiprocessing.set_start_method(MULTIPROCESSING_CONTEXT)
841
+ except RuntimeError as e:
842
+ logger.error(
843
+ f"Multiprocessing context already set to {multiprocessing.get_context()}: "
844
+ + f"ignoring {e}"
845
+ )
846
+ except Exception as e:
847
+ logger.error(f"Failed to set multiprocessing context: {e}")
848
+ raise
849
+ finally:
850
+ logger.info(f"Using multiprocessing context: {multiprocessing.get_context()}")
647
851
  parser = argparse.ArgumentParser(description="rslearn")
648
852
  parser.add_argument(
649
853
  "category", help="Command category: dataset, annotate, or model"
@@ -653,12 +857,11 @@ def main():
653
857
 
654
858
  handler = handler_registry.get((args.category, args.command))
655
859
  if handler is None:
656
- print(f"Unknown command: {args.category} {args.command}", file=sys.stderr)
860
+ logger.error(f"Unknown command: {args.category} {args.command}")
657
861
  sys.exit(1)
658
862
 
659
863
  handler()
660
864
 
661
865
 
662
866
  if __name__ == "__main__":
663
- multiprocessing.set_start_method("forkserver")
664
867
  main()