hpcflow-new2 0.2.0a188__py3-none-any.whl → 0.2.0a190__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (115) hide show
  1. hpcflow/__pyinstaller/hook-hpcflow.py +8 -6
  2. hpcflow/_version.py +1 -1
  3. hpcflow/app.py +1 -0
  4. hpcflow/data/scripts/main_script_test_hdf5_in_obj.py +1 -1
  5. hpcflow/data/scripts/main_script_test_hdf5_out_obj.py +1 -1
  6. hpcflow/sdk/__init__.py +21 -15
  7. hpcflow/sdk/app.py +2133 -770
  8. hpcflow/sdk/cli.py +281 -250
  9. hpcflow/sdk/cli_common.py +6 -2
  10. hpcflow/sdk/config/__init__.py +1 -1
  11. hpcflow/sdk/config/callbacks.py +77 -42
  12. hpcflow/sdk/config/cli.py +126 -103
  13. hpcflow/sdk/config/config.py +578 -311
  14. hpcflow/sdk/config/config_file.py +131 -95
  15. hpcflow/sdk/config/errors.py +112 -85
  16. hpcflow/sdk/config/types.py +145 -0
  17. hpcflow/sdk/core/actions.py +1054 -994
  18. hpcflow/sdk/core/app_aware.py +24 -0
  19. hpcflow/sdk/core/cache.py +81 -63
  20. hpcflow/sdk/core/command_files.py +275 -185
  21. hpcflow/sdk/core/commands.py +111 -107
  22. hpcflow/sdk/core/element.py +724 -503
  23. hpcflow/sdk/core/enums.py +192 -0
  24. hpcflow/sdk/core/environment.py +74 -93
  25. hpcflow/sdk/core/errors.py +398 -51
  26. hpcflow/sdk/core/json_like.py +540 -272
  27. hpcflow/sdk/core/loop.py +380 -334
  28. hpcflow/sdk/core/loop_cache.py +160 -43
  29. hpcflow/sdk/core/object_list.py +370 -207
  30. hpcflow/sdk/core/parameters.py +728 -600
  31. hpcflow/sdk/core/rule.py +59 -41
  32. hpcflow/sdk/core/run_dir_files.py +33 -22
  33. hpcflow/sdk/core/task.py +1546 -1325
  34. hpcflow/sdk/core/task_schema.py +240 -196
  35. hpcflow/sdk/core/test_utils.py +126 -88
  36. hpcflow/sdk/core/types.py +387 -0
  37. hpcflow/sdk/core/utils.py +410 -305
  38. hpcflow/sdk/core/validation.py +82 -9
  39. hpcflow/sdk/core/workflow.py +1192 -1028
  40. hpcflow/sdk/core/zarr_io.py +98 -137
  41. hpcflow/sdk/demo/cli.py +46 -33
  42. hpcflow/sdk/helper/cli.py +18 -16
  43. hpcflow/sdk/helper/helper.py +75 -63
  44. hpcflow/sdk/helper/watcher.py +61 -28
  45. hpcflow/sdk/log.py +83 -59
  46. hpcflow/sdk/persistence/__init__.py +8 -31
  47. hpcflow/sdk/persistence/base.py +988 -586
  48. hpcflow/sdk/persistence/defaults.py +6 -0
  49. hpcflow/sdk/persistence/discovery.py +38 -0
  50. hpcflow/sdk/persistence/json.py +408 -153
  51. hpcflow/sdk/persistence/pending.py +158 -123
  52. hpcflow/sdk/persistence/store_resource.py +37 -22
  53. hpcflow/sdk/persistence/types.py +307 -0
  54. hpcflow/sdk/persistence/utils.py +14 -11
  55. hpcflow/sdk/persistence/zarr.py +477 -420
  56. hpcflow/sdk/runtime.py +44 -41
  57. hpcflow/sdk/submission/{jobscript_info.py → enums.py} +39 -12
  58. hpcflow/sdk/submission/jobscript.py +444 -404
  59. hpcflow/sdk/submission/schedulers/__init__.py +133 -40
  60. hpcflow/sdk/submission/schedulers/direct.py +97 -71
  61. hpcflow/sdk/submission/schedulers/sge.py +132 -126
  62. hpcflow/sdk/submission/schedulers/slurm.py +263 -268
  63. hpcflow/sdk/submission/schedulers/utils.py +7 -2
  64. hpcflow/sdk/submission/shells/__init__.py +14 -15
  65. hpcflow/sdk/submission/shells/base.py +102 -29
  66. hpcflow/sdk/submission/shells/bash.py +72 -55
  67. hpcflow/sdk/submission/shells/os_version.py +31 -30
  68. hpcflow/sdk/submission/shells/powershell.py +37 -29
  69. hpcflow/sdk/submission/submission.py +203 -257
  70. hpcflow/sdk/submission/types.py +143 -0
  71. hpcflow/sdk/typing.py +163 -12
  72. hpcflow/tests/conftest.py +8 -6
  73. hpcflow/tests/schedulers/slurm/test_slurm_submission.py +5 -2
  74. hpcflow/tests/scripts/test_main_scripts.py +60 -30
  75. hpcflow/tests/shells/wsl/test_wsl_submission.py +6 -4
  76. hpcflow/tests/unit/test_action.py +86 -75
  77. hpcflow/tests/unit/test_action_rule.py +9 -4
  78. hpcflow/tests/unit/test_app.py +13 -6
  79. hpcflow/tests/unit/test_cli.py +1 -1
  80. hpcflow/tests/unit/test_command.py +71 -54
  81. hpcflow/tests/unit/test_config.py +20 -15
  82. hpcflow/tests/unit/test_config_file.py +21 -18
  83. hpcflow/tests/unit/test_element.py +58 -62
  84. hpcflow/tests/unit/test_element_iteration.py +3 -1
  85. hpcflow/tests/unit/test_element_set.py +29 -19
  86. hpcflow/tests/unit/test_group.py +4 -2
  87. hpcflow/tests/unit/test_input_source.py +116 -93
  88. hpcflow/tests/unit/test_input_value.py +29 -24
  89. hpcflow/tests/unit/test_json_like.py +44 -35
  90. hpcflow/tests/unit/test_loop.py +65 -58
  91. hpcflow/tests/unit/test_object_list.py +17 -12
  92. hpcflow/tests/unit/test_parameter.py +16 -7
  93. hpcflow/tests/unit/test_persistence.py +48 -35
  94. hpcflow/tests/unit/test_resources.py +20 -18
  95. hpcflow/tests/unit/test_run.py +8 -3
  96. hpcflow/tests/unit/test_runtime.py +2 -1
  97. hpcflow/tests/unit/test_schema_input.py +23 -15
  98. hpcflow/tests/unit/test_shell.py +3 -2
  99. hpcflow/tests/unit/test_slurm.py +8 -7
  100. hpcflow/tests/unit/test_submission.py +39 -19
  101. hpcflow/tests/unit/test_task.py +352 -247
  102. hpcflow/tests/unit/test_task_schema.py +33 -20
  103. hpcflow/tests/unit/test_utils.py +9 -11
  104. hpcflow/tests/unit/test_value_sequence.py +15 -12
  105. hpcflow/tests/unit/test_workflow.py +114 -83
  106. hpcflow/tests/unit/test_workflow_template.py +0 -1
  107. hpcflow/tests/workflows/test_jobscript.py +2 -1
  108. hpcflow/tests/workflows/test_workflows.py +18 -13
  109. {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/METADATA +2 -1
  110. hpcflow_new2-0.2.0a190.dist-info/RECORD +165 -0
  111. hpcflow/sdk/core/parallel.py +0 -21
  112. hpcflow_new2-0.2.0a188.dist-info/RECORD +0 -158
  113. {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/LICENSE +0 -0
  114. {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/WHEEL +0 -0
  115. {hpcflow_new2-0.2.0a188.dist-info → hpcflow_new2-0.2.0a190.dist-info}/entry_points.txt +0 -0
@@ -2,14 +2,22 @@
2
2
  Models of data stores as resources.
3
3
  """
4
4
 
5
+ from __future__ import annotations
5
6
  from abc import ABC, abstractmethod
6
7
  import copy
7
8
  import json
8
- from pathlib import Path
9
- from typing import Callable, Union
9
+ from typing import Any, Callable, TYPE_CHECKING
10
10
 
11
11
  from hpcflow.sdk.core.utils import get_md5_hash
12
12
 
13
+ if TYPE_CHECKING:
14
+ from collections.abc import Mapping
15
+ from logging import Logger
16
+ from pathlib import Path
17
+ import zarr # type: ignore
18
+ from fsspec import AbstractFileSystem # type: ignore
19
+ from ..app import BaseApp
20
+
13
21
 
14
22
  class StoreResource(ABC):
15
23
  """Class to represent a persistent resource within which store data lives.
@@ -25,37 +33,37 @@ class StoreResource(ABC):
25
33
  The store name.
26
34
  """
27
35
 
28
- def __init__(self, app, name: str) -> None:
29
- self.app = app
36
+ def __init__(self, app: BaseApp, name: str) -> None:
37
+ self._app = app
30
38
  self.name = name
31
- self.data = {"read": None, "update": None}
39
+ self.data: dict[str, Any] = {"read": None, "update": None}
32
40
  self.hash = None
33
41
 
34
42
  def __repr__(self) -> str:
35
43
  return f"{self.__class__.__name__}(name={self.name!r})"
36
44
 
37
45
  @property
38
- def logger(self):
46
+ def logger(self) -> Logger:
39
47
  """
40
48
  The logger.
41
49
  """
42
- return self.app.persistence_logger
50
+ return self._app.persistence_logger
43
51
 
44
52
  @abstractmethod
45
- def _load(self):
53
+ def _load(self) -> Any:
46
54
  pass
47
55
 
48
56
  @abstractmethod
49
- def _dump(self, data):
57
+ def _dump(self, data: dict | list):
50
58
  pass
51
59
 
52
- def open(self, action):
60
+ def open(self, action: str):
53
61
  """
54
62
  Open the store.
55
63
 
56
64
  Parameters
57
65
  ----------
58
- action: str
66
+ action:
59
67
  What we are opening the store for; typically either ``read`` or ``update``.
60
68
  """
61
69
  if action == "read":
@@ -80,17 +88,17 @@ class StoreResource(ABC):
80
88
  self.data[action] = data
81
89
 
82
90
  try:
83
- self.hash = get_md5_hash(data)
91
+ self.hash = get_md5_hash(data) # type: ignore
84
92
  except Exception:
85
93
  pass
86
94
 
87
- def close(self, action):
95
+ def close(self, action: str):
88
96
  """
89
97
  Close the store for a particular action.
90
98
 
91
99
  Parameters
92
100
  ----------
93
- action: str
101
+ action:
94
102
  What we are closing the store for.
95
103
  Should match a previous call to :py:meth:`close`.
96
104
  """
@@ -135,24 +143,31 @@ class JSONFileStoreResource(StoreResource):
135
143
  The filesystem that the JSON file resides within.
136
144
  """
137
145
 
138
- def __init__(self, app, name: str, filename: str, path: Union[str, Path], fs):
146
+ def __init__(
147
+ self,
148
+ app: BaseApp,
149
+ name: str,
150
+ filename: str,
151
+ path: str | Path,
152
+ fs: AbstractFileSystem,
153
+ ):
139
154
  self.filename = filename
140
155
  self.path = path
141
156
  self.fs = fs
142
157
  super().__init__(app, name)
143
158
 
144
159
  @property
145
- def _full_path(self):
160
+ def _full_path(self) -> str:
146
161
  return f"{self.path}/{self.filename}"
147
162
 
148
- def _load(self):
163
+ def _load(self) -> Any:
149
164
  self.logger.debug(f"{self!r}: loading JSON from file.")
150
165
  with self.fs.open(self._full_path, mode="rt") as fp:
151
166
  return json.load(fp)
152
167
 
153
- def _dump(self, data):
168
+ def _dump(self, data: Mapping | list):
154
169
  self.logger.debug(f"{self!r}: dumping JSON to file")
155
- if "runs" in data:
170
+ if isinstance(data, dict) and "runs" in data:
156
171
  self.logger.debug(f"...runs: {data['runs']}")
157
172
  with self.fs.open(self._full_path, mode="wt") as fp:
158
173
  json.dump(data, fp, indent=2)
@@ -172,16 +187,16 @@ class ZarrAttrsStoreResource(StoreResource):
172
187
  How to actually perform an open on the underlying resource.
173
188
  """
174
189
 
175
- def __init__(self, app, name: str, open_call: Callable):
190
+ def __init__(self, app: BaseApp, name: str, open_call: Callable[..., zarr.Group]):
176
191
  self.open_call = open_call
177
192
  super().__init__(app, name)
178
193
 
179
- def _load(self):
194
+ def _load(self) -> Any:
180
195
  self.logger.debug(f"{self!r}: loading Zarr attributes.")
181
196
  item = self.open_call(mode="r")
182
197
  return copy.deepcopy(item.attrs.asdict())
183
198
 
184
- def _dump(self, data):
199
+ def _dump(self, data: dict | list):
185
200
  self.logger.debug(f"{self!r}: dumping Zarr attributes.")
186
201
  item = self.open_call(mode="r+")
187
202
  item.attrs.put(data)
@@ -0,0 +1,307 @@
1
+ """
2
+ Types used in type-checking the persistence subsystem.
3
+ """
4
+ from __future__ import annotations
5
+ from typing import Any, Generic, TypeVar, TYPE_CHECKING
6
+ from typing_extensions import TypedDict, NotRequired, TypeAlias
7
+
8
+ if TYPE_CHECKING:
9
+ from .base import StoreTask, StoreElement, StoreElementIter, StoreEAR, StoreParameter
10
+ from ..core.json_like import JSONDocument
11
+ from ..core.parameters import ParameterValue
12
+ from ..core.types import IterableParam
13
+ from ..typing import DataIndex, ParamSource
14
+
15
+ #: Bound type variable: :class:`StoreTask`.
16
+ AnySTask = TypeVar("AnySTask", bound="StoreTask")
17
+ #: Bound type variable: :class:`StoreElement`.
18
+ AnySElement = TypeVar("AnySElement", bound="StoreElement")
19
+ #: Bound type variable: :class:`StoreElementITer`.
20
+ AnySElementIter = TypeVar("AnySElementIter", bound="StoreElementIter")
21
+ #: Bound type variable: :class:`StoreEAR`.
22
+ AnySEAR = TypeVar("AnySEAR", bound="StoreEAR")
23
+ #: Bound type variable: :class:`StoreParameter`.
24
+ AnySParameter = TypeVar("AnySParameter", bound="StoreParameter")
25
+ #: Type of possible stored parameters.
26
+ ParameterTypes: TypeAlias = (
27
+ "ParameterValue | list | tuple | set | dict | int | float | str | None | Any"
28
+ )
29
+
30
+
31
+ class File(TypedDict):
32
+ """
33
+ Descriptor for file metadata.
34
+ """
35
+
36
+ #: Whether to store the contents.
37
+ store_contents: bool
38
+ #: The path to the file.
39
+ path: str
40
+
41
+
42
+ class FileDescriptor(TypedDict):
43
+ """
44
+ Descriptor for file metadata.
45
+ """
46
+
47
+ #: Whether this is an input file.
48
+ is_input: bool
49
+ #: Whether to store the contents.
50
+ store_contents: bool
51
+ #: Where the file will go.
52
+ dst_path: str
53
+ #: The path to the file.
54
+ path: str | None
55
+ #: Whether to delete the file after processing.
56
+ clean_up: bool
57
+ # The contents of the file.
58
+ contents: NotRequired[str]
59
+
60
+
61
+ class LoopDescriptor(TypedDict):
62
+ """
63
+ Descriptor for loop metadata.
64
+ """
65
+
66
+ #: The parameters iterated over by the loop.
67
+ iterable_parameters: dict[str, IterableParam]
68
+ #: The template data from which the loop was created.
69
+ loop_template: NotRequired[dict[str, Any]]
70
+ #: The number of iterations generated by a loop.
71
+ #: Note that the type is really ``list[tuple[tuple[int, ...], int]]``
72
+ #: but the persistence implementations don't handle tuples usefully.
73
+ num_added_iterations: list[list[list[int] | int]]
74
+ #: The parents of the loop.
75
+ parents: list[str]
76
+
77
+
78
+ # TODO: This type looks familiar...
79
+ class StoreCreationInfo(TypedDict):
80
+ """
81
+ Information about the creation of the persistence store.
82
+ """
83
+
84
+ #: Information about the application.
85
+ app_info: dict[str, Any]
86
+ #: When the persistence store was created.
87
+ create_time: str
88
+ #: The unique identifier for for the store/workflow.
89
+ id: str
90
+
91
+
92
+ class ElemMeta(TypedDict):
93
+ """
94
+ The kwargs supported for a StoreElement.
95
+ """
96
+
97
+ #: The ID of the element.
98
+ id_: int
99
+ #: The index of the element.
100
+ index: int
101
+ #: The index of the element in its element set.
102
+ es_idx: int
103
+ #: The indices of the element in the sequences that contain it.
104
+ seq_idx: dict[str, int]
105
+ #: The indices of the element's sources.
106
+ src_idx: dict[str, int]
107
+ #: The task associated with the element.
108
+ task_ID: int
109
+ #: The iteration IDs.
110
+ iteration_IDs: list[int]
111
+
112
+
113
+ class IterMeta(TypedDict):
114
+ """
115
+ The kwargs supported for a StoreElementIter.
116
+ """
117
+
118
+ #: The index of the iteration.
119
+ data_idx: DataIndex
120
+ #: The EARs associated with the iteration.
121
+ EAR_IDs: dict[int, list[int]]
122
+ #: Whether the EARs have been initialised.
123
+ EARs_initialised: bool
124
+ #: The ID of the element.
125
+ element_ID: int
126
+ #: The loops containing the iteration.
127
+ loop_idx: dict[str, int]
128
+ #: The schema parameters being iterated over.
129
+ schema_parameters: list[str]
130
+
131
+
132
+ class RunMeta(TypedDict):
133
+ """
134
+ The kwargs supported for StoreEAR.
135
+ """
136
+
137
+ #: The ID of the EAR.
138
+ id_: int
139
+ #: The ID of the element iteration containing the EAR.
140
+ elem_iter_ID: int
141
+ #: The index of the action that generated the EAR.
142
+ action_idx: int
143
+ #: The commands that the EAR will run.
144
+ commands_idx: list[int]
145
+ #: The data handled by the EAR.
146
+ data_idx: DataIndex
147
+ #: Metadata about the EAR.
148
+ metadata: Metadata | None
149
+ #: When the EAR ended, if known.
150
+ end_time: NotRequired[str | None]
151
+ #: The exit code of the EAR, if known.
152
+ exit_code: int | None
153
+ #: When the EAR started, if known.
154
+ start_time: NotRequired[str | None]
155
+ #: Working directory snapshot at start.
156
+ snapshot_start: dict[str, Any] | None
157
+ #: Working directory snapshot at end.
158
+ snapshot_end: dict[str, Any] | None
159
+ #: The index of the EAR in the submissions.
160
+ submission_idx: int | None
161
+ #: Where the EAR is set to run.
162
+ run_hostname: str | None
163
+ #: Whether the EAR succeeded, if known.
164
+ success: bool | None
165
+ #: Whether the EAR was skipped.
166
+ skip: bool
167
+
168
+
169
+ class TaskMeta(TypedDict):
170
+ """
171
+ Information about a task.
172
+ """
173
+
174
+ #: The ID of the task.
175
+ id_: int
176
+ #: The index of the task in the workflow.
177
+ index: int
178
+ #: The elements in the task.
179
+ element_IDs: list[int]
180
+
181
+
182
+ class TemplateMeta(TypedDict): # FIXME: Incomplete, see WorkflowTemplate
183
+ """
184
+ Metadata about a workflow template.
185
+ """
186
+
187
+ #: Descriptors for loops.
188
+ loops: list[dict]
189
+ #: Descriptors for tasks.
190
+ tasks: list[dict]
191
+
192
+
193
+ class Metadata(TypedDict):
194
+ """
195
+ Workflow metadata.
196
+ """
197
+
198
+ #: Information about the store's creation.
199
+ creation_info: NotRequired[StoreCreationInfo]
200
+ #: Elements in the workflow.
201
+ elements: NotRequired[list[ElemMeta]]
202
+ #: Iterations in the workflow.
203
+ iters: NotRequired[list[IterMeta]]
204
+ #: Loops in the workflow.
205
+ loops: NotRequired[list[LoopDescriptor]]
206
+ #: The name of the workflow.
207
+ name: NotRequired[str]
208
+ #: The number of added tasks.
209
+ num_added_tasks: NotRequired[int]
210
+ #: The replacement workflow, if any.
211
+ replaced_workflow: NotRequired[str]
212
+ #: Element Action Runs in the workflow.
213
+ runs: NotRequired[list[RunMeta]]
214
+ #: Tasks in the workflow.
215
+ tasks: NotRequired[list[TaskMeta]]
216
+ #: The template that generated the workflow.
217
+ template: NotRequired[TemplateMeta]
218
+ #: Custom template components used.
219
+ template_components: NotRequired[dict[str, Any]]
220
+ #: Format for timestamps.
221
+ ts_fmt: NotRequired[str]
222
+ #: Format for timestamps used in naming.
223
+ ts_name_fmt: NotRequired[str]
224
+
225
+
226
+ class TypeLookup(TypedDict, total=False):
227
+ """
228
+ Information for looking up the type of a parameter.
229
+
230
+ Note
231
+ ----
232
+ Not a total typed dictionary.
233
+ """
234
+
235
+ #: Tuples involving the parameter.
236
+ tuples: list[list[int]]
237
+ #: Sets involving the parameter.
238
+ sets: list[list[int]]
239
+ #: Arrays involving the parameter.
240
+ arrays: list[list[list[int] | int]]
241
+ #: Masked arrays involving the parameter.
242
+ masked_arrays: list[list[int | list[int]]]
243
+
244
+
245
+ class EncodedStoreParameter(TypedDict):
246
+ """
247
+ The encoding of a :class:`StoreParameter`.
248
+ """
249
+
250
+ #: The parameter data.
251
+ data: Any
252
+ #: Information for looking up the type.
253
+ type_lookup: TypeLookup
254
+
255
+
256
+ class PersistenceCache(
257
+ TypedDict, Generic[AnySTask, AnySElement, AnySElementIter, AnySEAR, AnySParameter]
258
+ ):
259
+ """
260
+ Cache used internally by the persistence engine.
261
+ """
262
+
263
+ #: Tasks.
264
+ tasks: dict[int, AnySTask]
265
+ #: Elements.
266
+ elements: dict[int, AnySElement]
267
+ #: Element iterations.
268
+ element_iters: dict[int, AnySElementIter]
269
+ #: Element action runs.
270
+ EARs: dict[int, AnySEAR]
271
+ #: Parameter sources.
272
+ param_sources: dict[int, ParamSource]
273
+ #: Number of tasks.
274
+ num_tasks: int | None
275
+ #: Parameters.
276
+ parameters: dict[int, AnySParameter]
277
+ #: Number of element action runs.
278
+ num_EARs: int | None
279
+
280
+
281
+ class ZarrAttrsDict(TypedDict):
282
+ """
283
+ Zarr workflow attributes descriptor.
284
+ """
285
+
286
+ #: Workflow name.
287
+ name: str
288
+ #: Timestamp format.
289
+ ts_fmt: str
290
+ #: Timestamp format for names.
291
+ ts_name_fmt: str
292
+ #: Information about the creation of the workflow and persistent store.
293
+ creation_info: StoreCreationInfo
294
+ #: The template used to build the workflow.
295
+ template: TemplateMeta
296
+ #: Custom components used to build the workflow.
297
+ template_components: dict[str, Any]
298
+ #: Number of tasks added.
299
+ num_added_tasks: int
300
+ #: Tasks in the workflow.
301
+ tasks: list[dict[str, Any]]
302
+ #: Loops in the workflow.
303
+ loops: list[dict[str, Any]]
304
+ #: Submissions by the workflow.
305
+ submissions: list[JSONDocument]
306
+ #: Replacement workflow, if any.
307
+ replaced_workflow: NotRequired[str]
@@ -2,12 +2,22 @@
2
2
  Miscellaneous persistence-related helpers.
3
3
  """
4
4
 
5
+ from __future__ import annotations
5
6
  from getpass import getpass
7
+ from typing import TYPE_CHECKING
6
8
 
7
9
  from hpcflow.sdk.core.errors import WorkflowNotFoundError
8
10
 
11
+ if TYPE_CHECKING:
12
+ from typing import Callable, TypeVar
13
+ from fsspec import AbstractFileSystem # type: ignore
9
14
 
10
- def ask_pw_on_auth_exc(f, *args, add_pw_to=None, **kwargs):
15
+ T = TypeVar("T")
16
+
17
+
18
+ def ask_pw_on_auth_exc(
19
+ f: Callable[..., T], *args, add_pw_to: str | None = None, **kwargs
20
+ ) -> tuple[T, str | None]:
11
21
  """
12
22
  Run the given function on the given arguments and add a password if the function
13
23
  fails with an SSHException.
@@ -24,19 +34,14 @@ def ask_pw_on_auth_exc(f, *args, add_pw_to=None, **kwargs):
24
34
  if not add_pw_to:
25
35
  kwargs["password"] = pw
26
36
  else:
27
- kwargs[add_pw_to]["password"] = pw
37
+ kwargs[add_pw_to] = {**kwargs[add_pw_to], "password": pw}
28
38
 
29
39
  out = f(*args, **kwargs)
30
40
 
31
- if not add_pw_to:
32
- del kwargs["password"]
33
- else:
34
- del kwargs[add_pw_to]["password"]
35
-
36
41
  return out, pw
37
42
 
38
43
 
39
- def infer_store(path: str, fs) -> str:
44
+ def infer_store(path: str, fs: AbstractFileSystem) -> str:
40
45
  """Identify the store type using the path and file system parsed by fsspec.
41
46
 
42
47
  Parameters
@@ -63,8 +68,6 @@ def infer_store(path: str, fs) -> str:
63
68
  elif fs.glob(f"{path}/metadata.json"):
64
69
  store_fmt = "json"
65
70
  else:
66
- raise WorkflowNotFoundError(
67
- f"Cannot infer a store format at path {path!r} with file system {fs!r}."
68
- )
71
+ raise WorkflowNotFoundError(path, fs)
69
72
 
70
73
  return store_fmt