hpcflow-new2 0.2.0a176__py3-none-any.whl → 0.2.0a178__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -25,6 +25,7 @@ from hpcflow.sdk.core import (
25
25
  ABORT_EXIT_CODE,
26
26
  )
27
27
  from hpcflow.sdk.core.actions import EARStatus
28
+ from hpcflow.sdk.core.loop_cache import LoopCache
28
29
  from hpcflow.sdk.log import TimeIt
29
30
  from hpcflow.sdk.persistence import store_cls_from_str, DEFAULT_STORE_FORMAT
30
31
  from hpcflow.sdk.persistence.base import TEMPLATE_COMP_TYPES, AnySEAR
@@ -41,6 +42,7 @@ from hpcflow.sdk.submission.schedulers.direct import DirectScheduler
41
42
  from hpcflow.sdk.typing import PathLike
42
43
  from hpcflow.sdk.core.json_like import ChildObjectSpec, JSONLike
43
44
  from .utils import (
45
+ nth_key,
44
46
  read_JSON_file,
45
47
  read_JSON_string,
46
48
  read_YAML_str,
@@ -625,19 +627,28 @@ class Workflow:
625
627
  )
626
628
  with wk._store.cached_load():
627
629
  with wk.batch_update(is_workflow_creation=True):
628
- for idx, task in enumerate(template.tasks):
630
+ with wk._store.cache_ctx():
631
+ for idx, task in enumerate(template.tasks):
632
+ if status:
633
+ status.update(
634
+ f"Adding task {idx + 1}/{len(template.tasks)} "
635
+ f"({task.name!r})..."
636
+ )
637
+ wk._add_task(task)
629
638
  if status:
630
639
  status.update(
631
- f"Adding task {idx + 1}/{len(template.tasks)} "
632
- f"({task.name!r})..."
640
+ f"Preparing to add {len(template.loops)} loops..."
633
641
  )
634
- wk._add_task(task)
635
- for idx, loop in enumerate(template.loops):
636
- if status:
637
- status.update(
638
- f"Adding loop {idx + 1}/" f"{len(template.loops)}..."
639
- )
640
- wk._add_loop(loop)
642
+ if template.loops:
643
+ # TODO: if loop with non-initialisable actions, will fail
644
+ cache = LoopCache.build(workflow=wk, loops=template.loops)
645
+ for idx, loop in enumerate(template.loops):
646
+ if status:
647
+ status.update(
648
+ f"Adding loop {idx + 1}/"
649
+ f"{len(template.loops)} ({loop.name!r})"
650
+ )
651
+ wk._add_loop(loop, cache=cache, status=status)
641
652
  except Exception:
642
653
  if status:
643
654
  status.stop()
@@ -1101,7 +1112,7 @@ class Workflow:
1101
1112
 
1102
1113
  @TimeIt.decorator
1103
1114
  def _add_empty_loop(
1104
- self, loop: app.Loop
1115
+ self, loop: app.Loop, cache: LoopCache
1105
1116
  ) -> Tuple[app.WorkflowLoop, List[app.ElementIteration]]:
1106
1117
  """Add a new loop (zeroth iterations only) to the workflow."""
1107
1118
 
@@ -1114,15 +1125,15 @@ class Workflow:
1114
1125
  self.template._add_empty_loop(loop_c)
1115
1126
 
1116
1127
  # all these element iterations will be initialised for the new loop:
1117
- iters = self.get_element_iterations_of_tasks(loop_c.task_insert_IDs)
1118
- iter_IDs = [i.id_ for i in iters]
1128
+ iter_IDs = cache.get_iter_IDs(loop_c)
1129
+ iter_loop_idx = cache.get_iter_loop_indices(iter_IDs)
1119
1130
 
1120
1131
  # create and insert a new WorkflowLoop:
