westpa 2022.13__cp312-cp312-manylinux2014_aarch64.manylinux_2_17_aarch64.manylinux_2_28_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (162) hide show
  1. westpa/__init__.py +14 -0
  2. westpa/_version.py +21 -0
  3. westpa/analysis/__init__.py +5 -0
  4. westpa/analysis/core.py +749 -0
  5. westpa/analysis/statistics.py +27 -0
  6. westpa/analysis/trajectories.py +369 -0
  7. westpa/cli/__init__.py +0 -0
  8. westpa/cli/core/__init__.py +0 -0
  9. westpa/cli/core/w_fork.py +152 -0
  10. westpa/cli/core/w_init.py +230 -0
  11. westpa/cli/core/w_run.py +77 -0
  12. westpa/cli/core/w_states.py +212 -0
  13. westpa/cli/core/w_succ.py +99 -0
  14. westpa/cli/core/w_truncate.py +68 -0
  15. westpa/cli/tools/__init__.py +0 -0
  16. westpa/cli/tools/ploterr.py +506 -0
  17. westpa/cli/tools/plothist.py +706 -0
  18. westpa/cli/tools/w_assign.py +597 -0
  19. westpa/cli/tools/w_bins.py +166 -0
  20. westpa/cli/tools/w_crawl.py +119 -0
  21. westpa/cli/tools/w_direct.py +557 -0
  22. westpa/cli/tools/w_dumpsegs.py +94 -0
  23. westpa/cli/tools/w_eddist.py +506 -0
  24. westpa/cli/tools/w_fluxanl.py +376 -0
  25. westpa/cli/tools/w_ipa.py +832 -0
  26. westpa/cli/tools/w_kinavg.py +127 -0
  27. westpa/cli/tools/w_kinetics.py +96 -0
  28. westpa/cli/tools/w_multi_west.py +414 -0
  29. westpa/cli/tools/w_ntop.py +213 -0
  30. westpa/cli/tools/w_pdist.py +515 -0
  31. westpa/cli/tools/w_postanalysis_matrix.py +82 -0
  32. westpa/cli/tools/w_postanalysis_reweight.py +53 -0
  33. westpa/cli/tools/w_red.py +491 -0
  34. westpa/cli/tools/w_reweight.py +780 -0
  35. westpa/cli/tools/w_select.py +226 -0
  36. westpa/cli/tools/w_stateprobs.py +111 -0
  37. westpa/cli/tools/w_timings.py +113 -0
  38. westpa/cli/tools/w_trace.py +599 -0
  39. westpa/core/__init__.py +0 -0
  40. westpa/core/_rc.py +673 -0
  41. westpa/core/binning/__init__.py +55 -0
  42. westpa/core/binning/_assign.c +36018 -0
  43. westpa/core/binning/_assign.cpython-312-aarch64-linux-gnu.so +0 -0
  44. westpa/core/binning/_assign.pyx +370 -0
  45. westpa/core/binning/assign.py +454 -0
  46. westpa/core/binning/binless.py +96 -0
  47. westpa/core/binning/binless_driver.py +54 -0
  48. westpa/core/binning/binless_manager.py +189 -0
  49. westpa/core/binning/bins.py +47 -0
  50. westpa/core/binning/mab.py +506 -0
  51. westpa/core/binning/mab_driver.py +54 -0
  52. westpa/core/binning/mab_manager.py +197 -0
  53. westpa/core/data_manager.py +1761 -0
  54. westpa/core/extloader.py +74 -0
  55. westpa/core/h5io.py +1079 -0
  56. westpa/core/kinetics/__init__.py +24 -0
  57. westpa/core/kinetics/_kinetics.c +45174 -0
  58. westpa/core/kinetics/_kinetics.cpython-312-aarch64-linux-gnu.so +0 -0
  59. westpa/core/kinetics/_kinetics.pyx +815 -0
  60. westpa/core/kinetics/events.py +147 -0
  61. westpa/core/kinetics/matrates.py +156 -0
  62. westpa/core/kinetics/rate_averaging.py +266 -0
  63. westpa/core/progress.py +218 -0
  64. westpa/core/propagators/__init__.py +54 -0
  65. westpa/core/propagators/executable.py +592 -0
  66. westpa/core/propagators/loaders.py +196 -0
  67. westpa/core/reweight/__init__.py +14 -0
  68. westpa/core/reweight/_reweight.c +36899 -0
  69. westpa/core/reweight/_reweight.cpython-312-aarch64-linux-gnu.so +0 -0
  70. westpa/core/reweight/_reweight.pyx +439 -0
  71. westpa/core/reweight/matrix.py +126 -0
  72. westpa/core/segment.py +119 -0
  73. westpa/core/sim_manager.py +839 -0
  74. westpa/core/states.py +359 -0
  75. westpa/core/systems.py +93 -0
  76. westpa/core/textio.py +74 -0
  77. westpa/core/trajectory.py +603 -0
  78. westpa/core/we_driver.py +910 -0
  79. westpa/core/wm_ops.py +43 -0
  80. westpa/core/yamlcfg.py +298 -0
  81. westpa/fasthist/__init__.py +34 -0
  82. westpa/fasthist/_fasthist.c +38755 -0
  83. westpa/fasthist/_fasthist.cpython-312-aarch64-linux-gnu.so +0 -0
  84. westpa/fasthist/_fasthist.pyx +222 -0
  85. westpa/mclib/__init__.py +271 -0
  86. westpa/mclib/__main__.py +28 -0
  87. westpa/mclib/_mclib.c +34610 -0
  88. westpa/mclib/_mclib.cpython-312-aarch64-linux-gnu.so +0 -0
  89. westpa/mclib/_mclib.pyx +226 -0
  90. westpa/oldtools/__init__.py +4 -0
  91. westpa/oldtools/aframe/__init__.py +35 -0
  92. westpa/oldtools/aframe/atool.py +75 -0
  93. westpa/oldtools/aframe/base_mixin.py +26 -0
  94. westpa/oldtools/aframe/binning.py +178 -0
  95. westpa/oldtools/aframe/data_reader.py +560 -0
  96. westpa/oldtools/aframe/iter_range.py +200 -0
  97. westpa/oldtools/aframe/kinetics.py +117 -0
  98. westpa/oldtools/aframe/mcbs.py +153 -0
  99. westpa/oldtools/aframe/output.py +39 -0
  100. westpa/oldtools/aframe/plotting.py +88 -0
  101. westpa/oldtools/aframe/trajwalker.py +126 -0
  102. westpa/oldtools/aframe/transitions.py +469 -0
  103. westpa/oldtools/cmds/__init__.py +0 -0
  104. westpa/oldtools/cmds/w_ttimes.py +361 -0
  105. westpa/oldtools/files.py +34 -0
  106. westpa/oldtools/miscfn.py +23 -0
  107. westpa/oldtools/stats/__init__.py +4 -0
  108. westpa/oldtools/stats/accumulator.py +35 -0
  109. westpa/oldtools/stats/edfs.py +129 -0
  110. westpa/oldtools/stats/mcbs.py +96 -0
  111. westpa/tools/__init__.py +33 -0
  112. westpa/tools/binning.py +472 -0
  113. westpa/tools/core.py +340 -0
  114. westpa/tools/data_reader.py +159 -0
  115. westpa/tools/dtypes.py +31 -0
  116. westpa/tools/iter_range.py +198 -0
  117. westpa/tools/kinetics_tool.py +343 -0
  118. westpa/tools/plot.py +283 -0
  119. westpa/tools/progress.py +17 -0
  120. westpa/tools/selected_segs.py +154 -0
  121. westpa/tools/wipi.py +751 -0
  122. westpa/trajtree/__init__.py +4 -0
  123. westpa/trajtree/_trajtree.c +17829 -0
  124. westpa/trajtree/_trajtree.cpython-312-aarch64-linux-gnu.so +0 -0
  125. westpa/trajtree/_trajtree.pyx +130 -0
  126. westpa/trajtree/trajtree.py +117 -0
  127. westpa/westext/__init__.py +0 -0
  128. westpa/westext/adaptvoronoi/__init__.py +3 -0
  129. westpa/westext/adaptvoronoi/adaptVor_driver.py +214 -0
  130. westpa/westext/hamsm_restarting/__init__.py +3 -0
  131. westpa/westext/hamsm_restarting/example_overrides.py +35 -0
  132. westpa/westext/hamsm_restarting/restart_driver.py +1165 -0
  133. westpa/westext/stringmethod/__init__.py +11 -0
  134. westpa/westext/stringmethod/fourier_fitting.py +69 -0
  135. westpa/westext/stringmethod/string_driver.py +253 -0
  136. westpa/westext/stringmethod/string_method.py +306 -0
  137. westpa/westext/weed/BinCluster.py +180 -0
  138. westpa/westext/weed/ProbAdjustEquil.py +100 -0
  139. westpa/westext/weed/UncertMath.py +247 -0
  140. westpa/westext/weed/__init__.py +10 -0
  141. westpa/westext/weed/weed_driver.py +192 -0
  142. westpa/westext/wess/ProbAdjust.py +101 -0
  143. westpa/westext/wess/__init__.py +6 -0
  144. westpa/westext/wess/wess_driver.py +217 -0
  145. westpa/work_managers/__init__.py +57 -0
  146. westpa/work_managers/core.py +396 -0
  147. westpa/work_managers/environment.py +134 -0
  148. westpa/work_managers/mpi.py +318 -0
  149. westpa/work_managers/processes.py +201 -0
  150. westpa/work_managers/serial.py +28 -0
  151. westpa/work_managers/threads.py +79 -0
  152. westpa/work_managers/zeromq/__init__.py +20 -0
  153. westpa/work_managers/zeromq/core.py +635 -0
  154. westpa/work_managers/zeromq/node.py +131 -0
  155. westpa/work_managers/zeromq/work_manager.py +526 -0
  156. westpa/work_managers/zeromq/worker.py +320 -0
  157. westpa-2022.13.dist-info/METADATA +179 -0
  158. westpa-2022.13.dist-info/RECORD +162 -0
  159. westpa-2022.13.dist-info/WHEEL +7 -0
  160. westpa-2022.13.dist-info/entry_points.txt +30 -0
  161. westpa-2022.13.dist-info/licenses/LICENSE +21 -0
  162. westpa-2022.13.dist-info/top_level.txt +1 -0
