rslearn 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. rslearn/arg_parser.py +2 -9
  2. rslearn/config/__init__.py +2 -0
  3. rslearn/config/dataset.py +64 -20
  4. rslearn/dataset/add_windows.py +1 -1
  5. rslearn/dataset/dataset.py +34 -84
  6. rslearn/dataset/materialize.py +5 -5
  7. rslearn/dataset/storage/__init__.py +1 -0
  8. rslearn/dataset/storage/file.py +202 -0
  9. rslearn/dataset/storage/storage.py +140 -0
  10. rslearn/dataset/window.py +26 -80
  11. rslearn/lightning_cli.py +22 -11
  12. rslearn/main.py +12 -37
  13. rslearn/models/anysat.py +11 -9
  14. rslearn/models/attention_pooling.py +177 -0
  15. rslearn/models/clay/clay.py +8 -9
  16. rslearn/models/clip.py +18 -15
  17. rslearn/models/component.py +111 -0
  18. rslearn/models/concatenate_features.py +21 -11
  19. rslearn/models/conv.py +15 -8
  20. rslearn/models/croma.py +13 -8
  21. rslearn/models/detr/detr.py +25 -14
  22. rslearn/models/dinov3.py +11 -6
  23. rslearn/models/faster_rcnn.py +19 -9
  24. rslearn/models/feature_center_crop.py +12 -9
  25. rslearn/models/fpn.py +19 -8
  26. rslearn/models/galileo/galileo.py +23 -18
  27. rslearn/models/module_wrapper.py +26 -57
  28. rslearn/models/molmo.py +16 -14
  29. rslearn/models/multitask.py +102 -73
  30. rslearn/models/olmoearth_pretrain/model.py +135 -38
  31. rslearn/models/panopticon.py +8 -7
  32. rslearn/models/pick_features.py +18 -24
  33. rslearn/models/pooling_decoder.py +22 -14
  34. rslearn/models/presto/presto.py +16 -10
  35. rslearn/models/presto/single_file_presto.py +4 -10
  36. rslearn/models/prithvi.py +12 -8
  37. rslearn/models/resize_features.py +21 -7
  38. rslearn/models/sam2_enc.py +11 -9
  39. rslearn/models/satlaspretrain.py +15 -9
  40. rslearn/models/simple_time_series.py +37 -17
  41. rslearn/models/singletask.py +24 -17
  42. rslearn/models/ssl4eo_s12.py +15 -10
  43. rslearn/models/swin.py +22 -13
  44. rslearn/models/terramind.py +24 -7
  45. rslearn/models/trunk.py +6 -3
  46. rslearn/models/unet.py +18 -9
  47. rslearn/models/upsample.py +22 -9
  48. rslearn/train/all_patches_dataset.py +89 -37
  49. rslearn/train/dataset.py +105 -97
  50. rslearn/train/lightning_module.py +51 -32
  51. rslearn/train/model_context.py +54 -0
  52. rslearn/train/prediction_writer.py +111 -41
  53. rslearn/train/scheduler.py +15 -0
  54. rslearn/train/tasks/classification.py +34 -15
  55. rslearn/train/tasks/detection.py +24 -31
  56. rslearn/train/tasks/embedding.py +33 -29
  57. rslearn/train/tasks/multi_task.py +7 -7
  58. rslearn/train/tasks/per_pixel_regression.py +41 -19
  59. rslearn/train/tasks/regression.py +38 -21
  60. rslearn/train/tasks/segmentation.py +33 -15
  61. rslearn/train/tasks/task.py +3 -2
  62. rslearn/train/transforms/resize.py +74 -0
  63. rslearn/utils/geometry.py +73 -0
  64. rslearn/utils/jsonargparse.py +66 -0
  65. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/METADATA +1 -1
  66. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/RECORD +71 -66
  67. rslearn/dataset/index.py +0 -173
  68. rslearn/models/registry.py +0 -22
  69. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/WHEEL +0 -0
  70. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/entry_points.txt +0 -0
  71. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/LICENSE +0 -0
  72. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/licenses/NOTICE +0 -0
  73. {rslearn-0.0.17.dist-info → rslearn-0.0.19.dist-info}/top_level.txt +0 -0