1121
1132
  new_loop = self.app.WorkflowLoop.new_empty_loop(
1122
1133
  index=new_index,
1123
1134
  workflow=self,
1124
1135
  template=loop_c,
1125
- iterations=iters,
1136
+ iter_loop_idx=iter_loop_idx,
1126
1137
  )
1127
1138
  self.loops.add_object(new_loop)
1128
1139
  wk_loop = self.loops[new_index]
@@ -1144,15 +1155,28 @@ class Workflow:
1144
1155
 
1145
1156
  self._pending["loops"].append(new_index)
1146
1157
 
1158
+ # update cache loop indices:
1159
+ cache.update_loop_indices(new_loop_name=loop_c.name, iter_IDs=iter_IDs)
1160
+
1147
1161
  return wk_loop
1148
1162
 
1149
1163
  @TimeIt.decorator
1150
- def _add_loop(self, loop: app.Loop) -> None:
1151
- new_wk_loop = self._add_empty_loop(loop)
1164
+ def _add_loop(
1165
+ self, loop: app.Loop, cache: Optional[Dict] = None, status: Optional[Any] = None
1166
+ ) -> None:
1167
+ if not cache:
1168
+ cache = LoopCache.build(workflow=self, loops=[loop])
1169
+ new_wk_loop = self._add_empty_loop(loop, cache)
1152
1170
  if loop.num_iterations is not None:
1153
1171
  # fixed number of iterations, so add remaining N > 0 iterations:
1154
- for _ in range(loop.num_iterations - 1):
1155
- new_wk_loop.add_iteration()
1172
+ if status:
1173
+ status_prev = status.status
1174
+ for iter_idx in range(loop.num_iterations - 1):
1175
+ if status:
1176
+ status.update(
1177
+ f"{status_prev}: iteration {iter_idx + 2}/{loop.num_iterations}."
1178
+ )
1179
+ new_wk_loop.add_iteration(cache=cache)
1156
1180
 
1157
1181
  def add_loop(self, loop: app.Loop) -> None:
1158
1182
  """Add a loop to a subset of workflow tasks."""
@@ -1326,6 +1350,7 @@ class Workflow:
1326
1350
  iters.append(iter_i)
1327
1351
  return iters
1328
1352
 
1353
+ @TimeIt.decorator
1329
1354
  def get_elements_from_IDs(self, id_lst: Iterable[int]) -> List[app.Element]:
1330
1355
  """Return element objects from a list of IDs."""
1331
1356
 
@@ -1334,6 +1359,7 @@ class Workflow:
1334
1359
  task_IDs = [i.task_ID for i in store_elems]
1335
1360
  store_tasks = self._store.get_tasks_by_IDs(task_IDs)
1336
1361
 
1362
+ element_idx_by_task = defaultdict(set)
1337
1363
  index_paths = []
1338
1364
  for el, tk in zip(store_elems, store_tasks):
1339
1365
  elem_idx = tk.element_IDs.index(el.id_)
@@ -1343,15 +1369,23 @@ class Workflow:
1343
1369
  "task_idx": tk.index,
1344
1370
  }
1345
1371
  )
1372
+ element_idx_by_task[tk.index].add(elem_idx)
1373
+
1374
+ elements_by_task = {}
1375
+ for task_idx, elem_idx in element_idx_by_task.items():
1376
+ task = self.tasks[task_idx]
1377
+ elements_by_task[task_idx] = dict(
1378
+ zip(elem_idx, task.elements[list(elem_idx)])
1379
+ )
1346
1380
 
1347
1381
  objs = []
1348
1382
  for idx_dat in index_paths:
1349
- task = self.tasks[idx_dat["task_idx"]]
1350
- elem = task.elements[idx_dat["elem_idx"]]
1383
+ elem = elements_by_task[idx_dat["task_idx"]][idx_dat["elem_idx"]]
1351
1384
  objs.append(elem)
1352
1385
 
1353
1386
  return objs
1354
1387
 
