hpcflow-new2 0.2.0a50__py3-none-any.whl → 0.2.0a52__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. hpcflow/_version.py +1 -1
  2. hpcflow/sdk/__init__.py +1 -1
  3. hpcflow/sdk/api.py +1 -1
  4. hpcflow/sdk/app.py +20 -11
  5. hpcflow/sdk/cli.py +34 -59
  6. hpcflow/sdk/core/__init__.py +13 -1
  7. hpcflow/sdk/core/actions.py +235 -126
  8. hpcflow/sdk/core/command_files.py +32 -24
  9. hpcflow/sdk/core/element.py +110 -114
  10. hpcflow/sdk/core/errors.py +57 -0
  11. hpcflow/sdk/core/loop.py +18 -34
  12. hpcflow/sdk/core/parameters.py +5 -3
  13. hpcflow/sdk/core/task.py +135 -131
  14. hpcflow/sdk/core/task_schema.py +11 -4
  15. hpcflow/sdk/core/utils.py +110 -2
  16. hpcflow/sdk/core/workflow.py +964 -676
  17. hpcflow/sdk/data/template_components/environments.yaml +0 -44
  18. hpcflow/sdk/data/template_components/task_schemas.yaml +52 -10
  19. hpcflow/sdk/persistence/__init__.py +21 -33
  20. hpcflow/sdk/persistence/base.py +1340 -458
  21. hpcflow/sdk/persistence/json.py +424 -546
  22. hpcflow/sdk/persistence/pending.py +563 -0
  23. hpcflow/sdk/persistence/store_resource.py +131 -0
  24. hpcflow/sdk/persistence/utils.py +57 -0
  25. hpcflow/sdk/persistence/zarr.py +852 -841
  26. hpcflow/sdk/submission/jobscript.py +133 -112
  27. hpcflow/sdk/submission/shells/bash.py +62 -16
  28. hpcflow/sdk/submission/shells/powershell.py +87 -16
  29. hpcflow/sdk/submission/submission.py +59 -35
  30. hpcflow/tests/unit/test_element.py +4 -9
  31. hpcflow/tests/unit/test_persistence.py +218 -0
  32. hpcflow/tests/unit/test_task.py +11 -12
  33. hpcflow/tests/unit/test_utils.py +82 -0
  34. hpcflow/tests/unit/test_workflow.py +3 -1
  35. {hpcflow_new2-0.2.0a50.dist-info → hpcflow_new2-0.2.0a52.dist-info}/METADATA +3 -1
  36. {hpcflow_new2-0.2.0a50.dist-info → hpcflow_new2-0.2.0a52.dist-info}/RECORD +38 -34
  37. {hpcflow_new2-0.2.0a50.dist-info → hpcflow_new2-0.2.0a52.dist-info}/WHEEL +0 -0
  38. {hpcflow_new2-0.2.0a50.dist-info → hpcflow_new2-0.2.0a52.dist-info}/entry_points.txt +0 -0
@@ -1,34 +1,39 @@
1
1
  from __future__ import annotations
2
2
  from contextlib import contextmanager
3
+
3
4
  import copy
4
- from datetime import datetime
5
+ from dataclasses import dataclass
6
+ from datetime import datetime, timezone
5
7
  from pathlib import Path
6
-
7
- import shutil
8
- import time
9
- from typing import Any, Dict, Generator, Iterator, List, Optional, Tuple, Union
8
+ from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
10
9
  import warnings
10
+
11
+ from fsspec.implementations.zip import ZipFileSystem
11
12
  import numpy as np
