amberflow 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- amberflow/__init__.py +6 -0
- amberflow/artifacts/__init__.py +8 -0
- amberflow/artifacts/alchemical.py +388 -0
- amberflow/artifacts/artifactsdecorators.py +146 -0
- amberflow/artifacts/baseartifact.py +631 -0
- amberflow/artifacts/md.py +85 -0
- amberflow/artifacts/md.pyi +61 -0
- amberflow/artifacts/mdparameters.py +509 -0
- amberflow/artifacts/structure.py +363 -0
- amberflow/artifacts/topology.py +149 -0
- amberflow/artifacts/trajectory.py +94 -0
- amberflow/checkpoint.py +42 -0
- amberflow/data/cpptraj/autoimage +1 -0
- amberflow/data/cpptraj/go +1 -0
- amberflow/data/cpptraj/parm +1 -0
- amberflow/data/cpptraj/rms_first +1 -0
- amberflow/data/cpptraj/rms_reference_out +1 -0
- amberflow/data/cpptraj/strip +1 -0
- amberflow/data/cpptraj/trajin +1 -0
- amberflow/data/cpptraj/trajout +1 -0
- amberflow/data/mdin/mbar_lambda +1 -0
- amberflow/data/mdin/md +11 -0
- amberflow/data/mdin/md_icfe +55 -0
- amberflow/data/mdin/md_icfe_nmropt +58 -0
- amberflow/data/mdin/md_icfe_nmropt_varying +62 -0
- amberflow/data/mdin/md_icfe_varying +58 -0
- amberflow/data/mdin/md_restrained +14 -0
- amberflow/data/mdin/md_restrained_varying +17 -0
- amberflow/data/mdin/min +8 -0
- amberflow/data/mdin/min_icfe +35 -0
- amberflow/data/mdin/min_icfe_nmropt +39 -0
- amberflow/data/mdin/min_restrained +8 -0
- amberflow/data/mdin/ti_exch_mbar +59 -0
- amberflow/data/mdin/ti_exch_mbar_nmropt +62 -0
- amberflow/data/tleap/leaprc +6 -0
- amberflow/data/tleap/load_nonstandard +2 -0
- amberflow/data/tleap/load_pdb +1 -0
- amberflow/data/tleap/neutralize +2 -0
- amberflow/data/tleap/orthorhombic +5 -0
- amberflow/data/tleap/quit +1 -0
- amberflow/data/tleap/salt +5 -0
- amberflow/data/tleap/save_amberparm +1 -0
- amberflow/data/tleap/solvatebox +1 -0
- amberflow/data/tleap/solvateoct +1 -0
- amberflow/data/tleap/truncated_octahedron +5 -0
- amberflow/flows/__init__.py +3 -0
- amberflow/flows/flows.py +30 -0
- amberflow/pipeline.py +602 -0
- amberflow/primitives/__init__.py +6 -0
- amberflow/primitives/command.py +443 -0
- amberflow/primitives/executor.py +344 -0
- amberflow/primitives/log.py +65 -0
- amberflow/primitives/primitives.py +445 -0
- amberflow/primitives/units.py +19 -0
- amberflow/primitives/utils.py +105 -0
- amberflow/schedulers/__init__.py +1 -0
- amberflow/schedulers/schedulers.py +368 -0
- amberflow/worknodes/__init__.py +13 -0
- amberflow/worknodes/afeanalysis.py +441 -0
- amberflow/worknodes/alchemical.py +725 -0
- amberflow/worknodes/analysis.py +15 -0
- amberflow/worknodes/baseworknode.py +498 -0
- amberflow/worknodes/buildbox.py +685 -0
- amberflow/worknodes/cpptraj.py +288 -0
- amberflow/worknodes/generatetopology.py +334 -0
- amberflow/worknodes/helpernodes.py +334 -0
- amberflow/worknodes/md.py +325 -0
- amberflow/worknodes/mdtools.py +427 -0
- amberflow/worknodes/parametrization.py +90 -0
- amberflow/worknodes/worknodesdecorators.py +404 -0
- amberflow/worknodes/worknodeutils.py +116 -0
- amberflow-0.2.2.dist-info/METADATA +92 -0
- amberflow-0.2.2.dist-info/RECORD +75 -0
- amberflow-0.2.2.dist-info/WHEEL +4 -0
- amberflow-0.2.2.dist-info/licenses/LICENSE.txt +9 -0
amberflow/__init__.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
1
|
+
from pathlib import Path
|
|
2
|
+
from typing import Optional, Sequence, Union, SupportsIndex, Iterator
|
|
3
|
+
|
|
4
|
+
import numpy as np
|
|
5
|
+
|
|
6
|
+
from amberflow.artifacts import fileartifact, BaseArtifact, BaseArtifactFile, BaseArtifactDir
|
|
7
|
+
from amberflow.primitives import filepath_t, FileHandle
|
|
8
|
+
|
|
9
|
+
__all__ = (
|
|
10
|
+
"LambdaSchedule",
|
|
11
|
+
"BaseStatesFile",
|
|
12
|
+
"BaseRestartStatesFile",
|
|
13
|
+
"BaseTrajectoryStatesFile",
|
|
14
|
+
"ComplexProteinLigandRestartStates",
|
|
15
|
+
"BinderLigandRestartStates",
|
|
16
|
+
"ComplexNucleicAcidLigandRestartStates",
|
|
17
|
+
"ComplexProteinLigandTrajectoryStatesNC",
|
|
18
|
+
"BinderLigandTrajectoryStatesNC",
|
|
19
|
+
"ComplexNucleicAcidLigandTrajectoryStates",
|
|
20
|
+
"MdoutStates",
|
|
21
|
+
"EdgeMBARhtml",
|
|
22
|
+
"EdgeMBARxml",
|
|
23
|
+
"Datdir",
|
|
24
|
+
"TargetDatdir",
|
|
25
|
+
"ReferenceDatdir",
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class LambdaSchedule(BaseArtifact):
|
|
30
|
+
"""
|
|
31
|
+
A class representing a schedule of lambda values for alchemical transformations.
|
|
32
|
+
|
|
33
|
+
Lambda values are used in alchemical free energy calculations to define the
|
|
34
|
+
intermediate states between two end states.
|
|
35
|
+
"""
|
|
36
|
+
tags: tuple = ("",)
|
|
37
|
+
|
|
38
|
+
def __init__(self, lambdas: Sequence[float], decimals: int = 5) -> None:
|
|
39
|
+
"""
|
|
40
|
+
Initialize a LambdaSchedule with a sequence of lambda values.
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
lambdas : Sequence[float]
|
|
45
|
+
A sequence of lambda values between 0 and 1
|
|
46
|
+
decimals : int, optional
|
|
47
|
+
Number of decimal places to round lambda values to, by default 5
|
|
48
|
+
"""
|
|
49
|
+
self.lambdas = np.array(lambdas)
|
|
50
|
+
# Just setting a default large number
|
|
51
|
+
self.decimals = 20
|
|
52
|
+
if decimals != 0:
|
|
53
|
+
self.lambdas = np.round(self.lambdas, decimals=5)
|
|
54
|
+
self.decimals = decimals
|
|
55
|
+
|
|
56
|
+
def __getitem__(self, index: Union[SupportsIndex, slice]) -> Union[float, "LambdaSchedule"]:
|
|
57
|
+
if isinstance(index, slice):
|
|
58
|
+
return type(self)(self.lambdas[index])
|
|
59
|
+
return float(self.lambdas[index])
|
|
60
|
+
|
|
61
|
+
def __iter__(self) -> Iterator[float]:
|
|
62
|
+
for x in self.lambdas:
|
|
63
|
+
yield float(x)
|
|
64
|
+
|
|
65
|
+
def get_formatted(self, index: Union[SupportsIndex]) -> str:
|
|
66
|
+
return f"{self[index]:.{self.decimals}f}"
|
|
67
|
+
|
|
68
|
+
def formatted(self) -> Iterator[str]:
|
|
69
|
+
for x in self.lambdas:
|
|
70
|
+
yield f"{x:.{self.decimals}f}"
|
|
71
|
+
|
|
72
|
+
def __contains__(self, item: float) -> bool:
|
|
73
|
+
return item in self.lambdas
|
|
74
|
+
|
|
75
|
+
def __repr__(self) -> str:
|
|
76
|
+
return f"{type(self).__name__}(lambdas={self.lambdas.tolist()})"
|
|
77
|
+
|
|
78
|
+
def __eq__(self, other: object) -> bool:
|
|
79
|
+
if not isinstance(other, type(self)):
|
|
80
|
+
return NotImplemented(f"Bad comparison between {type(self)} and {type(other)}")
|
|
81
|
+
return np.array_equal(self.lambdas, other.lambdas)
|
|
82
|
+
|
|
83
|
+
def __len__(self) -> int:
|
|
84
|
+
return len(self.lambdas)
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class BaseStatesFile(BaseArtifact):
|
|
88
|
+
"""
|
|
89
|
+
Base class for managing collections of state files in alchemical simulations.
|
|
90
|
+
|
|
91
|
+
This class provides functionality to handle multiple files corresponding to different
|
|
92
|
+
lambda states in an alchemical transformation.
|
|
93
|
+
"""
|
|
94
|
+
def __init__(
|
|
95
|
+
self, filepath: filepath_t, *args, prefix, suffix, lambdas: Optional[Sequence[float]] = None, **kwargs
|
|
96
|
+
) -> None:
|
|
97
|
+
"""
|
|
98
|
+
Initialize a BaseStatesFile object.
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
filepath : filepath_t
|
|
103
|
+
Path to a representative file in the collection
|
|
104
|
+
prefix : str
|
|
105
|
+
Prefix for the filenames
|
|
106
|
+
suffix : str
|
|
107
|
+
Suffix (extension) for the filenames
|
|
108
|
+
lambdas : Optional[Sequence[float]], optional
|
|
109
|
+
Sequence of lambda values, by default None
|
|
110
|
+
"""
|
|
111
|
+
self.filepath = Path(FileHandle(filepath))
|
|
112
|
+
self.name: str = self.filepath.stem[len(prefix) + 1:]
|
|
113
|
+
super()._check_file(self.filepath, prefix, suffix)
|
|
114
|
+
# use FileHandle to ensure the files exist
|
|
115
|
+
if lambdas is not None:
|
|
116
|
+
name_wo_clambda = "_".join(filepath.stem.split("_")[:-1])
|
|
117
|
+
self.states = {
|
|
118
|
+
float(clambda): FileHandle(filepath.with_name(f"{name_wo_clambda}_{clambda}{suffix}"))
|
|
119
|
+
for clambda in lambdas
|
|
120
|
+
}
|
|
121
|
+
else:
|
|
122
|
+
prefix = prefix if prefix != "" else "*"
|
|
123
|
+
self.states = {}
|
|
124
|
+
for state in sorted(filepath.parent.glob(f"{prefix}_*{suffix}")):
|
|
125
|
+
clambda = float(state.stem.split("_")[-1])
|
|
126
|
+
if 0 <= clambda <= 1:
|
|
127
|
+
self.states[clambda] = state
|
|
128
|
+
|
|
129
|
+
self.nlambdas = len(self.states)
|
|
130
|
+
|
|
131
|
+
def __getitem__(self, key: float) -> filepath_t:
|
|
132
|
+
return self.states[key]
|
|
133
|
+
|
|
134
|
+
def __iter__(self) -> iter:
|
|
135
|
+
return iter(self.states.values())
|
|
136
|
+
|
|
137
|
+
def __len__(self) -> int:
|
|
138
|
+
return len(self.states)
|
|
139
|
+
|
|
140
|
+
def __str__(self) -> str:
|
|
141
|
+
return f"{self.__class__.__name__}(states={self.states})"
|
|
142
|
+
|
|
143
|
+
def __repr__(self) -> str:
|
|
144
|
+
return f"{self.__class__.__name__}(states={self.states})"
|
|
145
|
+
|
|
146
|
+
def values(self):
|
|
147
|
+
return self.states.values()
|
|
148
|
+
|
|
149
|
+
def keys(self):
|
|
150
|
+
return self.states.keys()
|
|
151
|
+
|
|
152
|
+
def items(self):
|
|
153
|
+
return self.states.items()
|
|
154
|
+
|
|
155
|
+
def get(self, key, default=None):
|
|
156
|
+
return self.states.get(key, default)
|
|
157
|
+
|
|
158
|
+
@staticmethod
|
|
159
|
+
def get_name(filepath: Path, prefix: str) -> str:
|
|
160
|
+
return filepath.stem[len(prefix):]
|
|
161
|
+
|
|
162
|
+
def __fspath__(self) -> Union[str, bytes, Path]:
|
|
163
|
+
return str(self.filepath)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
class BaseRestartStatesFile(BaseStatesFile):
|
|
167
|
+
pass
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
class BaseTrajectoryStatesFile(BaseStatesFile):
|
|
171
|
+
pass
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
@fileartifact
|
|
175
|
+
class ComplexProteinLigandRestartStates(BaseRestartStatesFile):
|
|
176
|
+
prefix: str = "complex"
|
|
177
|
+
suffix: str = ".rst7"
|
|
178
|
+
tags: tuple = ("protein", "ligand", "alchemical")
|
|
179
|
+
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
filepath: filepath_t,
|
|
183
|
+
*args,
|
|
184
|
+
lambdas: Optional[Sequence[float]] = None,
|
|
185
|
+
**kwargs,
|
|
186
|
+
) -> None:
|
|
187
|
+
super().__init__(filepath, *args, prefix=self.prefix, suffix=self.suffix, lambdas=lambdas, **kwargs)
|
|
188
|
+
|
|
189
|
+
|
|
190
|
+
@fileartifact
|
|
191
|
+
class BinderLigandRestartStates(BaseRestartStatesFile):
|
|
192
|
+
prefix: str = "binder"
|
|
193
|
+
suffix: str = ".rst7"
|
|
194
|
+
tags: tuple = ("ligand", "alchemical")
|
|
195
|
+
|
|
196
|
+
def __init__(
|
|
197
|
+
self,
|
|
198
|
+
filepath: filepath_t,
|
|
199
|
+
*args,
|
|
200
|
+
lambdas: Optional[Sequence[float]] = None,
|
|
201
|
+
**kwargs,
|
|
202
|
+
) -> None:
|
|
203
|
+
super().__init__(filepath, *args, prefix=self.prefix, suffix=self.suffix, lambdas=lambdas, **kwargs)
|
|
204
|
+
|
|
205
|
+
|
|
206
|
+
@fileartifact
|
|
207
|
+
class ComplexNucleicAcidLigandRestartStates(BaseRestartStatesFile):
|
|
208
|
+
prefix: str = "complex"
|
|
209
|
+
suffix: str = ".rst7"
|
|
210
|
+
tags: tuple = ("nucleicacid", "ligand", "alchemical")
|
|
211
|
+
|
|
212
|
+
def __init__(
|
|
213
|
+
self,
|
|
214
|
+
filepath: filepath_t,
|
|
215
|
+
*args,
|
|
216
|
+
lambdas: Optional[Sequence[float]] = None,
|
|
217
|
+
**kwargs,
|
|
218
|
+
) -> None:
|
|
219
|
+
super().__init__(filepath, *args, prefix=self.prefix, suffix=self.suffix, lambdas=lambdas, **kwargs)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
@fileartifact
|
|
223
|
+
class ComplexProteinLigandTrajectoryStatesNC(BaseTrajectoryStatesFile):
|
|
224
|
+
prefix: str = "complex"
|
|
225
|
+
suffix: str = ".nc"
|
|
226
|
+
tags: tuple = ("protein", "ligand", "alchemical")
|
|
227
|
+
|
|
228
|
+
def __init__(
|
|
229
|
+
self,
|
|
230
|
+
filepath: filepath_t,
|
|
231
|
+
*args,
|
|
232
|
+
lambdas: Optional[Sequence[float]] = None,
|
|
233
|
+
**kwargs,
|
|
234
|
+
) -> None:
|
|
235
|
+
super().__init__(filepath, *args, prefix=self.prefix, suffix=self.suffix, lambdas=lambdas, **kwargs)
|
|
236
|
+
|
|
237
|
+
|
|
238
|
+
@fileartifact
|
|
239
|
+
class BinderLigandTrajectoryStatesNC(BaseTrajectoryStatesFile):
|
|
240
|
+
prefix: str = "binder"
|
|
241
|
+
suffix: str = ".nc"
|
|
242
|
+
tags: tuple = ("ligand", "ligand", "alchemical")
|
|
243
|
+
|
|
244
|
+
def __init__(
|
|
245
|
+
self,
|
|
246
|
+
filepath: filepath_t,
|
|
247
|
+
*args,
|
|
248
|
+
lambdas: Optional[Sequence[float]] = None,
|
|
249
|
+
**kwargs,
|
|
250
|
+
) -> None:
|
|
251
|
+
super().__init__(filepath, *args, prefix=self.prefix, suffix=self.suffix, lambdas=lambdas, **kwargs)
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
@fileartifact
|
|
255
|
+
class ComplexNucleicAcidLigandTrajectoryStates(BaseTrajectoryStatesFile):
|
|
256
|
+
prefix: str = "complex"
|
|
257
|
+
suffix: str = ".nc"
|
|
258
|
+
tags: tuple = ("nucleicacid", "ligand", "alchemical")
|
|
259
|
+
|
|
260
|
+
def __init__(
|
|
261
|
+
self,
|
|
262
|
+
filepath: filepath_t,
|
|
263
|
+
*args,
|
|
264
|
+
lambdas: Optional[Sequence[float]] = None,
|
|
265
|
+
**kwargs,
|
|
266
|
+
) -> None:
|
|
267
|
+
super().__init__(filepath, *args, prefix=self.prefix, suffix=self.suffix, lambdas=lambdas, **kwargs)
|
|
268
|
+
|
|
269
|
+
|
|
270
|
+
@fileartifact
|
|
271
|
+
class MdoutStates(BaseStatesFile):
|
|
272
|
+
prefix: str = ""
|
|
273
|
+
suffix: str = ".mdout"
|
|
274
|
+
tags: tuple[str] = ("alchemical",)
|
|
275
|
+
|
|
276
|
+
def __init__(self, filepath: filepath_t, *args, **kwargs) -> None:
|
|
277
|
+
super().__init__(filepath, *args, prefix=self.prefix, suffix=self.suffix, **kwargs)
|
|
278
|
+
|
|
279
|
+
|
|
280
|
+
@fileartifact
|
|
281
|
+
class EdgeMBARhtml(BaseArtifactFile):
|
|
282
|
+
prefix: str = ""
|
|
283
|
+
suffix: str = ".html"
|
|
284
|
+
tags: tuple[str] = ("",)
|
|
285
|
+
|
|
286
|
+
def __init__(self, filepath: filepath_t, *args, **kwargs) -> None:
|
|
287
|
+
super().__init__(filepath, *args, prefix=self.prefix, suffix=self.suffix, **kwargs)
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
@fileartifact
|
|
291
|
+
class EdgeMBARxml(BaseArtifactFile):
|
|
292
|
+
prefix: str = ""
|
|
293
|
+
suffix: str = ".xml"
|
|
294
|
+
tags: tuple[str] = ("",)
|
|
295
|
+
|
|
296
|
+
def __init__(self, filepath: filepath_t, *args, **kwargs) -> None:
|
|
297
|
+
super().__init__(filepath, *args, prefix=self.prefix, suffix=self.suffix, **kwargs)
|
|
298
|
+
|
|
299
|
+
@fileartifact
|
|
300
|
+
class Datdir(BaseArtifactDir):
|
|
301
|
+
prefix: str = ""
|
|
302
|
+
suffix: str = ""
|
|
303
|
+
tags: tuple[str] = ("",)
|
|
304
|
+
|
|
305
|
+
def __init__(self, filepath: filepath_t, *args, edge: str = "sysname", environment: str = "com", stage: str = "vdw",
|
|
306
|
+
trial: int = 1, states: Optional[Sequence[float]] = None, makedir=False, **kwargs) -> None:
|
|
307
|
+
self.edge = edge
|
|
308
|
+
self.environment = environment
|
|
309
|
+
self.stage = stage
|
|
310
|
+
self.trial = f"t{trial}"
|
|
311
|
+
self.states= tuple(states) if states is not None else None
|
|
312
|
+
self.parent_filepath = Path(filepath)
|
|
313
|
+
new_filepath = Path(filepath, self.edge, self.environment, self.stage, self.trial)
|
|
314
|
+
if makedir:
|
|
315
|
+
new_filepath.mkdir(parents=True, exist_ok=True)
|
|
316
|
+
super().__init__(new_filepath, *args, prefix=kwargs.get("prefix"), suffix=kwargs.get("suffix"))
|
|
317
|
+
|
|
318
|
+
def is_valid(self, nlambdas: Optional[int] = None, remlog: bool = True, mbar: bool= False) -> bool:
|
|
319
|
+
"""
|
|
320
|
+
Check if the directory contains the expected number of lambda states and required files.
|
|
321
|
+
|
|
322
|
+
Parameters
|
|
323
|
+
----------
|
|
324
|
+
nlambdas : int
|
|
325
|
+
Number of lambda states expected.
|
|
326
|
+
remlog : bool, optional
|
|
327
|
+
Whether to require at least one .yaml file (default: True).
|
|
328
|
+
mbar : bool, optional
|
|
329
|
+
Whether to use MBAR file counting logic (default: False). Set it to True only if you're sure your
|
|
330
|
+
run had valid MBAR Energy values for all windows against all windows.
|
|
331
|
+
|
|
332
|
+
Returns
|
|
333
|
+
-------
|
|
334
|
+
bool
|
|
335
|
+
True if the directory is valid, False otherwise.
|
|
336
|
+
"""
|
|
337
|
+
if nlambdas is None:
|
|
338
|
+
if not (states := getattr(self, "states", False)):
|
|
339
|
+
raise ValueError("The `states` attribute must be set before calling is_valid() without `nlambdas`.")
|
|
340
|
+
nlambdas = len(states)
|
|
341
|
+
# First, check that the directory actually exists
|
|
342
|
+
if not self.filepath.is_dir():
|
|
343
|
+
return False
|
|
344
|
+
if remlog:
|
|
345
|
+
try:
|
|
346
|
+
next(iter(self.filepath.glob(f"*.yaml")))
|
|
347
|
+
except StopIteration:
|
|
348
|
+
return False
|
|
349
|
+
# Check if the directory contains the expected number of dat files, given the number of lambdas.
|
|
350
|
+
dvdl_count = len(list(self.filepath.glob(f"dvdl*.dat")))
|
|
351
|
+
if dvdl_count < nlambdas:
|
|
352
|
+
return False
|
|
353
|
+
# if BAR: 3 dat files for each window, except the first and last windows which have 2 dat files each.
|
|
354
|
+
efep_count = len(list(self.filepath.glob(f"efep*.dat")))
|
|
355
|
+
efep_expected = nlambdas*nlambdas if mbar else (nlambdas-2)*3 + 4
|
|
356
|
+
return efep_count >= efep_expected
|
|
357
|
+
|
|
358
|
+
def get_path_template(self) -> str:
|
|
359
|
+
if self.stage == "":
|
|
360
|
+
return str(self.parent_filepath / r"{edge}/{env}/{trial}/efep_{traj}_{ene}.dat")
|
|
361
|
+
else:
|
|
362
|
+
return str(self.parent_filepath / r"{edge}/{env}/{stage}/{trial}/efep_{traj}_{ene}.dat")
|
|
363
|
+
|
|
364
|
+
@fileartifact
|
|
365
|
+
class TargetDatdir(Datdir):
|
|
366
|
+
prefix: str = ""
|
|
367
|
+
suffix: str = ""
|
|
368
|
+
tags: tuple[str] = ("target",)
|
|
369
|
+
|
|
370
|
+
def __init__(self, filepath: filepath_t, *args, edge: str = "sysname", environment: str = "aq", stage: str = "vdw",
|
|
371
|
+
trial: int = 1, states: Optional[Sequence[float]] = None, makedir=False, **kwargs) -> None:
|
|
372
|
+
super().__init__(filepath, *args, edge=edge, environment=environment, stage=stage,
|
|
373
|
+
trial=trial, states=states, makedir=makedir, prefix=self.prefix, suffix=self.suffix, **kwargs)
|
|
374
|
+
try:
|
|
375
|
+
self.boresch_restraints = next(iter(getattr(self, "filepath").glob(f"boresch*.yaml")))
|
|
376
|
+
except StopIteration:
|
|
377
|
+
self.boresch_restraints = None
|
|
378
|
+
|
|
379
|
+
@fileartifact
|
|
380
|
+
class ReferenceDatdir(Datdir):
|
|
381
|
+
prefix: str = ""
|
|
382
|
+
suffix: str = ""
|
|
383
|
+
tags: tuple[str] = ("reference",)
|
|
384
|
+
|
|
385
|
+
def __init__(self, filepath: filepath_t, *args, edge: str = "sysname", environment: str = "com", stage: str = "vdw",
|
|
386
|
+
trial: int = 1, states: Optional[Sequence[float]] = None, makedir=False, **kwargs) -> None:
|
|
387
|
+
super().__init__(filepath, *args, edge=edge, environment=environment, stage=stage,
|
|
388
|
+
trial=trial, states=states, makedir=makedir, prefix=self.prefix, suffix=self.suffix, **kwargs)
|
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import inspect
|
|
3
|
+
import shutil
|
|
4
|
+
from pathlib import Path
|
|
5
|
+
|
|
6
|
+
from amberflow.artifacts import BaseArtifact
|
|
7
|
+
from amberflow.primitives import ArtifactError
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def userartifact(cls: type) -> type:
|
|
11
|
+
setattr(cls, "_is_user_artifact", True)
|
|
12
|
+
return cls
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def fileartifact(cls: type) -> type:
|
|
16
|
+
"""
|
|
17
|
+
Class decorator to enforce proper file-based Artifacts:
|
|
18
|
+
- The class must have attributes prefix, suffix, and tags.
|
|
19
|
+
- __init__ must have a specific signature:
|
|
20
|
+
- First argument after self must be 'filepath' (required, positional).
|
|
21
|
+
- All subsequent arguments must be optional (have defaults, be
|
|
22
|
+
keyword-only, *args, or **kwargs).
|
|
23
|
+
And to help the user with helpful methods:
|
|
24
|
+
- copy_to(self, dest: str):
|
|
25
|
+
"""
|
|
26
|
+
|
|
27
|
+
if not any([hasattr(base, "prefix") for base in cls.__mro__] + [hasattr(cls, "prefix")]):
|
|
28
|
+
raise ArtifactError(f"User artifact {cls} must have a prefix (str) attribute")
|
|
29
|
+
if not any([hasattr(base, "suffix") for base in cls.__mro__] + [hasattr(cls, "suffix")]):
|
|
30
|
+
raise ArtifactError(f"User artifact {cls} must have a suffix (str) attribute")
|
|
31
|
+
if not any([hasattr(base, "tags") for base in cls.__mro__] + [hasattr(cls, "tags")]):
|
|
32
|
+
raise ArtifactError(f"User artifact {cls} must have a tags (tuple) attribute")
|
|
33
|
+
if not issubclass(cls, BaseArtifact):
|
|
34
|
+
raise ArtifactError(f"User artifact {cls} must inherit from BaseArtifact")
|
|
35
|
+
|
|
36
|
+
init_method = getattr(cls, "__init__", None)
|
|
37
|
+
|
|
38
|
+
if not init_method or not inspect.isfunction(init_method):
|
|
39
|
+
raise TypeError(f"Class {cls.__name__} must have an __init__ method.")
|
|
40
|
+
|
|
41
|
+
try:
|
|
42
|
+
sig = inspect.signature(init_method)
|
|
43
|
+
params = list(sig.parameters.values())
|
|
44
|
+
except ValueError:
|
|
45
|
+
# Handle cases like built-in types that might not have inspectable signatures
|
|
46
|
+
raise TypeError(f"Could not inspect the signature of {cls.__name__}.__init__")
|
|
47
|
+
|
|
48
|
+
if len(params) < 2:
|
|
49
|
+
raise TypeError(
|
|
50
|
+
f"{cls.__name__}.__init__ must accept at least one positional argument named 'filepath' after 'self'."
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
filepath_param = params[1]
|
|
54
|
+
if filepath_param.name != "filepath":
|
|
55
|
+
raise TypeError(
|
|
56
|
+
f"{cls.__name__}.__init__ first argument after 'self' must be named 'filepath', not '{filepath_param.name}'."
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
# Check if 'filepath' is required (no default value) and positional/keyword capable
|
|
60
|
+
is_positional_or_kw = filepath_param.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
61
|
+
has_no_default = filepath_param.default == inspect.Parameter.empty
|
|
62
|
+
|
|
63
|
+
if not (is_positional_or_kw and has_no_default):
|
|
64
|
+
raise TypeError(
|
|
65
|
+
f"{cls.__name__}.__init__ parameter 'filepath' must be a required positional argument (no default value). Found kind={filepath_param.kind}, default={filepath_param.default!r}"
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
# 3. Check all subsequent parameters (from index 2 onwards)
|
|
69
|
+
for i, param in enumerate(params[2:], start=2):
|
|
70
|
+
if param.default == inspect.Parameter.empty and param.kind not in (
|
|
71
|
+
inspect.Parameter.VAR_POSITIONAL,
|
|
72
|
+
inspect.Parameter.VAR_KEYWORD,
|
|
73
|
+
):
|
|
74
|
+
raise TypeError(
|
|
75
|
+
f"{cls.__name__}.__init__ parameter '{param.name}' (at index {i}) "
|
|
76
|
+
f"must be optional (have a default, be keyword-only, *args, or **kwargs). "
|
|
77
|
+
f"Found kind={param.kind} with no default value."
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
# `copy_to` method for all file based artifacts
|
|
81
|
+
def copy_to(self, dest: Path):
|
|
82
|
+
"""
|
|
83
|
+
Copies associated files to a destination and returns a new artifact instance.
|
|
84
|
+
|
|
85
|
+
Args:
|
|
86
|
+
self (BaseArtifact): The artifact instance.
|
|
87
|
+
dest (Path): The destination directory path.
|
|
88
|
+
|
|
89
|
+
Returns:
|
|
90
|
+
A new instance of the artifact class with updated filepath attributes.
|
|
91
|
+
|
|
92
|
+
Raises:
|
|
93
|
+
FileNotFoundError: If an attribute starting with 'filepath' points
|
|
94
|
+
to a non-existent file.
|
|
95
|
+
OSError: If file copying fails (e.g., permissions).
|
|
96
|
+
TypeError: If an attribute starting with 'filepath' is not a Path object.
|
|
97
|
+
"""
|
|
98
|
+
if not dest.is_dir():
|
|
99
|
+
raise TypeError(f"{dest} is not a valid directory")
|
|
100
|
+
|
|
101
|
+
path_mapping = {} # Maps original Path object -> new Path object
|
|
102
|
+
# Find all the attributes that start with 'filepath'
|
|
103
|
+
for attr_name, original_path in self.__dict__.items():
|
|
104
|
+
if attr_name.startswith("filepath"):
|
|
105
|
+
if not isinstance(original_path, Path):
|
|
106
|
+
raise ArtifactError(
|
|
107
|
+
f"Attribute '{attr_name}' starts with 'filepath' but is not a Path object (type: {type(original_path)})."
|
|
108
|
+
)
|
|
109
|
+
if not original_path.is_file():
|
|
110
|
+
raise FileNotFoundError(f"Attribute '{attr_name}' points to non-existent file: {original_path}")
|
|
111
|
+
|
|
112
|
+
# Only copy if we haven't copied this exact path object already
|
|
113
|
+
if original_path not in path_mapping:
|
|
114
|
+
destination_path = dest / original_path.name
|
|
115
|
+
try:
|
|
116
|
+
shutil.copy2(original_path, destination_path)
|
|
117
|
+
path_mapping[attr_name] = destination_path
|
|
118
|
+
except Exception as e:
|
|
119
|
+
err_msg = f"Failed to copy {original_path} to {destination_path}: {e}"
|
|
120
|
+
raise RuntimeError(err_msg) from e
|
|
121
|
+
|
|
122
|
+
new_instance = copy.deepcopy(self)
|
|
123
|
+
# Update the filepath attributes on the new instance
|
|
124
|
+
for attr_name, new_filepath in path_mapping.items():
|
|
125
|
+
setattr(new_instance, attr_name, new_filepath)
|
|
126
|
+
|
|
127
|
+
return new_instance
|
|
128
|
+
# Inject the copy_to method to the class
|
|
129
|
+
setattr(cls, "copy_to", copy_to)
|
|
130
|
+
|
|
131
|
+
def change_base_dir(self, old_base: Path, new_base: Path) -> Path:
|
|
132
|
+
"""
|
|
133
|
+
Changes the base directory of the artifact's filepath.
|
|
134
|
+
|
|
135
|
+
Args:
|
|
136
|
+
self (BaseArtifact): The artifact instance.
|
|
137
|
+
old_base (Path): The current base directory path.
|
|
138
|
+
new_base (Path): The new base directory path.
|
|
139
|
+
"""
|
|
140
|
+
self.filepath = Path(new_base, Path(self.filepath).relative_to(old_base))
|
|
141
|
+
return self.filepath
|
|
142
|
+
|
|
143
|
+
# Inject the copy_to method to the class
|
|
144
|
+
setattr(cls, "change_base_dir", change_base_dir)
|
|
145
|
+
|
|
146
|
+
return cls
|