1388
+ @TimeIt.decorator
1355
1389
  def get_element_iterations_from_IDs(
1356
1390
  self, id_lst: Iterable[int]
1357
1391
  ) -> List[app.ElementIteration]:
@@ -1365,6 +1399,8 @@ class Workflow:
1365
1399
  task_IDs = [i.task_ID for i in store_elems]
1366
1400
  store_tasks = self._store.get_tasks_by_IDs(task_IDs)
1367
1401
 
1402
+ element_idx_by_task = defaultdict(set)
1403
+
1368
1404
  index_paths = []
1369
1405
  for it, el, tk in zip(store_iters, store_elems, store_tasks):
1370
1406
  iter_idx = el.iteration_IDs.index(it.id_)
@@ -1376,11 +1412,18 @@ class Workflow:
1376
1412
  "task_idx": tk.index,
1377
1413
  }
1378
1414
  )
1415
+ element_idx_by_task[tk.index].add(elem_idx)
1416
+
1417
+ elements_by_task = {}
1418
+ for task_idx, elem_idx in element_idx_by_task.items():
1419
+ task = self.tasks[task_idx]
1420
+ elements_by_task[task_idx] = dict(
1421
+ zip(elem_idx, task.elements[list(elem_idx)])
1422
+ )
1379
1423
 
1380
1424
  objs = []
1381
1425
  for idx_dat in index_paths:
1382
- task = self.tasks[idx_dat["task_idx"]]
1383
- elem = task.elements[idx_dat["elem_idx"]]
1426
+ elem = elements_by_task[idx_dat["task_idx"]][idx_dat["elem_idx"]]
1384
1427
  iter_ = elem.iterations[idx_dat["iter_idx"]]
1385
1428
  objs.append(iter_)
1386
1429
 
@@ -1653,7 +1696,14 @@ class Workflow:
1653
1696
 
1654
1697
  return wk
1655
1698
 
1656
- def zip(self, path=".", log=None, overwrite=False) -> str:
1699
+ def zip(
1700
+ self,
1701
+ path=".",
1702
+ log=None,
1703
+ overwrite=False,
1704
+ include_execute=False,
1705
+ include_rechunk_backups=False,
1706
+ ) -> str:
1657
1707
  """
1658
1708
  Parameters
1659
1709
  ----------
@@ -1662,7 +1712,13 @@ class Workflow:
1662
1712
  directory, the zip file will be created within this directory. Otherwise,
1663
1713
  this path is assumed to be the full file path to the new zip file.
1664
1714
  """
1665
- return self._store.zip(path=path, log=log, overwrite=overwrite)
1715
+ return self._store.zip(
1716
+ path=path,
1717
+ log=log,
1718
+ overwrite=overwrite,
1719
+ include_execute=include_execute,
1720
+ include_rechunk_backups=include_rechunk_backups,
1721
+ )
1666
1722
 
1667
1723
  def unzip(self, path=".", log=None) -> str:
1668
1724
  """
@@ -2900,6 +2956,34 @@ class Workflow:
2900
2956
  final_runs[loop_name].append(final[0])
2901
2957
  return dict(final_runs)
2902
2958
 
2959
+ def rechunk_runs(
2960
+ self,
2961
+ chunk_size: Optional[int] = None,
2962
+ backup: Optional[bool] = True,
2963
+ status: Optional[bool] = True,
2964
+ ):
2965
+ self._store.rechunk_runs(chunk_size=chunk_size, backup=backup, status=status)
2966
+
2967
+ def rechunk_parameter_base(
2968
+ self,
2969
+ chunk_size: Optional[int] = None,
2970
+ backup: Optional[bool] = True,
2971
+ status: Optional[bool] = True,
2972
+ ):
2973
+ self._store.rechunk_parameter_base(
2974
+ chunk_size=chunk_size, backup=backup, status=status
2975
+ )
2976
+
2977
+ def rechunk(
2978
+ self,
2979
+ chunk_size: Optional[int] = None,
2980
+ backup: Optional[bool] = True,
2981
+ status: Optional[bool] = True,
2982
+ ):
2983
+ """Rechunk metadata/runs and parameters/base arrays."""
2984
+ self.rechunk_runs(chunk_size=chunk_size, backup=backup, status=status)
2985
+ self.rechunk_parameter_base(chunk_size=chunk_size, backup=backup, status=status)
2986
+
2903
2987
 
2904
2988
  @dataclass
2905
2989
  class WorkflowBlueprint:
@@ -716,6 +716,11 @@ class PersistentStore(ABC):
716
716
  """Cache for number of persistent tasks."""
717
717
  return self._cache["num_tasks"]
718
718
 
719
+ @property
720
+ def num_EARs_cache(self):
721
+ """Cache for total number of persistent EARs."""
722
+ return self._cache["num_EARs"]
723
+
719
724
  @property
720
725
  def param_sources_cache(self):
721
726
  """Cache for persistent parameter sources."""
@@ -730,6 +735,10 @@ class PersistentStore(ABC):
730
735
  def num_tasks_cache(self, value):
731
736
  self._cache["num_tasks"] = value
732
737
 
738
+ @num_EARs_cache.setter
739
+ def num_EARs_cache(self, value):
740
+ self._cache["num_EARs"] = value
741
+
733
742
  def _reset_cache(self):
734
743
  self._cache = {
735
744
  "tasks": {},
@@ -739,6 +748,7 @@ class PersistentStore(ABC):
739
748
  "param_sources": {},
740
749
  "num_tasks": None,
741
750
  "parameters": {},
751
+ "num_EARs": None,
742
752
  }
743
753
 
744
754
  @contextlib.contextmanager
@@ -873,6 +883,7 @@ class PersistentStore(ABC):
873
883
  """Get the total number of persistent and pending element iterations."""
874
884
  return self._get_num_persistent_elem_iters() + len(self._pending.add_elem_iters)
875
885
 
886
+ @TimeIt.decorator
876
887
  def _get_num_total_EARs(self):
877
888
  """Get the total number of persistent and pending EARs."""
878
889
  return self._get_num_persistent_EARs() + len(self._pending.add_EARs)
@@ -1296,9 +1307,11 @@ class PersistentStore(ABC):
1296
1307
  self.save()
1297
1308
 
1298
1309
  @TimeIt.decorator
1299
- def update_param_source(self, param_id: int, source: Dict, save: bool = True) -> None:
1300
- self.logger.debug(f"Updating parameter ID {param_id!r} source to {source!r}.")
1301
- self._pending.update_param_sources[param_id] = source
1310
+ def update_param_source(
1311
+ self, param_sources: Dict[int, Dict], save: bool = True
1312
+ ) -> None:
1313
+ self.logger.debug(f"Updating parameter sources with {param_sources!r}.")
1314
+ self._pending.update_param_sources.update(param_sources)
1302
1315
  if save:
1303
1316
  self.save()
1304
1317
 
@@ -303,12 +303,13 @@ class JSONPersistentStore(PersistentStore):
303
303
 
304
304
  def _get_num_persistent_tasks(self) -> int:
305
305
  """Get the number of persistent tasks."""
306
- if self.num_tasks_cache is not None:
306
+ if self.use_cache and self.num_tasks_cache is not None:
307
307
  num = self.num_tasks_cache
308
308
  else:
309
309
  with self.using_resource("metadata", action="read") as md:
310
310
  num = len(md["tasks"])
311
- self.num_tasks_cache = num
311
+ if self.use_cache and self.num_tasks_cache is None:
312
+ self.num_tasks_cache = num
312
313
  return num
313
314
 
314
315
  def _get_num_persistent_loops(self) -> int:
@@ -333,8 +334,14 @@ class JSONPersistentStore(PersistentStore):
333
334
 
334
335
  def _get_num_persistent_EARs(self) -> int:
335
336
  """Get the number of persistent EARs."""
336
- with self.using_resource("metadata", action="read") as md:
337
- return len(md["runs"])
337
+ if self.use_cache and self.num_EARs_cache is not None:
338
+ num = self.num_EARs_cache
339
+ else:
340
+ with self.using_resource("metadata", action="read") as md:
341
+ num = len(md["runs"])
342
+ if self.use_cache and self.num_EARs_cache is None:
343
+ self.num_EARs_cache = num
344
+ return num
338
345
 
339
346
  def _get_num_persistent_parameters(self):
340
347
  with self.using_resource("parameters", "read") as params:
@@ -275,6 +275,7 @@ class PendingChanges:
275
275
  EAR_ids = list(self.add_EARs.keys())
276
276
  self.logger.debug(f"commit: adding pending EARs with IDs: {EAR_ids!r}")
277
277
  self.store._append_EARs(EARs)
278
+ self.store.num_EARs_cache = None # invalidate cache
278
279
  # pending start/end times/snapshots, submission indices, and skips that belong
279
280
  # to pending EARs are now committed (accounted for in `get_EARs` above):
280
281
  self.set_EAR_submission_indices = {
@@ -408,6 +409,7 @@ class PendingChanges:
408
409
  @TimeIt.decorator
409
410
  def commit_loop_indices(self) -> None:
410
411
  """Make pending update to element iteration loop indices persistent."""
412
+ # TODO: batch up
411
413
  for iter_ID, loop_idx in self.update_loop_indices.items():
412
414
  self.logger.debug(
413
415
  f"commit: updating loop indices of iteration ID {iter_ID!r} with "
@@ -5,6 +5,8 @@ from contextlib import contextmanager
5
5
  from dataclasses import dataclass
6
6
  from datetime import datetime
7
7
  from pathlib import Path
8
+ import shutil
9
+ import time
8
10
  from typing import Any, Dict, Iterable, Iterator, List, Optional, Tuple, Union
9
11
 
10
12
  import numpy as np
@@ -774,9 +776,16 @@ class ZarrPersistentStore(PersistentStore):
774
776
  """Get the number of persistent element iterations."""
775
777
  return len(self._get_iters_arr())
776
778
 
779
+ @TimeIt.decorator
777
780
  def _get_num_persistent_EARs(self) -> int:
778
781
  """Get the number of persistent EARs."""
779
- return len(self._get_EARs_arr())
782
+ if self.use_cache and self.num_EARs_cache is not None:
783
+ num = self.num_EARs_cache
784
+ else:
785
+ num = len(self._get_EARs_arr())
786
+ if self.use_cache and self.num_EARs_cache is None:
787
+ self.num_EARs_cache = num
788
+ return num
780
789
 
781
790
  def _get_num_persistent_parameters(self):
782
791
  return len(self._get_parameter_base_array())
@@ -1145,7 +1154,14 @@ class ZarrPersistentStore(PersistentStore):
1145
1154
  with self.using_resource("attrs", action="read") as attrs:
1146
1155
  return attrs["name"]
1147
1156
 
1148
- def zip(self, path=".", log=None, overwrite=False):
1157
+ def zip(
1158
+ self,
1159
+ path=".",
1160
+ log=None,
1161
+ overwrite=False,
1162
+ include_execute=False,
1163
+ include_rechunk_backups=False,
1164
+ ):
1149
1165
  """