rslearn/arg_parser.py CHANGED
@@ -1,6 +1,5 @@
1
1
  """Custom Lightning ArgumentParser with environment variable substitution support."""
2
2
 
3
- import os
4
3
  from typing import Any
5
4
 
6
5
  from jsonargparse import Namespace
@@ -21,11 +20,7 @@ class RslearnArgumentParser(LightningArgumentParser):
21
20
  def parse_string(
22
21
  self,
23
22
  cfg_str: str,
24
- cfg_path: str | os.PathLike = "",
25
- ext_vars: dict | None = None,
26
- env: bool | None = None,
27
- defaults: bool = True,
28
- with_meta: bool | None = None,
23
+ *args: Any,
29
24
  **kwargs: Any,
30
25
  ) -> Namespace:
31
26
  """Pre-processes string for environment variable substitution before parsing."""
@@ -33,6 +28,4 @@ class RslearnArgumentParser(LightningArgumentParser):
33
28
  substituted_cfg_str = substitute_env_vars_in_string(cfg_str)
34
29
 
35
30
  # Call the parent method with the substituted config
36
- return super().parse_string(
37
- substituted_cfg_str, cfg_path, ext_vars, env, defaults, with_meta, **kwargs
38
- )
31
+ return super().parse_string(substituted_cfg_str, *args, **kwargs)
@@ -10,6 +10,7 @@ from .dataset import (
10
10
  LayerType,
11
11
  QueryConfig,
12
12
  SpaceMode,
13
+ StorageConfig,
13
14
  TimeMode,
14
15
  )
15
16
 
@@ -23,5 +24,6 @@ __all__ = [
23
24
  "LayerType",
24
25
  "QueryConfig",
25
26
  "SpaceMode",
27
+ "StorageConfig",
26
28
  "TimeMode",
27
29
  ]
rslearn/config/dataset.py CHANGED
@@ -25,12 +25,13 @@ from rasterio.enums import Resampling
25
25
  from upath import UPath
26
26
 
27
27
  from rslearn.log_utils import get_logger
28
- from rslearn.utils import PixelBounds, Projection
28
+ from rslearn.utils.geometry import PixelBounds, Projection, ResolutionFactor
29
29
  from rslearn.utils.raster_format import RasterFormat
30
30
  from rslearn.utils.vector_format import VectorFormat
31
31
 
32
32
  if TYPE_CHECKING:
33
33
  from rslearn.data_sources.data_source import DataSource
34
+ from rslearn.dataset.storage.storage import WindowStorageFactory
34
35
 
35
36
  logger = get_logger("__name__")
36
37
 
@@ -132,7 +133,11 @@ class BandSetConfig(BaseModel):
132
133
  bands.
133
134
  """
134
135
 
135
- dtype: DType = Field(description="Pixel value type to store the data under")
136
+ model_config = ConfigDict(extra="forbid")
137
+
138
+ dtype: DType = Field(
139
+ description="Pixel value type to store the data under. This is used during dataset materialize and model predict."
140
+ )
136
141
  bands: list[str] = Field(
137
142
  default_factory=lambda: [],
138
143
  description="List of band names in this BandSetConfig. One of bands or num_bands must be set.",
@@ -210,22 +215,12 @@ class BandSetConfig(BaseModel):
210
215
  Returns:
211
216
  tuple of updated projection and bounds with zoom offset applied
212
217
  """
213
- if self.zoom_offset == 0:
214
- return projection, bounds
215
- projection = Projection(
216
- projection.crs,
217
- projection.x_resolution / (2**self.zoom_offset),
218
- projection.y_resolution / (2**self.zoom_offset),
219
- )
220
- if self.zoom_offset > 0:
221
- zoom_factor = 2**self.zoom_offset
222
- bounds = tuple(x * zoom_factor for x in bounds) # type: ignore
218
+ if self.zoom_offset >= 0:
219
+ factor = ResolutionFactor(numerator=2**self.zoom_offset)
223
220
  else:
224
- bounds = tuple(
225
- x // (2 ** (-self.zoom_offset))
226
- for x in bounds # type: ignore
227
- )
228
- return projection, bounds
221
+ factor = ResolutionFactor(denominator=2 ** (-self.zoom_offset))
222
+
223
+ return (factor.multiply_projection(projection), factor.multiply_bounds(bounds))
229
224
 
230
225
  @field_validator("format", mode="before")
231
226
  @classmethod
@@ -329,7 +324,7 @@ class TimeMode(StrEnum):
329
324
  class QueryConfig(BaseModel):
330
325
  """A configuration for querying items in a data source."""
331
326
 
332
- model_config = ConfigDict(frozen=True)
327
+ model_config = ConfigDict(frozen=True, extra="forbid")
333
328
 
334
329
  space_mode: SpaceMode = Field(
335
330
  default=SpaceMode.MOSAIC,
@@ -363,7 +358,7 @@ class QueryConfig(BaseModel):
363
358
  class DataSourceConfig(BaseModel):
364
359
  """Configuration for a DataSource in a dataset layer."""
365
360
 
366
- model_config = ConfigDict(frozen=True)
361
+ model_config = ConfigDict(frozen=True, extra="forbid")
367
362
 
368
363
  class_path: str = Field(description="Class path for the data source.")
369
364
  init_args: dict[str, Any] = Field(
@@ -469,7 +464,7 @@ class CompositingMethod(StrEnum):
469
464
  class LayerConfig(BaseModel):
470
465
  """Configuration of a layer in a dataset."""
471
466
 
472
- model_config = ConfigDict(frozen=True)
467
+ model_config = ConfigDict(frozen=True, extra="forbid")
473
468
 
474
469
  type: LayerType = Field(description="The LayerType (raster or vector).")
475
470
  data_source: DataSourceConfig | None = Field(
@@ -592,11 +587,60 @@ class LayerConfig(BaseModel):
592
587
  return vector_format
593
588
 
594
589
 
590
+ class StorageConfig(BaseModel):
591
+ """Configuration for the WindowStorageFactory (window metadata storage backend)."""
592
+
593
+ model_config = ConfigDict(frozen=True, extra="forbid")
594
+
595
+ class_path: str = Field(
596
+ default="rslearn.dataset.storage.file.FileWindowStorageFactory",
597
+ description="Class path for the WindowStorageFactory.",
598
+ )
599
+ init_args: dict[str, Any] = Field(
600
+ default_factory=lambda: {},
601
+ description="jsonargparse init args for the WindowStorageFactory.",
602
+ )
603
+
604
+ def instantiate_window_storage_factory(self) -> "WindowStorageFactory":
605
+ """Instantiate the WindowStorageFactory specified by this config."""
606
+ from rslearn.dataset.storage.storage import WindowStorageFactory
607
+ from rslearn.utils.jsonargparse import init_jsonargparse
608
+
609
+ init_jsonargparse()
610
+ parser = jsonargparse.ArgumentParser()
611
+ parser.add_argument("--wsf", type=WindowStorageFactory)
612
+ cfg = parser.parse_object(
613
+ {
614
+ "wsf": dict(
615
+ class_path=self.class_path,
616
+ init_args=self.init_args,
617
+ )
618
+ }
619
+ )
620
+ wsf = parser.instantiate_classes(cfg).wsf
621
+ return wsf
622
+
623
+
595
624
  class DatasetConfig(BaseModel):
596
625
  """Overall dataset configuration."""
597
626
 
627
+ model_config = ConfigDict(extra="forbid")
628
+
598
629
  layers: dict[str, LayerConfig] = Field(description="Layers in the dataset.")
599
630
  tile_store: dict[str, Any] = Field(
600
631
  default={"class_path": "rslearn.tile_stores.default.DefaultTileStore"},
601
632
  description="jsonargparse configuration for the TileStore.",
602
633
  )
634
+ storage: StorageConfig = Field(
635
+ default_factory=lambda: StorageConfig(),
636
+ description="jsonargparse configuration for the WindowStorageFactory.",
637
+ )
638
+
639
+ @field_validator("layers", mode="after")
640
+ @classmethod
641
+ def layer_names_validator(cls, v: dict[str, LayerConfig]) -> dict[str, LayerConfig]:
642
+ """Ensure layer names don't contain periods, since we use periods to distinguish different materialized groups within a layer."""
643
+ for layer_name in v.keys():
644
+ if "." in layer_name:
645
+ raise ValueError(f"layer names must not contain periods: {layer_name}")
646
+ return v
@@ -131,7 +131,7 @@ def add_windows_from_geometries(
131
131
  f"_{time_range[0].isoformat()}_{time_range[1].isoformat()}"
132
132
  )
133
133
  window = Window(
134
- path=dataset.path / "windows" / group / cur_window_name,
134
+ storage=dataset.storage,
135
135
  group=group,
136
136
  name=cur_window_name,
137
137
  projection=cur_projection,
@@ -1,9 +1,8 @@
1
1
  """rslearn dataset class."""
2
2
 
3
3
  import json
4
- import multiprocessing
4
+ from typing import Any
5
5
 
6
- import tqdm
7
6
  from upath import UPath
8
7
 
9
8
  from rslearn.config import DatasetConfig
@@ -11,7 +10,6 @@ from rslearn.log_utils import get_logger
11
10
  from rslearn.template_params import substitute_env_vars_in_string
12
11
  from rslearn.tile_stores import TileStore, load_tile_store
13
12
 
14
- from .index import DatasetIndex
15
13
  from .window import Window
16
14
 
17
15
  logger = get_logger(__name__)
@@ -25,7 +23,7 @@ class Dataset:
25
23
  .. code-block:: none
26
24
 
27
25
  dataset/
28
- config.json
26
+ config.json # optional, if config provided as runtime object
29
27
  windows/
30
28
  group1/
31
29
  epsg:3857_10_623565_1528020/
@@ -42,106 +40,58 @@ class Dataset:
42
40
  materialize.
43
41
  """
44
42
 
45
- def __init__(self, path: UPath, disabled_layers: list[str] = []) -> None:
43
+ def __init__(
44
+ self,
45
+ path: UPath,
46
+ disabled_layers: list[str] = [],
47
+ dataset_config: DatasetConfig | None = None,
48
+ ) -> None:
46
49
  """Initializes a new Dataset.
47
50
 
48
51
  Args:
49
52
  path: the root directory of the dataset
50
53
  disabled_layers: list of layers to disable
54
+ dataset_config: optional dataset configuration to use instead of loading from the dataset directory
51
55
  """
52
56
  self.path = path
53
57
 
54
- # Load dataset configuration.
55
- with (self.path / "config.json").open("r") as f:
56
- config_content = f.read()
57
- config_content = substitute_env_vars_in_string(config_content)
58
- config = DatasetConfig.model_validate(json.loads(config_content))
59
-
60
- self.layers = {}
61
- for layer_name, layer_config in config.layers.items():
62
- # Layer names must not contain period, since we use period to
63
- # distinguish different materialized groups within a layer.
64
- assert "." not in layer_name, "layer names must not contain periods"
65
- if layer_name in disabled_layers:
66
- logger.warning(f"Layer {layer_name} is disabled")
67
- continue
68
- self.layers[layer_name] = layer_config
69
-
70
- self.tile_store_config = config.tile_store
71
-
72
- def _get_index(self) -> DatasetIndex | None:
73
- index_fname = self.path / DatasetIndex.FNAME
74
- if not index_fname.exists():
75
- return None
76
- return DatasetIndex.load_index(self.path)
58
+ if dataset_config is None:
59
+ # Load dataset configuration from the dataset directory.
60
+ with (self.path / "config.json").open("r") as f:
61
+ config_content = f.read()
62
+ config_content = substitute_env_vars_in_string(config_content)
63
+ dataset_config = DatasetConfig.model_validate(
64
+ json.loads(config_content)
65
+ )
66
+
67
+ self.layers = {}
68
+ for layer_name, layer_config in dataset_config.layers.items():
69
+ if layer_name in disabled_layers:
70
+ logger.warning(f"Layer {layer_name} is disabled")
71
+ continue
72
+ self.layers[layer_name] = layer_config
73
+
74
+ self.tile_store_config = dataset_config.tile_store
75
+ self.storage = (
76
+ dataset_config.storage.instantiate_window_storage_factory().get_storage(
77
+ self.path
78
+ )
79
+ )
77
80
 
78
81
  def load_windows(
79
82
  self,
80
83
  groups: list[str] | None = None,
81
84
  names: list[str] | None = None,
82
- show_progress: bool = False,
83
- workers: int = 0,
84
- no_index: bool = False,
85
+ **kwargs: Any,
85
86
  ) -> list[Window]:
86
87
  """Load the windows in the dataset.
87
88
 
88
89
  Args:
89
90
  groups: an optional list of groups to filter loading
90
91
  names: an optional list of window names to filter loading
91
- show_progress: whether to show tqdm progress bar
92
- workers: number of parallel workers, default 0 (use main thread only to load windows)
93
- no_index: don't use the dataset index even if it exists.
92
+ kwargs: optional keyword arguments to pass to WindowStorage.get_windows.
94
93
  """
95
- # Load from index if it exists.
96
- # We never use the index if names is set since loading the index will likely be
97
- # slower than loading a few windows.
98
- if not no_index and names is None:
99
- dataset_index = self._get_index()
100
- if dataset_index is not None:
101
- return dataset_index.get_windows(groups=groups, names=names)
102
-
103
- # Avoid directory does not exist errors later.
104
- if not (self.path / "windows").exists():
105
- return []
106
-
107
- window_dirs = []
108
- if not groups:
109
- groups = []
110
- for p in (self.path / "windows").iterdir():
111
- groups.append(p.name)
112
- for group in groups:
113
- group_dir = self.path / "windows" / group
114
- if not group_dir.exists():
115
- logger.warning(
116
- f"Skipping group directory {group_dir} since it does not exist"
117
- )
118
- continue
119
- if names:
120
- cur_names = names
121
- else:
122
- cur_names = []
123
- for p in group_dir.iterdir():
124
- cur_names.append(p.name)
125
-
126
- for window_name in cur_names:
127
- window_dir = group_dir / window_name
128
- window_dirs.append(window_dir)
129
-
130
- if workers == 0:
131
- windows = [Window.load(window_dir) for window_dir in window_dirs]
132
- else:
133
- p = multiprocessing.Pool(workers)
134
- outputs = p.imap_unordered(Window.load, window_dirs)
135
- if show_progress:
136
- outputs = tqdm.tqdm(
137
- outputs, total=len(window_dirs), desc="Loading windows"
138
- )
139
- windows = []
140
- for window in outputs:
141
- windows.append(window)
142
- p.close()
143
-
144
- return windows
94
+ return self.storage.get_windows(groups=groups, names=names, **kwargs)
145
95
 
146
96
  def get_tile_store(self) -> TileStore:
147
97
  """Get the tile store associated with this dataset.
@@ -161,7 +161,7 @@ def build_first_valid_composite(
161
161
  nodata_vals: list[Any],
162
162
  bands: list[str],
163
163
  bounds: PixelBounds,
164
- band_dtype: Any,
164
+ band_dtype: npt.DTypeLike,
165
165
  tile_store: TileStoreWithLayer,
166
166
  projection: Projection,
167
167
  remapper: Remapper | None,
@@ -233,7 +233,7 @@ def read_and_stack_raster_windows(
233
233
  projection: Projection,
234
234
  nodata_vals: list[Any],
235
235
  remapper: Remapper | None,
236
- band_dtype: Any,
236
+ band_dtype: npt.DTypeLike,
237
237
  resampling_method: Resampling = Resampling.bilinear,
238
238
  ) -> npt.NDArray[np.generic]:
239
239
  """Create a stack of extent aligned raster windows.
@@ -326,7 +326,7 @@ def build_mean_composite(
326
326
  nodata_vals: list[Any],
327
327
  bands: list[str],
328
328
  bounds: PixelBounds,
329
- band_dtype: Any,
329
+ band_dtype: npt.DTypeLike,
330
330
  tile_store: TileStoreWithLayer,
331
331
  projection: Projection,
332
332
  remapper: Remapper | None,
@@ -383,7 +383,7 @@ def build_median_composite(
383
383
  nodata_vals: list[Any],
384
384
  bands: list[str],
385
385
  bounds: PixelBounds,
386
- band_dtype: Any,
386
+ band_dtype: npt.DTypeLike,
387
387
  tile_store: TileStoreWithLayer,
388
388
  projection: Projection,
389
389
  remapper: Remapper | None,
@@ -471,7 +471,7 @@ def build_composite(
471
471
  nodata_vals=nodata_vals,
472
472
  bands=band_cfg.bands,
473
473
  bounds=bounds,
474
- band_dtype=band_cfg.dtype.value,
474
+ band_dtype=band_cfg.dtype.get_numpy_dtype(),
475
475
  tile_store=tile_store,
476
476
  projection=projection,
477
477
  resampling_method=layer_cfg.resampling_method.get_rasterio_resampling(),
@@ -0,0 +1 @@
1
+ """Storage backends for rslearn window metadata."""
@@ -0,0 +1,202 @@
1
+ """The default file-based window storage backend."""
2
+
3
+ import json
4
+ import multiprocessing
5
+
6
+ import tqdm
7
+ from typing_extensions import override
8
+ from upath import UPath
9
+
10
+ from rslearn.dataset.window import (
11
+ LAYERS_DIRECTORY_NAME,
12
+ Window,
13
+ WindowLayerData,
14
+ get_layer_and_group_from_dir_name,
15
+ get_window_layer_dir,
16
+ )
17
+ from rslearn.log_utils import get_logger
18
+ from rslearn.utils.fsspec import open_atomic
19
+ from rslearn.utils.mp import star_imap_unordered
20
+
21
+ from .storage import WindowStorage, WindowStorageFactory
22
+
23
+ logger = get_logger(__name__)
24
+
25
+
26
+ def load_window(storage: "FileWindowStorage", window_dir: UPath) -> Window:
27
+ """Load the window from its directory by reading metadata.json.
28
+
29
+ Args:
30
+ storage: the underlying FileWindowStorage.
31
+ window_dir: the path where the window is stored.
32
+
33
+ Returns:
34
+ the window object.
35
+ """
36
+ metadata_fname = window_dir / "metadata.json"
37
+ with metadata_fname.open() as f:
38
+ metadata = json.load(f)
39
+ return Window.from_metadata(storage, metadata)
40
+
41
+
42
+ class FileWindowStorage(WindowStorage):
43
+ """The default file-backed window storage."""
44
+
45
+ def __init__(self, path: UPath):
46
+ """Create a new FileWindowStorage.
47
+
48
+ Args:
49
+ path: the path to the dataset.
50
+ """
51
+ self.path = path
52
+
53
+ @override
54
+ def get_window_root(self, group: str, name: str) -> UPath:
55
+ return Window.get_window_root(self.path, group, name)
56
+
57
+ @override
58
+ def get_windows(
59
+ self,
60
+ groups: list[str] | None = None,
61
+ names: list[str] | None = None,
62
+ show_progress: bool = False,
63
+ workers: int = 0,
64
+ ) -> list["Window"]:
65
+ """Load the windows in the dataset.
66
+
67
+ Args:
68
+ groups: an optional list of groups to filter loading
69
+ names: an optional list of window names to filter loading
70
+ show_progress: whether to show tqdm progress bar
71
+ workers: number of parallel workers, default 0 (use main thread only to load windows)
72
+ """
73
+ # Avoid directory does not exist errors later.
74
+ if not (self.path / "windows").exists():
75
+ return []
76
+
77
+ window_dirs = []
78
+ if not groups:
79
+ groups = []
80
+ for p in (self.path / "windows").iterdir():
81
+ groups.append(p.name)
82
+ for group in groups:
83
+ group_dir = self.path / "windows" / group
84
+ if not group_dir.exists():
85
+ logger.warning(
86
+ f"Skipping group directory {group_dir} since it does not exist"
87
+ )
88
+ continue
89
+ if names:
90
+ cur_names = names
91
+ else:
92
+ cur_names = []
93
+ for p in group_dir.iterdir():
94
+ cur_names.append(p.name)
95
+
96
+ for window_name in cur_names:
97
+ window_dir = group_dir / window_name
98
+ window_dirs.append(window_dir)
99
+
100
+ if workers == 0:
101
+ windows = [load_window(self, window_dir) for window_dir in window_dirs]
102
+ else:
103
+ p = multiprocessing.Pool(workers)
104
+ outputs = star_imap_unordered(
105
+ p,
106
+ load_window,
107
+ [
108
+ dict(storage=self, window_dir=window_dir)
109
+ for window_dir in window_dirs
110
+ ],
111
+ )
112
+ if show_progress:
113
+ outputs = tqdm.tqdm(
114
+ outputs, total=len(window_dirs), desc="Loading windows"
115
+ )
116
+ windows = []
117
+ for window in outputs:
118
+ windows.append(window)
119
+ p.close()
120
+
121
+ return windows
122
+
123
+ @override
124
+ def create_or_update_window(self, window: Window) -> None:
125
+ window_path = self.get_window_root(window.group, window.name)
126
+ window_path.mkdir(parents=True, exist_ok=True)
127
+ metadata_path = window_path / "metadata.json"
128
+ logger.debug(f"Saving window metadata to {metadata_path}")
129
+ with open_atomic(metadata_path, "w") as f:
130
+ json.dump(window.get_metadata(), f)
131
+
132
+ @override
133
+ def get_layer_datas(self, group: str, name: str) -> dict[str, "WindowLayerData"]:
134
+ window_path = self.get_window_root(group, name)
135
+ items_fname = window_path / "items.json"
136
+ if not items_fname.exists():
137
+ return {}
138
+
139
+ with items_fname.open() as f:
140
+ layer_datas = [
141
+ WindowLayerData.deserialize(layer_data) for layer_data in json.load(f)
142
+ ]
143
+
144
+ return {layer_data.layer_name: layer_data for layer_data in layer_datas}
145
+
146
+ @override
147
+ def save_layer_datas(
148
+ self, group: str, name: str, layer_datas: dict[str, "WindowLayerData"]
149
+ ) -> None:
150
+ window_path = self.get_window_root(group, name)
151
+ json_data = [layer_data.serialize() for layer_data in layer_datas.values()]
152
+ items_fname = window_path / "items.json"
153
+ logger.info(f"Saving window items to {items_fname}")
154
+ with open_atomic(items_fname, "w") as f:
155
+ json.dump(json_data, f)
156
+
157
+ @override
158
+ def list_completed_layers(self, group: str, name: str) -> list[tuple[str, int]]:
159
+ window_path = self.get_window_root(group, name)
160
+ layers_directory = window_path / LAYERS_DIRECTORY_NAME
161
+ if not layers_directory.exists():
162
+ return []
163
+
164
+ completed_layers = []
165
+ for layer_dir in layers_directory.iterdir():
166
+ layer_name, group_idx = get_layer_and_group_from_dir_name(layer_dir.name)
167
+ if not self.is_layer_completed(group, name, layer_name, group_idx):
168
+ continue
169
+ completed_layers.append((layer_name, group_idx))
170
+
171
+ return completed_layers
172
+
173
+ @override
174
+ def is_layer_completed(
175
+ self, group: str, name: str, layer_name: str, group_idx: int = 0
176
+ ) -> bool:
177
+ window_path = self.get_window_root(group, name)
178
+ layer_dir = get_window_layer_dir(
179
+ window_path,
180
+ layer_name,
181
+ group_idx,
182
+ )
183
+ return (layer_dir / "completed").exists()
184
+
185
+ @override
186
+ def mark_layer_completed(
187
+ self, group: str, name: str, layer_name: str, group_idx: int = 0
188
+ ) -> None:
189
+ window_path = self.get_window_root(group, name)
190
+ layer_dir = get_window_layer_dir(window_path, layer_name, group_idx)
191
+ # We assume the directory exists because the layer should be materialized before
192
+ # being marked completed.
193
+ (layer_dir / "completed").touch()
194
+
195
+
196
+ class FileWindowStorageFactory(WindowStorageFactory):
197
+ """Factory class for FileWindowStorage."""
198
+
199
+ @override
200
+ def get_storage(self, ds_path: UPath) -> FileWindowStorage:
201
+ """Get a FileWindowStorage for the given dataset path."""
202
+ return FileWindowStorage(ds_path)