12
- import zarr
13
- from numcodecs import MsgPack
14
- from hpcflow.sdk import app
15
-
16
- from hpcflow.sdk.core.errors import WorkflowNotFoundError
17
- from hpcflow.sdk.core.utils import (
18
- bisect_slice,
19
- ensure_in,
20
- get_in_container,
21
- get_md5_hash,
22
- get_relative_path,
23
- set_in_container,
13
+ from hpcflow.sdk.core.errors import (
14
+ MissingParameterData,
15
+ MissingStoreEARError,
16
+ MissingStoreElementError,
17
+ MissingStoreElementIterationError,
18
+ MissingStoreTaskError,
24
19
  )
20
+ from hpcflow.sdk.core.utils import ensure_in, get_relative_path, set_in_container
25
21
  from hpcflow.sdk.persistence.base import (
26
- PersistentStore,
27
22
  PersistentStoreFeatures,
28
- remove_dir,
29
- rename_dir,
23
+ PersistentStore,
24
+ StoreEAR,
25
+ StoreElement,
26
+ StoreElementIter,
27
+ StoreParameter,
28
+ StoreTask,
30
29
  )
31
- from hpcflow.sdk.typing import PathLike
30
+ from hpcflow.sdk.persistence.store_resource import ZarrAttrsStoreResource
31
+ from hpcflow.sdk.persistence.utils import ask_pw_on_auth_exc
32
+ from hpcflow.sdk.persistence.pending import CommitResourceMap
33
+
34
+ from numcodecs import MsgPack, VLenArray
35
+
36
+ import zarr
32
37
 
33
38
 
34
39
  def _encode_numpy_array(obj, type_lookup, path, root_group, arr_path):
@@ -89,95 +94,316 @@ def _decode_masked_arrays(obj, type_lookup, path, arr_group, dataset_copy):
89
94
  return obj
90
95
 
91
96
 
92
- class ZarrPersistentStore(PersistentStore):
93
- """An efficient storage backend using Zarr that supports parameter-setting from
94
- multiple processes."""
97
+ def append_items_to_ragged_array(arr, items):
98
+ """Append an array to a Zarr ragged array.
99
+
100
+ I think `arr.append([item])` should work, but does not for some reason, so we do it
101
+ here by resizing and assignment."""
102
+ num = len(items)
103
+ arr.resize((len(arr) + num))
104
+ for idx, i in enumerate(items):
105
+ arr[-(num - idx)] = i
106
+
107
+
108
+ @dataclass
109
+ class ZarrStoreTask(StoreTask):
110
+ def encode(self) -> Tuple[int, np.ndarray, Dict]:
111
+ """Prepare store task data for the persistent store."""
112
+ wk_task = {"id_": self.id_, "element_IDs": np.array(self.element_IDs)}
113
+ task = {"id_": self.id_, **self.task_template}
114
+ return self.index, wk_task, task
115
+
116
+ @classmethod
117
+ def decode(cls, task_dat: Dict) -> ZarrStoreTask:
118
+ """Initialise a `StoreTask` from persistent task data"""
119
+ task_dat["element_IDs"] = task_dat["element_IDs"].tolist()
120
+ return super().decode(task_dat)
121
+
122
+
123
+ @dataclass
124
+ class ZarrStoreElement(StoreElement):
125
+ def encode(self, attrs: Dict) -> List:
126
+ """Prepare store elements data for the persistent store.
127
+
128
+ This method mutates `attrs`.
129
+ """
130
+ elem_enc = [
131
+ self.id_,
132
+ self.index,
133
+ self.es_idx,
134
+ [[ensure_in(k, attrs["seq_idx"]), v] for k, v in self.seq_idx.items()],
135
+ [[ensure_in(k, attrs["src_idx"]), v] for k, v in self.src_idx.items()],
136
+ self.task_ID,
137
+ self.iteration_IDs,
138
+ ]
139
+ return elem_enc
140
+
141
+ @classmethod
142
+ def decode(cls, elem_dat: List, attrs: Dict) -> ZarrStoreElement:
143
+ """Initialise a `StoreElement` from persistent element data"""
144
+ obj_dat = {
145
+ "id_": elem_dat[0],
146
+ "index": elem_dat[1],
147
+ "es_idx": elem_dat[2],
148
+ "seq_idx": {attrs["seq_idx"][k]: v for (k, v) in elem_dat[3]},
149
+ "src_idx": {attrs["src_idx"][k]: v for (k, v) in elem_dat[4]},
150
+ "task_ID": elem_dat[5],
151
+ "iteration_IDs": elem_dat[6],
152
+ }
153
+ return cls(is_pending=False, **obj_dat)
154
+
155
+
156
+ @dataclass
157
+ class ZarrStoreElementIter(StoreElementIter):
158
+ def encode(self, attrs: Dict) -> List:
159
+ """Prepare store element iteration data for the persistent store.
160
+
161
+ This method mutates `attrs`.
162
+ """
163
+ iter_enc = [
164
+ self.id_,
165
+ self.element_ID,
166
+ [[k, v] for k, v in self.EAR_IDs.items()] if self.EAR_IDs else None,
167
+ [
168
+ [ensure_in(dk, attrs["parameter_paths"]), dv]
169
+ for dk, dv in self.data_idx.items()
170
+ ],
171
+ [ensure_in(i, attrs["schema_parameters"]) for i in self.schema_parameters],
172
+ [[ensure_in(dk, attrs["loops"]), dv] for dk, dv in self.loop_idx.items()],
173
+ ]
174
+ return iter_enc
175
+
176
+ @classmethod
177
+ def decode(cls, iter_dat: List, attrs: Dict) -> StoreElementIter:
178
+ """Initialise a `StoreElementIter` from persistent element iteration data"""
179
+ obj_dat = {
180
+ "id_": iter_dat[0],
181
+ "element_ID": iter_dat[1],
182
+ "EAR_IDs": {i[0]: i[1] for i in iter_dat[2]} if iter_dat[2] else None,
183
+ "data_idx": {attrs["parameter_paths"][i[0]]: i[1] for i in iter_dat[3]},
184
+ "schema_parameters": [attrs["schema_parameters"][i] for i in iter_dat[4]],
185
+ "loop_idx": {i[0]: i[1] for i in iter_dat[5]},
186
+ }
187
+ return cls(is_pending=False, **obj_dat)
188
+
189
+
190
+ @dataclass
191
+ class ZarrStoreEAR(StoreEAR):
192
+ def encode(self, attrs: Dict, ts_fmt: str) -> Tuple[List, Tuple[np.datetime64]]:
193
+ """Prepare store EAR data for the persistent store.
194
+
195
+ This method mutates `attrs`.
196
+ """
197
+ EAR_enc = [
198
+ self.id_,
199
+ self.elem_iter_ID,
200
+ self.action_idx,
201
+ [
202
+ [ensure_in(dk, attrs["parameter_paths"]), dv]
203
+ for dk, dv in self.data_idx.items()
204
+ ],
205
+ self.submission_idx,
206
+ self.skip,
207
+ self.success,
208
+ self._encode_datetime(self.start_time, ts_fmt),
209
+ self._encode_datetime(self.end_time, ts_fmt),
210
+ self.snapshot_start,
211
+ self.snapshot_end,
212
+ self.exit_code,
213
+ self.metadata,
214
+ ]
215
+ return EAR_enc
216
+
217
+ @classmethod
218
+ def decode(cls, EAR_dat: List, attrs: Dict, ts_fmt: str) -> ZarrStoreEAR:
219
+ """Initialise a `ZarrStoreEAR` from persistent EAR data"""
220
+ obj_dat = {
221
+ "id_": EAR_dat[0],
222
+ "elem_iter_ID": EAR_dat[1],
223
+ "action_idx": EAR_dat[2],
224
+ "data_idx": {attrs["parameter_paths"][i[0]]: i[1] for i in EAR_dat[3]},
225
+ "submission_idx": EAR_dat[4],
226
+ "skip": EAR_dat[5],
227
+ "success": EAR_dat[6],
228
+ "start_time": cls._decode_datetime(EAR_dat[7], ts_fmt),
229
+ "end_time": cls._decode_datetime(EAR_dat[8], ts_fmt),
230
+ "snapshot_start": EAR_dat[9],
231
+ "snapshot_end": EAR_dat[10],
232
+ "exit_code": EAR_dat[11],
233
+ "metadata": EAR_dat[12],
234
+ }
235
+ return cls(is_pending=False, **obj_dat)
236
+
237
+
238
+ @dataclass
239
+ class ZarrStoreParameter(StoreParameter):
240
+ _encoders = { # keys are types
241
+ np.ndarray: _encode_numpy_array,
242
+ np.ma.core.MaskedArray: _encode_masked_array,
243
+ }
244
+ _decoders = { # keys are keys in type_lookup
245
+ "arrays": _decode_numpy_arrays,
246
+ "masked_arrays": _decode_masked_arrays,
247
+ }
248
+
249
+ def encode(self, root_group: zarr.Group, arr_path: str) -> Dict[str, Any]:
250
+ return super().encode(root_group=root_group, arr_path=arr_path)
251
+
252
+ @classmethod
253
+ def decode(
254
+ cls,
255
+ id_: int,
256
+ data: Union[None, Dict],
257
+ source: Dict,
258
+ arr_group: zarr.Group,
259
+ path: Optional[List[str]] = None,
260
+ dataset_copy: bool = False,
261
+ ) -> Any:
262
+ return super().decode(
263
+ id_=id_,
264
+ data=data,
265
+ source=source,
266
+ path=path,
267
+ arr_group=arr_group,
268
+ dataset_copy=dataset_copy,
269
+ )
270
+
95
271
 
272
+ class ZarrPersistentStore(PersistentStore):
96
273
  _name = "zarr"
97
274
  _features = PersistentStoreFeatures(
275
+ create=True,
276
+ edit=True,
98
277
  jobscript_parallelism=True,
99
278
  EAR_parallelism=True,
100
279
  schedulers=True,
101
280
  submission=True,
102
281
  )
103
282
 
104
- _param_grp_name = "parameter_data"
105
- _elem_grp_name = "element_data"
283
+ _store_task_cls = ZarrStoreTask
284
+ _store_elem_cls = ZarrStoreElement
285
+ _store_iter_cls = ZarrStoreElementIter
286
+ _store_EAR_cls = ZarrStoreEAR
287
+ _store_param_cls = ZarrStoreParameter
288
+
289
+ _param_grp_name = "parameters"
106
290
  _param_base_arr_name = "base"
107
291
  _param_sources_arr_name = "sources"
108
292
  _param_user_arr_grp_name = "arrays"
109
293
  _param_data_arr_grp_name = lambda _, param_idx: f"param_{param_idx}"
110
- _task_grp_name = lambda _, insert_ID: f"task_{insert_ID}"
111
- _task_elem_arr_name = "elements"
112
- _task_elem_iter_arr_name = "element_iters"
113
- _task_EAR_times_arr_name = "EAR_times"
114
-
115
- _parameter_encoders = { # keys are types
116
- np.ndarray: _encode_numpy_array,
117
- np.ma.core.MaskedArray: _encode_masked_array,
118
- }
119
- _parameter_decoders = { # keys are keys in type_lookup
120
- "arrays": _decode_numpy_arrays,
121
- "masked_arrays": _decode_masked_arrays,
122
- }
123
-
124
- def __init__(self, workflow: app.Workflow) -> None:
125
- self._metadata = None # cache used in `cached_load` context manager
126
- super().__init__(workflow)
294
+ _task_arr_name = "tasks"
295
+ _elem_arr_name = "elements"
296
+ _iter_arr_name = "iters"
297
+ _EAR_arr_name = "runs"
298
+ _time_res = "us" # microseconds; must not be smaller than micro!
299
+
300
+ _res_map = CommitResourceMap(commit_template_components=("attrs",))
301
+
302
+ def __init__(self, app, workflow, path, fs) -> None:
303
+ self._zarr_store = None # assigned on first access to `zarr_store`
304
+ self._resources = {
305
+ "attrs": ZarrAttrsStoreResource(
306
+ app, name="attrs", open_call=self._get_root_group
307
+ ),
308
+ }
309
+ super().__init__(app, workflow, path, fs)
127
310
 
128
- @classmethod
129
- def path_has_store(cls, path):
130
- return path.joinpath(".zgroup").is_file()
311
+ @contextmanager
312
+ def cached_load(self) -> Iterator[Dict]:
313
+ """Context manager to cache the root attributes."""
314
+ with self.using_resource("attrs", "read") as attrs:
315
+ yield attrs
131
316
 
132
- @property
133
- def store_path(self):
134
- return self.workflow_path
317
+ def remove_replaced_dir(self) -> None:
318
+ with self.using_resource("attrs", "update") as md:
319
+ if "replaced_workflow" in md:
320
+ self.logger.debug("removing temporarily renamed pre-existing workflow.")
321
+ self.remove_path(md["replaced_workflow"], self.fs)
322
+ md["replaced_workflow"] = None
135
323
 
136
- def exists(self) -> bool:
137
- try:
138
- self._get_root_group()
139
- except zarr.errors.PathNotFoundError:
140
- return False
141
- return True
324
+ def reinstate_replaced_dir(self) -> None:
325
+ with self.using_resource("attrs", "read") as md:
326
+ if "replaced_workflow" in md:
327
+ self.logger.debug(
328
+ "reinstating temporarily renamed pre-existing workflow."
329
+ )
330
+ self.rename_path(md["replaced_workflow"], self.path, self.fs)
142
331
 
143
- @property
144
- def has_pending(self) -> bool:
145
- """Returns True if there are pending changes that are not yet committed."""
146
- return any(bool(v) for k, v in self._pending.items() if k != "element_attrs")
147
-
148
- def _get_pending_dct(self) -> Dict:
149
- dct = super()._get_pending_dct()
150
- dct["element_attrs"] = {} # keys are task indices
151
- dct["element_iter_attrs"] = {} # keys are task indices
152
- dct["EAR_attrs"] = {} # keys are task indices
153
- dct["parameter_data"] = 0 # keep number of pending data rather than indices
154
- return dct
332
+ @staticmethod
333
+ def _get_zarr_store(path: str, fs) -> zarr.storage.Store:
334
+ return zarr.storage.FSStore(url=path, fs=fs)
155
335
 
156
336
  @classmethod
157
337
  def write_empty_workflow(
158
338
  cls,
339
+ app,
159
340
  template_js: Dict,
160
341
  template_components_js: Dict,
161
- workflow_path: Path,
162
- replaced_dir: Path,
342
+ wk_path: str,
343
+ fs,
344
+ fs_path: str,
345
+ replaced_wk: str,
163
346
  creation_info: Dict,
164
347
  ) -> None:
165
- metadata = {
348
+ attrs = {
349
+ "fs_path": fs_path,
166
350
  "creation_info": creation_info,
167
351
  "template": template_js,
168
352
  "template_components": template_components_js,
169
353
  "num_added_tasks": 0,
354
+ "tasks": [],
170
355
  "loops": [],
171
356
  "submissions": [],
172
357
  }
173
- if replaced_dir:
174
- metadata["replaced_dir"] = str(replaced_dir.name)
358
+ if replaced_wk:
359
+ attrs["replaced_workflow"] = replaced_wk
175
360
 
176
- store = zarr.DirectoryStore(workflow_path)
361
+ store = cls._get_zarr_store(wk_path, fs)
177
362
  root = zarr.group(store=store, overwrite=False)
178
- root.attrs.update(metadata)
363
+ root.attrs.update(attrs)
364
+
365
+ md = root.create_group("metadata")
366
+
367
+ tasks_arr = md.create_dataset(
368
+ name=cls._task_arr_name,
369
+ shape=0,
370
+ dtype=object,
371
+ object_codec=VLenArray(int),
372
+ )
373
+
374
+ elems_arr = md.create_dataset(
375
+ name=cls._elem_arr_name,
376
+ shape=0,
377
+ dtype=object,
378
+ object_codec=MsgPack(),
379
+ chunks=1000,
380
+ )
381
+ elems_arr.attrs.update({"seq_idx": [], "src_idx": []})
382
+
383
+ elem_iters_arr = md.create_dataset(
384
+ name=cls._iter_arr_name,
385
+ shape=0,
386
+ dtype=object,
387
+ object_codec=MsgPack(),
388
+ chunks=1000,
389
+ )
390
+ elem_iters_arr.attrs.update(
391
+ {
392
+ "loops": [],
393
+ "schema_parameters": [],
394
+ "parameter_paths": [],
395
+ }
396
+ )
397
+
398
+ EARs_arr = md.create_dataset(
399
+ name=cls._EAR_arr_name,
400
+ shape=0,
401
+ dtype=object,
402
+ object_codec=MsgPack(),
403
+ chunks=1, # single-chunk rows for multiprocess writing
404
+ )
405
+ EARs_arr.attrs.update({"parameter_paths": []})
179
406
 
180
- root.create_group(name=cls._elem_grp_name)
181
407
  parameter_data = root.create_group(name=cls._param_grp_name)
182
408
  parameter_data.create_dataset(
183
409
  name=cls._param_base_arr_name,
@@ -195,26 +421,288 @@ class ZarrPersistentStore(PersistentStore):
195
421
  )
196
422
  parameter_data.create_group(name=cls._param_user_arr_grp_name)
197
423
 
198
- def load_metadata(self):
199
- return self._metadata or self._load_metadata()
424
+ def _append_tasks(self, tasks: List[ZarrStoreTask]):
425
+ elem_IDs_arr = self._get_tasks_arr(mode="r+")
426
+ elem_IDs = []
427
+ with self.using_resource("attrs", "update") as attrs:
428
+ for i_idx, i in enumerate(tasks):
429
+ idx, wk_task_i, task_i = i.encode()
430
+ elem_IDs.append(wk_task_i.pop("element_IDs"))
431
+ wk_task_i["element_IDs_idx"] = len(elem_IDs_arr) + i_idx
432
+
433
+ attrs["tasks"].insert(idx, wk_task_i)
434
+ attrs["template"]["tasks"].insert(idx, task_i)
435
+ attrs["num_added_tasks"] += 1
436
+
437
+ # tasks array rows correspond to task IDs, and we assume `tasks` have sequentially
438
+ # increasing IDs.
439
+ append_items_to_ragged_array(arr=elem_IDs_arr, items=elem_IDs)
440
+
441
+ def _append_loops(self, loops: Dict[int, Dict]):
442
+ with self.using_resource("attrs", action="update") as attrs:
443
+ for loop_idx, loop in loops.items():
444
+ attrs["loops"].append(
445
+ {
446
+ "num_added_iterations": loop["num_added_iterations"],
447
+ "iterable_parameters": loop["iterable_parameters"],
448
+ }
449
+ )
450
+ attrs["template"]["loops"].append(loop["loop_template"])
451
+
452
+ def _append_submissions(self, subs: Dict[int, Dict]):
453
+ with self.using_resource("attrs", action="update") as attrs:
454
+ for sub_idx, sub_i in subs.items():
455
+ attrs["submissions"].append(sub_i)
456
+
457
+ def _append_task_element_IDs(self, task_ID: int, elem_IDs: List[int]):
458
+ # I don't think there's a way to "append" to an existing array in a zarr ragged
459
+ # array? So we have to build a new array from existing + new.
460
+ arr = self._get_tasks_arr(mode="r+")
461
+ elem_IDs_cur = arr[task_ID]
462
+ elem_IDs_new = np.concatenate((elem_IDs_cur, elem_IDs))
463
+ arr[task_ID] = elem_IDs_new
464
+
465
+ def _append_elements(self, elems: List[ZarrStoreElement]):
466
+ arr = self._get_elements_arr(mode="r+")
467
+ attrs_orig = arr.attrs.asdict()
468
+ attrs = copy.deepcopy(attrs_orig)
469
+ arr_add = np.empty((len(elems)), dtype=object)
470
+ arr_add[:] = [i.encode(attrs) for i in elems]
471
+ arr.append(arr_add)
472
+ if attrs != attrs_orig:
473
+ arr.attrs.put(attrs)
474
+
475
+ def _append_element_sets(self, task_id: int, es_js: List[Dict]):
476
+ task_idx = task_idx = self._get_task_id_to_idx_map()[task_id]
477
+ with self.using_resource("attrs", "update") as attrs:
478
+ attrs["template"]["tasks"][task_idx]["element_sets"].extend(es_js)
479
+
480
+ def _append_elem_iter_IDs(self, elem_ID: int, iter_IDs: List[int]):
481
+ arr = self._get_elements_arr(mode="r+")
482
+ attrs = arr.attrs.asdict()
483
+ elem_dat = arr[elem_ID]
484
+ store_elem = ZarrStoreElement.decode(elem_dat, attrs)
485
+ store_elem = store_elem.append_iteration_IDs(iter_IDs)
486
+ arr[elem_ID] = store_elem.encode(
487
+ attrs
488
+ ) # attrs shouldn't be mutated (TODO: test!)
489
+
490
+ def _append_elem_iters(self, iters: List[ZarrStoreElementIter]):
491
+ arr = self._get_iters_arr(mode="r+")
492
+ attrs_orig = arr.attrs.asdict()
493
+ attrs = copy.deepcopy(attrs_orig)
494
+ arr_add = np.empty((len(iters)), dtype=object)
495
+ arr_add[:] = [i.encode(attrs) for i in iters]
496
+ arr.append(arr_add)
497
+ if attrs != attrs_orig:
498
+ arr.attrs.put(attrs)
499
+
500
+ def _append_elem_iter_EAR_IDs(self, iter_ID: int, act_idx: int, EAR_IDs: List[int]):
501
+ arr = self._get_iters_arr(mode="r+")
502
+ attrs = arr.attrs.asdict()
503
+ iter_dat = arr[iter_ID]
504
+ store_iter = ZarrStoreElementIter.decode(iter_dat, attrs)
505
+ store_iter = store_iter.append_EAR_IDs(pend_IDs={act_idx: EAR_IDs})
506
+ arr[iter_ID] = store_iter.encode(
507
+ attrs
508
+ ) # attrs shouldn't be mutated (TODO: test!)
509
+
510
+ def _append_submission_attempts(self, sub_attempts: Dict[int, List[int]]):
511
+ with self.using_resource("attrs", action="update") as attrs:
512
+ for sub_idx, attempts_i in sub_attempts.items():
513
+ attrs["submissions"][sub_idx]["submission_attempts"].extend(attempts_i)
514
+
515
+ def _update_loop_index(self, iter_ID: int, loop_idx: Dict):
516
+ arr = self._get_iters_arr(mode="r+")
517
+ attrs = arr.attrs.asdict()
518
+ iter_dat = arr[iter_ID]
519
+ store_iter = ZarrStoreElementIter.decode(iter_dat, attrs)
520
+ store_iter = store_iter.update_loop_idx(loop_idx)
521
+ arr[iter_ID] = store_iter.encode(attrs)
522
+
523
+ def _update_loop_num_iters(self, index: int, num_iters: int):
524
+ with self.using_resource("attrs", action="update") as attrs:
525
+ attrs["loops"][index]["num_added_iterations"] = num_iters
526
+
527
+ def _append_EARs(self, EARs: List[ZarrStoreEAR]):
528
+ arr = self._get_EARs_arr(mode="r+")
529
+ attrs_orig = arr.attrs.asdict()
530
+ attrs = copy.deepcopy(attrs_orig)
531
+ arr_add = np.empty((len(EARs)), dtype=object)
532
+ arr_add[:] = [i.encode(attrs, self.ts_fmt) for i in EARs]
533
+ arr.append(arr_add)
534
+
535
+ if attrs != attrs_orig:
536
+ arr.attrs.put(attrs)
537
+
538
+ def _update_EAR_submission_index(self, EAR_id: int, sub_idx: int):
539
+ arr = self._get_EARs_arr(mode="r+")
540
+ attrs_orig = arr.attrs.asdict()
541
+ attrs = copy.deepcopy(attrs_orig)
542
+
543
+ EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
544
+ EAR_i = EAR_i.update(submission_idx=sub_idx)
545
+ arr[EAR_id] = EAR_i.encode(attrs, self.ts_fmt)
546
+
547
+ if attrs != attrs_orig:
548
+ arr.attrs.put(attrs)
549
+
550
+ def _update_EAR_start(self, EAR_id: int, s_time: datetime, s_snap: Dict):
551
+ arr = self._get_EARs_arr(mode="r+")
552
+ attrs_orig = arr.attrs.asdict()
553
+ attrs = copy.deepcopy(attrs_orig)
554
+
555
+ EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
556
+ EAR_i = EAR_i.update(
557
+ start_time=s_time,
558
+ snapshot_start=s_snap,
559
+ )
560
+ arr[EAR_id] = EAR_i.encode(attrs, self.ts_fmt)
561
+
562
+ if attrs != attrs_orig:
563
+ arr.attrs.put(attrs)
564
+
565
+ def _update_EAR_end(
566
+ self, EAR_id: int, e_time: datetime, e_snap: Dict, ext_code: int, success: bool
567
+ ):
568
+ arr = self._get_EARs_arr(mode="r+")
569
+ attrs_orig = arr.attrs.asdict()
570
+ attrs = copy.deepcopy(attrs_orig)
571
+
572
+ EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
573
+ EAR_i = EAR_i.update(
574
+ end_time=e_time,
575
+ snapshot_end=e_snap,
576
+ exit_code=ext_code,
577
+ success=success,
578
+ )
579
+ arr[EAR_id] = EAR_i.encode(attrs, self.ts_fmt)
580
+
581
+ if attrs != attrs_orig:
582
+ arr.attrs.put(attrs)
583
+
584
+ def _update_EAR_skip(self, EAR_id: int):
585
+ arr = self._get_EARs_arr(mode="r+")
586
+ attrs_orig = arr.attrs.asdict()
587
+ attrs = copy.deepcopy(attrs_orig)
588
+
589
+ EAR_i = self._get_persistent_EARs([EAR_id])[EAR_id]
590
+ EAR_i = EAR_i.update(skip=True)
591
+ arr[EAR_id] = EAR_i.encode(attrs, self.ts_fmt)
592
+
593
+ if attrs != attrs_orig:
594
+ arr.attrs.put(attrs)
595
+
596
+ def _update_jobscript_version_info(self, vers_info: Dict):
597
+ with self.using_resource("attrs", action="update") as attrs:
598
+ for sub_idx, js_vers_info in vers_info.items():
599
+ for js_idx, vers_info_i in js_vers_info.items():
600
+ attrs["submissions"][sub_idx]["jobscripts"][js_idx][
601
+ "version_info"
602
+ ] = vers_info_i
603
+
604
+ def _update_jobscript_submit_time(self, sub_times: Dict):
605
+ with self.using_resource("attrs", action="update") as attrs:
606
+ for sub_idx, js_sub_times in sub_times.items():
607
+ for js_idx, sub_time_i in js_sub_times.items():
608
+ sub_time_fmt = sub_time_i.strftime(self.ts_fmt)
609
+ attrs["submissions"][sub_idx]["jobscripts"][js_idx][
610
+ "submit_time"
611
+ ] = sub_time_fmt
612
+
613
+ def _update_jobscript_job_ID(self, job_IDs: Dict):
614
+ with self.using_resource("attrs", action="update") as attrs:
615
+ for sub_idx, js_job_IDs in job_IDs.items():
616
+ for js_idx, job_ID_i in js_job_IDs.items():
617
+ attrs["submissions"][sub_idx]["jobscripts"][js_idx][
618
+ "scheduler_job_ID"
619
+ ] = job_ID_i
620
+
621
+ def _append_parameters(self, params: List[ZarrStoreParameter]):
622
+ """Add new persistent parameters."""
623
+ base_arr = self._get_parameter_base_array(mode="r+")
624
+ src_arr = self._get_parameter_sources_array(mode="r+")
625
+ for param_i in params:
626
+ dat_i = param_i.encode(
627
+ root_group=self._get_parameter_user_array_group(mode="r+"),
628
+ arr_path=self._param_data_arr_grp_name(param_i.id_),
629
+ )
630
+ base_arr.append([dat_i])
631
+ src_arr.append([dict(sorted(param_i.source.items()))])
200
632
 
201
- def _load_metadata(self):
202
- return self._get_root_group(mode="r").attrs.asdict()
633
+ def _set_parameter_value(self, param_id: int, value: Any, is_file: bool):
634
+ """Set an unset persistent parameter."""
203
635
 
204
- @contextmanager
205
- def cached_load(self) -> Iterator[Dict]:
206
- """Context manager to cache the root attributes (i.e. metadata)."""
207
- if self._metadata:
208
- yield
636
+ # the `decode` call in `_get_persistent_parameters` should be quick:
637
+ param_i = self._get_persistent_parameters([param_id])[param_id]
638
+ if is_file:
639
+ param_i = param_i.set_file(value)
209
640
  else:
210
- try:
211
- self._metadata = self._load_metadata()
212
- yield
213
- finally:
214
- self._metadata = None
641
+ param_i = param_i.set_data(value)
642
+ dat_i = param_i.encode(
643
+ root_group=self._get_parameter_user_array_group(mode="r+"),
644
+ arr_path=self._param_data_arr_grp_name(param_i.id_),
645
+ )
646
+
647
+ # no need to update sources array:
648
+ base_arr = self._get_parameter_base_array(mode="r+")
649
+ base_arr[param_id] = dat_i
650
+
651
+ def _update_parameter_source(self, param_id: int, src: Dict):
652
+ """Update the source of a persistent parameter."""
653
+
654
+ param_i = self._get_persistent_parameters([param_id])[param_id]
655
+ param_i = param_i.update_source(src)
656
+
657
+ # no need to update base array:
658
+ src_arr = self._get_parameter_sources_array(mode="r+")
659
+ src_arr[param_id] = param_i.source
660
+
661
+ def _update_template_components(self, tc: Dict):
662
+ with self.using_resource("attrs", "update") as md:
663
+ md["template_components"] = tc
664
+
665
+ def _get_num_persistent_tasks(self) -> int:
666
+ """Get the number of persistent elements."""
667
+ return len(self._get_tasks_arr())
668
+
669
+ def _get_num_persistent_loops(self) -> int:
670
+ """Get the number of persistent loops."""
671
+ with self.using_resource("attrs", "read") as attrs:
672
+ return len(attrs["loops"])
673
+
674
+ def _get_num_persistent_submissions(self) -> int:
675
+ """Get the number of persistent submissions."""
676
+ with self.using_resource("attrs", "read") as attrs:
677
+ return len(attrs["submissions"])
678
+
679
+ def _get_num_persistent_elements(self) -> int:
680
+ """Get the number of persistent elements."""
681
+ return len(self._get_elements_arr())
682
+
683
+ def _get_num_persistent_elem_iters(self) -> int:
684
+ """Get the number of persistent element iterations."""
685
+ return len(self._get_iters_arr())
686
+
687
+ def _get_num_persistent_EARs(self) -> int:
688
+ """Get the number of persistent EARs."""
689
+ return len(self._get_EARs_arr())
690
+
691
+ def _get_num_persistent_parameters(self):
692
+ return len(self._get_parameter_base_array())
693
+
694
+ def _get_num_persistent_added_tasks(self):
695
+ with self.using_resource("attrs", "read") as attrs:
696
+ return attrs["num_added_tasks"]
697
+
698
+ @property
699
+ def zarr_store(self) -> zarr.storage.Store:
700
+ if self._zarr_store is None:
701
+ self._zarr_store = self._get_zarr_store(self.path, self.fs)
702
+ return self._zarr_store
215
703
 
216
704
  def _get_root_group(self, mode: str = "r") -> zarr.Group:
217
- return zarr.open(self.workflow.path, mode=mode)
705
+ return zarr.open(self.zarr_store, mode=mode)
218
706
 
219
707
  def _get_parameter_group(self, mode: str = "r") -> zarr.Group:
220
708
  return self._get_root_group(mode=mode).get(self._param_grp_name)
@@ -237,793 +725,316 @@ class ZarrPersistentStore(PersistentStore):
237
725
  self._param_data_arr_grp_name(parameter_idx)
238
726
  )
239
727
 
240
- def _get_element_group(self, mode: str = "r") -> zarr.Group:
241
- return self._get_root_group(mode=mode).get(self._elem_grp_name)
728
+ def _get_metadata_group(self, mode: str = "r") -> zarr.Group:
729
+ return self._get_root_group(mode=mode).get("metadata")
242
730
 
243
- def _get_task_group_path(self, insert_ID: int) -> str:
244
- return self._task_grp_name(insert_ID)
731
+ def _get_tasks_arr(self, mode: str = "r") -> zarr.Array:
732
+ return self._get_metadata_group(mode=mode).get(self._task_arr_name)
245
733
 
246
- def _get_task_group(self, insert_ID: int, mode: str = "r") -> zarr.Group:
247
- return self._get_element_group(mode=mode).get(self._task_grp_name(insert_ID))
734
+ def _get_elements_arr(self, mode: str = "r") -> zarr.Array:
735
+ return self._get_metadata_group(mode=mode).get(self._elem_arr_name)
248
736
 
249
- def _get_task_elements_array(self, insert_ID: int, mode: str = "r") -> zarr.Array:
250
- return self._get_task_group(insert_ID, mode=mode).get(self._task_elem_arr_name)
737
+ def _get_iters_arr(self, mode: str = "r") -> zarr.Array:
738
+ return self._get_metadata_group(mode=mode).get(self._iter_arr_name)
251
739
 
252
- def _get_task_elem_iters_array(self, insert_ID: int, mode: str = "r") -> zarr.Array:
253
- return self._get_task_group(insert_ID, mode=mode).get(
254
- self._task_elem_iter_arr_name
255
- )
740
+ def _get_EARs_arr(self, mode: str = "r") -> zarr.Array:
741
+ return self._get_metadata_group(mode=mode).get(self._EAR_arr_name)
256
742
 
257
- def _get_task_EAR_times_array(self, insert_ID: int, mode: str = "r") -> zarr.Array:
258
- return self._get_task_group(insert_ID, mode=mode).get(
259
- self._task_EAR_times_arr_name
743
+ @classmethod
744
+ def make_test_store_from_spec(
745
+ cls,
746
+ spec,
747
+ dir=None,
748
+ path="test_store",
749
+ overwrite=False,
750
+ ):
751
+ """Generate an store for testing purposes."""
752
+
753
+ path = Path(dir or "", path)
754
+ store = zarr.DirectoryStore(path)
755
+ root = zarr.group(store=store, overwrite=overwrite)
756
+ md = root.create_group("metadata")
757
+
758
+ tasks_arr = md.create_dataset(
759
+ name=cls._task_arr_name,
760
+ shape=0,
761
+ dtype=object,
762
+ object_codec=VLenArray(int),
260
763
  )
261
764
 
262
- def _get_task_element_attrs(self, task_idx: int, task_insert_ID: int) -> Dict:
263
- if task_idx in self._pending["element_attrs"]:
264
- attrs = self._pending["element_attrs"][task_idx]
265
- elif task_idx in self._pending["tasks"]:
266
- # the task is new and not yet committed
267
- attrs = self._get_element_array_empty_attrs()
268
- else:
269
- attrs = self._get_task_elements_array(task_insert_ID, mode="r").attrs
270
- attrs = attrs.asdict()
271
- return attrs
272
-
273
- def _get_task_element_iter_attrs(self, task_idx: int, task_insert_ID: int) -> Dict:
274
- if task_idx in self._pending["element_iter_attrs"]:
275
- attrs = self._pending["element_iter_attrs"][task_idx]
276
- elif task_idx in self._pending["tasks"]:
277
- # the task is new and not yet committed
278
- attrs = self._get_element_iter_array_empty_attrs()
279
- else:
280
- attrs = self._get_task_elem_iters_array(task_insert_ID, mode="r").attrs
281
- attrs = attrs.asdict()
282
- return attrs
283
-
284
- def add_elements(
285
- self,
286
- task_idx: int,
287
- task_insert_ID: int,
288
- elements: List[Dict],
289
- element_iterations: List[Dict],
290
- ) -> None:
291
- attrs_original = self._get_task_element_attrs(task_idx, task_insert_ID)
292
- elements, attrs = self._compress_elements(elements, attrs_original)
293
- if attrs != attrs_original:
294
- if task_idx not in self._pending["element_attrs"]:
295
- self._pending["element_attrs"][task_idx] = {}
296
- self._pending["element_attrs"][task_idx].update(attrs)
297
-
298
- iter_attrs_original = self._get_task_element_iter_attrs(task_idx, task_insert_ID)
299
- element_iters, iter_attrs = self._compress_element_iters(
300
- element_iterations, iter_attrs_original
765
+ elems_arr = md.create_dataset(
766
+ name=cls._elem_arr_name,
767
+ shape=0,
768
+ dtype=object,
769
+ object_codec=MsgPack(),
770
+ chunks=1000,
301
771
  )
302
- if iter_attrs != iter_attrs_original:
303
- if task_idx not in self._pending["element_iter_attrs"]:
304
- self._pending["element_iter_attrs"][task_idx] = {}
305
- self._pending["element_iter_attrs"][task_idx].update(iter_attrs)
306
-
307
- return super().add_elements(task_idx, task_insert_ID, elements, element_iters)
772
+ elems_arr.attrs.update({"seq_idx": [], "src_idx": []})
308
773
 
309
- def add_element_iterations(
310
- self,
311
- task_idx: int,
312
- task_insert_ID: int,
313
- element_iterations: List[Dict],
314
- element_iters_idx: Dict[int, List[int]],
315
- ) -> None:
316
- iter_attrs_original = self._get_task_element_iter_attrs(task_idx, task_insert_ID)
317
- element_iters, iter_attrs = self._compress_element_iters(
318
- element_iterations, iter_attrs_original
774
+ elem_iters_arr = md.create_dataset(
775
+ name=cls._iter_arr_name,
776
+ shape=0,
777
+ dtype=object,
778
+ object_codec=MsgPack(),
779
+ chunks=1000,
319
780
  )
320
- if iter_attrs != iter_attrs_original:
321
- if task_idx not in self._pending["element_iter_attrs"]:
322
- self._pending["element_iter_attrs"][task_idx] = {}
323
- self._pending["element_iter_attrs"][task_idx].update(iter_attrs)
324
-
325
- return super().add_element_iterations(
326
- task_idx,
327
- task_insert_ID,
328
- element_iters,
329
- element_iters_idx,
781
+ elem_iters_arr.attrs.update(
782
+ {
783
+ "loops": [],
784
+ "schema_parameters": [],
785
+ "parameter_paths": [],
786
+ }
330
787
  )
331
788
 
332
- def add_EARs(
333
- self,
334
- task_idx: int,
335
- task_insert_ID: int,
336
- element_iter_idx: int,
337
- EARs: Dict,
338
- param_src_updates: Dict,
339
- ) -> None:
340
- iter_attrs_original = self._get_task_element_iter_attrs(task_idx, task_insert_ID)
341
- EARs, iter_attrs = self._compress_EARs(EARs, iter_attrs_original)
342
- if iter_attrs != iter_attrs_original:
343
- if task_idx not in self._pending["element_iter_attrs"]:
344
- self._pending["element_iter_attrs"][task_idx] = {}
345
- self._pending["element_iter_attrs"][task_idx].update(iter_attrs)
346
-
347
- key = (task_idx, task_insert_ID, element_iter_idx)
348
- if key not in self._pending["EARs"]:
349
- self._pending["EARs"][key] = []
350
- self._pending["EARs"][key].extend(EARs)
351
- self._pending["parameter_source_updates"].update(param_src_updates)
352
- self.save()
353
-
354
- def _compress_elements(self, elements: List, attrs: Dict) -> Tuple[List, Dict]:
355
- """Split element data into lists of integers and lookup lists to effectively
356
- compress the data.
357
-
358
- See also: `_decompress_elements` for the inverse operation.
359
-
360
- """
361
-
362
- attrs = copy.deepcopy(attrs)
363
- compressed = []
364
- for elem in elements:
365
- seq_idx = [
366
- [ensure_in(k, attrs["seq_idx"]), v] for k, v in elem["seq_idx"].items()
367
- ]
368
- src_idx = [
369
- [ensure_in(k, attrs["src_idx"]), v] for k, v in elem["src_idx"].items()
370
- ]
371
- compressed.append(
372
- [
373
- elem["iterations_idx"],
374
- elem["es_idx"],
375
- seq_idx,
376
- src_idx,
377
- ]
378
- )
379
- return compressed, attrs
380
-
381
- def _compress_element_iters(
382
- self, element_iters: List, attrs: Dict
383
- ) -> Tuple[List, Dict]:
384
- """Split element iteration data into lists of integers and lookup lists to
385
- effectively compress the data.
386
-
387
- See also: `_decompress_element_iters` for the inverse operation.
388
-
389
- """
390
-
391
- attrs = copy.deepcopy(attrs)
392
- compressed = []
393
- for iter_i in element_iters:
394
- loop_idx = [
395
- [ensure_in(k, attrs["loops"]), v] for k, v in iter_i["loop_idx"].items()
396
- ]
397
- schema_params = [
398
- ensure_in(k, attrs["schema_parameters"])
399
- for k in iter_i["schema_parameters"]
400
- ]
401
- data_idx = [
402
- [ensure_in(dk, attrs["parameter_paths"]), dv]
403
- for dk, dv in iter_i["data_idx"].items()
404
- ]
405
-
406
- EARs, attrs = self._compress_EARs(iter_i["actions"], attrs)
407
- compact = [
408
- iter_i["global_idx"],
409
- data_idx,
410
- int(iter_i["EARs_initialised"]),
411
- schema_params,
412
- loop_idx,
413
- EARs,
414
- ]
415
- compressed.append(compact)
416
- return compressed, attrs
417
-
418
- def _compress_EARs(self, EARs: Dict, attrs: Dict) -> List:
419
- """Split EAR data into lists of integers and lookup lists to effectively compress
420
- the data.
421
-
422
- See also: `_decompress_EARs` for the inverse operation.
789
+ EARs_arr = md.create_dataset(
790
+ name=cls._EAR_arr_name,
791
+ shape=0,
792
+ dtype=object,
793
+ object_codec=MsgPack(),
794
+ chunks=1000,
795
+ )
796
+ EARs_arr.attrs.update({"parameter_paths": []})
797
+
798
+ tasks, elems, elem_iters, EARs = super().prepare_test_store_from_spec(spec)
799
+
800
+ path = Path(path).resolve()
801
+ tasks = [ZarrStoreTask(**i).encode() for i in tasks]
802
+ elements = [ZarrStoreElement(**i).encode(elems_arr.attrs.asdict()) for i in elems]
803
+ elem_iters = [
804
+ ZarrStoreElementIter(**i).encode(elem_iters_arr.attrs.asdict())
805
+ for i in elem_iters
806
+ ]
807
+ EARs = [ZarrStoreEAR(**i).encode(EARs_arr.attrs.asdict()) for i in EARs]
808
+
809
+ append_items_to_ragged_array(tasks_arr, tasks)
810
+
811
+ elem_arr_add = np.empty((len(elements)), dtype=object)
812
+ elem_arr_add[:] = elements
813
+ elems_arr.append(elem_arr_add)
814
+
815
+ iter_arr_add = np.empty((len(elem_iters)), dtype=object)
816
+ iter_arr_add[:] = elem_iters
817
+ elem_iters_arr.append(iter_arr_add)
818
+
819
+ EAR_arr_add = np.empty((len(EARs)), dtype=object)
820
+ EAR_arr_add[:] = EARs
821
+ EARs_arr.append(EAR_arr_add)
822
+
823
+ return cls(path)
824
+
825
+ def _get_persistent_template_components(self):
826
+ with self.using_resource("attrs", "read") as attrs:
827
+ return attrs["template_components"]
828
+
829
+ def _get_persistent_template(self):
830
+ with self.using_resource("attrs", "read") as attrs:
831
+ return attrs["template"]
832
+
833
+ def _get_persistent_tasks(
834
+ self, id_lst: Optional[Iterable[int]] = None
835
+ ) -> Dict[int, ZarrStoreTask]:
836
+ with self.using_resource("attrs", action="read") as attrs:
837
+ task_dat = {}
838
+ elem_IDs = []
839
+ for idx, i in enumerate(attrs["tasks"]):
840
+ i = copy.deepcopy(i)
841
+ elem_IDs.append(i.pop("element_IDs_idx"))
842
+ if id_lst is None or i["id_"] in id_lst:
843
+ task_dat[i["id_"]] = {**i, "index": idx}
844
+ if task_dat:
845
+ try:
846
+ elem_IDs_arr_dat = self._get_tasks_arr().get_coordinate_selection(
847
+ elem_IDs
848
+ )
849
+ except zarr.errors.BoundsCheckError:
850
+ raise MissingStoreTaskError(elem_IDs) from None # TODO: not an ID list
423
851
 
424
- """
425
- attrs = copy.deepcopy(attrs)
426
- compressed = []
427
- for act_idx, runs in EARs.items():
428
- act_run_i = [
429
- act_idx,
430
- [
431
- [
432
- run["index"], # TODO: is this needed?
433
- -1
434
- if run["metadata"]["submission_idx"] is None
435
- else run["metadata"]["submission_idx"],
436
- -1
437
- if run["metadata"]["success"] is None
438
- else int(run["metadata"]["success"]),
439
- [
440
- [ensure_in(dk, attrs["parameter_paths"]), dv]
441
- for dk, dv in run["data_idx"].items()
442
- ],
443
- ]
444
- for run in runs
445
- ],
446
- ]
447
- compressed.append(act_run_i)
448
- return compressed, attrs
449
-
450
- def _decompress_elements(self, elements: List, attrs: Dict) -> List:
451
- out = []
452
- for elem in elements:
453
- elem_i = {
454
- "iterations_idx": elem[0],
455
- "es_idx": elem[1],
456
- "seq_idx": {attrs["seq_idx"][k]: v for (k, v) in elem[2]},
457
- "src_idx": {attrs["src_idx"][k]: v for (k, v) in elem[3]},
852
+ return {
853
+ id_: ZarrStoreTask.decode({**i, "element_IDs": elem_IDs_arr_dat[id_]})
854
+ for idx, (id_, i) in enumerate(task_dat.items())
458
855
  }
459
- out.append(elem_i)
460
- return out
461
-
462
- def _decompress_element_iters(self, element_iters: List, attrs: Dict) -> List:
463
- out = []
464
- for iter_i in element_iters:
465
- iter_i_decomp = {
466
- "global_idx": iter_i[0],
467
- "data_idx": {attrs["parameter_paths"][k]: v for (k, v) in iter_i[1]},
468
- "EARs_initialised": bool(iter_i[2]),
469
- "schema_parameters": [attrs["schema_parameters"][k] for k in iter_i[3]],
470
- "loop_idx": {attrs["loops"][k]: v for (k, v) in iter_i[4]},
471
- "actions": self._decompress_EARs(iter_i[5], attrs),
856
+ else:
857
+ return {}
858
+
859
+ def _get_persistent_loops(self, id_lst: Optional[Iterable[int]] = None):
860
+ with self.using_resource("attrs", "read") as attrs:
861
+ loop_dat = {
862
+ idx: i
863
+ for idx, i in enumerate(attrs["loops"])
864
+ if id_lst is None or idx in id_lst
472
865
  }
473
- out.append(iter_i_decomp)
474
- return out
866
+ return loop_dat
475
867
 
476
- def _decompress_EARs(self, EARs: List, attrs: Dict) -> List:
477
- out = {
478
- act_idx: [
868
+ def _get_persistent_submissions(self, id_lst: Optional[Iterable[int]] = None):
869
+ with self.using_resource("attrs", "read") as attrs:
870
+ subs_dat = copy.deepcopy(
479
871
  {
480
- "index": run[0],
481
- "metadata": {
482
- "submission_idx": None if run[1] == -1 else run[1],
483
- "success": None if run[2] == -1 else bool(run[2]),
484
- },
485
- "data_idx": {attrs["parameter_paths"][k]: v for (k, v) in run[3]},
872
+ idx: i
873
+ for idx, i in enumerate(attrs["submissions"])
874
+ if id_lst is None or idx in id_lst
486
875
  }
487
- for run in runs
488
- ]
489
- for (act_idx, runs) in EARs
490
- }
491
- return out
492
-
493
- @staticmethod
494
- def _get_element_array_empty_attrs() -> Dict:
495
- return {"seq_idx": [], "src_idx": []}
496
-
497
- @staticmethod
498
- def _get_element_iter_array_empty_attrs() -> Dict:
499
- return {
500
- "loops": [],
501
- "schema_parameters": [],
502
- "parameter_paths": [],
503
- }
504
-
505
- def _get_zarr_store(self):
506
- return self._get_root_group().store
507
-
508
- def _remove_pending_parameter_data(self) -> None:
509
- """Delete pending parameter data from disk."""
510
- base = self._get_parameter_base_array(mode="r+")
511
- sources = self._get_parameter_sources_array(mode="r+")
512
- for param_idx in range(self._pending["parameter_data"], 0, -1):
513
- grp = self._get_parameter_data_array_group(param_idx - 1)
514
- if grp:
515
- zarr.storage.rmdir(store=self._get_zarr_store(), path=grp.path)
516
- base.resize(base.size - self._pending["parameter_data"])
517
- sources.resize(sources.size - self._pending["parameter_data"])
518
-
519
- def reject_pending(self) -> None:
520
- if self._pending["parameter_data"]:
521
- self._remove_pending_parameter_data()
522
- super().reject_pending()
523
-
524
- def commit_pending(self) -> None:
525
- md = self.load_metadata()
526
-
527
- # merge new tasks:
528
- for task_idx, task_js in self._pending["template_tasks"].items():
529
- md["template"]["tasks"].insert(task_idx, task_js) # TODO should be index?
530
-
531
- # write new workflow tasks to disk:
532
- for task_idx, _ in self._pending["tasks"].items():
533
- insert_ID = self._pending["template_tasks"][task_idx]["insert_ID"]
534
- task_group = self._get_element_group(mode="r+").create_group(
535
- self._get_task_group_path(insert_ID)
536
- )
537
- element_arr = task_group.create_dataset(
538
- name=self._task_elem_arr_name,
539
- shape=0,
540
- dtype=object,
541
- object_codec=MsgPack(),
542
- chunks=1000, # TODO: check this is a sensible size with many elements
543
876
  )
544
- element_arr.attrs.update(self._get_element_array_empty_attrs())
545
- element_iters_arr = task_group.create_dataset(
546
- name=self._task_elem_iter_arr_name,
547
- shape=0,
548
- dtype=object,
549
- object_codec=MsgPack(),
550
- chunks=1000, # TODO: check this is a sensible size with many elements
551
- )
552
- element_iters_arr.attrs.update(self._get_element_iter_array_empty_attrs())
553
- with warnings.catch_warnings():
554
- warnings.simplefilter("ignore", DeprecationWarning)
555
- # zarr (2.14.2, at least) compares the fill value to zero, which, due to
556
- # this numpy bug https://github.com/numpy/numpy/issues/13548, issues a
557
- # DeprecationWarning. This bug is fixed in numpy 1.25
558
- # (https://github.com/numpy/numpy/pull/22707), which has a minimum python
559
- # version of 3.9. So for now, we will suppress it.
560
- EAR_times_arr = task_group.create_dataset(
561
- name=self._task_EAR_times_arr_name,
562
- shape=(0, 2),
563
- dtype="M8[us]", # microsecond resolution
564
- fill_value=np.datetime64("NaT"),
565
- chunks=(1, None), # single-chunk for multiprocess writing
566
- )
877
+ # cast jobscript submit-times and jobscript `task_elements` keys:
878
+ for sub_idx, sub in subs_dat.items():
879
+ for js_idx, js in enumerate(sub["jobscripts"]):
880
+ if js["submit_time"]:
881
+ subs_dat[sub_idx]["jobscripts"][js_idx][
882
+ "submit_time"
883
+ ] = datetime.strptime(js["submit_time"], self.ts_fmt)
884
+
885
+ for key in list(js["task_elements"].keys()):
886
+ subs_dat[sub_idx]["jobscripts"][js_idx]["task_elements"][
887
+ int(key)
888
+ ] = subs_dat[sub_idx]["jobscripts"][js_idx]["task_elements"].pop(
889
+ key
890
+ )
891
+
892
+ return subs_dat
893
+
894
+ def _get_persistent_elements(
895
+ self, id_lst: Iterable[int]
896
+ ) -> Dict[int, ZarrStoreElement]:
897
+ arr = self._get_elements_arr()
898
+ attrs = arr.attrs.asdict()
899
+ try:
900
+ elem_arr_dat = arr.get_coordinate_selection(list(id_lst))
901
+ except zarr.errors.BoundsCheckError:
902
+ raise MissingStoreElementError(id_lst) from None
903
+ elem_dat = dict(zip(id_lst, elem_arr_dat))
904
+ return {k: ZarrStoreElement.decode(v, attrs) for k, v in elem_dat.items()}
905
+
906
+ def _get_persistent_element_iters(
907
+ self, id_lst: Iterable[int]
908
+ ) -> Dict[int, StoreElementIter]:
909
+ arr = self._get_iters_arr()
910
+ attrs = arr.attrs.asdict()
911
+ try:
912
+ iter_arr_dat = arr.get_coordinate_selection(list(id_lst))
913
+ except zarr.errors.BoundsCheckError:
914
+ raise MissingStoreElementIterationError(id_lst) from None
915
+ iter_dat = dict(zip(id_lst, iter_arr_dat))
916
+ return {k: ZarrStoreElementIter.decode(v, attrs) for k, v in iter_dat.items()}
567
917
 
568
- md["num_added_tasks"] += 1
569
-
570
- # merge new template components:
571
- self._merge_pending_template_components(md["template_components"])
572
-
573
- # merge new element sets:
574
- for task_idx, es_js in self._pending["element_sets"].items():
575
- md["template"]["tasks"][task_idx]["element_sets"].extend(es_js)
576
-
577
- # write new elements to disk:
578
- for (task_idx, insert_ID), elements in self._pending["elements"].items():
579
- elem_arr = self._get_task_elements_array(insert_ID, mode="r+")
580
- elem_arr_add = np.empty((len(elements)), dtype=object)
581
- elem_arr_add[:] = elements
582
- elem_arr.append(elem_arr_add)
583
- if task_idx in self._pending["element_attrs"]:
584
- elem_arr.attrs.put(self._pending["element_attrs"][task_idx])
585
-
586
- for (_, insert_ID), iters_idx in self._pending["element_iterations_idx"].items():
587
- elem_arr = self._get_task_elements_array(insert_ID, mode="r+")
588
- for elem_idx, iters_idx_i in iters_idx.items():
589
- elem_dat = elem_arr[elem_idx]
590
- elem_dat[0] += iters_idx_i
591
- elem_arr[elem_idx] = elem_dat
592
-
593
- # commit new element iterations:
594
- for (task_idx, insert_ID), element_iters in self._pending[
595
- "element_iterations"
596
- ].items():
597
- elem_iter_arr = self._get_task_elem_iters_array(insert_ID, mode="r+")
598
- elem_iter_arr_add = np.empty(len(element_iters), dtype=object)
599
- elem_iter_arr_add[:] = element_iters
600
- elem_iter_arr.append(elem_iter_arr_add)
601
- if task_idx in self._pending["element_iter_attrs"]:
602
- elem_iter_arr.attrs.put(self._pending["element_iter_attrs"][task_idx])
603
-
604
- # commit new element iteration loop indices:
605
- for (_, insert_ID, iters_idx_i), loop_idx_i in self._pending["loop_idx"].items():
606
- elem_iter_arr = self._get_task_elem_iters_array(insert_ID, mode="r+")
607
- iter_dat = elem_iter_arr[iters_idx_i]
608
- iter_dat[4].extend(loop_idx_i)
609
- elem_iter_arr[iters_idx_i] = iter_dat
610
-
611
- # commit new element iteration EARs:
612
- for (_, insert_ID, iters_idx_i), actions_i in self._pending["EARs"].items():
613
- elem_iter_arr = self._get_task_elem_iters_array(insert_ID, mode="r+")
614
- iter_dat = elem_iter_arr[iters_idx_i]
615
- iter_dat[5].extend(actions_i)
616
- iter_dat[2] = int(True) # EARs_initialised
617
- elem_iter_arr[iters_idx_i] = iter_dat
618
-
619
- EAR_times_arr = self._get_task_EAR_times_array(insert_ID, mode="r+")
620
- new_shape = (EAR_times_arr.shape[0] + len(actions_i), EAR_times_arr.shape[1])
621
- EAR_times_arr.resize(new_shape)
622
-
623
- # commit new EAR submission indices:
624
- for (ins_ID, it_idx, act_idx, rn_idx), sub_idx in self._pending[
625
- "EAR_submission_idx"
626
- ].items():
627
- elem_iter_arr = self._get_task_elem_iters_array(ins_ID, mode="r+")
628
- iter_dat = elem_iter_arr[it_idx]
629
- iter_dat[5][act_idx][1][rn_idx][1] = sub_idx
630
- elem_iter_arr[it_idx] = iter_dat
631
-
632
- # commit new EAR start times:
633
- for (ins_ID, it_idx, act_idx, rn_idx), start in self._pending[
634
- "EAR_start_times"
635
- ].items():
636
- elem_iter_arr = self._get_task_elem_iters_array(ins_ID, mode="r+")
637
- iter_dat = elem_iter_arr[it_idx]
638
- for act_idx_i, runs in iter_dat[5]:
639
- if act_idx_i == act_idx:
640
- EAR_idx = runs[rn_idx][0]
641
- EAR_times_arr = self._get_task_EAR_times_array(ins_ID, mode="r+")
642
- EAR_times_arr[EAR_idx, 0] = start
643
-
644
- # commit new EAR end times:
645
- for (ins_ID, it_idx, act_idx, rn_idx), end in self._pending[
646
- "EAR_end_times"
647
- ].items():
648
- elem_iter_arr = self._get_task_elem_iters_array(ins_ID, mode="r+")
649
- iter_dat = elem_iter_arr[it_idx]
650
- for act_idx_i, runs in iter_dat[5]:
651
- if act_idx_i == act_idx:
652
- EAR_idx = runs[rn_idx][0]
653
- EAR_times_arr = self._get_task_EAR_times_array(ins_ID, mode="r+")
654
- EAR_times_arr[EAR_idx, 1] = end
655
-
656
- # commit new loops:
657
- md["template"]["loops"].extend(self._pending["template_loops"])
658
-
659
- # commit new workflow loops:
660
- md["loops"].extend(self._pending["loops"])
661
-
662
- for loop_idx, num_added_iters in self._pending["loops_added_iters"].items():
663
- md["loops"][loop_idx]["num_added_iterations"] = num_added_iters
664
-
665
- # commit new submissions:
666
- md["submissions"].extend(self._pending["submissions"])
667
-
668
- # commit new submission attempts:
669
- for sub_idx, attempts_i in self._pending["submission_attempts"].items():
670
- md["submissions"][sub_idx]["submission_attempts"].extend(attempts_i)
671
-
672
- # commit new jobscript version info:
673
- for sub_idx, js_vers_info in self._pending["jobscript_version_info"].items():
674
- for js_idx, vers_info in js_vers_info.items():
675
- md["submissions"][sub_idx]["jobscripts"][js_idx][
676
- "version_info"
677
- ] = vers_info
678
-
679
- # commit new jobscript job IDs:
680
- for sub_idx, job_IDs in self._pending["jobscript_job_IDs"].items():
681
- for js_idx, job_ID in job_IDs.items():
682
- md["submissions"][sub_idx]["jobscripts"][js_idx][
683
- "scheduler_job_ID"
684
- ] = job_ID
685
-
686
- # commit new jobscript submit times:
687
- for sub_idx, js_submit_times in self._pending["jobscript_submit_times"].items():
688
- for js_idx, submit_time in js_submit_times.items():
689
- md["submissions"][sub_idx]["jobscripts"][js_idx][
690
- "submit_time"
691
- ] = submit_time.strftime(self.ts_fmt)
692
-
693
- # note: parameter sources are committed immediately with parameter data, so there
694
- # is no need to add elements to the parameter sources array.
695
-
696
- sources = self._get_parameter_sources_array(mode="r+")
697
- for param_idx, src_update in self._pending["parameter_source_updates"].items():
698
- src = sources[param_idx]
699
- src.update(src_update)
700
- src = dict(sorted(src.items()))
701
- sources[param_idx] = src
702
-
703
- if self._pending["remove_replaced_dir_record"]:
704
- del md["replaced_dir"]
705
-
706
- # TODO: maybe clear pending keys individually, so if there is an error we can
707
- # retry/continue with committing?
708
-
709
- # commit updated metadata:
710
- self._get_root_group(mode="r+").attrs.put(md)
711
- self.clear_pending()
712
-
713
- def _get_persistent_template_components(self) -> Dict:
714
- return self.load_metadata()["template_components"]
715
-
716
- def get_template(self) -> Dict:
717
- # No need to consider pending; this is called once per Workflow object
718
- return self.load_metadata()["template"]
719
-
720
- def get_loops(self) -> List[Dict]:
721
- # No need to consider pending; this is called once per Workflow object
722
- return self.load_metadata()["loops"]
723
-
724
- def get_submissions(self) -> List[Dict]:
725
- # No need to consider pending; this is called once per Workflow object
726
- subs = copy.deepcopy(self.load_metadata()["submissions"])
727
-
728
- # cast jobscript submit-times and jobscript `task_elements` keys:
729
- for sub_idx, sub in enumerate(subs):
730
- for js_idx, js in enumerate(sub["jobscripts"]):
731
- if js["submit_time"]:
732
- subs[sub_idx]["jobscripts"][js_idx][
733
- "submit_time"
734
- ] = datetime.strptime(js["submit_time"], self.ts_fmt)
735
- for key in list(js["task_elements"].keys()):
736
- subs[sub_idx]["jobscripts"][js_idx]["task_elements"][int(key)] = subs[
737
- sub_idx
738
- ]["jobscripts"][js_idx]["task_elements"].pop(key)
918
+ def _get_persistent_EARs(self, id_lst: Iterable[int]) -> Dict[int, ZarrStoreEAR]:
919
+ arr = self._get_EARs_arr()
920
+ attrs = arr.attrs.asdict()
739
921
 
740
- return subs
922
+ try:
923
+ EAR_arr_dat = arr.get_coordinate_selection(list(id_lst))
924
+ except zarr.errors.BoundsCheckError:
925
+ raise MissingStoreEARError(id_lst) from None
741
926
 
742
- def get_num_added_tasks(self) -> int:
743
- return self.load_metadata()["num_added_tasks"] + len(self._pending["tasks"])
927
+ EAR_dat = dict(zip(id_lst, EAR_arr_dat))
744
928
 
745
- def get_all_tasks_metadata(self) -> List[Dict]:
746
- out = []
747
- for _, grp in self._get_element_group().groups():
748
- out.append(
749
- {
750
- "num_elements": len(grp.get(self._task_elem_arr_name)),
751
- "num_element_iterations": len(grp.get(self._task_elem_iter_arr_name)),
752
- "num_EARs": len(grp.get(self._task_EAR_times_arr_name)),
753
- }
754
- )
755
- return out
929
+ iters = {
930
+ k: ZarrStoreEAR.decode(EAR_dat=v, attrs=attrs, ts_fmt=self.ts_fmt)
931
+ for k, v in EAR_dat.items()
932
+ }
933
+ return iters
756
934
 
757
- def get_task_elements(
758
- self,
759
- task_idx: int,
760
- task_insert_ID: int,
761
- selection: slice,
762
- keep_iterations_idx: bool = False,
763
- ) -> List:
764
- task = self.workflow.tasks[task_idx]
765
- num_pers = task._num_elements
766
- num_iter_pers = task._num_element_iterations
767
- pers_slice, pend_slice = bisect_slice(selection, num_pers)
768
- pers_range = range(pers_slice.start, pers_slice.stop, pers_slice.step)
769
-
770
- elem_iter_arr = None
771
- EAR_times_arr = None
772
- if len(pers_range):
773
- elem_arr = self._get_task_elements_array(task_insert_ID)
774
- elem_iter_arr = self._get_task_elem_iters_array(task_insert_ID)
775
- EAR_times_arr = self._get_task_EAR_times_array(task_insert_ID)
776
- try:
777
- elements = list(elem_arr[pers_slice])
778
- except zarr.errors.NegativeStepError:
779
- elements = [elem_arr[idx] for idx in pers_range]
780
- else:
781
- elements = []
782
-
783
- key = (task_idx, task_insert_ID)
784
- if key in self._pending["elements"]:
785
- elements += self._pending["elements"][key][pend_slice]
786
-
787
- # add iterations:
788
- sel_range = range(selection.start, selection.stop, selection.step)
789
- iterations = {}
790
- for element_idx, element in zip(sel_range, elements):
791
- # find which iterations to add:
792
- iters_idx = element[0]
793
-
794
- # include pending iterations:
795
- if key in self._pending["element_iterations_idx"]:
796
- iters_idx += self._pending["element_iterations_idx"][key][element_idx]
797
-
798
- # populate new iterations list:
799
- for iter_idx_i in iters_idx:
800
- if iter_idx_i + 1 > num_iter_pers:
801
- i_pending = iter_idx_i - num_iter_pers
802
- iter_i = copy.deepcopy(
803
- self._pending["element_iterations"][key][i_pending]
804
- )
805
- else:
806
- iter_i = elem_iter_arr[iter_idx_i]
807
-
808
- # include pending EARs:
809
- EARs_key = (task_idx, task_insert_ID, iter_idx_i)
810
- if EARs_key in self._pending["EARs"]:
811
- iter_i[5].extend(self._pending["EARs"][EARs_key])
812
- # if there are pending EARs then EARs must be initialised:
813
- iter_i[2] = int(True)
814
-
815
- # include pending loops:
816
- loop_idx_key = (task_idx, task_insert_ID, iter_idx_i)
817
- if loop_idx_key in self._pending["loop_idx"]:
818
- iter_i[4].extend(self._pending["loop_idx"][loop_idx_key])
819
-
820
- iterations[iter_idx_i] = iter_i
821
-
822
- elements = self._decompress_elements(elements, self._get_task_element_attrs(*key))
823
-
824
- iters_k, iters_v = zip(*iterations.items())
825
- attrs = self._get_task_element_iter_attrs(*key)
826
- iters_v = self._decompress_element_iters(iters_v, attrs)
827
- elem_iters = dict(zip(iters_k, iters_v))
828
-
829
- for elem_idx, elem_i in zip(sel_range, elements):
830
- elem_i["index"] = elem_idx
831
-
832
- # populate iterations
833
- elem_i["iterations"] = [elem_iters[i] for i in elem_i["iterations_idx"]]
834
-
835
- # add EAR start/end times from separate array:
836
- for iter_idx_i, iter_i in zip(elem_i["iterations_idx"], elem_i["iterations"]):
837
- iter_i["index"] = iter_idx_i
838
- for act_idx, runs in iter_i["actions"].items():
839
- for run_idx in range(len(runs)):
840
- run = iter_i["actions"][act_idx][run_idx]
841
- EAR_idx = run["index"]
842
- start_time = None
843
- end_time = None
844
- try:
845
- EAR_times = EAR_times_arr[EAR_idx]
846
- start_time, end_time = EAR_times
847
- start_time = None if np.isnat(start_time) else start_time
848
- end_time = None if np.isnat(end_time) else end_time
849
- # TODO: cast to native datetime types
850
- except (TypeError, zarr.errors.BoundsCheckError):
851
- pass
852
- run["metadata"]["start_time"] = start_time
853
- run["metadata"]["end_time"] = end_time
854
-
855
- # update pending submission indices:
856
- key = (task_insert_ID, iter_idx_i, act_idx, run_idx)
857
- if key in self._pending["EAR_submission_idx"]:
858
- sub_idx = self._pending["EAR_submission_idx"][key]
859
- run["metadata"]["submission_idx"] = sub_idx
860
-
861
- if not keep_iterations_idx:
862
- del elem_i["iterations_idx"]
863
-
864
- return elements
865
-
866
- def _encode_parameter_data(
935
+ def _get_persistent_parameters(
867
936
  self,
868
- obj: Any,
869
- root_group: zarr.Group,
870
- arr_path: str,
871
- path: List = None,
872
- type_lookup: Optional[Dict] = None,
873
- ) -> Dict[str, Any]:
874
- return super()._encode_parameter_data(
875
- obj=obj,
876
- path=path,
877
- type_lookup=type_lookup,
878
- root_group=root_group,
879
- arr_path=arr_path,
880
- )
937
+ id_lst: Iterable[int],
938
+ dataset_copy: Optional[bool] = False,
939
+ ) -> Dict[int, ZarrStoreParameter]:
940
+ base_arr = self._get_parameter_base_array(mode="r")
941
+ src_arr = self._get_parameter_sources_array(mode="r")
881
942
 
882
- def _decode_parameter_data(
883
- self,
884
- data: Union[None, Dict],
885
- arr_group: zarr.Group,
886
- path: Optional[List[str]] = None,
887
- dataset_copy=False,
888
- ) -> Any:
889
- return super()._decode_parameter_data(
890
- data=data,
891
- path=path,
892
- arr_group=arr_group,
893
- dataset_copy=dataset_copy,
894
- )
895
-
896
- def _add_parameter_data(self, data: Any, source: Dict) -> int:
897
- base_arr = self._get_parameter_base_array(mode="r+")
898
- sources = self._get_parameter_sources_array(mode="r+")
899
- idx = base_arr.size
900
-
901
- if data is not None:
902
- data = self._encode_parameter_data(
903
- obj=data["data"],
904
- root_group=self._get_parameter_user_array_group(mode="r+"),
905
- arr_path=self._param_data_arr_grp_name(idx),
943
+ try:
944
+ param_arr_dat = base_arr.get_coordinate_selection(list(id_lst))
945
+ src_arr_dat = src_arr.get_coordinate_selection(list(id_lst))
946
+ except zarr.errors.BoundsCheckError:
947
+ raise MissingParameterData(id_lst) from None
948
+
949
+ param_dat = dict(zip(id_lst, param_arr_dat))
950
+ src_dat = dict(zip(id_lst, src_arr_dat))
951
+
952
+ params = {
953
+ k: ZarrStoreParameter.decode(
954
+ id_=k,
955
+ data=v,
956
+ source=src_dat[k],
957
+ arr_group=self._get_parameter_data_array_group(k),
958
+ dataset_copy=dataset_copy,
906
959
  )
960
+ for k, v in param_dat.items()
961
+ }
907
962
 
908
- base_arr.append([data])
909
- sources.append([dict(sorted(source.items()))])
910
- self._pending["parameter_data"] += 1
911
- self.save()
912
-
913
- return idx
914
-
915
- def set_parameter(self, index: int, data: Any) -> None:
916
- """Set the value of a pre-allocated parameter."""
917
-
918
- if self.is_parameter_set(index):
919
- raise RuntimeError(f"Parameter at index {index} is already set!")
920
-
921
- base_arr = self._get_parameter_base_array(mode="r+")
922
- base_arr[index] = self._encode_parameter_data(
923
- obj=data,
924
- root_group=self._get_parameter_user_array_group(mode="r+"),
925
- arr_path=self._param_data_arr_grp_name(index),
926
- )
963
+ return params
927
964
 
928
- def _get_parameter_data(self, index: int) -> Any:
929
- return self._get_parameter_base_array(mode="r")[index]
965
+ def _get_persistent_param_sources(self, id_lst: Iterable[int]) -> Dict[int, Dict]:
966
+ src_arr = self._get_parameter_sources_array(mode="r")
967
+ try:
968
+ src_arr_dat = src_arr.get_coordinate_selection(list(id_lst))
969
+ except zarr.errors.BoundsCheckError:
970
+ raise MissingParameterData(id_lst) from None
930
971
 
931
- def get_parameter_data(self, index: int) -> Tuple[bool, Any]:
932
- data = self._get_parameter_data(index)
933
- is_set = False if data is None else True
934
- data = self._decode_parameter_data(
935
- data=data,
936
- arr_group=self._get_parameter_data_array_group(index),
937
- )
938
- return (is_set, data)
939
-
940
- def get_parameter_source(self, index: int) -> Dict:
941
- src = self._get_parameter_sources_array(mode="r")[index]
942
- if index in self._pending["parameter_source_updates"]:
943
- src.update(self._pending["parameter_source_updates"][index])
944
- src = dict(sorted(src.items()))
945
-
946
- return src
947
-
948
- def get_all_parameter_data(self) -> Dict[int, Any]:
949
- max_key = self._get_parameter_base_array(mode="r").size - 1
950
- out = {}
951
- for idx in range(max_key + 1):
952
- out[idx] = self.get_parameter_data(idx)
953
- return out
954
-
955
- def is_parameter_set(self, index: int) -> bool:
956
- return self._get_parameter_data(index) is not None
957
-
958
- def check_parameters_exist(
959
- self, indices: Union[int, List[int]]
960
- ) -> Union[bool, List[bool]]:
961
- is_multi = True
962
- if not isinstance(indices, (list, tuple)):
963
- is_multi = False
964
- indices = [indices]
965
- base = self._get_parameter_base_array(mode="r")
966
- idx_range = range(base.size)
967
- exists = [i in idx_range for i in indices]
968
- if not is_multi:
969
- exists = exists[0]
970
- return exists
971
-
972
- def _init_task_loop(
973
- self,
974
- task_idx: int,
975
- task_insert_ID: int,
976
- element_sel: slice,
977
- name: str,
978
- ) -> None:
979
- """Initialise the zeroth iteration of a named loop for a specified task."""
972
+ return dict(zip(id_lst, src_arr_dat))
980
973
 
981
- elements = self.get_task_elements(
982
- task_idx=task_idx,
983
- task_insert_ID=task_insert_ID,
984
- selection=element_sel,
985
- keep_iterations_idx=True,
974
+ def _get_persistent_parameter_set_status(
975
+ self, id_lst: Iterable[int]
976
+ ) -> Dict[int, bool]:
977
+ base_arr = self._get_parameter_base_array(mode="r")
978
+ try:
979
+ param_arr_dat = base_arr.get_coordinate_selection(list(id_lst))
980
+ except zarr.errors.BoundsCheckError:
981
+ raise MissingParameterData(id_lst) from None
982
+
983
+ return dict(zip(id_lst, [i is not None for i in param_arr_dat]))
984
+
985
+ def _get_persistent_parameter_IDs(self) -> List[int]:
986
+ # we assume the row index is equivalent to ID, might need to revisit in future
987
+ base_arr = self._get_parameter_base_array(mode="r")
988
+ return list(range(len(base_arr)))
989
+
990
+ def get_creation_info(self):
991
+ with self.using_resource("attrs", action="read") as attrs:
992
+ return attrs["creation_info"]
993
+
994
+ def get_fs_path(self):
995
+ with self.using_resource("attrs", action="read") as attrs:
996
+ return attrs["fs_path"]
997
+
998
+ def to_zip(self):
999
+ # TODO: need to update `fs_path` in the new store (because this used to get
1000
+ # `Workflow.name`), but can't seem to open `dst_zarr_store` below:
1001
+ src_zarr_store = self.zarr_store
1002
+ new_fs_path = f"{self.workflow.fs_path}.zip"
1003
+ zfs, _ = ask_pw_on_auth_exc(
1004
+ ZipFileSystem,
1005
+ fo=new_fs_path,
1006
+ mode="w",
1007
+ target_options={},
1008
+ add_pw_to="target_options",
986
1009
  )
1010
+ dst_zarr_store = zarr.storage.FSStore(url="", fs=zfs)
1011
+ zarr.convenience.copy_store(src_zarr_store, dst_zarr_store)
1012
+ del zfs # ZipFileSystem remains open for instance lifetime
1013
+ return new_fs_path
987
1014
 
988
- attrs_original = self._get_task_element_iter_attrs(task_idx, task_insert_ID)
989
- attrs = copy.deepcopy(attrs_original)
990
- for element in elements:
991
- for iter_idx, iter_i in zip(element["iterations_idx"], element["iterations"]):
992
- if name in (attrs["loops"][k] for k in iter_i["loop_idx"]):
993
- raise ValueError(f"Loop {name!r} already initialised!")
994
1015
 
995
- key = (task_idx, task_insert_ID, iter_idx)
996
- if key not in self._pending["loop_idx"]:
997
- self._pending["loop_idx"][key] = []
1016
+ class ZarrZipPersistentStore(ZarrPersistentStore):
1017
+ """A store designed mainly as an archive format that can be uploaded to e.g. Zenodo."""
998
1018
 
999
- self._pending["loop_idx"][key].append(
1000
- [ensure_in(name, attrs["loops"]), 0]
1001
- )
1019
+ _name = "zip"
1020
+ _features = PersistentStoreFeatures(
1021
+ create=False,
1022
+ edit=False,
1023
+ jobscript_parallelism=False,
1024
+ EAR_parallelism=False,
1025
+ schedulers=False,
1026
+ submission=False,
1027
+ )
1002
1028
 
1003
- if attrs != attrs_original:
1004
- if task_idx not in self._pending["element_iter_attrs"]:
1005
- self._pending["element_iter_attrs"][task_idx] = {}
1006
- self._pending["element_iter_attrs"][task_idx].update(attrs)
1029
+ # TODO: enforce read-only nature
1007
1030
 
1008
- def remove_replaced_dir(self) -> None:
1009
- md = self.load_metadata()
1010
- if "replaced_dir" in md:
1011
- remove_dir(Path(md["replaced_dir"]))
1012
- self._pending["remove_replaced_dir_record"] = True
1013
- self.save()
1031
+ def to_zip(self):
1032
+ raise NotImplementedError("Already a zip store!")
1014
1033
 
1015
- def reinstate_replaced_dir(self) -> None:
1016
- print(f"reinstate replaced directory!")
1017
- md = self.load_metadata()
1018
- if "replaced_dir" in md:
1019
- rename_dir(Path(md["replaced_dir"]), self.workflow_path)
1020
-
1021
- def copy(self, path: PathLike = None) -> None:
1022
- shutil.copytree(self.workflow_path, path)
1034
+ def copy(self, path=None) -> str:
1035
+ # not sure how to do this.
1036
+ raise NotImplementedError()
1023
1037
 
1024
- def is_modified_on_disk(self) -> bool:
1025
- if self._metadata:
1026
- return get_md5_hash(self._load_metadata()) != get_md5_hash(self._metadata)
1027
- else:
1028
- # nothing to compare to
1029
- return False
1038
+ def delete_no_confirm(self) -> None:
1039
+ # `ZipFileSystem.rm()` does not seem to be implemented.
1040
+ raise NotImplementedError()