1150
1166
  Parameters
1151
1167
  ----------
@@ -1181,16 +1197,120 @@ class ZarrPersistentStore(PersistentStore):
1181
1197
  add_pw_to="target_options",
1182
1198
  )
1183
1199
  dst_zarr_store = zarr.storage.FSStore(url="", fs=zfs)
1200
+ excludes = []
1201
+ if not include_execute:
1202
+ excludes.append("execute")
1203
+ if not include_rechunk_backups:
1204
+ excludes.append("runs.bak")
1205
+ excludes.append("base.bak")
1206
+
1184
1207
  zarr.convenience.copy_store(
1185
1208
  src_zarr_store,
1186
1209
  dst_zarr_store,
1187
- excludes="execute",
1210
+ excludes=excludes or None,
1188
1211
  log=log,
1189
1212
  )
1190
1213
  del zfs # ZipFileSystem remains open for instance lifetime
1191
1214
  status.stop()
1192
1215
  return dst_path
1193
1216
 
1217
+ def _rechunk_arr(
1218
+ self,
1219
+ arr,
1220
+ chunk_size: Optional[int] = None,
1221
+ backup: Optional[bool] = True,
1222
+ status: Optional[bool] = True,
1223
+ ):
1224
+ arr_path = Path(self.workflow.path) / arr.path
1225
+ arr_name = arr.path.split("/")[-1]
1226
+
1227
+ if status:
1228
+ console = Console()
1229
+ status = console.status("Rechunking...")
1230
+ status.start()
1231
+ backup_time = None
1232
+
1233
+ if backup:
1234
+ if status:
1235
+ status.update("Backing up...")
1236
+ backup_path = arr_path.with_suffix(".bak")
1237
+ if backup_path.is_dir():
1238
+ pass
1239
+ else:
1240
+ tic = time.perf_counter()
1241
+ shutil.copytree(arr_path, backup_path)
1242
+ toc = time.perf_counter()
1243
+ backup_time = toc - tic
1244
+
1245
+ tic = time.perf_counter()
1246
+ arr_rc_path = arr_path.with_suffix(".rechunked")
1247
+ arr = zarr.open(arr_path)
1248
+ if status:
1249
+ status.update("Creating new array...")
1250
+ arr_rc = zarr.create(
1251
+ store=arr_rc_path,
1252
+ shape=arr.shape,
1253
+ chunks=arr.shape if chunk_size is None else chunk_size,
1254
+ dtype=object,
1255
+ object_codec=MsgPack(),
1256
+ )
1257
+ if status:
1258
+ status.update("Copying data...")
1259
+ data = np.empty(shape=arr.shape, dtype=object)
1260
+ bad_data = []
1261
+ for idx in range(len(arr)):
1262
+ try:
1263
+ data[idx] = arr[idx]
1264
+ except RuntimeError:
1265
+ # blosc decompression errors
1266
+ bad_data.append(idx)
1267
+ pass
1268
+ arr_rc[:] = data
1269
+
1270
+ arr_rc.attrs.put(arr.attrs.asdict())
1271
+
1272
+ if status:
1273
+ status.update("Deleting old array...")
1274
+ shutil.rmtree(arr_path)
1275
+
1276
+ if status:
1277
+ status.update("Moving new array into place...")
1278
+ shutil.move(arr_rc_path, arr_path)
1279
+
1280
+ toc = time.perf_counter()
1281
+ rechunk_time = toc - tic
1282
+
1283
+ if status:
1284
+ status.stop()
1285
+
1286
+ if backup_time:
1287
+ print(f"Time to backup {arr_name}: {backup_time:.1f} s")
1288
+
1289
+ print(f"Time to rechunk and move {arr_name}: {rechunk_time:.1f} s")
1290
+
1291
+ if bad_data:
1292
+ print(f"Bad data at {arr_name} indices: {bad_data}.")
1293
+
1294
+ return arr_rc
1295
+
1296
+ def rechunk_parameter_base(
1297
+ self,
1298
+ chunk_size: Optional[int] = None,
1299
+ backup: Optional[bool] = True,
1300
+ status: Optional[bool] = True,
1301
+ ):
1302
+ arr = self._get_parameter_base_array()
1303
+ return self._rechunk_arr(arr, chunk_size, backup, status)
1304
+
1305
+ def rechunk_runs(
1306
+ self,
1307
+ chunk_size: Optional[int] = None,
1308
+ backup: Optional[bool] = True,
1309
+ status: Optional[bool] = True,
1310
+ ):
1311
+ arr = self._get_EARs_arr()
1312
+ return self._rechunk_arr(arr, chunk_size, backup, status)
1313
+
1194
1314
 
