hpcflow-new2 0.2.0a188__py3-none-any.whl → 0.2.0a190__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. hpcflow/__pyinstaller/hook-hpcflow.py +8 -6
  2. hpcflow/_version.py +1 -1
  3. hpcflow/app.py +1 -0
  4. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
  5. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
  6. hpcflow/sdk/__init__.py +21 -15
  7. hpcflow/sdk/app.py +2133 -770
  8. hpcflow/sdk/cli.py +281 -250
  9. hpcflow/sdk/cli_common.py +6 -2
  10. hpcflow/sdk/config/__init__.py +1 -1
  11. hpcflow/sdk/config/callbacks.py +77 -42
  12. hpcflow/sdk/config/cli.py +126 -103
  13. hpcflow/sdk/config/config.py +578 -311
  14. hpcflow/sdk/config/config_file.py +131 -95
  15. hpcflow/sdk/config/errors.py +112 -85
  16. hpcflow/sdk/config/types.py +145 -0
  17. hpcflow/sdk/core/actions.py +1054 -994
  18. hpcflow/sdk/core/app_aware.py +24 -0
  19. hpcflow/sdk/core/cache.py +81 -63
  20. hpcflow/sdk/core/command_files.py +275 -185
  21. hpcflow/sdk/core/commands.py +111 -107
  22. hpcflow/sdk/core/element.py +724 -503
  23. hpcflow/sdk/core/enums.py +192 -0
  24. hpcflow/sdk/core/environment.py +74 -93
  25. hpcflow/sdk/core/errors.py +398 -51
  26. hpcflow/sdk/core/json_like.py +540 -272
  27. hpcflow/sdk/core/loop.py +380 -334
  28. hpcflow/sdk/core/loop_cache.py +160 -43
  29. hpcflow/sdk/core/object_list.py +370 -207
  30. hpcflow/sdk/core/parameters.py +728 -600
  31. hpcflow/sdk/core/rule.py +59 -41
  32. hpcflow/sdk/core/run_dir_files.py +33 -22
  33. hpcflow/sdk/core/task.py +1546 -1325
  34. hpcflow/sdk/core/task_schema.py +240 -196
  35. hpcflow/sdk/core/test_utils.py +126 -88
  36. hpcflow/sdk/core/types.py +387 -0
  37. hpcflow/sdk/core/utils.py +410 -305
  38. hpcflow/sdk/core/validation.py +82 -9
  39. hpcflow/sdk/core/workflow.py +1192 -1028
  40. hpcflow/sdk/core/zarr_io.py +98 -137
  41. hpcflow/sdk/demo/cli.py +46 -33
  42. hpcflow/sdk/helper/cli.py +18 -16
  43. hpcflow/sdk/helper/helper.py +75 -63
  44. hpcflow/sdk/helper/watcher.py +61 -28
  45. hpcflow/sdk/log.py +83 -59
  46. hpcflow/sdk/persistence/__init__.py +8 -31
  47. hpcflow/sdk/persistence/base.py +988 -586
  48. hpcflow/sdk/persistence/defaults.py +6 -0
  49. hpcflow/sdk/persistence/discovery.py +38 -0
  50. hpcflow/sdk/persistence/json.py +408 -153
  51. hpcflow/sdk/persistence/pending.py +158 -123
  52. hpcflow/sdk/persistence/store_resource.py +37 -22
  53. hpcflow/sdk/persistence/types.py +307 -0
  54. hpcflow/sdk/persistence/utils.py +14 -11
  55. hpcflow/sdk/persistence/zarr.py +477 -420
  56. hpcflow/sdk/runtime.py +44 -41
  57. hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
  58. hpcflow/sdk/submission/jobscript.py +444 -404
  59. hpcflow/sdk/submission/schedulers/__init__.py +133 -40
  60. hpcflow/sdk/submission/schedulers/direct.py +97 -71
  61. hpcflow/sdk/submission/schedulers/sge.py +132 -126
  62. hpcflow/sdk/submission/schedulers/slurm.py +263 -268
  63. hpcflow/sdk/submission/schedulers/utils.py +7 -2
  64. hpcflow/sdk/submission/shells/__init__.py +14 -15
  65. hpcflow/sdk/submission/shells/base.py +102 -29
  66. hpcflow/sdk/submission/shells/bash.py +72 -55
  67. hpcflow/sdk/submission/shells/os_version.py +31 -30
  68. hpcflow/sdk/submission/shells/powershell.py +37 -29
  69. hpcflow/sdk/submission/submission.py +203 -257
  70. hpcflow/sdk/submission/types.py +143 -0
  71. hpcflow/sdk/typing.py +163 -12
  72. hpcflow/tests/conftest.py +8 -6
  73. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
  74. hpcflow/tests/scripts/test_main_scripts.py +60 -30
  75. hpcflow/tests/shells/wsl/test_wsl_submission.py +6 -4
  76. hpcflow/tests/unit/test_action.py +86 -75
  77. hpcflow/tests/unit/test_action_rule.py +9 -4
  78. hpcflow/tests/unit/test_app.py +13 -6
  79. hpcflow/tests/unit/test_cli.py +1 -1
  80. hpcflow/tests/unit/test_command.py +71 -54
  81. hpcflow/tests/unit/test_config.py +20 -15
  82. hpcflow/tests/unit/test_config_file.py +21 -18
  83. hpcflow/tests/unit/test_element.py +58 -62
  84. hpcflow/tests/unit/test_element_iteration.py +3 -1
  85. hpcflow/tests/unit/test_element_set.py +29 -19
  86. hpcflow/tests/unit/test_group.py +4 -2
  87. hpcflow/tests/unit/test_input_source.py +116 -93
  88. hpcflow/tests/unit/test_input_value.py +29 -24
  89. hpcflow/tests/unit/test_json_like.py +44 -35
  90. hpcflow/tests/unit/test_loop.py +65 -58
  91. hpcflow/tests/unit/test_object_list.py +17 -12
  92. hpcflow/tests/unit/test_parameter.py +16 -7
  93. hpcflow/tests/unit/test_persistence.py +48 -35
  94. hpcflow/tests/unit/test_resources.py +20 -18
  95. hpcflow/tests/unit/test_run.py +8 -3
  96. hpcflow/tests/unit/test_runtime.py +2 -1
  97. hpcflow/tests/unit/test_schema_input.py +23 -15
  98. hpcflow/tests/unit/test_shell.py +3 -2
  99. hpcflow/tests/unit/test_slurm.py +8 -7
  100. hpcflow/tests/unit/test_submission.py +39 -19
  101. hpcflow/tests/unit/test_task.py +352 -247
  102. hpcflow/tests/unit/test_task_schema.py +33 -20
  103. hpcflow/tests/unit/test_utils.py +9 -11
  104. hpcflow/tests/unit/test_value_sequence.py +15 -12
  105. hpcflow/tests/unit/test_workflow.py +114 -83
  106. hpcflow/tests/unit/test_workflow_template.py +0 -1
  107. hpcflow/tests/workflows/test_jobscript.py +2 -1
  108. hpcflow/tests/workflows/test_workflows.py +18 -13
  109. {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/METADATA +2 -1
  110. hpcflow_new2-0.2.0a190.dist-info/RECORD +165 -0
  111. hpcflow/sdk/core/parallel.py +0 -21
  112. hpcflow_new2-0.2.0a188.dist-info/RECORD +0 -158
  113. {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/LICENSE +0 -0
  114. {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/WHEEL +0 -0
  115. {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/entry_points.txt +0 -0
@@ -7,19 +7,23 @@ from __future__ import annotations
7
7
  import copy
8
8
  from contextlib import contextmanager
9
9
  from dataclasses import dataclass
10
- from datetime import datetime
11
10
  from pathlib import Path
11
+ from typing import Any, cast, TYPE_CHECKING
12
+ from typing_extensions import override
12
13
  import shutil
13
14
  import time
14
- from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
15
15
 
16
16
  import numpy as np
17
- import zarr
18
- from fsspec.implementations.zip import ZipFileSystem
17
+ from numpy.ma.core import MaskedArray
18
+ import zarr # type: ignore
19
+ from zarr.errors import BoundsCheckError # type: ignore
20
+ from zarr.storage import DirectoryStore, FSStore # type: ignore
21
+ from fsspec.implementations.zip import ZipFileSystem # type: ignore
19
22
  from rich.console import Console
20
- from numcodecs import MsgPack, VLenArray, blosc, Blosc, Zstd
21
- from reretry import retry
23
+ from numcodecs import MsgPack, VLenArray, blosc, Blosc, Zstd # type: ignore
24
+ from reretry import retry # type: ignore
22
25
 
26
+ from hpcflow.sdk.typing import hydrate
23
27
  from hpcflow.sdk.core.errors import (
24
28
  MissingParameterData,
25
29
  MissingStoreEARError,
@@ -38,18 +42,47 @@ from hpcflow.sdk.persistence.base import (
38
42
  StoreParameter,
39
43
  StoreTask,
40
44
  )
45
+ from hpcflow.sdk.persistence.types import (
46
+ LoopDescriptor,
47
+ StoreCreationInfo,
48
+ TemplateMeta,
49
+ ZarrAttrsDict,
50
+ )
41
51
  from hpcflow.sdk.persistence.store_resource import ZarrAttrsStoreResource
42
52
  from hpcflow.sdk.persistence.utils import ask_pw_on_auth_exc
43
53
  from hpcflow.sdk.persistence.pending import CommitResourceMap
44
54
  from hpcflow.sdk.persistence.base import update_param_source_dict
45
55
  from hpcflow.sdk.log import TimeIt
46
56
 
57
+ if TYPE_CHECKING:
58
+ from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
59
+ from datetime import datetime
60
+ from fsspec import AbstractFileSystem # type: ignore
61
+ from logging import Logger
62
+ from typing import ClassVar
63
+ from typing_extensions import Self, TypeAlias
64
+ from numpy.typing import NDArray
65
+ from zarr import Array, Group # type: ignore
66
+ from zarr.attrs import Attributes # type: ignore
67
+ from zarr.storage import Store # type: ignore
68
+ from .types import TypeLookup
69
+ from ..app import BaseApp
70
+ from ..core.json_like import JSONed, JSONDocument
71
+ from ..typing import ParamSource, PathLike
72
+
73
+
74
+ #: List of any (Zarr-serializable) value.
75
+ ListAny: TypeAlias = "list[Any]"
76
+ #: Zarr attribute mapping context.
77
+ ZarrAttrs: TypeAlias = "dict[str, list[str]]"
78
+ _JS: TypeAlias = "dict[str, list[dict[str, dict]]]"
79
+
47
80
 
48
81
  blosc.use_threads = False # hpcflow is a multiprocess program in general
49
82
 
50
83
 
51
84
  @TimeIt.decorator
52
- def _zarr_get_coord_selection(arr, selection, logger):
85
+ def _zarr_get_coord_selection(arr: Array, selection: Any, logger: Logger):
53
86
  @retry(
54
87
  RuntimeError,
55
88
  tries=10,
@@ -59,53 +92,84 @@ def _zarr_get_coord_selection(arr, selection, logger):
59
92
  logger=logger,
60
93
  )
61
94
  @TimeIt.decorator
62
- def _inner(arr, selection):
95
+ def _inner(arr: Array, selection: Any):
63
96
  return arr.get_coordinate_selection(selection)
64
97
 
65
98
  return _inner(arr, selection)
66
99
 
67
100
 
68
- def _encode_numpy_array(obj, type_lookup, path, root_group, arr_path):
101
+ def _encode_numpy_array(
102
+ obj: NDArray,
103
+ type_lookup: TypeLookup,
104
+ path: list[int],
105
+ root_group: Group,
106
+ arr_path: list[int],
107
+ ) -> int:
69
108
  # Might need to generate new group:
70
109
  param_arr_group = root_group.require_group(arr_path)
71
- names = [int(i.split("arr_")[1]) for i in param_arr_group.keys()]
72
- if not names:
73
- new_idx = 0
74
- else:
75
- new_idx = max(names) + 1
110
+ new_idx = (
111
+ max((int(i.removeprefix("arr_")) for i in param_arr_group.keys()), default=-1) + 1
112
+ )
76
113
  param_arr_group.create_dataset(name=f"arr_{new_idx}", data=obj)
77
114
  type_lookup["arrays"].append([path, new_idx])
78
115
 
79
116
  return len(type_lookup["arrays"]) - 1
80
117
 
81
118
 
82
- def _decode_numpy_arrays(obj, type_lookup, path, arr_group, dataset_copy):
83
- for arr_path, arr_idx in type_lookup["arrays"]:
119
+ def _decode_numpy_arrays(
120
+ obj: dict | None,
121
+ type_lookup: TypeLookup,
122
+ path: list[int],
123
+ arr_group: Group,
124
+ dataset_copy: bool,
125
+ ):
126
+ # Yuck! Type lies! Zarr's internal types are not modern Python types.
127
+ arrays = cast("Iterable[tuple[list[int], int]]", type_lookup.get("arrays", []))
128
+ obj_: dict | NDArray | None = obj
129
+ for arr_path, arr_idx in arrays:
84
130
  try:
85
131
  rel_path = get_relative_path(arr_path, path)
86
132
  except ValueError:
87
133
  continue
88
134
 
89
- dataset = arr_group.get(f"arr_{arr_idx}")
135
+ dataset: NDArray = arr_group.get(f"arr_{arr_idx}")
90
136
  if dataset_copy:
91
137
  dataset = dataset[:]
92
138
 
93
139
  if rel_path:
94
- set_in_container(obj, rel_path, dataset)
140
+ set_in_container(obj_, rel_path, dataset)
95
141
  else:
96
- obj = dataset
142
+ obj_ = dataset
97
143
 
98
- return obj
144
+ return obj_
99
145
 
100
146
 
101
- def _encode_masked_array(obj, type_lookup, path, root_group, arr_path):
147
+ def _encode_masked_array(
148
+ obj: MaskedArray,
149
+ type_lookup: TypeLookup,
150
+ path: list[int],
151
+ root_group: Group,
152
+ arr_path: list[int],
153
+ ):
102
154
  data_idx = _encode_numpy_array(obj.data, type_lookup, path, root_group, arr_path)
103
155
  mask_idx = _encode_numpy_array(obj.mask, type_lookup, path, root_group, arr_path)
104
156
  type_lookup["masked_arrays"].append([path, [data_idx, mask_idx]])
105
157
 
106
158
 
107
- def _decode_masked_arrays(obj, type_lookup, path, arr_group, dataset_copy):
108
- for arr_path, (data_idx, mask_idx) in type_lookup["masked_arrays"]:
159
+ def _decode_masked_arrays(
160
+ obj: dict,
161
+ type_lookup: TypeLookup,
162
+ path: list[int],
163
+ arr_group: Group,
164
+ dataset_copy: bool,
165
+ ):
166
+ # Yuck! Type lies! Zarr's internal types are not modern Python types.
167
+ masked_arrays = cast(
168
+ "Iterable[tuple[list[int], tuple[int, int]]]",
169
+ type_lookup.get("masked_arrays", []),
170
+ )
171
+ obj_: dict | MaskedArray = obj
172
+ for arr_path, (data_idx, mask_idx) in masked_arrays:
109
173
  try:
110
174
  rel_path = get_relative_path(arr_path, path)
111
175
  except ValueError:
@@ -113,17 +177,17 @@ def _decode_masked_arrays(obj, type_lookup, path, arr_group, dataset_copy):
113
177
 
114
178
  data = arr_group.get(f"arr_{data_idx}")
115
179
  mask = arr_group.get(f"arr_{mask_idx}")
116
- dataset = np.ma.core.MaskedArray(data=data, mask=mask)
180
+ dataset: MaskedArray = MaskedArray(data=data, mask=mask)
117
181
 
118
182
  if rel_path:
119
- set_in_container(obj, rel_path, dataset)
183
+ set_in_container(obj_, rel_path, dataset)
120
184
  else:
121
- obj = dataset
185
+ obj_ = dataset
122
186
 
123
- return obj
187
+ return obj_
124
188
 
125
189
 
126
- def append_items_to_ragged_array(arr, items):
190
+ def append_items_to_ragged_array(arr: Array, items: Sequence[int]):
127
191
  """Append an array to a Zarr ragged array.
128
192
 
129
193
  I think `arr.append([item])` should work, but does not for some reason, so we do it
@@ -135,36 +199,39 @@ def append_items_to_ragged_array(arr, items):
135
199
 
136
200
 
137
201
  @dataclass
138
- class ZarrStoreTask(StoreTask):
202
+ class ZarrStoreTask(StoreTask[dict]):
139
203
  """
140
204
  Represents a task in a Zarr persistent store.
141
205
  """
142
206
 
143
- def encode(self) -> Tuple[int, np.ndarray, Dict]:
207
+ @override
208
+ def encode(self) -> tuple[int, dict, dict[str, Any]]:
144
209
  """Prepare store task data for the persistent store."""
145
210
  wk_task = {"id_": self.id_, "element_IDs": np.array(self.element_IDs)}
146
- task = {"id_": self.id_, **self.task_template}
211
+ task = {"id_": self.id_, **(self.task_template or {})}
147
212
  return self.index, wk_task, task
148
213
 
214
+ @override
149
215
  @classmethod
150
- def decode(cls, task_dat: Dict) -> ZarrStoreTask:
216
+ def decode(cls, task_dat: dict) -> Self:
151
217
  """Initialise a `StoreTask` from persistent task data"""
152
218
  task_dat["element_IDs"] = task_dat["element_IDs"].tolist()
153
- return super().decode(task_dat)
219
+ return cls(is_pending=False, **task_dat)
154
220
 
155
221
 
156
222
  @dataclass
157
- class ZarrStoreElement(StoreElement):
223
+ class ZarrStoreElement(StoreElement[ListAny, ZarrAttrs]):
158
224
  """
159
225
  Represents an element in a Zarr persistent store.
160
226
  """
161
227
 
162
- def encode(self, attrs: Dict) -> List:
228
+ @override
229
+ def encode(self, attrs: ZarrAttrs) -> ListAny:
163
230
  """Prepare store elements data for the persistent store.
164
231
 
165
232
  This method mutates `attrs`.
166
233
  """
167
- elem_enc = [
234
+ return [
168
235
  self.id_,
169
236
  self.index,
170
237
  self.es_idx,
@@ -173,10 +240,10 @@ class ZarrStoreElement(StoreElement):
173
240
  self.task_ID,
174
241
  self.iteration_IDs,
175
242
  ]
176
- return elem_enc
177
243
 
244
+ @override
178
245
  @classmethod
179
- def decode(cls, elem_dat: List, attrs: Dict) -> ZarrStoreElement:
246
+ def decode(cls, elem_dat: ListAny, attrs: ZarrAttrs) -> Self:
180
247
  """Initialise a `StoreElement` from persistent element data"""
181
248
  obj_dat = {
182
249
  "id_": elem_dat[0],
@@ -191,21 +258,22 @@ class ZarrStoreElement(StoreElement):
191
258
 
192
259
 
193
260
  @dataclass
194
- class ZarrStoreElementIter(StoreElementIter):
261
+ class ZarrStoreElementIter(StoreElementIter[ListAny, ZarrAttrs]):
195
262
  """
196
263
  Represents an element iteration in a Zarr persistent store.
197
264
  """
198
265
 
199
- def encode(self, attrs: Dict) -> List:
266
+ @override
267
+ def encode(self, attrs: ZarrAttrs) -> ListAny:
200
268
  """Prepare store element iteration data for the persistent store.
201
269
 
202
270
  This method mutates `attrs`.
203
271
  """
204
- iter_enc = [
272
+ return [
205
273
  self.id_,
206
274
  self.element_ID,
207
275
  int(self.EARs_initialised),
208
- [[k, v] for k, v in self.EAR_IDs.items()] if self.EAR_IDs else None,
276
+ [[ek, ev] for ek, ev in self.EAR_IDs.items()] if self.EAR_IDs else None,
209
277
  [
210
278
  [ensure_in(dk, attrs["parameter_paths"]), dv]
211
279
  for dk, dv in self.data_idx.items()
@@ -213,11 +281,11 @@ class ZarrStoreElementIter(StoreElementIter):
213
281
  [ensure_in(i, attrs["schema_parameters"]) for i in self.schema_parameters],
214
282
  [[ensure_in(dk, attrs["loops"]), dv] for dk, dv in self.loop_idx.items()],
215
283
  ]
216
- return iter_enc
217
284
 
285
+ @override
218
286
  @classmethod
219
- def decode(cls, iter_dat: List, attrs: Dict) -> StoreElementIter:
220
- """Initialise a `StoreElementIter` from persistent element iteration data"""
287
+ def decode(cls, iter_dat: ListAny, attrs: ZarrAttrs) -> Self:
288
+ """Initialise a `ZarrStoreElementIter` from persistent element iteration data"""
221
289
  obj_dat = {
222
290
  "id_": iter_dat[0],
223
291
  "element_ID": iter_dat[1],
@@ -231,17 +299,18 @@ class ZarrStoreElementIter(StoreElementIter):
231
299
 
232
300
 
233
301
  @dataclass
234
- class ZarrStoreEAR(StoreEAR):
302
+ class ZarrStoreEAR(StoreEAR[ListAny, ZarrAttrs]):
235
303
  """
236
304
  Represents an element action run in a Zarr persistent store.
237
305
  """
238
306
 
239
- def encode(self, attrs: Dict, ts_fmt: str) -> Tuple[List, Tuple[np.datetime64]]:
307
+ @override
308
+ def encode(self, ts_fmt: str, attrs: ZarrAttrs) -> ListAny:
240
309
  """Prepare store EAR data for the persistent store.
241
310
 
242
311
  This method mutates `attrs`.
243
312
  """
244
- EAR_enc = [
313
+ return [
245
314
  self.id_,
246
315
  self.elem_iter_ID,
247
316
  self.action_idx,
@@ -261,10 +330,10 @@ class ZarrStoreEAR(StoreEAR):
261
330
  self.run_hostname,
262
331
  self.commands_idx,
263
332
  ]
264
- return EAR_enc
265
333
 
334
+ @override
266
335
  @classmethod
267
- def decode(cls, EAR_dat: List, attrs: Dict, ts_fmt: str) -> ZarrStoreEAR:
336
+ def decode(cls, EAR_dat: ListAny, ts_fmt: str, attrs: ZarrAttrs) -> Self:
268
337
  """Initialise a `ZarrStoreEAR` from persistent EAR data"""
269
338
  obj_dat = {
270
339
  "id_": EAR_dat[0],
@@ -287,50 +356,37 @@ class ZarrStoreEAR(StoreEAR):
287
356
 
288
357
 
289
358
  @dataclass
359
+ @hydrate
290
360
  class ZarrStoreParameter(StoreParameter):
291
361
  """
292
362
  Represents a parameter in a Zarr persistent store.
293
363
  """
294
364
 
295
- _encoders = { # keys are types
365
+ _encoders: ClassVar[dict[type, Callable]] = { # keys are types
296
366
  np.ndarray: _encode_numpy_array,
297
- np.ma.core.MaskedArray: _encode_masked_array,
367
+ MaskedArray: _encode_masked_array,
298
368
  }
299
- _decoders = { # keys are keys in type_lookup
369
+ _decoders: ClassVar[dict[str, Callable]] = { # keys are keys in type_lookup
300
370
  "arrays": _decode_numpy_arrays,
301
371
  "masked_arrays": _decode_masked_arrays,
302
372
  }
303
373
 
304
- def encode(self, root_group: zarr.Group, arr_path: str) -> Dict[str, Any]:
305
- return super().encode(root_group=root_group, arr_path=arr_path)
306
374
 
307
- @classmethod
308
- def decode(
309
- cls,
310
- id_: int,
311
- data: Union[None, Dict],
312
- source: Dict,
313
- arr_group: zarr.Group,
314
- path: Optional[List[str]] = None,
315
- dataset_copy: bool = False,
316
- ) -> Any:
317
- return super().decode(
318
- id_=id_,
319
- data=data,
320
- source=source,
321
- path=path,
322
- arr_group=arr_group,
323
- dataset_copy=dataset_copy,
324
- )
325
-
326
-
327
- class ZarrPersistentStore(PersistentStore):
375
+ class ZarrPersistentStore(
376
+ PersistentStore[
377
+ ZarrStoreTask,
378
+ ZarrStoreElement,
379
+ ZarrStoreElementIter,
380
+ ZarrStoreEAR,
381
+ ZarrStoreParameter,
382
+ ]
383
+ ):
328
384
  """
329
385
  A persistent store implemented using Zarr.
330
386
  """
331
387
 
332
- _name = "zarr"
333
- _features = PersistentStoreFeatures(
388
+ _name: ClassVar[str] = "zarr"
389
+ _features: ClassVar[PersistentStoreFeatures] = PersistentStoreFeatures(
334
390
  create=True,
335
391
  edit=True,
336
392
  jobscript_parallelism=True,
@@ -339,26 +395,42 @@ class ZarrPersistentStore(PersistentStore):
339
395
  submission=True,
340
396
  )
341
397
 
342
- _store_task_cls = ZarrStoreTask
343
- _store_elem_cls = ZarrStoreElement
344
- _store_iter_cls = ZarrStoreElementIter
345
- _store_EAR_cls = ZarrStoreEAR
346
- _store_param_cls = ZarrStoreParameter
347
-
348
- _param_grp_name = "parameters"
349
- _param_base_arr_name = "base"
350
- _param_sources_arr_name = "sources"
351
- _param_user_arr_grp_name = "arrays"
352
- _param_data_arr_grp_name = lambda _, param_idx: f"param_{param_idx}"
353
- _task_arr_name = "tasks"
354
- _elem_arr_name = "elements"
355
- _iter_arr_name = "iters"
356
- _EAR_arr_name = "runs"
357
- _time_res = "us" # microseconds; must not be smaller than micro!
358
-
359
- _res_map = CommitResourceMap(commit_template_components=("attrs",))
360
-
361
- def __init__(self, app, workflow, path, fs) -> None:
398
+ @classmethod
399
+ def _store_task_cls(cls) -> type[ZarrStoreTask]:
400
+ return ZarrStoreTask
401
+
402
+ @classmethod
403
+ def _store_elem_cls(cls) -> type[ZarrStoreElement]:
404
+ return ZarrStoreElement
405
+
406
+ @classmethod
407
+ def _store_iter_cls(cls) -> type[ZarrStoreElementIter]:
408
+ return ZarrStoreElementIter
409
+
410
+ @classmethod
411
+ def _store_EAR_cls(cls) -> type[ZarrStoreEAR]:
412
+ return ZarrStoreEAR
413
+
414
+ @classmethod
415
+ def _store_param_cls(cls) -> type[ZarrStoreParameter]:
416
+ return ZarrStoreParameter
417
+
418
+ _param_grp_name: ClassVar[str] = "parameters"
419
+ _param_base_arr_name: ClassVar[str] = "base"
420
+ _param_sources_arr_name: ClassVar[str] = "sources"
421
+ _param_user_arr_grp_name: ClassVar[str] = "arrays"
422
+ _param_data_arr_grp_name: ClassVar = lambda _, param_idx: f"param_{param_idx}"
423
+ _task_arr_name: ClassVar[str] = "tasks"
424
+ _elem_arr_name: ClassVar[str] = "elements"
425
+ _iter_arr_name: ClassVar[str] = "iters"
426
+ _EAR_arr_name: ClassVar[str] = "runs"
427
+ _time_res: ClassVar[str] = "us" # microseconds; must not be smaller than micro!
428
+
429
+ _res_map: ClassVar[CommitResourceMap] = CommitResourceMap(
430
+ commit_template_components=("attrs",)
431
+ )
432
+
433
+ def __init__(self, app, workflow, path: str | Path, fs: AbstractFileSystem) -> None:
362
434
  self._zarr_store = None # assigned on first access to `zarr_store`
363
435
  self._resources = {
364
436
  "attrs": ZarrAttrsStoreResource(
@@ -368,10 +440,10 @@ class ZarrPersistentStore(PersistentStore):
368
440
  super().__init__(app, workflow, path, fs)
369
441
 
370
442
  @contextmanager
371
- def cached_load(self) -> Iterator[Dict]:
443
+ def cached_load(self) -> Iterator[None]:
372
444
  """Context manager to cache the root attributes."""
373
445
  with self.using_resource("attrs", "read") as attrs:
374
- yield attrs
446
+ yield
375
447
 
376
448
  def remove_replaced_dir(self) -> None:
377
449
  """
@@ -380,8 +452,8 @@ class ZarrPersistentStore(PersistentStore):
380
452
  with self.using_resource("attrs", "update") as md:
381
453
  if "replaced_workflow" in md:
382
454
  self.logger.debug("removing temporarily renamed pre-existing workflow.")
383
- self.remove_path(md["replaced_workflow"], self.fs)
384
- md["replaced_workflow"] = None
455
+ self.remove_path(md["replaced_workflow"])
456
+ del md["replaced_workflow"]
385
457
 
386
458
  def reinstate_replaced_dir(self) -> None:
387
459
  """
@@ -392,32 +464,38 @@ class ZarrPersistentStore(PersistentStore):
392
464
  self.logger.debug(
393
465
  "reinstating temporarily renamed pre-existing workflow."
394
466
  )
395
- self.rename_path(md["replaced_workflow"], self.path, self.fs)
467
+ self.rename_path(
468
+ md["replaced_workflow"],
469
+ self.path,
470
+ )
396
471
 
397
472
  @staticmethod
398
- def _get_zarr_store(path: str, fs) -> zarr.storage.Store:
399
- return zarr.storage.FSStore(url=path, fs=fs)
473
+ def _get_zarr_store(path: str | Path, fs: AbstractFileSystem) -> Store:
474
+ return FSStore(url=str(path), fs=fs)
475
+
476
+ _CODEC: ClassVar = MsgPack()
400
477
 
401
478
  @classmethod
402
479
  def write_empty_workflow(
403
480
  cls,
404
- app,
405
- template_js: Dict,
406
- template_components_js: Dict,
481
+ app: BaseApp,
482
+ *,
483
+ template_js: TemplateMeta,
484
+ template_components_js: dict[str, Any],
407
485
  wk_path: str,
408
- fs,
486
+ fs: AbstractFileSystem,
409
487
  name: str,
410
- replaced_wk: str,
488
+ replaced_wk: str | None,
411
489
  ts_fmt: str,
412
490
  ts_name_fmt: str,
413
- creation_info: Dict,
414
- compressor: Optional[Union[str, None]] = "blosc",
415
- compressor_kwargs: Optional[Dict[str, Any]] = None,
491
+ creation_info: StoreCreationInfo,
492
+ compressor: str | None = "blosc",
493
+ compressor_kwargs: dict[str, Any] | None = None,
416
494
  ) -> None:
417
495
  """
418
496
  Write an empty persistent workflow.
419
497
  """
420
- attrs = {
498
+ attrs: ZarrAttrsDict = {
421
499
  "name": name,
422
500
  "ts_fmt": ts_fmt,
423
501
  "ts_name_fmt": ts_name_fmt,
@@ -459,7 +537,7 @@ class ZarrPersistentStore(PersistentStore):
459
537
  name=cls._elem_arr_name,
460
538
  shape=0,
461
539
  dtype=object,
462
- object_codec=MsgPack(),
540
+ object_codec=cls._CODEC,
463
541
  chunks=1000,
464
542
  compressor=cmp,
465
543
  )
@@ -469,7 +547,7 @@ class ZarrPersistentStore(PersistentStore):
469
547
  name=cls._iter_arr_name,
470
548
  shape=0,
471
549
  dtype=object,
472
- object_codec=MsgPack(),
550
+ object_codec=cls._CODEC,
473
551
  chunks=1000,
474
552
  compressor=cmp,
475
553
  )
@@ -485,18 +563,18 @@ class ZarrPersistentStore(PersistentStore):
485
563
  name=cls._EAR_arr_name,
486
564
  shape=0,
487
565
  dtype=object,
488
- object_codec=MsgPack(),
566
+ object_codec=cls._CODEC,
489
567
  chunks=1, # single-chunk rows for multiprocess writing
490
568
  compressor=cmp,
491
569
  )
492
- EARs_arr.attrs.update({"parameter_paths": []})
570
+ EARs_arr.attrs["parameter_paths"] = []
493
571
 
494
572
  parameter_data = root.create_group(name=cls._param_grp_name)
495
573
  parameter_data.create_dataset(
496
574
  name=cls._param_base_arr_name,
497
575
  shape=0,
498
576
  dtype=object,
499
- object_codec=MsgPack(),
577
+ object_codec=cls._CODEC,
500
578
  chunks=1,
501
579
  compressor=cmp,
502
580
  write_empty_chunks=False,
@@ -506,15 +584,15 @@ class ZarrPersistentStore(PersistentStore):
506
584
  name=cls._param_sources_arr_name,
507
585
  shape=0,
508
586
  dtype=object,
509
- object_codec=MsgPack(),
587
+ object_codec=cls._CODEC,
510
588
  chunks=1000, # TODO: check this is a sensible size with many parameters
511
589
  compressor=cmp,
512
590
  )
513
591
  parameter_data.create_group(name=cls._param_user_arr_grp_name)
514
592
 
515
- def _append_tasks(self, tasks: List[ZarrStoreTask]):
593
+ def _append_tasks(self, tasks: Iterable[ZarrStoreTask]):
516
594
  elem_IDs_arr = self._get_tasks_arr(mode="r+")
517
- elem_IDs = []
595
+ elem_IDs: list[int] = []
518
596
  with self.using_resource("attrs", "update") as attrs:
519
597
  for i_idx, i in enumerate(tasks):
520
598
  idx, wk_task_i, task_i = i.encode()
@@ -529,9 +607,9 @@ class ZarrPersistentStore(PersistentStore):
529
607
  # increasing IDs.
530
608
  append_items_to_ragged_array(arr=elem_IDs_arr, items=elem_IDs)
531
609
 
532
- def _append_loops(self, loops: Dict[int, Dict]):
610
+ def _append_loops(self, loops: dict[int, LoopDescriptor]):
533
611
  with self.using_resource("attrs", action="update") as attrs:
534
- for loop_idx, loop in loops.items():
612
+ for loop in loops.values():
535
613
  attrs["loops"].append(
536
614
  {
537
615
  "num_added_iterations": loop["num_added_iterations"],
@@ -541,12 +619,11 @@ class ZarrPersistentStore(PersistentStore):
541
619
  )
542
620
  attrs["template"]["loops"].append(loop["loop_template"])
543
621
 
544
- def _append_submissions(self, subs: Dict[int, Dict]):
622
+ def _append_submissions(self, subs: dict[int, JSONDocument]):
545
623
  with self.using_resource("attrs", action="update") as attrs:
546
- for sub_idx, sub_i in subs.items():
547
- attrs["submissions"].append(sub_i)
624
+ attrs["submissions"].extend(subs.values())
548
625
 
549
- def _append_task_element_IDs(self, task_ID: int, elem_IDs: List[int]):
626
+ def _append_task_element_IDs(self, task_ID: int, elem_IDs: list[int]):
550
627
  # I don't think there's a way to "append" to an existing array in a zarr ragged
551
628
  # array? So we have to build a new array from existing + new.
552
629
  arr = self._get_tasks_arr(mode="r+")
@@ -554,169 +631,161 @@ class ZarrPersistentStore(PersistentStore):
554
631
  elem_IDs_new = np.concatenate((elem_IDs_cur, elem_IDs))
555
632
  arr[task_ID] = elem_IDs_new
556
633
 
557
- def _append_elements(self, elems: List[ZarrStoreElement]):
558
- arr = self._get_elements_arr(mode="r+")
559
- attrs_orig = arr.attrs.asdict()
634
+ @staticmethod
635
+ def __as_dict(attrs: Attributes) -> ZarrAttrs:
636
+ """
637
+ Type thunk to work around incomplete typing in zarr.
638
+ """
639
+ return cast("ZarrAttrs", attrs.asdict())
640
+
641
+ @contextmanager
642
+ def __mutate_attrs(self, arr: Array) -> Iterator[ZarrAttrs]:
643
+ attrs_orig = self.__as_dict(arr.attrs)
560
644
  attrs = copy.deepcopy(attrs_orig)
561
- arr_add = np.empty((len(elems)), dtype=object)
562
- arr_add[:] = [i.encode(attrs) for i in elems]
563
- arr.append(arr_add)
645
+ yield attrs
564
646
  if attrs != attrs_orig:
565
647
  arr.attrs.put(attrs)
566
648
 
567
- def _append_element_sets(self, task_id: int, es_js: List[Dict]):
649
+ def _append_elements(self, elems: Sequence[ZarrStoreElement]):
650
+ arr = self._get_elements_arr(mode="r+")
651
+ with self.__mutate_attrs(arr) as attrs:
652
+ arr_add = np.empty((len(elems)), dtype=object)
653
+ arr_add[:] = [elem.encode(attrs) for elem in elems]
654
+ arr.append(arr_add)
655
+
656
+ def _append_element_sets(self, task_id: int, es_js: Sequence[Mapping]):
568
657
  task_idx = task_idx = self._get_task_id_to_idx_map()[task_id]
569
658
  with self.using_resource("attrs", "update") as attrs:
570
659
  attrs["template"]["tasks"][task_idx]["element_sets"].extend(es_js)
571
660
 
572
- def _append_elem_iter_IDs(self, elem_ID: int, iter_IDs: List[int]):
661
+ def _append_elem_iter_IDs(self, elem_ID: int, iter_IDs: Iterable[int]):
573
662
  arr = self._get_elements_arr(mode="r+")
574
- attrs = arr.attrs.asdict()
575
- elem_dat = arr[elem_ID]
663
+ attrs = self.__as_dict(arr.attrs)
664
+ elem_dat = cast("list", arr[elem_ID])
576
665
  store_elem = ZarrStoreElement.decode(elem_dat, attrs)
577
666
  store_elem = store_elem.append_iteration_IDs(iter_IDs)
578
- arr[elem_ID] = store_elem.encode(
579
- attrs
580
- ) # attrs shouldn't be mutated (TODO: test!)
667
+ arr[elem_ID] = store_elem.encode(attrs)
668
+ # attrs shouldn't be mutated (TODO: test!)
581
669
 
582
- def _append_elem_iters(self, iters: List[ZarrStoreElementIter]):
670
+ def _append_elem_iters(self, iters: Sequence[ZarrStoreElementIter]):
583
671
  arr = self._get_iters_arr(mode="r+")
584
- attrs_orig = arr.attrs.asdict()
585
- attrs = copy.deepcopy(attrs_orig)
586
- arr_add = np.empty((len(iters)), dtype=object)
587
- arr_add[:] = [i.encode(attrs) for i in iters]
588
- arr.append(arr_add)
589
- if attrs != attrs_orig:
590
- arr.attrs.put(attrs)
672
+ with self.__mutate_attrs(arr) as attrs:
673
+ arr_add = np.empty((len(iters)), dtype=object)
674
+ arr_add[:] = [i.encode(attrs) for i in iters]
675
+ arr.append(arr_add)
591
676
 
592
- def _append_elem_iter_EAR_IDs(self, iter_ID: int, act_idx: int, EAR_IDs: List[int]):
677
+ def _append_elem_iter_EAR_IDs(
678
+ self, iter_ID: int, act_idx: int, EAR_IDs: Sequence[int]
679
+ ):
593
680
  arr = self._get_iters_arr(mode="r+")
594
- attrs = arr.attrs.asdict()
595
- iter_dat = arr[iter_ID]
681
+ attrs = self.__as_dict(arr.attrs)
682
+ iter_dat = cast("list", arr[iter_ID])
596
683
  store_iter = ZarrStoreElementIter.decode(iter_dat, attrs)
597
684
  store_iter = store_iter.append_EAR_IDs(pend_IDs={act_idx: EAR_IDs})
598
- arr[iter_ID] = store_iter.encode(
599
- attrs
600
- ) # attrs shouldn't be mutated (TODO: test!)
685
+ arr[iter_ID] = store_iter.encode(attrs)
686
+ # attrs shouldn't be mutated (TODO: test!)
601
687
 
602
688
  def _update_elem_iter_EARs_initialised(self, iter_ID: int):
603
689
  arr = self._get_iters_arr(mode="r+")
604
- attrs = arr.attrs.asdict()
605
- iter_dat = arr[iter_ID]
690
+ attrs = self.__as_dict(arr.attrs)
691
+ iter_dat = cast("list", arr[iter_ID])
606
692
  store_iter = ZarrStoreElementIter.decode(iter_dat, attrs)
607
693
  store_iter = store_iter.set_EARs_initialised()
608
- arr[iter_ID] = store_iter.encode(
609
- attrs
610
- ) # attrs shouldn't be mutated (TODO: test!)
694
+ arr[iter_ID] = store_iter.encode(attrs)
695
+ # attrs shouldn't be mutated (TODO: test!)
611
696
 
612
- def _append_submission_parts(self, sub_parts: Dict[int, Dict[str, List[int]]]):
697
+ def _append_submission_parts(self, sub_parts: dict[int, dict[str, list[int]]]):
613
698
  with self.using_resource("attrs", action="update") as attrs:
614
699
  for sub_idx, sub_i_parts in sub_parts.items():
700
+ sub = cast("dict", attrs["submissions"][sub_idx])
615
701
  for dt_str, parts_j in sub_i_parts.items():
616
- attrs["submissions"][sub_idx]["submission_parts"][dt_str] = parts_j
702
+ sub["submission_parts"][dt_str] = parts_j
617
703
 
618
- def _update_loop_index(self, iter_ID: int, loop_idx: Dict):
704
+ def _update_loop_index(self, iter_ID: int, loop_idx: Mapping[str, int]):
619
705
  arr = self._get_iters_arr(mode="r+")
620
- attrs = arr.attrs.asdict()
621
- iter_dat = arr[iter_ID]
706
+ attrs = self.__as_dict(arr.attrs)
707
+ iter_dat = cast("list", arr[iter_ID])
622
708
  store_iter = ZarrStoreElementIter.decode(iter_dat, attrs)
623
709
  store_iter = store_iter.update_loop_idx(loop_idx)
624
710
  arr[iter_ID] = store_iter.encode(attrs)
625
711
 
626
- def _update_loop_num_iters(self, index: int, num_iters: int):
712
+ def _update_loop_num_iters(self, index: int, num_iters: list[list[list[int] | int]]):
627
713
  with self.using_resource("attrs", action="update") as attrs:
628
714
  attrs["loops"][index]["num_added_iterations"] = num_iters
629
715
 
630
- def _update_loop_parents(self, index: int, parents: List[str]):
716
+ def _update_loop_parents(self, index: int, parents: list[str]):
631
717
  with self.using_resource("attrs", action="update") as attrs:
632
718
  attrs["loops"][index]["parents"] = parents
633
719
 
634
- def _append_EARs(self, EARs: List[ZarrStoreEAR]):
720
+ def _append_EARs(self, EARs: Sequence[ZarrStoreEAR]):
635
721
  arr = self._get_EARs_arr(mode="r+")
636
- attrs_orig = arr.attrs.asdict()
637
- attrs = copy.deepcopy(attrs_orig)
638
- arr_add = np.empty((len(EARs)), dtype=object)
639
- arr_add[:] = [i.encode(attrs, self.ts_fmt) for i in EARs]
640
- arr.append(arr_add)
641
-
642
- if attrs != attrs_orig:
643
- arr.attrs.put(attrs)
722
+ with self.__mutate_attrs(arr) as attrs:
723
+ arr_add = np.empty((len(EARs)), dtype=object)
724
+ arr_add[:] = [ear.encode(self.ts_fmt, attrs) for ear in EARs]
725
+ arr.append(arr_add)
644
726
 
645
727
  @TimeIt.decorator
646
- def _update_EAR_submission_indices(self, sub_indices: Dict[int:int]):
647
- EAR_IDs = list(sub_indices.keys())
728
+ def _update_EAR_submission_indices(self, sub_indices: Mapping[int, int]):
729
+ EAR_IDs = list(sub_indices)
648
730
  EARs = self._get_persistent_EARs(EAR_IDs)
649
731
 
650
732
  arr = self._get_EARs_arr(mode="r+")
651
- attrs_orig = arr.attrs.asdict()
652
- attrs = copy.deepcopy(attrs_orig)
653
-
654
- encoded_EARs = []
655
- for EAR_ID_i, sub_idx_i in sub_indices.items():
656
- new_EAR_i = EARs[EAR_ID_i].update(submission_idx=sub_idx_i)
657
- # seems to be a Zarr bug that prevents `set_coordinate_selection` with an
658
- # object array, so set one-by-one:
659
- arr[EAR_ID_i] = new_EAR_i.encode(attrs, self.ts_fmt)
660
-
661
- if attrs != attrs_orig:
662
- arr.attrs.put(attrs)
663
-
664
- def _update_EAR_start(self, EAR_id: int, s_time: datetime, s_snap: Dict, s_hn: str):
733
+ with self.__mutate_attrs(arr) as attrs:
734
+ for EAR_ID_i, sub_idx_i in sub_indices.items():
735
+ new_EAR_i = EARs[EAR_ID_i].update(submission_idx=sub_idx_i)
736
+ # seems to be a Zarr bug that prevents `set_coordinate_selection` with an
737
+ # object array, so set one-by-one:
738
+ arr[EAR_ID_i] = new_EAR_i.encode(self.ts_fmt, attrs)
739
+
740
+ def _update_EAR_start(
741
+ self, EAR_id: int, s_time: datetime, s_snap: dict[str, Any], s_hn: str
742
+ ):
665
743
  arr = self._get_EARs_arr(mode="r+")
666
- attrs_orig = arr.attrs.asdict()
667
- attrs = copy.deepcopy(attrs_orig)
668
-
669
- EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
670
- EAR_i = EAR_i.update(
671
- start_time=s_time,
672
- snapshot_start=s_snap,
673
- run_hostname=s_hn,
674
- )
675
- arr[EAR_id] = EAR_i.encode(attrs, self.ts_fmt)
676
-
677
- if attrs != attrs_orig:
678
- arr.attrs.put(attrs)
744
+ with self.__mutate_attrs(arr) as attrs:
745
+ EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
746
+ EAR_i = EAR_i.update(
747
+ start_time=s_time,
748
+ snapshot_start=s_snap,
749
+ run_hostname=s_hn,
750
+ )
751
+ arr[EAR_id] = EAR_i.encode(self.ts_fmt, attrs)
679
752
 
680
753
  def _update_EAR_end(
681
- self, EAR_id: int, e_time: datetime, e_snap: Dict, ext_code: int, success: bool
754
+ self,
755
+ EAR_id: int,
756
+ e_time: datetime,
757
+ e_snap: dict[str, Any],
758
+ ext_code: int,
759
+ success: bool,
682
760
  ):
683
761
  arr = self._get_EARs_arr(mode="r+")
684
- attrs_orig = arr.attrs.asdict()
685
- attrs = copy.deepcopy(attrs_orig)
686
-
687
- EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
688
- EAR_i = EAR_i.update(
689
- end_time=e_time,
690
- snapshot_end=e_snap,
691
- exit_code=ext_code,
692
- success=success,
693
- )
694
- arr[EAR_id] = EAR_i.encode(attrs, self.ts_fmt)
695
-
696
- if attrs != attrs_orig:
697
- arr.attrs.put(attrs)
762
+ with self.__mutate_attrs(arr) as attrs:
763
+ EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
764
+ EAR_i = EAR_i.update(
765
+ end_time=e_time,
766
+ snapshot_end=e_snap,
767
+ exit_code=ext_code,
768
+ success=success,
769
+ )
770
+ arr[EAR_id] = EAR_i.encode(self.ts_fmt, attrs)
698
771
 
699
772
  def _update_EAR_skip(self, EAR_id: int):
700
773
  arr = self._get_EARs_arr(mode="r+")
701
- attrs_orig = arr.attrs.asdict()
702
- attrs = copy.deepcopy(attrs_orig)
703
-
704
- EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
705
- EAR_i = EAR_i.update(skip=True)
706
- arr[EAR_id] = EAR_i.encode(attrs, self.ts_fmt)
707
-
708
- if attrs != attrs_orig:
709
- arr.attrs.put(attrs)
774
+ with self.__mutate_attrs(arr) as attrs:
775
+ EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
776
+ EAR_i = EAR_i.update(skip=True)
777
+ arr[EAR_id] = EAR_i.encode(self.ts_fmt, attrs)
710
778
 
711
- def _update_js_metadata(self, js_meta: Dict):
779
+ def _update_js_metadata(self, js_meta: dict[int, dict[int, dict[str, Any]]]):
712
780
  with self.using_resource("attrs", action="update") as attrs:
713
781
  for sub_idx, all_js_md in js_meta.items():
782
+ sub = cast(
783
+ "dict[str, list[dict[str, Any]]]", attrs["submissions"][sub_idx]
784
+ )
714
785
  for js_idx, js_meta_i in all_js_md.items():
715
- attrs["submissions"][sub_idx]["jobscripts"][js_idx].update(
716
- **js_meta_i
717
- )
786
+ sub["jobscripts"][js_idx].update(**js_meta_i)
718
787
 
719
- def _append_parameters(self, params: List[ZarrStoreParameter]):
788
+ def _append_parameters(self, params: Sequence[StoreParameter]):
720
789
  """Add new persistent parameters."""
721
790
  base_arr = self._get_parameter_base_array(mode="r+", write_empty_chunks=False)
722
791
  src_arr = self._get_parameter_sources_array(mode="r+")
@@ -725,8 +794,8 @@ class ZarrPersistentStore(PersistentStore):
725
794
  )
726
795
 
727
796
  param_encode_root_group = self._get_parameter_user_array_group(mode="r+")
728
- param_enc = []
729
- src_enc = []
797
+ param_enc: list[dict[str, Any] | int] = []
798
+ src_enc: list[dict] = []
730
799
  for param_i in params:
731
800
  dat_i = param_i.encode(
732
801
  root_group=param_encode_root_group,
@@ -741,16 +810,15 @@ class ZarrPersistentStore(PersistentStore):
741
810
  f"PersistentStore._append_parameters: finished adding {len(params)} parameters."
742
811
  )
743
812
 
744
- def _set_parameter_values(self, set_parameters: Dict[int, Tuple[Any, bool]]):
813
+ def _set_parameter_values(self, set_parameters: dict[int, tuple[Any, bool]]):
745
814
  """Set multiple unset persistent parameters."""
746
815
 
747
- param_ids = list(set_parameters.keys())
816
+ param_ids = list(set_parameters)
748
817
  # the `decode` call in `_get_persistent_parameters` should be quick:
749
818
  params = self._get_persistent_parameters(param_ids)
750
- new_data = []
819
+ new_data: list[dict[str, Any] | int] = []
751
820
  param_encode_root_group = self._get_parameter_user_array_group(mode="r+")
752
821
  for param_id, (value, is_file) in set_parameters.items():
753
-
754
822
  param_i = params[param_id]
755
823
  if is_file:
756
824
  param_i = param_i.set_file(value)
@@ -768,19 +836,19 @@ class ZarrPersistentStore(PersistentStore):
768
836
  base_arr = self._get_parameter_base_array(mode="r+")
769
837
  base_arr.set_coordinate_selection(param_ids, new_data)
770
838
 
771
- def _update_parameter_sources(self, sources: Dict[int, Dict]):
839
+ def _update_parameter_sources(self, sources: Mapping[int, ParamSource]):
772
840
  """Update the sources of multiple persistent parameters."""
773
841
 
774
- param_ids = list(sources.keys())
842
+ param_ids = list(sources)
775
843
  src_arr = self._get_parameter_sources_array(mode="r+")
776
844
  existing_sources = src_arr.get_coordinate_selection(param_ids)
777
- new_sources = []
778
- for idx, source_i in enumerate(sources.values()):
779
- new_src_i = update_param_source_dict(existing_sources[idx], source_i)
780
- new_sources.append(new_src_i)
845
+ new_sources = [
846
+ update_param_source_dict(cast("ParamSource", existing_sources[idx]), source_i)
847
+ for idx, source_i in enumerate(sources.values())
848
+ ]
781
849
  src_arr.set_coordinate_selection(param_ids, new_sources)
782
850
 
783
- def _update_template_components(self, tc: Dict):
851
+ def _update_template_components(self, tc: dict[str, Any]):
784
852
  with self.using_resource("attrs", "update") as md:
785
853
  md["template_components"] = tc
786
854
 
@@ -832,46 +900,48 @@ class ZarrPersistentStore(PersistentStore):
832
900
  return attrs["num_added_tasks"]
833
901
 
834
902
  @property
835
- def zarr_store(self) -> zarr.storage.Store:
903
+ def zarr_store(self) -> Store:
836
904
  """
837
905
  The underlying store object.
838
906
  """
839
907
  if self._zarr_store is None:
908
+ assert self.fs is not None
840
909
  self._zarr_store = self._get_zarr_store(self.path, self.fs)
841
910
  return self._zarr_store
842
911
 
843
- def _get_root_group(self, mode: str = "r", **kwargs) -> zarr.Group:
912
+ def _get_root_group(self, mode: str = "r", **kwargs) -> Group:
844
913
  return zarr.open(self.zarr_store, mode=mode, **kwargs)
845
914
 
846
- def _get_parameter_group(self, mode: str = "r", **kwargs) -> zarr.Group:
915
+ def _get_parameter_group(self, mode: str = "r", **kwargs) -> Group:
847
916
  return self._get_root_group(mode=mode, **kwargs).get(self._param_grp_name)
848
917
 
849
- def _get_parameter_base_array(self, mode: str = "r", **kwargs) -> zarr.Array:
918
+ def _get_parameter_base_array(self, mode: str = "r", **kwargs) -> Array:
850
919
  path = f"{self._param_grp_name}/{self._param_base_arr_name}"
851
920
  return zarr.open(self.zarr_store, mode=mode, path=path, **kwargs)
852
921
 
853
- def _get_parameter_sources_array(self, mode: str = "r") -> zarr.Array:
922
+ def _get_parameter_sources_array(self, mode: str = "r") -> Array:
854
923
  return self._get_parameter_group(mode=mode).get(self._param_sources_arr_name)
855
924
 
856
- def _get_parameter_user_array_group(self, mode: str = "r") -> zarr.Group:
925
+ def _get_parameter_user_array_group(self, mode: str = "r") -> Group:
857
926
  return self._get_parameter_group(mode=mode).get(self._param_user_arr_grp_name)
858
927
 
859
928
  def _get_parameter_data_array_group(
860
929
  self,
861
930
  parameter_idx: int,
862
931
  mode: str = "r",
863
- ) -> zarr.Group:
932
+ ) -> Group:
864
933
  return self._get_parameter_user_array_group(mode=mode).get(
865
934
  self._param_data_arr_grp_name(parameter_idx)
866
935
  )
867
936
 
868
- def _get_array_group_and_dataset(self, mode: str, param_id: int, data_path):
937
+ def _get_array_group_and_dataset(
938
+ self, mode: str, param_id: int, data_path: list[int]
939
+ ):
869
940
  base_dat = self._get_parameter_base_array(mode="r")[param_id]
870
- arr_idx = None
871
941
  for arr_dat_path, arr_idx in base_dat["type_lookup"]["arrays"]:
872
942
  if arr_dat_path == data_path:
873
943
  break
874
- if arr_idx is None:
944
+ else:
875
945
  raise ValueError(
876
946
  f"Could not find array path {data_path} in the base data for parameter "
877
947
  f"ID {param_id}."
@@ -881,19 +951,19 @@ class ZarrPersistentStore(PersistentStore):
881
951
  )
882
952
  return group, f"arr_{arr_idx}"
883
953
 
884
- def _get_metadata_group(self, mode: str = "r") -> zarr.Group:
954
+ def _get_metadata_group(self, mode: str = "r") -> Group:
885
955
  return self._get_root_group(mode=mode).get("metadata")
886
956
 
887
- def _get_tasks_arr(self, mode: str = "r") -> zarr.Array:
957
+ def _get_tasks_arr(self, mode: str = "r") -> Array:
888
958
  return self._get_metadata_group(mode=mode).get(self._task_arr_name)
889
959
 
890
- def _get_elements_arr(self, mode: str = "r") -> zarr.Array:
960
+ def _get_elements_arr(self, mode: str = "r") -> Array:
891
961
  return self._get_metadata_group(mode=mode).get(self._elem_arr_name)
892
962
 
893
- def _get_iters_arr(self, mode: str = "r") -> zarr.Array:
963
+ def _get_iters_arr(self, mode: str = "r") -> Array:
894
964
  return self._get_metadata_group(mode=mode).get(self._iter_arr_name)
895
965
 
896
- def _get_EARs_arr(self, mode: str = "r") -> zarr.Array:
966
+ def _get_EARs_arr(self, mode: str = "r") -> Array:
897
967
  return self._get_metadata_group(mode=mode).get(self._EAR_arr_name)
898
968
 
899
969
  @classmethod
@@ -905,10 +975,10 @@ class ZarrPersistentStore(PersistentStore):
905
975
  overwrite=False,
906
976
  ):
907
977
  """Generate an store for testing purposes."""
978
+ ts_fmt = "FIXME"
908
979
 
909
980
  path = Path(dir or "", path)
910
- store = zarr.DirectoryStore(path)
911
- root = zarr.group(store=store, overwrite=overwrite)
981
+ root = zarr.group(store=DirectoryStore(path), overwrite=overwrite)
912
982
  md = root.create_group("metadata")
913
983
 
914
984
  tasks_arr = md.create_dataset(
@@ -922,7 +992,7 @@ class ZarrPersistentStore(PersistentStore):
922
992
  name=cls._elem_arr_name,
923
993
  shape=0,
924
994
  dtype=object,
925
- object_codec=MsgPack(),
995
+ object_codec=cls._CODEC,
926
996
  chunks=1000,
927
997
  )
928
998
  elems_arr.attrs.update({"seq_idx": [], "src_idx": []})
@@ -931,7 +1001,7 @@ class ZarrPersistentStore(PersistentStore):
931
1001
  name=cls._iter_arr_name,
932
1002
  shape=0,
933
1003
  dtype=object,
934
- object_codec=MsgPack(),
1004
+ object_codec=cls._CODEC,
935
1005
  chunks=1000,
936
1006
  )
937
1007
  elem_iters_arr.attrs.update(
@@ -946,12 +1016,12 @@ class ZarrPersistentStore(PersistentStore):
946
1016
  name=cls._EAR_arr_name,
947
1017
  shape=0,
948
1018
  dtype=object,
949
- object_codec=MsgPack(),
1019
+ object_codec=cls._CODEC,
950
1020
  chunks=1000,
951
1021
  )
952
- EARs_arr.attrs.update({"parameter_paths": []})
1022
+ EARs_arr.attrs["parameter_paths"] = []
953
1023
 
954
- tasks, elems, elem_iters, EARs = super().prepare_test_store_from_spec(spec)
1024
+ tasks, elems, elem_iters, EARs_ = super().prepare_test_store_from_spec(spec)
955
1025
 
956
1026
  path = Path(path).resolve()
957
1027
  tasks = [ZarrStoreTask(**i).encode() for i in tasks]
@@ -960,21 +1030,13 @@ class ZarrPersistentStore(PersistentStore):
960
1030
  ZarrStoreElementIter(**i).encode(elem_iters_arr.attrs.asdict())
961
1031
  for i in elem_iters
962
1032
  ]
963
- EARs = [ZarrStoreEAR(**i).encode(EARs_arr.attrs.asdict()) for i in EARs]
1033
+ EARs = [ZarrStoreEAR(**i).encode(ts_fmt, EARs_arr.attrs.asdict()) for i in EARs_]
964
1034
 
965
1035
  append_items_to_ragged_array(tasks_arr, tasks)
966
1036
 
967
- elem_arr_add = np.empty((len(elements)), dtype=object)
968
- elem_arr_add[:] = elements
969
- elems_arr.append(elem_arr_add)
970
-
971
- iter_arr_add = np.empty((len(elem_iters)), dtype=object)
972
- iter_arr_add[:] = elem_iters
973
- elem_iters_arr.append(iter_arr_add)
974
-
975
- EAR_arr_add = np.empty((len(EARs)), dtype=object)
976
- EAR_arr_add[:] = EARs
977
- EARs_arr.append(EAR_arr_add)
1037
+ elems_arr.append(np.fromiter(elements, dtype=object))
1038
+ elem_iters_arr.append(np.fromiter(elem_iters, dtype=object))
1039
+ EARs_arr.append(np.fromiter(EARs, dtype=object))
978
1040
 
979
1041
  return cls(path)
980
1042
 
@@ -982,17 +1044,18 @@ class ZarrPersistentStore(PersistentStore):
982
1044
  with self.using_resource("attrs", "read") as attrs:
983
1045
  return attrs["template_components"]
984
1046
 
985
- def _get_persistent_template(self):
1047
+ def _get_persistent_template(self) -> dict[str, JSONed]:
986
1048
  with self.using_resource("attrs", "read") as attrs:
987
- return attrs["template"]
1049
+ return cast("dict[str, JSONed]", attrs["template"])
988
1050
 
989
1051
  @TimeIt.decorator
990
- def _get_persistent_tasks(self, id_lst: Iterable[int]) -> Dict[int, ZarrStoreTask]:
1052
+ def _get_persistent_tasks(self, id_lst: Iterable[int]) -> dict[int, ZarrStoreTask]:
991
1053
  tasks, id_lst = self._get_cached_persistent_tasks(id_lst)
992
1054
  if id_lst:
993
1055
  with self.using_resource("attrs", action="read") as attrs:
994
- task_dat = {}
995
- elem_IDs = []
1056
+ task_dat: dict[int, dict[str, Any]] = {}
1057
+ elem_IDs: list[int] = []
1058
+ i: dict[str, Any]
996
1059
  for idx, i in enumerate(attrs["tasks"]):
997
1060
  i = copy.deepcopy(i)
998
1061
  elem_IDs.append(i.pop("element_IDs_idx"))
@@ -1003,65 +1066,62 @@ class ZarrPersistentStore(PersistentStore):
1003
1066
  elem_IDs_arr_dat = self._get_tasks_arr().get_coordinate_selection(
1004
1067
  elem_IDs
1005
1068
  )
1006
- except zarr.errors.BoundsCheckError:
1069
+ except BoundsCheckError:
1007
1070
  raise MissingStoreTaskError(
1008
1071
  elem_IDs
1009
1072
  ) from None # TODO: not an ID list
1010
1073
 
1011
1074
  new_tasks = {
1012
1075
  id_: ZarrStoreTask.decode({**i, "element_IDs": elem_IDs_arr_dat[id_]})
1013
- for idx, (id_, i) in enumerate(task_dat.items())
1076
+ for id_, i in task_dat.items()
1014
1077
  }
1015
- else:
1016
- new_tasks = {}
1017
- self.task_cache.update(new_tasks)
1018
- tasks.update(new_tasks)
1078
+ self.task_cache.update(new_tasks)
1079
+ tasks.update(new_tasks)
1019
1080
  return tasks
1020
1081
 
1021
1082
  @TimeIt.decorator
1022
- def _get_persistent_loops(self, id_lst: Optional[Iterable[int]] = None):
1083
+ def _get_persistent_loops(
1084
+ self, id_lst: Iterable[int] | None = None
1085
+ ) -> dict[int, LoopDescriptor]:
1023
1086
  with self.using_resource("attrs", "read") as attrs:
1024
- loop_dat = {
1025
- idx: i
1087
+ return {
1088
+ idx: cast("LoopDescriptor", i)
1026
1089
  for idx, i in enumerate(attrs["loops"])
1027
1090
  if id_lst is None or idx in id_lst
1028
1091
  }
1029
- return loop_dat
1030
1092
 
1031
1093
  @TimeIt.decorator
1032
- def _get_persistent_submissions(self, id_lst: Optional[Iterable[int]] = None):
1094
+ def _get_persistent_submissions(self, id_lst: Iterable[int] | None = None):
1033
1095
  self.logger.debug("loading persistent submissions from the zarr store")
1096
+ ids = set(id_lst or ())
1034
1097
  with self.using_resource("attrs", "read") as attrs:
1035
1098
  subs_dat = copy.deepcopy(
1036
1099
  {
1037
1100
  idx: i
1038
1101
  for idx, i in enumerate(attrs["submissions"])
1039
- if id_lst is None or idx in id_lst
1102
+ if id_lst is None or idx in ids
1040
1103
  }
1041
1104
  )
1042
1105
  # cast jobscript submit-times and jobscript `task_elements` keys:
1043
- for sub_idx, sub in subs_dat.items():
1044
- for js_idx, js in enumerate(sub["jobscripts"]):
1045
- for key in list(js["task_elements"].keys()):
1046
- subs_dat[sub_idx]["jobscripts"][js_idx]["task_elements"][
1047
- int(key)
1048
- ] = subs_dat[sub_idx]["jobscripts"][js_idx]["task_elements"].pop(
1049
- key
1050
- )
1106
+ for sub in subs_dat.values():
1107
+ for js in cast("_JS", sub)["jobscripts"]:
1108
+ task_elements = js["task_elements"]
1109
+ for key in list(task_elements):
1110
+ task_elements[int(key)] = task_elements.pop(key)
1051
1111
 
1052
1112
  return subs_dat
1053
1113
 
1054
1114
  @TimeIt.decorator
1055
1115
  def _get_persistent_elements(
1056
1116
  self, id_lst: Iterable[int]
1057
- ) -> Dict[int, ZarrStoreElement]:
1117
+ ) -> dict[int, ZarrStoreElement]:
1058
1118
  elems, id_lst = self._get_cached_persistent_elements(id_lst)
1059
1119
  if id_lst:
1060
1120
  arr = self._get_elements_arr()
1061
1121
  attrs = arr.attrs.asdict()
1062
1122
  try:
1063
1123
  elem_arr_dat = arr.get_coordinate_selection(id_lst)
1064
- except zarr.errors.BoundsCheckError:
1124
+ except BoundsCheckError:
1065
1125
  raise MissingStoreElementError(id_lst) from None
1066
1126
  elem_dat = dict(zip(id_lst, elem_arr_dat))
1067
1127
  new_elems = {
@@ -1074,14 +1134,14 @@ class ZarrPersistentStore(PersistentStore):
1074
1134
  @TimeIt.decorator
1075
1135
  def _get_persistent_element_iters(
1076
1136
  self, id_lst: Iterable[int]
1077
- ) -> Dict[int, ZarrStoreElementIter]:
1137
+ ) -> dict[int, ZarrStoreElementIter]:
1078
1138
  iters, id_lst = self._get_cached_persistent_element_iters(id_lst)
1079
1139
  if id_lst:
1080
1140
  arr = self._get_iters_arr()
1081
1141
  attrs = arr.attrs.asdict()
1082
1142
  try:
1083
1143
  iter_arr_dat = arr.get_coordinate_selection(id_lst)
1084
- except zarr.errors.BoundsCheckError:
1144
+ except BoundsCheckError:
1085
1145
  raise MissingStoreElementIterationError(id_lst) from None
1086
1146
  iter_dat = dict(zip(id_lst, iter_arr_dat))
1087
1147
  new_iters = {
@@ -1092,7 +1152,7 @@ class ZarrPersistentStore(PersistentStore):
1092
1152
  return iters
1093
1153
 
1094
1154
  @TimeIt.decorator
1095
- def _get_persistent_EARs(self, id_lst: Iterable[int]) -> Dict[int, ZarrStoreEAR]:
1155
+ def _get_persistent_EARs(self, id_lst: Iterable[int]) -> dict[int, ZarrStoreEAR]:
1096
1156
  runs, id_lst = self._get_cached_persistent_EARs(id_lst)
1097
1157
  if id_lst:
1098
1158
  arr = self._get_EARs_arr()
@@ -1100,11 +1160,11 @@ class ZarrPersistentStore(PersistentStore):
1100
1160
  try:
1101
1161
  self.logger.debug(f"_get_persistent_EARs: {id_lst=}")
1102
1162
  EAR_arr_dat = _zarr_get_coord_selection(arr, id_lst, self.logger)
1103
- except zarr.errors.BoundsCheckError:
1163
+ except BoundsCheckError:
1104
1164
  raise MissingStoreEARError(id_lst) from None
1105
1165
  EAR_dat = dict(zip(id_lst, EAR_arr_dat))
1106
1166
  new_runs = {
1107
- k: ZarrStoreEAR.decode(EAR_dat=v, attrs=attrs, ts_fmt=self.ts_fmt)
1167
+ k: ZarrStoreEAR.decode(EAR_dat=v, ts_fmt=self.ts_fmt, attrs=attrs)
1108
1168
  for k, v in EAR_dat.items()
1109
1169
  }
1110
1170
  self.EAR_cache.update(new_runs)
@@ -1114,11 +1174,8 @@ class ZarrPersistentStore(PersistentStore):
1114
1174
 
1115
1175
  @TimeIt.decorator
1116
1176
  def _get_persistent_parameters(
1117
- self,
1118
- id_lst: Iterable[int],
1119
- dataset_copy: Optional[bool] = False,
1120
- ) -> Dict[int, ZarrStoreParameter]:
1121
-
1177
+ self, id_lst: Iterable[int], *, dataset_copy: bool = False, **kwargs
1178
+ ) -> dict[int, ZarrStoreParameter]:
1122
1179
  params, id_lst = self._get_cached_persistent_parameters(id_lst)
1123
1180
  if id_lst:
1124
1181
  base_arr = self._get_parameter_base_array(mode="r")
@@ -1127,7 +1184,7 @@ class ZarrPersistentStore(PersistentStore):
1127
1184
  try:
1128
1185
  param_arr_dat = base_arr.get_coordinate_selection(list(id_lst))
1129
1186
  src_arr_dat = src_arr.get_coordinate_selection(list(id_lst))
1130
- except zarr.errors.BoundsCheckError:
1187
+ except BoundsCheckError:
1131
1188
  raise MissingParameterData(id_lst) from None
1132
1189
 
1133
1190
  param_dat = dict(zip(id_lst, param_arr_dat))
@@ -1149,13 +1206,15 @@ class ZarrPersistentStore(PersistentStore):
1149
1206
  return params
1150
1207
 
1151
1208
  @TimeIt.decorator
1152
- def _get_persistent_param_sources(self, id_lst: Iterable[int]) -> Dict[int, Dict]:
1209
+ def _get_persistent_param_sources(
1210
+ self, id_lst: Iterable[int]
1211
+ ) -> dict[int, ParamSource]:
1153
1212
  sources, id_lst = self._get_cached_persistent_param_sources(id_lst)
1154
1213
  if id_lst:
1155
1214
  src_arr = self._get_parameter_sources_array(mode="r")
1156
1215
  try:
1157
1216
  src_arr_dat = src_arr.get_coordinate_selection(list(id_lst))
1158
- except zarr.errors.BoundsCheckError:
1217
+ except BoundsCheckError:
1159
1218
  raise MissingParameterData(id_lst) from None
1160
1219
  new_sources = dict(zip(id_lst, src_arr_dat))
1161
1220
  self.param_sources_cache.update(new_sources)
@@ -1164,16 +1223,16 @@ class ZarrPersistentStore(PersistentStore):
1164
1223
 
1165
1224
  def _get_persistent_parameter_set_status(
1166
1225
  self, id_lst: Iterable[int]
1167
- ) -> Dict[int, bool]:
1226
+ ) -> dict[int, bool]:
1168
1227
  base_arr = self._get_parameter_base_array(mode="r")
1169
1228
  try:
1170
1229
  param_arr_dat = base_arr.get_coordinate_selection(list(id_lst))
1171
- except zarr.errors.BoundsCheckError:
1230
+ except BoundsCheckError:
1172
1231
  raise MissingParameterData(id_lst) from None
1173
1232
 
1174
1233
  return dict(zip(id_lst, [i is not None for i in param_arr_dat]))
1175
1234
 
1176
- def _get_persistent_parameter_IDs(self) -> List[int]:
1235
+ def _get_persistent_parameter_IDs(self) -> list[int]:
1177
1236
  # we assume the row index is equivalent to ID, might need to revisit in future
1178
1237
  base_arr = self._get_parameter_base_array(mode="r")
1179
1238
  return list(range(len(base_arr)))
@@ -1208,11 +1267,11 @@ class ZarrPersistentStore(PersistentStore):
1208
1267
 
1209
1268
  def zip(
1210
1269
  self,
1211
- path=".",
1212
- log=None,
1213
- overwrite=False,
1214
- include_execute=False,
1215
- include_rechunk_backups=False,
1270
+ path: str = ".",
1271
+ log: str | None = None,
1272
+ overwrite: bool = False,
1273
+ include_execute: bool = False,
1274
+ include_rechunk_backups: bool = False,
1216
1275
  ):
1217
1276
  """
1218
1277
  Convert the persistent store to zipped form.
@@ -1224,69 +1283,66 @@ class ZarrPersistentStore(PersistentStore):
1224
1283
  directory, the zip file will be created within this directory. Otherwise,
1225
1284
  this path is assumed to be the full file path to the new zip file.
1226
1285
  """
1227
- console = Console()
1228
- status = console.status(f"Zipping workflow {self.workflow.name!r}...")
1229
- status.start()
1230
-
1231
- # TODO: this won't work for remote file systems
1232
- dst_path = Path(path).resolve()
1233
- if dst_path.is_dir():
1234
- dst_path = dst_path.joinpath(self.workflow.name).with_suffix(".zip")
1235
-
1236
- if not overwrite and dst_path.exists():
1237
- status.stop()
1238
- raise FileExistsError(
1239
- f"File at path already exists: {dst_path!r}. Pass `overwrite=True` to "
1240
- f"overwrite the existing file."
1241
- )
1286
+ with Console().status(f"Zipping workflow {self.workflow.name!r}..."):
1287
+ # TODO: this won't work for remote file systems
1288
+ dst_path = Path(path).resolve()
1289
+ if dst_path.is_dir():
1290
+ dst_path = dst_path.joinpath(self.workflow.name).with_suffix(".zip")
1291
+
1292
+ if not overwrite and dst_path.exists():
1293
+ raise FileExistsError(
1294
+ f"File at path already exists: {dst_path!r}. Pass `overwrite=True` to "
1295
+ f"overwrite the existing file."
1296
+ )
1242
1297
 
1243
- dst_path = str(dst_path)
1298
+ dst_path_s = str(dst_path)
1244
1299
 
1245
- src_zarr_store = self.zarr_store
1246
- zfs, _ = ask_pw_on_auth_exc(
1247
- ZipFileSystem,
1248
- fo=dst_path,
1249
- mode="w",
1250
- target_options={},
1251
- add_pw_to="target_options",
1252
- )
1253
- dst_zarr_store = zarr.storage.FSStore(url="", fs=zfs)
1254
- excludes = []
1255
- if not include_execute:
1256
- excludes.append("execute")
1257
- if not include_rechunk_backups:
1258
- excludes.append("runs.bak")
1259
- excludes.append("base.bak")
1260
-
1261
- zarr.convenience.copy_store(
1262
- src_zarr_store,
1263
- dst_zarr_store,
1264
- excludes=excludes or None,
1265
- log=log,
1266
- )
1267
- del zfs # ZipFileSystem remains open for instance lifetime
1268
- status.stop()
1269
- return dst_path
1300
+ src_zarr_store = self.zarr_store
1301
+ zfs, _ = ask_pw_on_auth_exc(
1302
+ ZipFileSystem,
1303
+ fo=dst_path_s,
1304
+ mode="w",
1305
+ target_options={},
1306
+ add_pw_to="target_options",
1307
+ )
1308
+ dst_zarr_store = FSStore(url="", fs=zfs)
1309
+ excludes = []
1310
+ if not include_execute:
1311
+ excludes.append("execute")
1312
+ if not include_rechunk_backups:
1313
+ excludes.append("runs.bak")
1314
+ excludes.append("base.bak")
1315
+
1316
+ zarr.copy_store(
1317
+ src_zarr_store,
1318
+ dst_zarr_store,
1319
+ excludes=excludes or None,
1320
+ log=log,
1321
+ )
1322
+ del zfs # ZipFileSystem remains open for instance lifetime
1323
+ return dst_path_s
1324
+
1325
+ def unzip(self, path: str = ".", log: str | None = None):
1326
+ raise ValueError("Not a zip store!")
1270
1327
 
1271
1328
  def _rechunk_arr(
1272
1329
  self,
1273
- arr,
1274
- chunk_size: Optional[int] = None,
1275
- backup: Optional[bool] = True,
1276
- status: Optional[bool] = True,
1277
- ):
1330
+ arr: Array,
1331
+ chunk_size: int | None = None,
1332
+ backup: bool = True,
1333
+ status: bool = True,
1334
+ ) -> Array:
1278
1335
  arr_path = Path(self.workflow.path) / arr.path
1279
1336
  arr_name = arr.path.split("/")[-1]
1280
1337
 
1281
1338
  if status:
1282
- console = Console()
1283
- status = console.status("Rechunking...")
1284
- status.start()
1339
+ s = Console().status("Rechunking...")
1340
+ s.start()
1285
1341
  backup_time = None
1286
1342
 
1287
1343
  if backup:
1288
1344
  if status:
1289
- status.update("Backing up...")
1345
+ s.update("Backing up...")
1290
1346
  backup_path = arr_path.with_suffix(".bak")
1291
1347
  if backup_path.is_dir():
1292
1348
  pass
@@ -1300,16 +1356,16 @@ class ZarrPersistentStore(PersistentStore):
1300
1356
  arr_rc_path = arr_path.with_suffix(".rechunked")
1301
1357
  arr = zarr.open(arr_path)
1302
1358
  if status:
1303
- status.update("Creating new array...")
1359
+ s.update("Creating new array...")
1304
1360
  arr_rc = zarr.create(
1305
1361
  store=arr_rc_path,
1306
1362
  shape=arr.shape,
1307
1363
  chunks=arr.shape if chunk_size is None else chunk_size,
1308
1364
  dtype=object,
1309
- object_codec=MsgPack(),
1365
+ object_codec=self._CODEC,
1310
1366
  )
1311
1367
  if status:
1312
- status.update("Copying data...")
1368
+ s.update("Copying data...")
1313
1369
  data = np.empty(shape=arr.shape, dtype=object)
1314
1370
  bad_data = []
1315
1371
  for idx in range(len(arr)):
@@ -1318,24 +1374,23 @@ class ZarrPersistentStore(PersistentStore):
1318
1374
  except RuntimeError:
1319
1375
  # blosc decompression errors
1320
1376
  bad_data.append(idx)
1321
- pass
1322
1377
  arr_rc[:] = data
1323
1378
 
1324
1379
  arr_rc.attrs.put(arr.attrs.asdict())
1325
1380
 
1326
1381
  if status:
1327
- status.update("Deleting old array...")
1382
+ s.update("Deleting old array...")
1328
1383
  shutil.rmtree(arr_path)
1329
1384
 
1330
1385
  if status:
1331
- status.update("Moving new array into place...")
1386
+ s.update("Moving new array into place...")
1332
1387
  shutil.move(arr_rc_path, arr_path)
1333
1388
 
1334
1389
  toc = time.perf_counter()
1335
1390
  rechunk_time = toc - tic
1336
1391
 
1337
1392
  if status:
1338
- status.stop()
1393
+ s.stop()
1339
1394
 
1340
1395
  if backup_time:
1341
1396
  print(f"Time to backup {arr_name}: {backup_time:.1f} s")
@@ -1349,10 +1404,10 @@ class ZarrPersistentStore(PersistentStore):
1349
1404
 
1350
1405
  def rechunk_parameter_base(
1351
1406
  self,
1352
- chunk_size: Optional[int] = None,
1353
- backup: Optional[bool] = True,
1354
- status: Optional[bool] = True,
1355
- ):
1407
+ chunk_size: int | None = None,
1408
+ backup: bool = True,
1409
+ status: bool = True,
1410
+ ) -> Array:
1356
1411
  """
1357
1412
  Rechunk the parameter data to be stored more efficiently.
1358
1413
  """
@@ -1361,10 +1416,10 @@ class ZarrPersistentStore(PersistentStore):
1361
1416
 
1362
1417
  def rechunk_runs(
1363
1418
  self,
1364
- chunk_size: Optional[int] = None,
1365
- backup: Optional[bool] = True,
1366
- status: Optional[bool] = True,
1367
- ):
1419
+ chunk_size: int | None = None,
1420
+ backup: bool = True,
1421
+ status: bool = True,
1422
+ ) -> Array:
1368
1423
  """
1369
1424
  Rechunk the run data to be stored more efficiently.
1370
1425
  """
@@ -1381,8 +1436,8 @@ class ZarrZipPersistentStore(ZarrPersistentStore):
1381
1436
  Archive format persistent stores cannot be updated without being unzipped first.
1382
1437
  """
1383
1438
 
1384
- _name = "zip"
1385
- _features = PersistentStoreFeatures(
1439
+ _name: ClassVar[str] = "zip"
1440
+ _features: ClassVar[PersistentStoreFeatures] = PersistentStoreFeatures(
1386
1441
  create=False,
1387
1442
  edit=False,
1388
1443
  jobscript_parallelism=False,
@@ -1393,10 +1448,17 @@ class ZarrZipPersistentStore(ZarrPersistentStore):
1393
1448
 
1394
1449
  # TODO: enforce read-only nature
1395
1450
 
1396
- def zip(self):
1451
+ def zip(
1452
+ self,
1453
+ path: str = ".",
1454
+ log: str | None = None,
1455
+ overwrite: bool = False,
1456
+ include_execute: bool = False,
1457
+ include_rechunk_backups: bool = False,
1458
+ ):
1397
1459
  raise ValueError("Already a zip store!")
1398
1460
 
1399
- def unzip(self, path=".", log=None):
1461
+ def unzip(self, path: str = ".", log: str | None = None) -> str:
1400
1462
  """
1401
1463
  Expand the persistent store.
1402
1464
 
@@ -1409,28 +1471,23 @@ class ZarrZipPersistentStore(ZarrPersistentStore):
1409
1471
 
1410
1472
  """
1411
1473
 
1412
- console = Console()
1413
- status = console.status(f"Unzipping workflow {self.workflow.name!r}...")
1414
- status.start()
1415
-
1416
- # TODO: this won't work for remote file systems
1417
- dst_path = Path(path).resolve()
1418
- if dst_path.is_dir():
1419
- dst_path = dst_path.joinpath(self.workflow.name)
1474
+ with Console().status(f"Unzipping workflow {self.workflow.name!r}..."):
1475
+ # TODO: this won't work for remote file systems
1476
+ dst_path = Path(path).resolve()
1477
+ if dst_path.is_dir():
1478
+ dst_path = dst_path.joinpath(self.workflow.name)
1420
1479
 
1421
- if dst_path.exists():
1422
- status.stop()
1423
- raise FileExistsError(f"Directory at path already exists: {dst_path!r}.")
1480
+ if dst_path.exists():
1481
+ raise FileExistsError(f"Directory at path already exists: {dst_path!r}.")
1424
1482
 
1425
- dst_path = str(dst_path)
1483
+ dst_path_s = str(dst_path)
1426
1484
 
1427
- src_zarr_store = self.zarr_store
1428
- dst_zarr_store = zarr.storage.FSStore(url=dst_path)
1429
- zarr.convenience.copy_store(src_zarr_store, dst_zarr_store, log=log)
1430
- status.stop()
1431
- return dst_path
1485
+ src_zarr_store = self.zarr_store
1486
+ dst_zarr_store = FSStore(url=dst_path_s)
1487
+ zarr.copy_store(src_zarr_store, dst_zarr_store, log=log)
1488
+ return dst_path_s
1432
1489
 
1433
- def copy(self, path=None) -> str:
1490
+ def copy(self, path: PathLike = None) -> Path:
1434
1491
  # not sure how to do this.
1435
1492
  raise NotImplementedError()
1436
1493
 
@@ -1441,8 +1498,8 @@ class ZarrZipPersistentStore(ZarrPersistentStore):
1441
1498
  def _rechunk_arr(
1442
1499
  self,
1443
1500
  arr,
1444
- chunk_size: Optional[int] = None,
1445
- backup: Optional[bool] = True,
1446
- status: Optional[bool] = True,
1447
- ):
1501
+ chunk_size: int | None = None,
1502
+ backup: bool = True,
1503
+ status: bool = True,
1504
+ ) -> Array:
1448
1505
  raise NotImplementedError