@@ -0,0 +1,603 @@
1
+ import numpy as np
2
+ import os
3
+ from functools import cache
4
+
5
+ from mdtraj import Trajectory
6
+
7
+
8
+ def parseResidueAtoms(residue, map):
9
+ """Parse all atoms from residue. Taken from OpenMM 8.2.0."""
10
+
11
+ for atom in residue.findall('Atom'):
12
+ name = atom.attrib['name']
13
+ for id in atom.attrib:
14
+ map[atom.attrib[id]] = name
15
+
16
+
17
+ @cache
18
+ def loadNameReplacementTables():
19
+ """Load the list of atom and residue name replacements. Taken from OpenMM 8.2.0."""
20
+
21
+ # importing things here because they're only used in this function
22
+ try:
23
+ from importlib.resources import files
24
+ except ImportError:
25
+ from importlib_resources import files
26
+
27
+ import xml.etree.ElementTree as etree
28
+ from copy import copy
29
+
30
+ residueNameReplacements = {}
31
+ atomNameReplacements = {}
32
+
33
+ # This XML file is a to map all sorts of atom names/ residue names to the PDB 3.0 convention.
34
+ tree = etree.parse(files('westpa') / 'data/pdbNames.xml')
35
+ allResidues = {}
36
+ proteinResidues = {}
37
+ nucleicAcidResidues = {}
38
+ for residue in tree.getroot().findall('Residue'):
39
+ name = residue.attrib['name']
40
+ if name == 'All':
41
+ parseResidueAtoms(residue, allResidues)
42
+ elif name == 'Protein':
43
+ parseResidueAtoms(residue, proteinResidues)
44
+ elif name == 'Nucleic':
45
+ parseResidueAtoms(residue, nucleicAcidResidues)
46
+ for atom in allResidues:
47
+ proteinResidues[atom] = allResidues[atom]
48
+ nucleicAcidResidues[atom] = allResidues[atom]
49
+ for residue in tree.getroot().findall('Residue'):
50
+ name = residue.attrib['name']
51
+ for id in residue.attrib:
52
+ if id == 'name' or id.startswith('alt'):
53
+ residueNameReplacements[residue.attrib[id]] = name
54
+ if 'type' not in residue.attrib:
55
+ atoms = copy(allResidues)
56
+ elif residue.attrib['type'] == 'Protein':
57
+ atoms = copy(proteinResidues)
58
+ elif residue.attrib['type'] == 'Nucleic':
59
+ atoms = copy(nucleicAcidResidues)
60
+ else:
61
+ atoms = copy(allResidues)
62
+ parseResidueAtoms(residue, atoms)
63
+ atomNameReplacements[name] = atoms
64
+
65
+ return residueNameReplacements, atomNameReplacements
66
+
67
+
68
+ def convert_mdanalysis_top_to_mdtraj(universe):
69
+ """Convert a MDAnalysis Universe object's topology to a ``mdtraj.Topology`` object."""
70
+
71
+ from mdtraj import Topology
72
+ from mdtraj.core.element import get_by_symbol
73
+ from MDAnalysis.exceptions import NoDataError
74
+
75
+ top = Topology() # Empty topology object
76
+ residueNameReplacements, atomNameReplacements = loadNameReplacementTables()
77
+
78
+ # Add in all the chains (called segments in MDAnalysis)
79
+ for chain_segment in universe.segments:
80
+ top.add_chain()
81
+
82
+ all_chains = list(top.chains)
83
+
84
+ # Add in all the residues
85
+ for residue in universe.residues:
86
+ try:
87
+ resname = residueNameReplacements[residue.resname]
88
+ except KeyError:
89
+ resname = residue.resname
90
+
91
+ top.add_residue(name=resname, chain=all_chains[residue.segindex], resSeq=residue.resid)
92
+
93
+ all_residues = list(top.residues)
94
+
95
+ # Add in all the atoms
96
+ for atom, resid in zip(universe.atoms, universe.atoms.resindices):
97
+ try:
98
+ atomname = residueNameReplacements[atom.resname][atom.name]
99
+ except (KeyError, TypeError):
100
+ atomname = atom.name
101
+
102
+ top.add_atom(name=atomname, element=get_by_symbol(atom.element), residue=all_residues[resid])
103
+
104
+ all_atoms = list(top.atoms)
105
+
106
+ # Add in all the bonds. Depending on the topology type (e.g., pdb), there might not be bond information.
107
+ try:
108
+ for b_idx in universe.bonds._bix:
109
+ top.add_bond(all_atoms[b_idx[0]], all_atoms[b_idx[1]])
110
+ except NoDataError:
111
+ top.create_standard_bonds()
112
+
113
+ return top
114
+
115
+
116
+ class WESTTrajectory(Trajectory):
117
+ """A subclass of ``mdtraj.Trajectory`` that contains the trajectory of atom coordinates with
118
+ pointers denoting the iteration number and segment index of each frame."""
119
+
120
+ def __init__(
121
+ self,
122
+ coordinates,
123
+ topology=None,
124
+ time=None,
125
+ iter_labels=None,
126
+ seg_labels=None,
127
+ pcoords=None,
128
+ parent_ids=None,
129
+ unitcell_lengths=None,
130
+ unitcell_angles=None,
131
+ ):
132
+ if isinstance(coordinates, Trajectory):
133
+ xyz = coordinates.xyz
134
+ topology = coordinates.topology if topology is None else topology
135
+ time = coordinates.time if time is None else time
136
+ unitcell_lengths = coordinates.unitcell_lengths if unitcell_lengths is None else unitcell_lengths
137
+ unitcell_angles = coordinates.unitcell_angles if unitcell_angles is None else unitcell_angles
138
+ else:
139
+ xyz = coordinates
140
+
141
+ super(WESTTrajectory, self).__init__(xyz, topology, time, unitcell_lengths, unitcell_angles)
142
+ self._shape = None
143
+ self.iter_labels = iter_labels
144
+ self.seg_labels = seg_labels
145
+ self.pcoords = pcoords
146
+ self.parent_ids = parent_ids
147
+
148
+ def _string_summary_basic(self):
149
+ """Basic summary of WESTTrajectory in string form."""
150
+
151
+ unitcell_str = 'and unitcells' if self._have_unitcell else 'without unitcells'
152
+ value = "%s with %d frames, %d atoms, %d residues, %s" % (
153
+ self.__class__.__name__,
154
+ self.n_frames,
155
+ self.n_atoms,
156
+ self.n_residues,
157
+ unitcell_str,
158
+ )
159
+ return value
160
+
161
+ def _check_labels(self, value):
162
+ if value is None:
163
+ value = 0
164
+ elif isinstance(value, list):
165
+ value = np.array(value)
166
+
167
+ if np.isscalar(value):
168
+ value = np.array([value] * self.n_frames, dtype=int)
169
+ elif value.shape != (self.n_frames,):
170
+ raise ValueError('Wrong shape. Got %s, should be %s' % (value.shape, (self.n_frames,)))
171
+
172
+ return value
173
+
174
+ def _check_pcoords(self, value):
175
+ if value is None:
176
+ value = 0.0
177
+ elif isinstance(value, list):
178
+ value = np.array(value)
179
+
180
+ if np.isscalar(value):
181
+ value = np.array([(value,)] * self.n_frames, dtype=float)
182
+
183
+ if value.ndim == 1:
184
+ value = np.tile(value, (self.n_frames, 1))
185
+ elif value.ndim != 2:
186
+ raise ValueError('pcoords must be a 2-D array')
187
+
188
+ elif value.shape[0] != self.n_frames:
189
+ raise ValueError('Wrong length. Got %s, should be %s' % (value.shape[0], self.n_frames))
190
+
191
+ return value
192
+
193
+ def iter_label_values(self):
194
+ visited_ids = []
195
+
196
+ for i in self.iter_labels:
197
+ if i in visited_ids:
198
+ continue
199
+ visited_ids.append(i)
200
+ yield i
201
+
202
+ def seg_label_values(self, iteration=None):
203
+ seg_labels = self.seg_labels[self.iter_labels == iteration]
204
+ visited_ids = []
205
+
206
+ for j in seg_labels:
207
+ if j in visited_ids:
208
+ continue
209
+ visited_ids.append(j)
210
+ yield j
211
+
212
+ @property
213
+ def label_values(self):
214
+ for i in self.iter_label_values():
215
+ for j in self.seg_label_values(i):
216
+ yield i, j
217
+
218
+ def _iter_blocks(self):
219
+ for i, j in self.label_values:
220
+ IandJ = np.logical_and(self.iter_labels == i, self.seg_labels == j)
221
+ yield i, j, IandJ
222
+
223
+ @property
224
+ def iter_labels(self):
225
+ """Iteration index corresponding to each frame
226
+
227
+ Returns
228
+ -------
229
+ time : np.ndarray, shape=(n_frames,)
230
+ The iteration index corresponding to each frame
231
+ """
232
+
233
+ return self._iters
234
+
235
+ @iter_labels.setter
236
+ def iter_labels(self, value):
237
+ """Set the iteration index corresponding to each frame"""
238
+
239
+ self._iters = self._check_labels(value)
240
+ self._shape = None
241
+
242
+ @property
243
+ def seg_labels(self):
244
+ """
245
+ Segment index corresponding to each frame
246
+
247
+ Returns
248
+ -------
249
+ time : np.ndarray, shape=(n_frames,)
250
+ The segment index corresponding to each frame
251
+ """
252
+
253
+ return self._segs
254
+
255
+ @seg_labels.setter
256
+ def seg_labels(self, value):
257
+ """Set the segment index corresponding to each frame"""
258
+
259
+ self._segs = self._check_labels(value)
260
+ self._shape = None
261
+
262
+ @property
263
+ def pcoords(self):
264
+ return self._pcoords
265
+
266
+ @pcoords.setter
267
+ def pcoords(self, value):
268
+ self._pcoords = self._check_pcoords(value)
269
+
270
+ @property
271
+ def parent_ids(self):
272
+ return self._parent_ids
273
+
274
+ @parent_ids.setter
275
+ def parent_ids(self, value):
276
+ self._parent_ids = self._check_labels(value)
277
+
278
+ def join(self, other, check_topology=True, discard_overlapping_frames=False):
279
+ """
280
+ Join two ``Trajectory``s. This overrides ``mdtraj.Trajectory.join``
281
+ so that it also handles WESTPA pointers.
282
+ ``mdtraj.Trajectory.join``'s documentation for more details.
283
+ """
284
+
285
+ if isinstance(other, Trajectory):
286
+ other = [other]
287
+
288
+ new_traj = super(WESTTrajectory, self).join(
289
+ other, check_topology=check_topology, discard_overlapping_frames=discard_overlapping_frames
290
+ )
291
+
292
+ trajectories = [self] + other
293
+ if discard_overlapping_frames:
294
+ for i in range(len(trajectories) - 1):
295
+ x0 = trajectories[i].xyz[-1]
296
+ x1 = trajectories[i + 1].xyz[0]
297
+
298
+ if np.all(np.abs(x1 - x0) < 2e-3):
299
+ trajectories[i] = trajectories[i][:-1]
300
+
301
+ iter_labels = []
302
+ seg_labels = []
303
+ parent_ids = []
304
+ pshape = self.pcoords.shape
305
+ pcoords = []
306
+
307
+ for t in trajectories:
308
+ if hasattr(t, "iter_labels"):
309
+ iters = t.iter_labels
310
+ else:
311
+ iters = np.zeros(len(t)) - 1 # default iter label: -1
312
+
313
+ iter_labels.append(iters)
314
+
315
+ if hasattr(t, "seg_labels"):
316
+ segs = t.seg_labels
317
+ else:
318
+ segs = np.zeros(len(t)) - 1 # default seg label: -1
319
+
320
+ seg_labels.append(segs)
321
+
322
+ if hasattr(t, "parent_ids"):
323
+ pids = t.parent_ids
324
+ else:
325
+ pids = np.zeros(len(t)) - 1 # default parent_id: -1
326
+
327
+ parent_ids.append(pids)
328
+
329
+ if hasattr(t, "pcoords"):
330
+ p = t.pcoords
331
+ else:
332
+ p = np.zeros((len(t), pshape[-1]), dtype=float) # default pcoord: 0.0
333
+
334
+ pcoords.append(p)
335
+
336
+ iter_labels = np.concatenate(iter_labels)
337
+ seg_labels = np.concatenate(seg_labels)
338
+ parent_ids = np.concatenate(parent_ids)
339
+ pcoords = np.concatenate(pcoords)
340
+
341
+ new_westpa_traj = WESTTrajectory(
342
+ new_traj, iter_labels=iter_labels, seg_labels=seg_labels, pcoords=pcoords, parent_ids=parent_ids
343
+ )
344
+
345
+ return new_westpa_traj
346
+
347
+ def slice(self, key, copy=True):
348
+ """
349
+ Slice the ``Trajectory``. This overrides ``mdtraj.Trajectory.slice``
350
+ so that it also handles WESTPA pointers. Please see
351
+ ``mdtraj.Trajectory.slice``'s documentation for more details.
352
+ """
353
+
354
+ if isinstance(key, tuple):
355
+ if self._shape is None:
356
+ uniq_iters = np.unique(self.iter_labels)
357
+ max_iter = uniq_iters.max()
358
+ max_seg = self.seg_labels.max()
359
+ max_n_trajs = 0
360
+ for _, _, block in self._iter_blocks():
361
+ n_trajs = block.sum()
362
+ if n_trajs > max_n_trajs:
363
+ max_n_trajs = n_trajs
364
+
365
+ self._shape = (max_iter, max_seg, max_n_trajs)
366
+ else:
367
+ max_iter, max_seg, max_n_trajs = self._shape
368
+
369
+ M = np.full((max_iter + 1, max_seg + 1, max_n_trajs), -1, dtype=int)
370
+ all_traj_indices = np.arange(self.n_frames, dtype=int)
371
+ for i, j, block in self._iter_blocks():
372
+ traj_indices = all_traj_indices[block]
373
+
374
+ for k, traj_idx in enumerate(traj_indices):
375
+ M[i, j, k] = traj_idx
376
+
377
+ selected_indices = M[key].flatten()
378
+ if np.isscalar(selected_indices):
379
+ selected_indices = np.array([selected_indices])
380
+ key = selected_indices[selected_indices != -1]
381
+
382
+ iters = self.iter_labels[key]
383
+ segs = self.seg_labels[key]
384
+ pcoords = self.pcoords[key, :]
385
+ parent_ids = self.parent_ids[key]
386
+
387
+ traj = super(WESTTrajectory, self).slice(key, copy)
388
+ traj.iter_labels = iters
389
+ traj.seg_labels = segs
390
+ traj.pcoords = pcoords
391
+ traj.parent_ids = parent_ids
392
+
393
+ return traj
394
+
395
+
396
+ def get_extension(filename):
397
+ """A function to get the format extension of a file."""
398
+
399
+ (base, extension) = os.path.splitext(filename)
400
+
401
+ # Return the other part of the extension as well if it's a gzip.
402
+ if extension == '.gz':
403
+ return os.path.splitext(base)[1] + extension
404
+
405
+ return extension
406
+
407
+
408
+ def find_top_traj_file(folder, eligible_top, eligible_traj):
409
+ """
410
+ A general (reusable) function for identifying and returning the appropriate
411
+ file names in ``folder`` which are toplogy and trajectory. Useful when writing custom loaders.
412
+ Note that it's possible that the topology_file and trajectory_file are identical.
413
+
414
+ Parameters
415
+ ----------
416
+ folder : str or os.Pathlike
417
+ A string or Pathlike to the folder to search.
418
+
419
+ eligible_top : list of strings
420
+ A list of accepted topology file extensions.
421
+
422
+ eligible_traj : list of strings
423
+ A list of accepted trajectory file extensions.
424
+
425
+
426
+ Returns
427
+ -------
428
+ top_file : str
429
+ Path to topology file
430
+
431
+ traj_file : str
432
+ Path to trajectory file
433
+ """
434
+
435
+ # Setting up the return variables
436
+ top_file = traj_file = None
437
+
438
+ # Extract a list of all files, ignoring hidden files that start with a '.'
439
+ file_list = [f_name for f_name in os.listdir(folder) if not f_name.startswith('.')]
440
+
441
+ for filename in file_list:
442
+ filepath = os.path.join(folder, filename)
443
+ if not os.path.isfile(filepath):
444
+ continue
445
+
446
+ ext = get_extension(filename).lower()
447
+ # Catching trajectory formats that can be topology and trajectories at the same time.
448
+ # Only activates when there is a single file in the folder.
449
+ if len(file_list) < 2 and ext in eligible_top and ext in eligible_traj:
450
+ top_file = filename
451
+ traj_file = filename
452
+
453
+ # Assuming topology file is copied first.
454
+ if ext in eligible_top and top_file is None:
455
+ top_file = filename
456
+ elif ext in eligible_traj and traj_file is None:
457
+ traj_file = filename
458
+
459
+ if top_file is not None and traj_file is not None:
460
+ break
461
+
462
+ if traj_file is None:
463
+ raise ValueError('trajectory file not found')
464
+
465
+ traj_file = os.path.join(folder, traj_file)
466
+
467
+ if top_file is not None:
468
+ top_file = os.path.join(folder, top_file)
469
+
470
+ return top_file, traj_file
471
+
472
+
473
+ @cache
474
+ def mdtraj_supported_extensions():
475
+ from mdtraj import FormatRegistry, formats as mdformats
476
+ from mdtraj.core.trajectory import _TOPOLOGY_EXTS
477
+
478
+ FormatRegistry.loaders['.rst'] = mdformats.amberrst.load_restrt
479
+ FormatRegistry.fileobjects['.rst'] = mdformats.AmberRestartFile
480
+ FormatRegistry.loaders['.ncrst'] = mdformats.amberrst.load_ncrestrt
481
+ FormatRegistry.fileobjects['.ncrst'] = mdformats.AmberRestartFile
482
+
483
+ TRAJECTORY_EXTS = list(FormatRegistry.loaders.keys())
484
+ TOPOLOGY_EXTS = list(_TOPOLOGY_EXTS)
485
+
486
+ for ext in [".h5", ".hdf5", ".lh5"]:
487
+ TOPOLOGY_EXTS.remove(ext)
488
+
489
+ return TOPOLOGY_EXTS, TRAJECTORY_EXTS
490
+
491
+
492
+ @cache
493
+ def mdanalysis_supported_extensions():
494
+ import MDAnalysis as mda
495
+
496
+ TRAJECTORY_EXTS = [reader.format if isinstance(reader.format, list) else [reader.format] for reader in mda._READERS.values()]
497
+ TRAJECTORY_EXTS = list(set(f'.{ext.lower()}' for ilist in TRAJECTORY_EXTS for ext in ilist))
498
+
499
+ TOPOLOGY_EXTS = [parser.format if isinstance(parser.format, list) else [parser.format] for parser in mda._PARSERS.values()]
500
+ TOPOLOGY_EXTS = list(set(f'.{ext.lower()}' for ilist in TOPOLOGY_EXTS for ext in ilist))
501
+
502
+ return TOPOLOGY_EXTS, TRAJECTORY_EXTS
503
+
504
+
505
+ def load_mdtraj(folder):
506
+ """
507
+ Load trajectory from ``folder`` using ``mdtraj`` and return a ``mdtraj.Trajectory``
508
+ object. The folder should contain a trajectory and a topology file (with a recognizable
509
+ extension) that is supported by ``mdtraj``. The topology file is optional if the
510
+ trajectory file contains topology data (e.g., HDF5 format).
511
+ """
512
+
513
+ from mdtraj import load as load_traj
514
+
515
+ TOPOLOGY_EXTS, TRAJECTORY_EXTS = mdtraj_supported_extensions()
516
+
517
+ top_file, traj_file = find_top_traj_file(folder, TOPOLOGY_EXTS, TRAJECTORY_EXTS)
518
+
519
+ # MDTraj likes the (optional) topology part to be provided within a dictionary
520
+ traj = load_traj(traj_file, **{'top': top_file})
521
+
522
+ return traj
523
+
524
+
525
+ def load_netcdf(folder):
526
+ """
527
+ Load netcdf file from ``folder`` using ``scipy.io`` and return a ``mdtraj.Trajectory``
528
+ object. The folder should contain a Amber trajectory file with extensions `.nc` or `.ncdf`.
529
+
530
+ Note coordinates and box lengths are all divided by 10 to change from Angstroms to nanometers.
531
+ """
532
+
533
+ from scipy.io import netcdf_file
534
+
535
+ _, traj_file = find_top_traj_file(folder, [], ['.nc', '.ncdf', '.ncrst'])
536
+
537
+ # Extracting these datasets
538
+ datasets = {'coordinates': None, 'cell_lengths': None, 'cell_angles': None, 'time': None}
539
+ convert = ['coordinates', 'cell_lengths'] # Length-based datasets that need to be converted from Å to nm
540
+ optional = ['cell_lengths', 'cell_angles']
541
+
542
+ with netcdf_file(traj_file) as rootgrp:
543
+ for key, val in datasets.items():
544
+ if key in optional:
545
+ pass
546
+ elif key in convert and key in rootgrp.variables:
547
+ datasets[key] = rootgrp.variables[key][()].copy() / 10 # From Å to nm
548
+ else:
549
+ datasets[key] = rootgrp.variables[key][()].copy() # noqa: F841
550
+
551
+ map_dataset = {
552
+ 'coordinates': datasets['coordinates'],
553
+ 'unitcell_lengths': datasets['cell_lengths'],
554
+ 'unitcell_angles': datasets['cell_angles'],
555
+ 'time': datasets['time'],
556
+ }
557
+
558
+ return WESTTrajectory(**map_dataset)
559
+
560
+
561
+ def load_mdanalysis(folder):
562
+ """
563
+ Load a file from ``folder`` using ``MDAnalysis`` and return a ``mdtraj.Trajectory``
564
+ object. The folder should contain a trajectory and a topology file (with a recognizable
565
+ extension) that is supported by ``MDAnalysis``. The topology file is optional if the
566
+ trajectory file contains topology data (e.g., H5MD format).
567
+
568
+ Note coordinates and box lengths are all divided by 10 to change from Angstroms to nanometers.
569
+ """
570
+
571
+ import MDAnalysis as mda
572
+
573
+ TOPOLOGY_EXTS, TRAJECTORY_EXTS = mdanalysis_supported_extensions()
574
+
575
+ top_file, traj_file = find_top_traj_file(folder, TOPOLOGY_EXTS, TRAJECTORY_EXTS)
576
+
577
+ u = mda.Universe(top_file, traj_file)
578
+
579
+ tot_frames = len(u.trajectory)
580
+ coords = np.zeros((tot_frames, len(u.atoms), 3))
581
+ time = np.zeros((tot_frames))
582
+
583
+ # Periodic Boundary Conditions
584
+ periodic = u.trajectory.periodic
585
+ cell_lengths = np.zeros((tot_frames, 3)) if periodic else None
586
+ cell_angles = np.zeros((tot_frames, 3)) if periodic else None
587
+
588
+ for iframe, frame in enumerate(u.trajectory):
589
+ coords[iframe] = frame._pos
590
+ time[iframe] = frame.time
591
+
592
+ if periodic:
593
+ cell_lengths[iframe] = frame.dimensions[:3]
594
+ cell_angles[iframe] = frame.dimensions[3:]
595
+
596
+ # Length-based datasets that need to be converted
597
+ convert = [coords, cell_lengths] if periodic else [coords]
598
+ for dset in convert:
599
+ dset = mda.units.convert(dset, 'angstrom', 'nanometer')
600
+
601
+ traj = WESTTrajectory(coordinates=coords, unitcell_lengths=cell_lengths, unitcell_angles=cell_angles, time=time)
602
+
603
+ return traj