1195
1315
  class ZarrZipPersistentStore(ZarrPersistentStore):
1196
1316
  """A store designed mainly as an archive format that can be uploaded to data
@@ -1250,3 +1370,12 @@ class ZarrZipPersistentStore(ZarrPersistentStore):
1250
1370
  def delete_no_confirm(self) -> None:
1251
1371
  # `ZipFileSystem.rm()` does not seem to be implemented.
1252
1372
  raise NotImplementedError()
1373
+
1374
+ def _rechunk_arr(
1375
+ self,
1376
+ arr,
1377
+ chunk_size: Optional[int] = None,
1378
+ backup: Optional[bool] = True,
1379
+ status: Optional[bool] = True,
1380
+ ):
1381
+ raise NotImplementedError
@@ -1,4 +1,6 @@
1
- from importlib import resources
1
+ from pathlib import Path
2
+ import numpy as np
3
+ import zarr
2
4
  import pytest
3
5
  from hpcflow.sdk.core.test_utils import make_test_data_YAML_workflow
4
6
  from hpcflow.sdk.persistence.base import StoreEAR, StoreElement, StoreElementIter
@@ -239,3 +241,118 @@ def test_make_zarr_store_no_compressor(null_config, tmp_path):
239
241
  store="zarr",
240
242
  store_kwargs={"compressor": None},
241
243
  )
244
+
245
+
246
+ @pytest.mark.integration
247
+ def test_zarr_rechunk_data_equivalent(null_config, tmp_path):
248
+ t1 = hf.Task(
249
+ schema=hf.task_schemas.test_t1_conditional_OS,
250
+ inputs={"p1": 101},
251
+ repeats=3,
252
+ )
253
+ wk = hf.Workflow.from_template_data(
254
+ tasks=[t1],
255
+ template_name="test_run_rechunk",
256
+ workflow_name="test_run_rechunk",
257
+ path=tmp_path,
258
+ )
259
+ wk.submit(wait=True, status=False, add_to_known=False)
260
+ wk.rechunk_runs(backup=True, status=False, chunk_size=None) # None -> one chunk
261
+
262
+ arr = wk._store._get_EARs_arr()
263
+ assert arr.chunks == arr.shape
264
+
265
+ bak_path = (Path(wk.path) / arr.path).with_suffix(".bak")
266
+ arr_bak = zarr.open(bak_path)
267
+
268
+ assert arr_bak.chunks == (1,)
269
+
270
+ # check backup and new runs data are equal:
271
+ assert np.all(arr[:] == arr_bak[:])
272
+
273
+ # check attributes are equal:
274
+ assert arr.attrs.asdict() == arr_bak.attrs.asdict()
275
+
276
+
277
+ @pytest.mark.integration
278
+ def test_zarr_rechunk_data_equivalent_custom_chunk_size(null_config, tmp_path):
279
+ t1 = hf.Task(
280
+ schema=hf.task_schemas.test_t1_conditional_OS,
281
+ inputs={"p1": 101},
282
+ repeats=3,
283
+ )
284
+ wk = hf.Workflow.from_template_data(
285
+ tasks=[t1],
286
+ template_name="test_run_rechunk",
287
+ workflow_name="test_run_rechunk",
288
+ path=tmp_path,
289
+ )
290
+ wk.submit(wait=True, status=False, add_to_known=False)
291
+ wk.rechunk_runs(backup=True, status=False, chunk_size=2)
292
+
293
+ arr = wk._store._get_EARs_arr()
294
+ assert arr.chunks == (2,)
295
+
296
+ bak_path = (Path(wk.path) / arr.path).with_suffix(".bak")
297
+ arr_bak = zarr.open(bak_path)
298
+
299
+ assert arr_bak.chunks == (1,)
300
+
301
+ # check backup and new runs data are equal:
302
+ assert np.all(arr[:] == arr_bak[:])
303
+
304
+
305
+ @pytest.mark.integration
306
+ def test_zarr_rechunk_data_no_backup_load_runs(null_config, tmp_path):
307
+ t1 = hf.Task(
308
+ schema=hf.task_schemas.test_t1_conditional_OS,
309
+ inputs={"p1": 101},
310
+ repeats=3,
311
+ )
312
+ wk = hf.Workflow.from_template_data(
313
+ tasks=[t1],
314
+ template_name="test_run_rechunk",
315
+ workflow_name="test_run_rechunk",
316
+ path=tmp_path,
317
+ )
318
+ wk.submit(wait=True, status=False, add_to_known=False)
319
+ wk.rechunk_runs(backup=False, status=False)
320
+
321
+ arr = wk._store._get_EARs_arr()
322
+
323
+ bak_path = (Path(wk.path) / arr.path).with_suffix(".bak")
324
+ assert not bak_path.is_file()
325
+
326
+ # check we can load runs:
327
+ runs = wk._store._get_persistent_EARs(id_lst=list(range(wk.num_EARs)))
328
+ run_ID = []
329
+ for i in runs.values():
330
+ run_ID.append(i.id_)
331
+
332
+
333
+ @pytest.mark.integration
334
+ def test_zarr_rechunk_data_no_backup_load_parameter_base(null_config, tmp_path):
335
+ t1 = hf.Task(
336
+ schema=hf.task_schemas.test_t1_conditional_OS,
337
+ inputs={"p1": 101},
338
+ repeats=3,
339
+ )
340
+ wk = hf.Workflow.from_template_data(
341
+ tasks=[t1],
342
+ template_name="test_run_rechunk",
343
+ workflow_name="test_run_rechunk",
344
+ path=tmp_path,
345
+ )
346
+ wk.submit(wait=True, status=False, add_to_known=False)
347
+ wk.rechunk_parameter_base(backup=False, status=False)
348
+
349
+ arr = wk._store._get_parameter_base_array()
350
+
351
+ bak_path = (Path(wk.path) / arr.path).with_suffix(".bak")
352
+ assert not bak_path.is_file()
353
+
354
+ # check we can load parameters:
355
+ params = wk.get_all_parameters()
356
+ param_IDs = []
357
+ for i in params:
358
+ param_IDs.append(i.id_)
@@ -13,6 +13,8 @@ from hpcflow.sdk.core.utils import (
13
13
  get_nested_indices,
14
14
  is_fsspec_url,
15
15
  linspace_rect,
16
+ nth_key,
17
+ nth_value,
16
18
  process_string_nodes,
17
19
  replace_items,
18
20
  check_valid_py_identifier,
@@ -556,3 +558,22 @@ def test_dict_values_process_flat_single_item_lists():
556
558
  "b": [4],
557
559
  "c": [5],
558
560
  }
561
+
562
+
563
+ def test_nth_key():
564
+ dct = {"a": 1, "b": 2}
565
+ assert [nth_key(dct, i) for i in range(len(dct))] == ["a", "b"]
566
+
567
+
568
+ def test_nth_value():
569
+ dct = {"a": 1, "b": 2}
570
+ assert [nth_value(dct, i) for i in range(len(dct))] == [1, 2]
571
+
572
+
573
+ def test_nth_key_raises():
574
+ dct = {"a": 1, "b": 2}
575
+ with pytest.raises(Exception):
576
+ nth_key(dct, 2)
577
+
578
+ with pytest.raises(Exception):
579
+ nth_key(dct, -1)